Out of memory

Memory, the bane of existence for humans and deep learning researchers alike. Not working at Google and have access to 20 V100’s to run GPT-3, well this post is for you. Like your average grad student, I don’t own a billion dollar company and my school doesn’t have millions hundreds of thousand (or millions) to blow on hardware that would make my life easier. So what do we do? We trade time and computation for memory. Finally, us little guys can compete with those big guys! But really, more realistically we have an excuse to get some sleep (who are we kidding?)

compiling

We still won’t be able to train GPT-3 on our measly 11Gb VRAM on our fancy 2080Ti. But this does allow us to do more than we could previously do. Really here we’re going to do a simple trick. If you are familiar with HPC or scientific computing (maybe you aren’t?) you know that we often set “checkpoints” for our code. So that if something breaks (always does) we don’t have to run the whole thing over again. If you’re not a scientific computing person you may have saved your pytorch model after a certain number of epochs. Same thing. Well we can just do the same thing during our back propagation. Yay!!!

In this post we’re not going to get into the nitty gritty here. But you can read this paper to learn all about it! Basically we can get sublinear memory growth. Best of all, PyTorch has a built in utility to help us out. We’re just going to be learning how to use that today.

We’re going to look at the two different checkpointing methods that pytorch provides and how they differ. One caveat is that measuring GPU usage is kinda a pain with pytorch. So we’re going to setup a systematic way to do things so we’ll always have similar numbers. With each variant we are going to clear the cuda cache and reset the max memory allocated. We’ll also be measuring what SMI reports. So we have two chunks of code

# Clear cuda memory
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
# Report memory statistics
import subprocess as sp
print(f"Allocated: {torch.cuda.memory_allocated()//2**20} MB")
print(f"Max Allocated: {torch.cuda.max_memory_allocated()//2**20} MB")
smi = int(sp.check_output(['nvidia-smi', '--query-gpu=memory.used', '--format=csv,nounits,noheader'], encoding='utf-8').split('\n')[0])
print(f"SMI: {smi} MB)

Additionally, we will be just making a very deep network that won’t fit in our GPU memory on its own. The reason to do this is so that we can ensure that we are actually saving memory. Really I just kept adding layers until I blew up my memory, there was no sane and logical methodology to this. So if you’re going to repeat the experiment just do that, it doesn’t really matter what the network looks like, we’re just using linear layers. I’m using a 2080 Super here that has 8GB of VRAM.

I want to add another side note: SMI reports that my card has 7951 MiB but the runtime error that I get when oversubscribing tells me that there is 7.77GiB total capacity. screenshot

Sequential Gradient Checkpointing (simple)

So the simplest thing to do is to use pytorch’s checkpoint_sequential. This has 3 arguments: the model, number of segments, and the input (the input can be a tuple). Let’s just create a dummy network, make it big. All we need to do to add checkpointing is change the forward

from torch.utils.checkpoint import checkpoints_sequential
...
   def forward(self, x):
       x = x.to(self.device)
       x.requires_grad=True
       o = checkpoint_sequential(self.my_model, self.num_segments, x)
       return o

So there’s a few things to note here that might be odd, we’ll go line by line (inside forward). First we send the variable to the device, you don’t have to do that in forward but it is nice. The second line we ensure that the variable has requires_grad. We have to have this for checkpointing to work. Third we use our magic line, which has assumed that in __init__ you have specified your model and the number of segments you want to break your model into. The best way to compose your model here is using nn.Sequential(). You might get thrown off by trying to break your model into too many segments.

So let’s look at how many times we’re splitting and what each metric reports. We will report how much memory is requested but can’t be allocated. We will also show how much we over subscribe and the percent savings from the first model that fits (reference).

Side Note: We are only looking at memory here. This network is small and the time differences are small for the two epochs we are training with and thus we don’t report. Expect the computation time to increase though.

Num Splits Alloc (MB) Max Alloc (MB) SMI (MB) Mem Oversub (MiB) Using (%)
0 NA NA NA 382 NA
2 148 6033 7212 NA reference (100)
4 145 4350 6000 NA 83.2
8 145 3777 5428 NA 75.3
10 145 4159 5810 NA 80.6
16 145 4999 6190 NA 85.8

So we’re learning some interesting things here, the biggest being that there is an optimal split. This was rather surprising to me. We can see that in the 8 splits we have about a 25% reduction in memory usage. What was surprising to me is that we don’t just get better and better. I am guessing that this is because of some overhead with the implementation. Either way, we have a very useful tool here.

Gradient Checkpointing (simple)

Next we’re going to just look at the standard gradient checkpointing method: torch.utils.checkpointing.checkpoint. This function is also pretty easy to use, you just need the model and the input (tuples can be used). Really the same thing we did before except we don’t have the number of segments. So why use the other one? Because here you have to manually split your network into subnetworks. So Our code might look like:

from torch.utils.checkpointing import checkpoint
...
   def forward(self, x):
       x = x.to(self.device)
       x.requires_grad=True
       o = checkpoint(self.model_0, x)
       o = checkpoint(self.model_1, o)
       return o

In this case we checkpoint the two sub-networks. Let’s (sorta) replicate the experiment with this. We will use the same model actually. I say “sorta” because we don’t have real control over how sequential splits things up. Instead I do my best where I create 4 networks: 1) we checkpoint every nn.Linear and nn.ReLU, 2) We checkpoint every pair, 3) we break into 3 equal-ish sub-networks, and 4) we break into 4 equal-ish sub-networks.

Name Alloc (MB) Max Alloc (MB) SMI (MB) Using (%)
2 Subs 148 6033 7212 reference
3 Subs 145 6642 7908 110
Every Other 83 3715 5408 75.0
Every 38 4892 6170 85.6

Okay, so sanity check, our 2 sub-networks DOES match the 2 split from sequential. What’s super interesting is that with 3 sub-networks we actually use 10% MORE memory! For reference the “Every” network breaks things into 16 sub-networks and “Every Other” breaks it into 8, so we actually get pretty close on the similar splits.

What We’ve Learned

We’ve learned some interesting things. checkpoint_sequential and checkpoint don’t give the same numbers, although close, for what we would expect to be the same networks/divisions. Before we were talking about GPT-3, we needed 11 V100’s (32GB) but now we only need 9! Well, 8.2, but we can’t have 0.3 of a V100 :(. So we could really save $20k here just from hardware, although we’ll use more electricity. Sometimes that’s easier to get money for though (especially if you’re a grad student and using someone else’s ;)

Okay, but what about FP16?

So let’s do the last version with FP16. We make a few changes to our code in our main function we do the following

...
net = model(device=torch.device('cuda')).half()
... training loop ...
   out = net(x.half())
...

(Again, computation should take longer to the same convergence)

Name Alloc (MB) Max Alloc (MB) SMI (MB) Using (%) Using From Prev (%)
2 Subs 75 3015 4054 reference 56.2
3 Subs 74 3318 4414 108.9 61.2
Every Other 42 1856 2954 72.9 41.0
Every 18 2443 3530 87.1 48.9

WOW we got some big savings. Not only could the original network not actually fit on our graphics card, BUT we were able to use 41% of the network THAT COULD! We’re probably using slightly less than 40% of our original network. This can be huge savings. Unfortunately we won’t get to reduce to 6 V100’s because we were already assuming FP16, but we needed to know this to save all that money!

All the code for this tutorial can be found on GitHub.