How slow is the tracing interpreter of PyPy's meta-tracing JIT?
I wanted to investigate the warmup behavior of the PyPy interpreter, so I wrote a somewhat arbitrary microbenchmark:
all_results = set() num = int(sys.argv[1]) class A(object): pass def main(): res = 0 for i in range(num): a = A() a.x = i d = {"a": a.x} l = [0, 1, d["a"]] res += l[-1] all_results.add(res)
This function is the ideal case for PyPy's JIT compiler, as it has a tight loop with many object allocations that all have predictable lifetimes. There is no control flow inside the loop. The JIT compiler can optimize this loop body extremely well, to just an integer addition.
However, I wanted to investigate the warmup behavior of the JIT compiler. PyPy's tracing JIT will trace this loop at its 1041th iteration. Tracing a loop is rather expensive in PyPy, due its meta-tracing architecture. When the loop is traced, there is actually a stack of two interpreters: the meta-tracing interpreter will execute the Python interpreter, while the latter executes one iteration of the Python loop. The meta-tracing interpreter executes the loop, but also keeps a log of all the operations that were executed by PyPy's Python interpreter. This log – the trace – is then optimized and compiled to machine code. From then on, that chunk of machine code will be executed instead of interpreting the loop body.
What really interests me here is just how costly the double-interpretation overhead of the meta-tracing interpreter is. I had no intuition about the slowdown factor of it, is it 10x, 100x, 1000x, 10,000x, 100,000? My goal is to understand the order of magnitude of that slowdown factor.
All the numbers in this blog post are run on PyPy 2.7.18 (because I had a nicely instrumented version of PyPy2 around), I expect the results to be similar (ie the same order of magnitude at least) on PyPy 3.11. Timing results are from my AMD Ryzen 7 PRO 7840U, running Ubuntu Linux 24.04.2.
Estimating the slowdown factor of the meta-tracing interpreter
To estimate the slowdown factor we can run this benchmark with num
set to
exactly 1041. Since the JIT will trace the 1041th iteration, this is the
maximally bad case for the JIT. The JIT will trace the loop, and the
meta-tracing interpreter will execute the Python interpreter for one iteration
of the loop, which is exactly what we want to measure. However, since we don't
execute another iteration, we never execute the compiled machine code, so we
can measure the overhead of the meta-tracing interpreter. If we set the
environment variable PYPYLOG=jit-summary:outputfilename
, we can see how much
time the meta-tracing interpreter spent tracing the loops.
To make the numbers slightly larger, I actually add and run 100 copies of the above main function, to reduce the noise of the measurement.
Click to expand!
from __future__ import print_function import sys all_funcs = [] all_results = set() s = """ class A%s(object): pass def main%s(): res = 0 for i in range(num): a = A%s() a.x = i d = {"a": a.x} l = [0, 1, d["a"]] res += l[-1] all_results.add(res) all_funcs.append(main%s) """ fullstr = "\n".join(s % (i, i, i, i) for i in range(100)) exec(fullstr) import time num = int(sys.argv[1]) t1 = time.time() for fn in all_funcs: fn() t2 = time.time() print(num, t2 - t1)
Now we can run this benchmark 100 times with bash to get a good average across many process execution:
for i in {0..100..1}; do PYPYLOG=jit-summary:out%d pypy -S x.py 1041 >> times; done
Afterwards, we can look at the out%d
files to see how much time the
meta-tracing interpreter spent tracing the loops (%d
is expanded to the
process id by the PyPy logging system). The content of these files will look
roughly like this:
[c8ca76f3db881] {jit-summary Tracing: 100 0.033022 Backend: 100 0.009439 TOTAL: 0.072409 ops: 66200 heapcached ops: 20200 recorded ops: 11200 calls: 2000 guards: 2200 opt ops: 9700 opt guards: 1800 opt guards shared: 1300 forcings: 0 abort: trace too long: 0 abort: compiling: 0 abort: vable escape: 0 abort: bad loop: 0 abort: force quasi-immut: 0 abort: segmenting trace: 0 virtualizables forced: 0 nvirtuals: 2500 nvholes: 0 nvreused: 1100 vecopt tried: 0 vecopt success: 0 Total # of loops: 100 Total # of bridges: 0 Freed # of loops: 0 Freed # of bridges: 0 [c8ca76f3e9618] jit-summary}
Let's extract the tracing times with grep:
grep Tracing out* | grep -o "0[.].*" > tracingtimes
Now we can compute the slowdown factor like this:
>>> with open('tracingtimes') as f: ... f = [float(l) for l in f] ... tracingtime = sum(f) / len(f) ... >>> with open('times') as f: ... f = [float(l.split()[1]) for l in f] ... exectime = sum(f) / len(f) ... >>> regulartime = (exectime - tracingtime) / 1040 >>> print(regulartime) 3.550879183375572e-05 >>> print(tracingtime) 0.03153696039603962 >>> slowdown = tracingtime / regulartime >>> print(slowdown) 888.1451259645404
(random side-note: I want to complain about the fact that the pygments lexer
does not support the >>>>
syntax of PyPy's interactive interpreter.)
This means that the meta-tracing interpreter is about ~900x slower than the regular interpreter for this benchmark. (Note that the tracing time alone is not the full warmup time of the JIT, since it doesn't include the time spent optimizing the trace and producing machine code.)
How many instructions does the meta-tracing interpreter execute?
Let's understand the work the Python interpreter does in one iteration of the
loop. We can disassemble one of the main
function to see what it does:
>>> import dis >>> dis.dis(main0) 6 0 LOAD_CONST 1 (0) 3 STORE_FAST 0 (res) 7 6 SETUP_LOOP 87 (to 96) 9 LOAD_GLOBAL 0 (range) 12 LOAD_GLOBAL 1 (num) 15 CALL_FUNCTION 1 18 GET_ITER >> 19 FOR_ITER 73 (to 95) 22 STORE_FAST 1 (i) 8 25 LOAD_GLOBAL 2 (A0) 28 CALL_FUNCTION 0 31 STORE_FAST 2 (a) 9 34 LOAD_FAST 1 (i) 37 LOAD_FAST 2 (a) 40 STORE_ATTR 3 (x) 10 43 BUILD_MAP 0 46 LOAD_FAST 2 (a) 49 LOAD_ATTR 3 (x) 52 LOAD_CONST 2 ('a') 55 STORE_MAP 56 STORE_FAST 3 (d) 11 59 LOAD_CONST 1 (0) 62 LOAD_CONST 3 (1) 65 LOAD_FAST 3 (d) 68 LOAD_CONST 2 ('a') 71 BINARY_SUBSCR 72 BUILD_LIST 3 75 STORE_FAST 4 (l) 12 78 LOAD_FAST 0 (res) 81 LOAD_FAST 4 (l) 84 LOAD_CONST 4 (-1) 87 BINARY_SUBSCR 88 INPLACE_ADD 89 STORE_FAST 0 (res) 92 JUMP_ABSOLUTE 19 >> 95 POP_BLOCK 13 >> 96 LOAD_GLOBAL 4 (all_results) 99 LOOKUP_METHOD 5 (add) 102 LOAD_FAST 0 (res) 105 CALL_METHOD 1 108 POP_TOP 109 LOAD_CONST 0 (None) 112 RETURN_VALUE
The loop body starts at bytecode offset 19 and ends at bytecode offset 95. The loop body consists of 25 Python bytecode instructions.
When the JIT compiler traces this loop, it will execute the Python interpreter for one iteration of the loop on top of the meta-tracing interpreter. I built a special version of PyPy that counts the number of opcodes executed by the meta-tracing interpreter while tracing. Here are the results:
function name in Python interpreter | number of operations executed | number of times called | number of operations traced |
---|---|---|---|
dispatch_bytecode |
864 | 28 | 0 |
PyFrame.dispatch |
708 | 1 | 0 |
W_TypeObject.lookup_where_with_method_cache |
142 | 6 | 0 |
handle_bytecode__AccessDirect_None |
112 | 28 | 0 |
list_BINARY_SUBSCR__AccessDirect_None |
79 | 2 | 0 |
W_TypeObject.descr_call |
77 | 1 | 3 |
_init_from_list_w_helper |
66 | 1 | 12 |
W_TypeObject.issubtype |
61 | 2 | 0 |
popvalues__AccessDirect_None |
54 | 1 | 3 |
getattr |
45 | 1 | 0 |
setattr |
43 | 1 | 0 |
Function.call_obj_args |
42 | 3 | 1 |
int_INPLACE_ADD |
41 | 1 | 4 |
jump_absolute |
37 | 1 | 2 |
BytesDictStrategy.setitem |
35 | 1 | 0 |
STORE_MAP |
35 | 1 | 0 |
_reorder_and_add |
34 | 1 | 0 |
call_args |
34 | 1 | 0 |
ll_dict_lookup |
33 | 2 | 5 |
setitem |
32 | 2 | 0 |
... | ... | ... | ... |
total | 3675 | 173 | 103 |
The names of these functions aren't too important, what matters is that the
meta-tracing interpreter executes 3675 operations to trace one iteration of
the loop in main
, which is a lot. It also traces 103 operations, most of
which are then removed again by the trace optimizer, leaving a trace with just
22 operations:
label(p0, p1, p6, p7, i45, i32, p20, p31, p35, i50, i36) debug_merge_point(0, 0, '<code object main0. file '<string>'. line 5> #19 FOR_ITER') i52 = int_lt(i50, 0) guard_false(i52) i53 = int_ge(i50, i36) guard_false(i53) i55 = int_add(i50, 1) debug_merge_point(0, 0, '<code object main0. file '<string>'. line 5> #25 LOAD_GLOBAL') # update iterater object setfield_gc(p20, i55, descr=<FieldS pypy.objspace.std.iterobject.W_AbstractSeqIterObject.inst_index 8>) guard_not_invalidated() debug_merge_point(0, 0, '<code object main0. file '<string>'. line 5> #88 INPLACE_ADD') # perform the addition, checking for machine integer overflow i56 = int_add_ovf(i45, i50) guard_no_overflow(descr=<Guard0x729830a52170>) [p0, p6, p7, p20, p1, i50, i32, i45] debug_merge_point(0, 0, '<code object main0. file '<string>'. line 5> #92 JUMP_ABSOLUTE') # check whether ctrl-c was pressed i58 = getfield_raw_i(125998063394752, descr=<FieldS pypysig_long_struct.c_value 0>) i60 = int_lt(i58, 0) guard_false(i60) debug_merge_point(0, 0, '<code object main0. file '<string>'. line 5> #19 FOR_ITER') jump(p0, p1, p6, p7, i56, i50, p20, p31, p35, i55, i36)
Note how in particular, no allocations are left in this trace. The optimize
recognized that neither the A
instance, the dictionary, nor the list need to
be allocated.
How long does it take for the JIT compilation time to be amortized?
To find out after how many iterations the JIT compilation time is amortized, we
can run the benchmark with varying num
values and measure the execution time,
both with and without the JIT. We do this by starting many processes using
more bash for
loops. We can then plot the execution time against the number of
iterations to see when the time with the JIT becomes lower than the time
without the JIT.
for i in {0..20000..1}; do pypy -S x.py "$i" >> datajit ; done; for i in {0..20000..1}; do pypy --jit off -S x.py "$i" >> datanojit ; done
And we can also run the same script on CPython, while we are at it:
for i in {1..10000..1}; do python3.13 -S x.py "$i" >> datacpy ; done;
(this is a bit apples against oranges because I am comparing PyPy2 against CPython 3.13 but oh well.)
Both will run for a while and then we can plot the data.
Gory matplotlib details hidden
# Co-created by Copilot import matplotlib.pyplot as plt import numpy as np THRESHOLD = 1041 # JIT threshold def plot_trace_costs(maximum, save_as, baseline_filename='datanojit', baseline_legend='PyPy No JIT'): with open('datajit', 'r') as f: datajit = [tuple(map(float, line.strip().split())) for line in f.readlines()] with open(baseline_filename, 'r') as f: datanojit = [tuple(map(float, line.strip().split())) for line in f.readlines()] # scale the y values by 1000 datajit = [(x, y * 1000) for x, y in datajit] datanojit = [(x, y * 1000) for x, y in datanojit] origdatajit = datajit.copy() # sort the data by x value datajit.sort(key=lambda x: x[0]) datanojit.sort(key=lambda x: x[0]) # filter out lines with x value greater than maximum datajit = [x for x in datajit if x[0] <= maximum] datanojit = [x for x in datanojit if x[0] <= maximum] # create new figure plt.figure(figsize=(10, 6)) plt.clf() # unpack data jit_x, jit_y = zip(*datajit) if datajit else ([], []) nojit_x, nojit_y = zip(*datanojit) if datanojit else ([], []) # plot the data points, using a scatter plot. use circle size of 2 plt.scatter(jit_x, jit_y, label='PyPy w/ JIT', alpha=0.3, s=1) plt.scatter(nojit_x, nojit_y, label=baseline_legend, alpha=0.3, s=1) # add thin vertical line at x=THRESHOLD plt.axvline(x=THRESHOLD, color='gray', linestyle='--', linewidth=1) # add text to the left of the line, near the top saying "JIT threshold" all_y = list(jit_y) + list(nojit_y) if all_y: plt.text(THRESHOLD, max(all_y) * 0.9, 'JIT threshold', color='gray', fontsize=10, ha='left') # fit a line to the JIT data, starting from x=THRESHOLD jit_x_fit = [x for x, _ in origdatajit if x >= THRESHOLD] jit_y_fit = [y for x, y in origdatajit if x >= THRESHOLD] assert len(jit_x_fit) == len(jit_y_fit), "JIT x and y data lengths do not match" assert jit_x_fit and jit_y_fit jit_fit = np.polyfit(jit_x_fit, jit_y_fit, 1) jit_fit_line = np.polyval(jit_fit, jit_x) plt.plot(jit_x, jit_fit_line, color='black', linestyle='--', linewidth=1) assert nojit_x and nojit_y nojit_x_fit = [x for x in nojit_x if x >= THRESHOLD] nojit_y_fit = [y for x, y in zip(nojit_x, nojit_y) if x >= THRESHOLD] assert nojit_x_fit and nojit_y_fit nojit_fit = np.polyfit(nojit_x_fit, nojit_y_fit, 1) nojit_fit_line = np.polyval(nojit_fit, nojit_x) plt.plot(nojit_x, nojit_fit_line, color='black', linestyle='--', linewidth=1) # compute intersection point of the two lines intersection_x = (nojit_fit[1] - jit_fit[1]) / (jit_fit[0] - nojit_fit[0]) intersection_y = jit_fit[0] * intersection_x + jit_fit[1] # also compute the speedup of the JIT and No JIT fits by dividing the slopes jit_speedup = 1 / (jit_fit[0] / nojit_fit[0]) # add text box to the upper right corner with the fit parameters fit_text = f'PyPy w/ JIT fit: y = {jit_fit[0]:.6f}x + {jit_fit[1]:.6f}\n{baseline_legend} fit: y = {nojit_fit[0]:.6f}x + {nojit_fit[1]:.6f}' speedup_text = f'Speedup: {jit_speedup:.2f}x\nIntersection: ({intersection_x:.2f}, {intersection_y:.2f})' fit_text += "\n" + speedup_text plt.text(0.95, 0.95, fit_text, transform=plt.gca().transAxes, fontsize=10, verticalalignment='top', horizontalalignment='right', bbox=dict(facecolor='white', alpha=0.5, edgecolor='none')) # compute the y-values of both fits at the JIT threshold jit_y_at_threshold = jit_fit[0] * THRESHOLD + jit_fit[1] nojit_y_at_threshold = nojit_fit[0] * THRESHOLD + nojit_fit[1] diff = jit_y_at_threshold - nojit_y_at_threshold diff_text = f'Difference at threshold: {diff:.2f} ms' print(diff_text) plt.xlabel('Iterations') plt.ylabel('time taken (microseconds)') plt.grid() plt.title(f'Warmup behaviour: PyPy w/ JIT vs {baseline_legend}') # legend on the top left corner plt.legend(loc='upper left') plt.savefig(save_as) if __name__ == '__main__': plot_trace_costs(maximum=20000, save_as='trace_costs_start.svg') plot_trace_costs(maximum=2000, save_as='trace_costs_very_start.svg') plot_trace_costs(maximum=10000, save_as='trace_costs_cpy.svg', baseline_filename='datacpy', baseline_legend='CPython 3.13')
This is what the plot looks like:
The plot shows that the JIT compilation time is amortized after about 2600 iterations. The JIT compilation time is about 39ms (which consists of the meta-tracing time plus code generation time). If the program runs for a very long time, the speedup of the JIT will approach >200x (but the example was picked to make this easy for the JIT, so it's not indicative of the performance of realistic Python code).
We can zoom in the at the first 2000 data points:
And here's the comparison with CPython 3.13:
Compared to CPython the speedup is ~160x and PyPy catches up with CPython only at around 4800 iterations. I don't know why CPython's performance is so much more noisy across the 10,000 iterations.
Conclusion
After these quick measurements we conclude that the meta-tracing interpreter is three orders of magnitude slower than the regular interpreter. We (Yusuke Izawa and I) plan to improve the performance of the meta-tracing interpreter in the future, and I look forward to seeing how the numbers change then.
Acknowledgements
Thanks to Yusuke Izawa for his work on instrumenting the meta-tracing interpreter, and for our previous collaboration.