Accelerating MMPreTrain models with JAX#
Accelerate your MMPreTrain models by converting them to JAX for faster inference.
Installations
Make sure you run this demo with GPU enabled!
[ ]:
!pip install -U -q openmim && mim install -q "mmpretrain>=1.0.0rc8"
!pip install -q ivy
!pip install -q dm-haiku
Let’s now import Ivy and the libraries we’ll use in this example:
[2]:
import jax
jax.devices()
import ivy
ivy.set_default_device("gpu:0")
import torch
import requests
import numpy as np
from PIL import Image
import time
import torchvision
from mmpretrain import get_model, list_models
from mmengine import ConfigDict
Sanity check to make sure checkpoint name is correct against mmpretrain’s model zoo
[3]:
checkpoint_name = "convnext-tiny_32xb128-noema_in1k"
list_models(checkpoint_name)
[3]:
['convnext-tiny_32xb128-noema_in1k']
Now we can load the ConvNext model from OpenMMLab’s mmpretrain library
[ ]:
jax.config.update("jax_enable_x64", True)
model = get_model(checkpoint_name, pretrained=True, device='cuda')
We will also need a sample image to pass during tracing, so let’s use the appropriate transforms to get the corresponding torch tensors.
[5]:
def get_scale(cfg):
if type(cfg) == ConfigDict:
if cfg.get('type', False) and cfg.get('scale', False):
return cfg['scale']
else:
for k in cfg.keys():
input_shape = get_scale(cfg[k])
if input_shape:
return input_shape
elif type(cfg) == list:
for block in cfg:
input_shape = get_scale(block)
if input_shape:
return input_shape
else:
return None
[6]:
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
input_shape = get_scale(model._config.train_pipeline)
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((input_shape, input_shape)),
torchvision.transforms.ToTensor()
])
tensor_image = transform(image).unsqueeze(0).to("cuda")
And finally, let’s transpile the model to haiku!
[ ]:
transpiled_graph = ivy.transpile(model, to="haiku", args=(tensor_image,))
After transpiling our model, we can see what’s the improvement in runtime efficiency like. For this let’s compile the original PyTorch model using torch.compile
[ ]:
# ref : https://github.com/pytorch/pytorch/issues/107960
!export LC_ALL="en_US.UTF-8"
!export LD_LIBRARY_PATH="/usr/lib64-nvidia"
!export LIBRARY_PATH="/usr/local/cuda/lib64/stubs"
!ldconfig /usr/lib64-nvidia
[8]:
tensor_image = transform(image).unsqueeze(0).to("cuda")
def _f(args):
return model(args)
comp_model = torch.compile(_f)
_ = comp_model(tensor_image)
Let’s now do the equivalent transformation in our new haiku model by using JAX just in time compilation:
[9]:
tensor_image = transform(image).unsqueeze(0).to("cuda")
np_image = tensor_image.detach().cpu().numpy()
jax_image = jax.device_put(jax.numpy.asarray(np_image), device=jax.devices()[0])
import haiku as hk
def _forward(args):
module = transpiled_graph()
return module(args)
rng_key = jax.random.PRNGKey(42)
jax_mlp_forward = hk.transform(_forward)
params = jax_mlp_forward.init(rng=rng_key, args=jax_image)
apply = jax.jit(jax_mlp_forward.apply)
_ = apply(params, None, jax_image)
Now that we have both models optimized, let’s see how their runtime speeds compare to each other!
[13]:
%timeit comp_model(tensor_image)
8.06 ms ± 2.7 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
[14]:
%timeit apply(params, None, jax_image).block_until_ready()
6.08 ms ± 13.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
As expected, we have made the model significantly faster with just one line of code! Latency gets even better on a V100 GPU, where we can get up to a 2-3x increase in execution speed! 🚀
Finally, as a sanity check, let’s load a different image and make sure that the results are the same in both models
[12]:
url = "http://images.cocodataset.org/train2017/000000283921.jpg"
image = Image.open(requests.get(url, stream=True).raw)
tensor_image = transform(image).unsqueeze(0).to("cuda")
np_image = tensor_image.detach().cpu().numpy()
jax_image = jax.device_put(jax.numpy.asarray(np_image), device=jax.devices()[0])
st = time.perf_counter()
out_torch = comp_model(tensor_image)
et = time.perf_counter()
print(f'Torch call took: {(et - st) * 1000:.2f}ms')
st = time.perf_counter()
out_jax = apply(params, None, jax_image)
et = time.perf_counter()
print(f'Jax call took: {(et - st) * 1000:.2f}ms')
print(np.allclose(out_torch.detach().cpu().numpy(), out_jax, atol=1e-4))
Torch call took: 6.66ms
Jax call took: 2.53ms
True
That’s pretty much it! The results from both models are the same, but we have achieved a solid speed up by using Ivy’s transpiler to convert the model to JAX!