Published on April 29, 2024 by Jérémy Scanvic
Writing slow code might very well be the easiest part of doing research. Fortunately slow code often turns out to be harmless, either because it rarely needs to run or because as slow as it might be compared to how fast it could be, it’s still relatively fast compared to the surrounding code. But sometimes slow code can be a big deal and figuring out which piece of code is problematic can be as painful as debugging a messy code base (if not worse). Fortunately, profilers can make our life easier when used properly.
In this post, I’m going to step you through how I’ve used a profiler to considerably speed up a Python program by about 55 seconds per iteration, while explaining my thought process along the way.
Before we delve into the details, let’s give a very simple definition of what a profiler is. A profiler is a tool which collects statistics about a program as it’s running, including how many times every function is called, and how long each call takes to complete. Profilers are valuable tools which enable us to find what functions are impacting the overall performance the most, allowing us to spend more time optimizing them instead of wasting time making the other faster.
The needle in a haystack
The first thing to do is to determine what we actually want to speed up. In my case, it’s a program written in Python for training deep neural networks.
Using the standard Python profiler cProfile, I’m able to run the program and obtain detailed statistics about its execution.
python -m cProfile -o stats.bin train.py
The statistics are written into a binary file (stats.bin
) and cannot be read directly. In order to display them in a human-readable format, we can use a simple Python snippet using the standard module pstats
which, as its name suggests, is made for dealing with profiling statistics.
import pstats
from pstats import SortKey
p = pstats.Stats('stats.bin')
p.sort_stats(SortKey.CUMULATIVE)
p.print_stats()
After running the snippet, the first thing I do is looking for the total execution time. In this case, it’s about 325 seconds.
9344245 function calls (9094709 primitive calls) in 325.736 seconds
Now, this is where it’s important not to get distracted. Right below this line is an extremely long table listing every function ever called, how many times each has been called and how long it took in total. The next thing I do is crucially not reading the table directly, but having a quick glance at the source code first, looking for functions that should be expensive to call.
for epoch in range(epochs):
for x, y in dataloader:
optimizer.zero_grad()
x_hat = model(y)
loss = loss_fn(x_hat, x)
loss.backward()
optimizer.step()
scheduler.step()
In my case, the script mostly consists of a training loop which is normally set to repeat for 500 iterations, but for the sake of profiling I’ve set it to 5 iterations. Inside the loop, I expect most of the time to be spent computing the forward pass (model(y)
), the backward pass (loss.backward()
) and loading the training pairs.
ncalls tottime percall cumtime percall filename:lineno(function)
500 0.011 0.000 3.020 0.006 model.py:39(forward)
500 0.004 0.000 2.828 0.006 torch/_tensor.py:463(backward)
Weirdly enough, only 6 seconds is spent computing the forward pass and the backward pass in total. That’s only 2% of the total execution time! It’s definitely not what I was expecting. This is where the profiler comes in handy, if I hadn’t used it I might have spent considerable time optimizing the deep neural network and it wouldn’t have made the program any faster.
ncalls tottime percall cumtime percall filename:lineno(function)
4000 6.721 0.002 266.710 0.067 dataset.py:26(__getitem__)
As it turns out, a disproportionate amount of time is spent loading training pairs. A whopping 265 seconds in total, or 80% of the total execution time. And this is only for 5 epochs, for a normal run with 500 epochs it’d take more than 7 hours! We’ve found the bottleneck and it’s a very tight one.
A low-hanging fruit
Arguably, loading the training pairs once shouldn’t take 55 seconds in the first place and I could try to optimize the initial loading procedure. But for now, I’m going to take the pragmatic approach and grab the low-hanging fruit. Without going too much into the details, I implemented memoization on top of the loading procedure. What this means is it’s going to take as long as before to load the training pairs the first time, but they’re going to be stored in memory and the next time they need to be loaded, it’s going to be near instantaneous.
from functools import wraps
class Dataset:
@staticmethod
def _memoize(f):
cache = {}
@wraps(f)
def wrapper(*args, **kwargs):
key = (args, frozenset(kwargs.items()))
if key not in cache:
cache[key] = f(*args, **kwargs)
return cache[key]
return wrapper
@_memoize
def __getitem__(self, index):
# a very slow piece of code
...
The most important step
The final and perhaps most important step is to make sure the program is actually running faster now that we’ve optimized it. To do so, we can run the profiler once more.
9025141 function calls (8775585 primitive calls) in 98.230 seconds
ncalls tottime percall cumtime percall filename:lineno(function)
800 1.935 0.002 55.233 0.069 dataset.py:26(__getitem__)
As expected, loading the training pairs is 5 times faster this time, or only about 55 seconds in total. This makes the whole program complete in less than 100 seconds, 3 times quicker than it used to be.
ncalls tottime percall cumtime percall filename:lineno(function)
500 0.009 0.000 2.794 0.006 model.py:39(forward)
500 0.004 0.000 3.453 0.007 torch/_tensor.py:463(backward)
And as a sanity check, we can verify that the forward pass and backward pass take as long as they used to (around 6 seconds in total).
Conclusion
We’ve seen how a profiler can be used to speed up a training script by about 55 seconds per iteration, by adding only a few lines of code and while spending very little time investigating the code. This shows that profilers can be extremely powerful tools in the hands of pragmatic minds.