Gradient Checkpointing and Memory Optimization
Memory, the bane of existence for humans and deep learning researchers alike. Not working at Google and have access to 20 V100’s to run GPT-3, well this post is for you. Like your average grad student, I don’t own a billion dollar company and my school doesn’t have millions hundreds of thousand (or millions) to blow on hardware that would make my life easier. So what do we do? We trade time and computation for memory. Finally, us little guys can compete with those big guys! But really, more realistically we have an excuse to get some sleep (who are we kidding?)
We still won’t be able to train GPT-3 on our measly 11Gb VRAM on our fancy 2080Ti. But this does allow us to do more than we could previously do. Really here we’re going to do a simple trick. If you are familiar with HPC or scientific computing (maybe you aren’t?) you know that we often set “checkpoints” for our code. So that if something breaks (always does) we don’t have to run the whole thing over again. If you’re not a scientific computing person you may have saved your pytorch model after a certain number of epochs. Same thing. Well we can just do the same thing during our back propagation. Yay!!!
In this post we’re not going to get into the nitty gritty here. But you can read this paper to learn all about it! Basically we can get sublinear memory growth. Best of all, PyTorch has a built in utility to help us out. We’re just going to be learning how to use that today.
We’re going to look at the two different checkpointing methods that pytorch provides and how they differ. One caveat is that measuring GPU usage is kinda a pain with pytorch. So we’re going to setup a systematic way to do things so we’ll always have similar numbers. With each variant we are going to clear the cuda cache and reset the max memory allocated. We’ll also be measuring what SMI reports. So we have two chunks of code
# Clear cuda memory
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
# Report memory statistics
import subprocess as sp
print(f"Allocated: {torch.cuda.memory_allocated()//2**20} MB")
print(f"Max Allocated: {torch.cuda.max_memory_allocated()//2**20} MB")
smi = int(sp.check_output(['nvidia-smi', '--query-gpu=memory.used', '--format=csv,nounits,noheader'], encoding='utf-8').split('\n')[0])
print(f"SMI: {smi} MB)
Additionally, we will be just making a very deep network that won’t fit in our GPU memory on its own. The reason to do this is so that we can ensure that we are actually saving memory. Really I just kept adding layers until I blew up my memory, there was no sane and logical methodology to this. So if you’re going to repeat the experiment just do that, it doesn’t really matter what the network looks like, we’re just using linear layers. I’m using a 2080 Super here that has 8GB of VRAM.
I want to add another side note: SMI reports that my card has 7951 MiB but the
runtime error that I get when oversubscribing tells me that there is 7.77GiB
total capacity.
Sequential Gradient Checkpointing (simple)
So the simplest thing to do is to use pytorch’s checkpoint_sequential
. This
has 3 arguments: the model, number of segments, and the input (the input can be
a tuple). Let’s just create a dummy network, make it big. All we need to do to
add checkpointing is change the forward
from torch.utils.checkpoint import checkpoints_sequential
...
def forward(self, x):
x = x.to(self.device)
x.requires_grad=True
o = checkpoint_sequential(self.my_model, self.num_segments, x)
return o
So there’s a few things to note here that might be odd, we’ll go line by line
(inside forward). First we send the variable to the device, you don’t have to do
that in forward but it is nice. The second line we ensure that the variable has
requires_grad
. We have to have this for checkpointing to work. Third we use
our magic line, which has assumed that in __init__
you have specified your
model and the number of segments you want to break your model into. The best way
to compose your model here is using nn.Sequential()
. You might get thrown off
by trying to break your model into too many segments.
So let’s look at how many times we’re splitting and what each metric reports. We will report how much memory is requested but can’t be allocated. We will also show how much we over subscribe and the percent savings from the first model that fits (reference).
Side Note: We are only looking at memory here. This network is small and the time differences are small for the two epochs we are training with and thus we don’t report. Expect the computation time to increase though.
Num Splits | Alloc (MB) | Max Alloc (MB) | SMI (MB) | Mem Oversub (MiB) | Using (%) |
---|---|---|---|---|---|
0 | NA | NA | NA | 382 | NA |
2 | 148 | 6033 | 7212 | NA | reference (100) |
4 | 145 | 4350 | 6000 | NA | 83.2 |
8 | 145 | 3777 | 5428 | NA | 75.3 |
10 | 145 | 4159 | 5810 | NA | 80.6 |
16 | 145 | 4999 | 6190 | NA | 85.8 |
So we’re learning some interesting things here, the biggest being that there is an optimal split. This was rather surprising to me. We can see that in the 8 splits we have about a 25% reduction in memory usage. What was surprising to me is that we don’t just get better and better. I am guessing that this is because of some overhead with the implementation. Either way, we have a very useful tool here.
Gradient Checkpointing (simple)
Next we’re going to just look at the standard gradient checkpointing method:
torch.utils.checkpointing.checkpoint
. This function is also pretty easy to
use, you just need the model and the input (tuples can be used). Really the same
thing we did before except we don’t have the number of segments. So why use the
other one? Because here you have to manually split your network into
subnetworks. So Our code might look like:
from torch.utils.checkpointing import checkpoint
...
def forward(self, x):
x = x.to(self.device)
x.requires_grad=True
o = checkpoint(self.model_0, x)
o = checkpoint(self.model_1, o)
return o
In this case we checkpoint the two sub-networks. Let’s (sorta) replicate the
experiment with this. We will use the same model actually. I say “sorta” because
we don’t have real control over how sequential splits things up. Instead I do my
best where I create 4 networks: 1) we checkpoint every nn.Linear
and
nn.ReLU
, 2) We checkpoint every pair, 3) we break into 3 equal-ish
sub-networks, and 4) we break into 4 equal-ish sub-networks.
Name | Alloc (MB) | Max Alloc (MB) | SMI (MB) | Using (%) |
---|---|---|---|---|
2 Subs | 148 | 6033 | 7212 | reference |
3 Subs | 145 | 6642 | 7908 | 110 |
Every Other | 83 | 3715 | 5408 | 75.0 |
Every | 38 | 4892 | 6170 | 85.6 |
Okay, so sanity check, our 2 sub-networks DOES match the 2 split from sequential. What’s super interesting is that with 3 sub-networks we actually use 10% MORE memory! For reference the “Every” network breaks things into 16 sub-networks and “Every Other” breaks it into 8, so we actually get pretty close on the similar splits.
What We’ve Learned
We’ve learned some interesting things. checkpoint_sequential
and checkpoint
don’t give the same numbers, although close, for what we would expect to be the
same networks/divisions. Before we were talking about GPT-3, we needed 11 V100’s
(32GB) but now we only need 9! Well, 8.2, but we can’t have 0.3 of a V100 :(. So
we could really save $20k here just from hardware, although we’ll use more
electricity. Sometimes that’s easier to get money for though (especially if
you’re a grad student and using someone else’s ;)
Okay, but what about FP16?
So let’s do the last version with FP16. We make a few changes to our code in our main function we do the following
...
net = model(device=torch.device('cuda')).half()
... training loop ...
out = net(x.half())
...
(Again, computation should take longer to the same convergence)
Name | Alloc (MB) | Max Alloc (MB) | SMI (MB) | Using (%) | Using From Prev (%) |
---|---|---|---|---|---|
2 Subs | 75 | 3015 | 4054 | reference | 56.2 |
3 Subs | 74 | 3318 | 4414 | 108.9 | 61.2 |
Every Other | 42 | 1856 | 2954 | 72.9 | 41.0 |
Every | 18 | 2443 | 3530 | 87.1 | 48.9 |
WOW we got some big savings. Not only could the original network not actually fit on our graphics card, BUT we were able to use 41% of the network THAT COULD! We’re probably using slightly less than 40% of our original network. This can be huge savings. Unfortunately we won’t get to reduce to 6 V100’s because we were already assuming FP16, but we needed to know this to save all that money!
All the code for this tutorial can be found on GitHub.