How slow is the tracing interpreter of PyPy's meta-tracing JIT?

I wanted to investigate the warmup behavior of the PyPy interpreter, so I wrote a somewhat arbitrary microbenchmark:

all_results = set()
num = int(sys.argv[1])

class A(object):
    pass

def main():
    res = 0
    for i in range(num):
        a = A()
        a.x = i
        d = {"a": a.x}
        l = [0, 1, d["a"]]
        res += l[-1]
    all_results.add(res)

This function is the ideal case for PyPy's JIT compiler, as it has a tight loop with many object allocations that all have predictable lifetimes. There is no control flow inside the loop. The JIT compiler can optimize this loop body extremely well, to just an integer addition.

However, I wanted to investigate the warmup behavior of the JIT compiler. PyPy's tracing JIT will trace this loop at its 1041th iteration. Tracing a loop is rather expensive in PyPy, due its meta-tracing architecture. When the loop is traced, there is actually a stack of two interpreters: the meta-tracing interpreter will execute the Python interpreter, while the latter executes one iteration of the Python loop. The meta-tracing interpreter executes the loop, but also keeps a log of all the operations that were executed by PyPy's Python interpreter. This log – the trace – is then optimized and compiled to machine code. From then on, that chunk of machine code will be executed instead of interpreting the loop body.

What really interests me here is just how costly the double-interpretation overhead of the meta-tracing interpreter is. I had no intuition about the slowdown factor of it, is it 10x, 100x, 1000x, 10,000x, 100,000? My goal is to understand the order of magnitude of that slowdown factor.

All the numbers in this blog post are run on PyPy 2.7.18 (because I had a nicely instrumented version of PyPy2 around), I expect the results to be similar (ie the same order of magnitude at least) on PyPy 3.11. Timing results are from my AMD Ryzen 7 PRO 7840U, running Ubuntu Linux 24.04.2.

Estimating the slowdown factor of the meta-tracing interpreter

To estimate the slowdown factor we can run this benchmark with num set to exactly 1041. Since the JIT will trace the 1041th iteration, this is the maximally bad case for the JIT. The JIT will trace the loop, and the meta-tracing interpreter will execute the Python interpreter for one iteration of the loop, which is exactly what we want to measure. However, since we don't execute another iteration, we never execute the compiled machine code, so we can measure the overhead of the meta-tracing interpreter. If we set the environment variable PYPYLOG=jit-summary:outputfilename, we can see how much time the meta-tracing interpreter spent tracing the loops.

To make the numbers slightly larger, I actually add and run 100 copies of the above main function, to reduce the noise of the measurement.

Click to expand!
from __future__ import print_function
import sys

all_funcs = []
all_results = set()
s = """
class A%s(object):
    pass

def main%s():
    res = 0
    for i in range(num):
        a = A%s()
        a.x = i
        d = {"a": a.x}
        l = [0, 1, d["a"]]
        res += l[-1]
    all_results.add(res)
all_funcs.append(main%s)
"""
fullstr = "\n".join(s % (i, i, i, i) for i in range(100))
exec(fullstr)

import time

num = int(sys.argv[1])

t1 = time.time()
for fn in all_funcs:
    fn()
t2 = time.time()
print(num, t2 - t1)

Now we can run this benchmark 100 times with bash to get a good average across many process execution:

for i in {0..100..1};
    do PYPYLOG=jit-summary:out%d pypy -S x.py 1041 >> times;
done

Afterwards, we can look at the out%d files to see how much time the meta-tracing interpreter spent tracing the loops (%d is expanded to the process id by the PyPy logging system). The content of these files will look roughly like this:

[c8ca76f3db881] {jit-summary
Tracing:        100     0.033022
Backend:        100     0.009439
TOTAL:                  0.072409
ops:                    66200
heapcached ops:         20200
recorded ops:           11200
  calls:                2000
guards:                 2200
opt ops:                9700
opt guards:             1800
opt guards shared:      1300
forcings:               0
abort: trace too long:  0
abort: compiling:       0
abort: vable escape:    0
abort: bad loop:        0
abort: force quasi-immut:       0
abort: segmenting trace:        0
virtualizables forced:  0
nvirtuals:              2500
nvholes:                0
nvreused:               1100
vecopt tried:           0
vecopt success:         0
Total # of loops:       100
Total # of bridges:     0
Freed # of loops:       0
Freed # of bridges:     0
[c8ca76f3e9618] jit-summary}

Let's extract the tracing times with grep:

grep Tracing out* | grep -o "0[.].*" > tracingtimes

Now we can compute the slowdown factor like this:

>>> with open('tracingtimes') as f:
...     f = [float(l) for l in f]
...     tracingtime = sum(f) / len(f)
...     
>>> with open('times') as f:
...     f = [float(l.split()[1]) for l in f]
...     exectime = sum(f) / len(f)
...     
>>> regulartime = (exectime - tracingtime) / 1040
>>> print(regulartime)
3.550879183375572e-05
>>> print(tracingtime)
0.03153696039603962
>>> slowdown = tracingtime / regulartime
>>> print(slowdown)
888.1451259645404

(random side-note: I want to complain about the fact that the pygments lexer does not support the >>>> syntax of PyPy's interactive interpreter.)

This means that the meta-tracing interpreter is about ~900x slower than the regular interpreter for this benchmark. (Note that the tracing time alone is not the full warmup time of the JIT, since it doesn't include the time spent optimizing the trace and producing machine code.)

How many instructions does the meta-tracing interpreter execute?

Let's understand the work the Python interpreter does in one iteration of the loop. We can disassemble one of the main function to see what it does:

>>> import dis
>>> dis.dis(main0)
  6           0 LOAD_CONST               1 (0)
              3 STORE_FAST               0 (res)

  7           6 SETUP_LOOP              87 (to 96)
              9 LOAD_GLOBAL              0 (range)
             12 LOAD_GLOBAL              1 (num)
             15 CALL_FUNCTION            1
             18 GET_ITER            
        >>   19 FOR_ITER                73 (to 95)
             22 STORE_FAST               1 (i)

  8          25 LOAD_GLOBAL              2 (A0)
             28 CALL_FUNCTION            0
             31 STORE_FAST               2 (a)

  9          34 LOAD_FAST                1 (i)
             37 LOAD_FAST                2 (a)
             40 STORE_ATTR               3 (x)

 10          43 BUILD_MAP                0
             46 LOAD_FAST                2 (a)
             49 LOAD_ATTR                3 (x)
             52 LOAD_CONST               2 ('a')
             55 STORE_MAP           
             56 STORE_FAST               3 (d)

 11          59 LOAD_CONST               1 (0)
             62 LOAD_CONST               3 (1)
             65 LOAD_FAST                3 (d)
             68 LOAD_CONST               2 ('a')
             71 BINARY_SUBSCR       
             72 BUILD_LIST               3
             75 STORE_FAST               4 (l)

 12          78 LOAD_FAST                0 (res)
             81 LOAD_FAST                4 (l)
             84 LOAD_CONST               4 (-1)
             87 BINARY_SUBSCR       
             88 INPLACE_ADD         
             89 STORE_FAST               0 (res)
             92 JUMP_ABSOLUTE           19
        >>   95 POP_BLOCK           

 13     >>   96 LOAD_GLOBAL              4 (all_results)
             99 LOOKUP_METHOD            5 (add)
            102 LOAD_FAST                0 (res)
            105 CALL_METHOD              1
            108 POP_TOP             
            109 LOAD_CONST               0 (None)
            112 RETURN_VALUE  

The loop body starts at bytecode offset 19 and ends at bytecode offset 95. The loop body consists of 25 Python bytecode instructions.

When the JIT compiler traces this loop, it will execute the Python interpreter for one iteration of the loop on top of the meta-tracing interpreter. I built a special version of PyPy that counts the number of opcodes executed by the meta-tracing interpreter while tracing. Here are the results:

function name in Python interpreter number of operations executed number of times called number of operations traced
dispatch_bytecode 864 28 0
PyFrame.dispatch 708 1 0
W_TypeObject.lookup_where_with_method_cache 142 6 0
handle_bytecode__AccessDirect_None 112 28 0
list_BINARY_SUBSCR__AccessDirect_None 79 2 0
W_TypeObject.descr_call 77 1 3
_init_from_list_w_helper 66 1 12
W_TypeObject.issubtype 61 2 0
popvalues__AccessDirect_None 54 1 3
getattr 45 1 0
setattr 43 1 0
Function.call_obj_args 42 3 1
int_INPLACE_ADD 41 1 4
jump_absolute 37 1 2
BytesDictStrategy.setitem 35 1 0
STORE_MAP 35 1 0
_reorder_and_add 34 1 0
call_args 34 1 0
ll_dict_lookup 33 2 5
setitem 32 2 0
... ... ... ...
total 3675 173 103

The names of these functions aren't too important, what matters is that the meta-tracing interpreter executes 3675 operations to trace one iteration of the loop in main, which is a lot. It also traces 103 operations, most of which are then removed again by the trace optimizer, leaving a trace with just 22 operations:

label(p0, p1, p6, p7, i45, i32, p20, p31, p35, i50, i36)
debug_merge_point(0, 0, '<code object main0. file '<string>'. line 5> #19 FOR_ITER')
i52 = int_lt(i50, 0)
guard_false(i52)
i53 = int_ge(i50, i36)
guard_false(i53)
i55 = int_add(i50, 1)
debug_merge_point(0, 0, '<code object main0. file '<string>'. line 5> #25 LOAD_GLOBAL')
# update iterater object
setfield_gc(p20, i55, descr=<FieldS pypy.objspace.std.iterobject.W_AbstractSeqIterObject.inst_index 8>)
guard_not_invalidated()
debug_merge_point(0, 0, '<code object main0. file '<string>'. line 5> #88 INPLACE_ADD')
# perform the addition, checking for machine integer overflow
i56 = int_add_ovf(i45, i50)
guard_no_overflow(descr=<Guard0x729830a52170>) [p0, p6, p7, p20, p1, i50, i32, i45]
debug_merge_point(0, 0, '<code object main0. file '<string>'. line 5> #92 JUMP_ABSOLUTE')
# check whether ctrl-c was pressed
i58 = getfield_raw_i(125998063394752, descr=<FieldS pypysig_long_struct.c_value 0>)
i60 = int_lt(i58, 0)
guard_false(i60)
debug_merge_point(0, 0, '<code object main0. file '<string>'. line 5> #19 FOR_ITER')
jump(p0, p1, p6, p7, i56, i50, p20, p31, p35, i55, i36)

Note how in particular, no allocations are left in this trace. The optimize recognized that neither the A instance, the dictionary, nor the list need to be allocated.

How long does it take for the JIT compilation time to be amortized?

To find out after how many iterations the JIT compilation time is amortized, we can run the benchmark with varying num values and measure the execution time, both with and without the JIT. We do this by starting many processes using more bash for loops. We can then plot the execution time against the number of iterations to see when the time with the JIT becomes lower than the time without the JIT.

for i in {0..20000..1}; do pypy -S x.py "$i" >> datajit ; done; for i in {0..20000..1}; do pypy --jit off -S x.py "$i" >> datanojit ; done

And we can also run the same script on CPython, while we are at it:

for i in {1..10000..1}; do python3.13 -S x.py "$i" >> datacpy ; done;

(this is a bit apples against oranges because I am comparing PyPy2 against CPython 3.13 but oh well.)

Both will run for a while and then we can plot the data.

Gory matplotlib details hidden
# Co-created by Copilot
import matplotlib.pyplot as plt
import numpy as np

THRESHOLD = 1041  # JIT threshold

def plot_trace_costs(maximum, save_as, baseline_filename='datanojit', baseline_legend='PyPy No JIT'):
    with open('datajit', 'r') as f:
        datajit = [tuple(map(float, line.strip().split())) for line in f.readlines()]
    with open(baseline_filename, 'r') as f:
        datanojit = [tuple(map(float, line.strip().split())) for line in f.readlines()]
    # scale the y values by 1000
    datajit = [(x, y * 1000) for x, y in datajit]
    datanojit = [(x, y * 1000) for x, y in datanojit]
    origdatajit = datajit.copy()
    # sort the data by x value
    datajit.sort(key=lambda x: x[0])
    datanojit.sort(key=lambda x: x[0])
    # filter out lines with x value greater than maximum
    datajit = [x for x in datajit if x[0] <= maximum]
    datanojit = [x for x in datanojit if x[0] <= maximum]
    # create new figure
    plt.figure(figsize=(10, 6))
    plt.clf()

    # unpack data
    jit_x, jit_y = zip(*datajit) if datajit else ([], [])
    nojit_x, nojit_y = zip(*datanojit) if datanojit else ([], [])

    # plot the data points, using a scatter plot. use circle size of 2
    plt.scatter(jit_x, jit_y, label='PyPy w/ JIT', alpha=0.3, s=1)
    plt.scatter(nojit_x, nojit_y, label=baseline_legend, alpha=0.3, s=1)
    # add thin vertical line at x=THRESHOLD
    plt.axvline(x=THRESHOLD, color='gray', linestyle='--', linewidth=1)
    # add text to the left of the line, near the top saying "JIT threshold"
    all_y = list(jit_y) + list(nojit_y)
    if all_y:
        plt.text(THRESHOLD, max(all_y) * 0.9, 'JIT threshold', color='gray', fontsize=10, ha='left')

    # fit a line to the JIT data, starting from x=THRESHOLD
    jit_x_fit = [x for x, _ in origdatajit if x >= THRESHOLD]
    jit_y_fit = [y for x, y in origdatajit if x >= THRESHOLD]
    assert len(jit_x_fit) == len(jit_y_fit), "JIT x and y data lengths do not match"
    assert jit_x_fit and jit_y_fit
    jit_fit = np.polyfit(jit_x_fit, jit_y_fit, 1)
    jit_fit_line = np.polyval(jit_fit, jit_x)
    plt.plot(jit_x, jit_fit_line, color='black', linestyle='--', linewidth=1)
    assert nojit_x and nojit_y
    nojit_x_fit = [x for x in nojit_x if x >= THRESHOLD]
    nojit_y_fit = [y for x, y in zip(nojit_x, nojit_y) if x >= THRESHOLD]
    assert nojit_x_fit and nojit_y_fit
    nojit_fit = np.polyfit(nojit_x_fit, nojit_y_fit, 1)
    nojit_fit_line = np.polyval(nojit_fit, nojit_x)
    plt.plot(nojit_x, nojit_fit_line, color='black', linestyle='--', linewidth=1)

    # compute intersection point of the two lines
    intersection_x = (nojit_fit[1] - jit_fit[1]) / (jit_fit[0] - nojit_fit[0])
    intersection_y = jit_fit[0] * intersection_x + jit_fit[1]
    # also compute the speedup of the JIT and No JIT fits by dividing the slopes
    jit_speedup = 1 / (jit_fit[0] / nojit_fit[0])

    # add text box to the upper right corner with the fit parameters
    fit_text = f'PyPy w/ JIT fit: y = {jit_fit[0]:.6f}x + {jit_fit[1]:.6f}\n{baseline_legend} fit: y = {nojit_fit[0]:.6f}x + {nojit_fit[1]:.6f}'
    speedup_text = f'Speedup: {jit_speedup:.2f}x\nIntersection: ({intersection_x:.2f}, {intersection_y:.2f})'
    fit_text += "\n" + speedup_text
    plt.text(0.95, 0.95, fit_text, transform=plt.gca().transAxes, fontsize=10,
             verticalalignment='top', horizontalalignment='right',
             bbox=dict(facecolor='white', alpha=0.5, edgecolor='none'))
    # compute the y-values of both fits at the JIT threshold
    jit_y_at_threshold = jit_fit[0] * THRESHOLD + jit_fit[1]
    nojit_y_at_threshold = nojit_fit[0] * THRESHOLD + nojit_fit[1]
    diff = jit_y_at_threshold - nojit_y_at_threshold
    diff_text = f'Difference at threshold: {diff:.2f} ms'
    print(diff_text)
    plt.xlabel('Iterations')
    plt.ylabel('time taken (microseconds)')
    plt.grid()
    plt.title(f'Warmup behaviour: PyPy w/ JIT vs {baseline_legend}')
    # legend on the top left corner
    plt.legend(loc='upper left')
    plt.savefig(save_as)

if __name__ == '__main__':
    plot_trace_costs(maximum=20000, save_as='trace_costs_start.svg')
    plot_trace_costs(maximum=2000, save_as='trace_costs_very_start.svg')
    plot_trace_costs(maximum=10000, save_as='trace_costs_cpy.svg', baseline_filename='datacpy', baseline_legend='CPython 3.13')

This is what the plot looks like: Trace Costs: PyPy w/ JIT vs PyPy No JIT

The plot shows that the JIT compilation time is amortized after about 2600 iterations. The JIT compilation time is about 39ms (which consists of the meta-tracing time plus code generation time). If the program runs for a very long time, the speedup of the JIT will approach >200x (but the example was picked to make this easy for the JIT, so it's not indicative of the performance of realistic Python code).

We can zoom in the at the first 2000 data points:

Trace Costs: PyPy w/ JIT vs PyPy No JIT, first 2000 points

And here's the comparison with CPython 3.13:

Trace Costs: PyPy w/ JIT vs CPython 3.13

Compared to CPython the speedup is ~160x and PyPy catches up with CPython only at around 4800 iterations. I don't know why CPython's performance is so much more noisy across the 10,000 iterations.

Conclusion

After these quick measurements we conclude that the meta-tracing interpreter is three orders of magnitude slower than the regular interpreter. We (Yusuke Izawa and I) plan to improve the performance of the meta-tracing interpreter in the future, and I look forward to seeing how the numbers change then.

Acknowledgements

Thanks to Yusuke Izawa for his work on instrumenting the meta-tracing interpreter, and for our previous collaboration.