KnownBits in Lean
In January I was at the excellent E-Graphs Dagstuhl where I met some really amazing people. One of the things I did during that week was to talk to Marcus Rossel and Andrés Goens about proving the soundness of Knownbits transfer functions in Lean. I don't really know any Lean, so it was really useful to have two people that know a lot about it to hold my hand through my attempts. Thanks a lot, Marcus and Andrés! The resulting Lean code is below, I'm sure it's not really that great.
Why Knownbits in Lean? In the really long blog post about Knownbits on the
PyPy blog I had checked soundness of the knownbits transfer functions with
Z3. However, that is a bit unsatisfying because bitvectors in Z3 require a
fixed width. The goal with Lean was to get bitwidth-independent proofs and to
learn a bit more about Lean. This succeeded only partly. We managed to prove
the soundness of bitwise not and bitwise and transfer functions, but
not of addition (which requires a much more complicated proof). Z3 just
successfully checks this with bit-blasting, but only for fixed bitvector
widths.
Proving Precision
On the other hand, I managed to prove precision of the bitwise not and
bitwise and transfer functions. The precision theorem for bitwise not states
that the result of the bitwise not transfer function is the smallest knownbits
value that is a sound bitwise not of the input knownbits value (and
equivalently for and). Z3 can check this for bitwise not, but times out
for bitwise and.
The QEMU smask domain
In the last couple of months I have also been interested in QEMU.
QEMU's Tiny Code Generator IR has a number of optimizations and also
contains a knownbits-style abstract domain. Additionally there is also the
smask domain.
It took me a while to understand how it works, it's basically a very coarse
range domain that can only represent ranges of the form [-2**n, 2**n - 1].
The domain values are represented with a single bitvector, which is a sequence
of 1s followed by all 0s. A concrete bitvector matches that smask, if all
the masked bits of concrete bitvector are all equal to the sign bit (i.e. all 0
or all 1). For example, the smask=0b11110000 for 8-bit bitvectors would
denote the range [-16, 15].
At the end of the Lean code below there is a simple sketch for this smask
domain from QEMU.
I did not actually get very far with the smask domain, but I managed to
prove soundness of the bitwise and transfer function, and defined the
addition transfer function. (I've also implemented a transfer function for
addition for QEMU for their full abstract domain, but never submitted it,
because I don't understand how they unit-test their optimizer and am
additionally scared by mailing-list based patch submission.)
Conclusion and Code
In conclusion, it was fun to learn some more about Lean! Being able to
formulate bitwidth-independent proofs, compared to checking fixed bitwidths
with Z3, feels like progress. I was also happy that I managed to prove
precision of the not and and transfer functions. However, add would
be quite a bit more effort (the appendix of the "Tristate" paper contains a
proof, but it's not simple at all). Given that I work a lot on other things
at the moment it's unlikely that I'll return to it in the near future. I hope
that some of the ongoing work on bitwidth-independent bitvector proof
checking can automate this work (but I didn't try out the code yet).
Hacking on Lean proofs was one of my personal highlights of the E-Graphs Dagstuhl, but some other really fun conversations happened there. I hope to blog about some of them at some point.
The actual Lean code is below. I've tried to intersperse it with at least some comments about what is going on. The code is also on Github.
import Std /- bitvector lemmas -/ theorem and_xor_distrib_left {x y z : BitVec w} : x &&& (y ^^^ z) = (x &&& y) ^^^ (x &&& z) := BitVec.eq_of_getElem_eq (by simp [Bool.and_xor_distrib_left]) theorem and_xor_distrib_right {x y z : BitVec w} : (x ^^^ y) &&& z = (x &&& z) ^^^ (y &&& z) := BitVec.eq_of_getElem_eq (by simp [Bool.and_xor_distrib_right]) theorem xor_both_sides (x y z : BitVec n) (h : x = y ^^^ z) : x ^^^ z = y := by grind theorem and_minus_one (x : BitVec n) : x &&& (~~~0) = x := by simp /- KnownBits -/ /- representation is like in the "tristate" paper https://arxiv.org/abs/2105.05398 table of values: | unknowns | ones | meaning | |----------|------|-----------| | 0 | 0 | 0 | | 0 | 1 | 1 | | 1 | 0 | unknown | | 1 | 1 | forbidden | in addition to the unknowns and ones bitvectors, we also have a well-formedness condition that the ones and unknowns bitvectors cannot overlap, which means that the forbidden state is not allowed. when defining a function that creates a knownbits, we need to prove that the well-formedness condition is satisfied. -/ structure KnownBits (n : Nat) where ones : BitVec n unknowns : BitVec n wf : ones &&& unknowns = 0 := by first | bv_decide | grind namespace KnownBits /- extensionality: two knownbits are equal if their ones and unknowns are equal -/ theorem ext {n} (a b : KnownBits n) (h : a.ones = b.ones ∧ a.unknowns = b.unknowns) : a = b := by cases a; cases b; simp at *; exact h; /- constructor from a constant value. the resulting knownbits has no unknowns and is a singleton set -/ @[grind, simp] def fromConst (bv : BitVec n) : KnownBits n where ones := bv unknowns := 0 /- access the known bits, just flip the unknowns -/ @[grind, simp] def knowns (kb : KnownBits n) : BitVec n := ~~~kb.unknowns /- access the zero bits, which are the known bits that are zero -/ @[grind, simp] def zeros (kb : KnownBits n) : BitVec n := kb.knowns &&& ~~~kb.ones /- membership of a concrete value in the knownbits set -/ @[grind, simp] def Contains (kb : KnownBits n) (val : BitVec n) : Prop := val &&& kb.knowns = kb.ones /- subset relation, mathematically formulated -/ @[grind, simp] def Subset (kb₁ kb₂ : KnownBits n) : Prop := ∀ bv, kb₁.Contains bv → kb₂.Contains bv /- a knownbits is a const set if it has no unknowns -/ def IsConst {n} (kb : KnownBits n) : Prop := kb.unknowns = 0 /- a const set contains only one value -/ theorem const_implies_singleton_set {n} {kb : KnownBits n} (h : kb.IsConst) : ∀ {bv}, kb.Contains bv → bv = kb.ones := by intro bv hcontains; simp [Contains] at hcontains; simp [IsConst] at h; rw [h] at hcontains; simp at hcontains; exact hcontains /- membership instance, allows using `∈` to check membership (except that it doesn't work for lean reasons that I don't understand) -/ instance : Membership (BitVec n) (KnownBits n) where mem := Contains theorem knowns_and_ones_equal_ones (kb : KnownBits n) : kb.knowns &&& kb.ones = kb.ones := by simp [knowns]; have h := kb.wf; have h₂ : 0 = kb.ones ^^^ kb.ones := by simp; rw [h₂] at h; have h₃ : kb.ones &&& kb.unknowns ^^^ kb.ones = kb.ones := by grind; have h₄ : kb.ones &&& kb.unknowns ^^^ kb.ones &&& (~~~0) = kb.ones := by grind; rw [←and_xor_distrib_left] at h₄; have h₅ : kb.ones &&& (~~~kb.unknowns) = kb.ones := by grind; rw [BitVec.and_comm] at h₅; exact h₅ /- every knownbits contains its ones (which means that all unknown bits are set to zero) -/ theorem contains_ones {n} (kb : KnownBits n) : kb.Contains kb.ones := by grind [knowns_and_ones_equal_ones, Contains] /- given any bitvector, we can force it to be a member of the knownbits set by setting the known bits to the desired value -/ def force_membership {n} (kb : KnownBits n) (bv : BitVec n) : BitVec n := kb.ones ||| (bv &&& kb.unknowns) theorem force_membership_is_member {n} (kb : KnownBits n) (bv : BitVec n) : kb.Contains (force_membership kb bv) := by have h := kb.wf; rw [force_membership]; rw [Contains]; have h₁ := knowns_and_ones_equal_ones kb; rw [← h₁]; grind /- invert (bitwise not) -/ /- let's start with the simplest transfer function, bitwise not. the definition is simple, all known bits are flipped-/ @[grind, simp] def invert (kb : KnownBits n) : KnownBits n where ones := kb.zeros unknowns := kb.unknowns /- mathematical definition of inversion. inv is a sound inversion of orig, if for all values bv in orig, then ~~~bv is in inv -/ @[grind, simp] def IsInversionOf (inv orig : KnownBits n) := ∀ {bv}, orig.Contains bv → inv.Contains (~~~bv) /- now we can prove that invert as defined above is indeed a sound inversion -/ theorem invert_sound (kb : KnownBits n) : kb.invert.IsInversionOf kb := by grind /- two helper theorems -/ theorem invert_invert_is_identity (kb : KnownBits n) : kb.invert.invert = kb := by rw [ext kb.invert.invert kb]; have h := kb.wf; simp [invert]; have h₁ : ~~~kb.unknowns &&& ~~~(~~~kb.unknowns &&& ~~~kb.ones) = ~~~kb.unknowns &&& kb.ones := by grind; rw [h₁]; rw [← knowns]; exact knowns_and_ones_equal_ones kb; theorem invert_implies_original_set_contains_invert {n} (kb : KnownBits n) (x : BitVec n) (h : kb.invert.Contains x) : kb.Contains (~~~x) := by rw [← invert_invert_is_identity kb]; grind /- theorem that invert is the most precise inversion. if we have any other inversion of the original set kb, then invert is a subset of that inversion -/ theorem invert_precise {kb inv : KnownBits n} (h : inv.IsInversionOf kb) : kb.invert.Subset inv := by intro bv hv; have h₁ : kb.Contains (~~~bv) := invert_implies_original_set_contains_invert kb bv hv; have h₂ : _ := h h₁; rw [BitVec.not_not] at h₂; exact h₂ /- and -/ /- the next simplest transfer function, for bitwise and. it's more complicated because it's a binary function, but it's still a bitwise function. the definition is simple, here's a table: | kb₁ bit | kb₂ bit | result bit | |---------|---------|------------| | 0 | 0 | 0 | | 0 | 1 | 0 | | 0 | ? | 0 | | 1 | 0 | 0 | | 1 | 1 | 1 | | 1 | ? | ? | | ? | 0 | 0 | | ? | 1 | ? | | ? | ? | ? | -/ @[grind, simp] def and (kb₁ kb₂ : KnownBits n) : KnownBits n := let ones := kb₁.ones &&& kb₂.ones let knowns := kb₁.zeros ||| kb₂.zeros ||| ones { ones, unknowns := ~~~knowns } /- mathematical criterion for when a knownbits is the sound bitwise and of two other knownbits -/ @[grind, simp] def IsAndOf (and a b : KnownBits n) := ∀ {bv₁ bv₂}, a.Contains bv₁ ∧ b.Contains bv₂ → and.Contains (bv₁ &&& bv₂) /- proof that and is sound -/ theorem and_sound {n} (kb₁ kb₂ : KnownBits n) : (kb₁.and kb₂).IsAndOf kb₁ kb₂ := by simp only [IsAndOf, and, Contains]; grind /- proof that and constant-folds, ie if the two arguments are constants, the result is also a constant-/ theorem and_constfolds {n} (kb₁ kb₂ : KnownBits n) (h₁ : kb₁.IsConst) (h₂ : kb₂.IsConst) : (kb₁.and kb₂).IsConst := by simp [IsConst, and] at *; rw [h₁, h₂]; grind; /- helper lemmas for proving precision of and -/ theorem and_force_membership {n} (kb₁ kb₂ : KnownBits n) (bv : BitVec n) (h : (kb₁.and kb₂).Contains bv) : (force_membership kb₁ bv &&& force_membership kb₂ bv) = bv := by simp [force_membership]; simp [and] at h; -- By simplifying the expression using the properties of bitwise operations, we can show that the result is indeed `bv`. have h_simp : (kb₁.ones ||| bv &&& kb₁.unknowns) &&& (kb₂.ones ||| bv &&& kb₂.unknowns) = (kb₁.ones &&& kb₂.ones) ||| (bv &&& (kb₁.unknowns ||| kb₂.unknowns)) := by grind; grind theorem apply_is_and_of {n} (and a b : KnownBits n) (h : and.IsAndOf a b) {bv₁ bv₂} (h₁ : a.Contains bv₁) (h₂ : b.Contains bv₂) : and.Contains (bv₁ &&& bv₂) := by grind; /- prove precision of and: if we have any other sound bitwise and of two knownbits, then the and function is a subset of that inversion -/ theorem and_precise {and a b : KnownBits n} (h : and.IsAndOf a b) : (a.and b).Subset and := by intro bv hv; simp only [IsAndOf] at h; have h₁ := force_membership_is_member a bv; have h₂ := force_membership_is_member b bv; have h₃ : (force_membership a bv &&& force_membership b bv) = bv := and_force_membership a b bv hv; have h₄ := apply_is_and_of and a b h h₁ h₂; rw [h₃] at h₄; exact h₄ /- add -/ /- add is complicated, in the tristate paper there's a long and tricky proof of soundness of add. I didn't manage to port the proof to lean yet-/ @[grind, simp] def add (kb₁ kb₂ : KnownBits n) : KnownBits n := let sumOnes := kb₁.ones + kb₂.ones let sumUnknowns := kb₁.unknowns + kb₂.unknowns let allCarriers := sumOnes + sumUnknowns let onesCarrier := allCarriers ^^^ sumOnes let unknowns := kb₁.unknowns ||| kb₂.unknowns ||| onesCarrier let ones := sumOnes &&& ~~~unknowns { ones, unknowns } def IsAddOf (sum a b : KnownBits n) := ∀ {bv₁ bv₂}, a.Contains bv₁ ∧ b.Contains bv₂ → sum.Contains (bv₁ + bv₂) theorem add_sound {bv₁ bv₂} {kb₁ kb₂ : KnownBits n} (h₁ : kb₁.Contains bv₁) (h₂ : kb₂.Contains bv₂) : (kb₁.add kb₂).Contains (bv₁ + bv₂) := by simp only [add, Contains, knowns] at *; sorry theorem add_constfolds {n} (kb₁ kb₂ : KnownBits n) (h₁ : kb₁.IsConst) (h₂ : kb₂.IsConst) : (kb₁.add kb₂).IsConst := by simp [IsConst] at h₁ h₂; simp [add, IsConst]; rw [h₁, h₂]; simp; theorem add_precise {sum a b : KnownBits n} (h : sum.IsAddOf a b) : (a.add b).Subset sum := by intro bv hv; simp only [Contains]; simp only [IsAddOf] at h; sorry /- SMask -/ /- bonus abstract domain. this one comes from QEMU. it took me a while to wrap my head around it, it's basically a coarse range that uses only a single bitvector to represent a range [-2^n..2^n-1]. that range is encoded as a sequence of 1 bits followed by a sequence of 0 bits. it's called "smask", "signed mask" because the one bits in the mask denote the bits in the concrete values that all need to be equal to the sign bit (ie all 0 or all 1) -/ structure SMask (n : Nat) where smask : BitVec n wf : ∃ x : Nat, smask = (~~~(0 : BitVec n)) <<< x ∧ x < n := by first | bv_decide | grind namespace SMask theorem ext {n} (a b : SMask n) (h : a.smask = b.smask) : a = b := by cases a; cases b; simp at *; exact h def top (n : Nat) (hn : n > 0) : SMask n where smask := (~~~(0 : BitVec n)) <<< (n - 1); wf := by apply Exists.intro (n - 1); grind def isTop (mask : SMask n) : Prop := ∃ hn : n > 0, mask = top n hn def lower_bound {n} (mask : SMask n) : (BitVec n) := mask.smask def upper_bound {n} (mask : SMask n) : (BitVec n) := ~~~(mask.smask) theorem smask_is_negative {n} (mask : SMask n) (h : n > 0) : BitVec.slt mask.smask 0 := by obtain ⟨x, ⟨ h1a, h1b⟩ ⟩ := mask.wf; rw [h1a]; apply BitVec.slt_zero_iff_msb_cond.mpr simp [BitVec.msb, BitVec.getMsbD_eq_getLsbD, BitVec.getLsbD_shiftLeft] omega /- contains predicate checks the definition, a value is contained in the smask if either all bits that are masked out are zero, or all bits that are masked out are one. this means that the smask represents a range of values where all values have the same bits in the masked out positions, and the masked out positions can be either all 0 or all 1. -/ @[grind, simp] def Contains (mask : SMask n) (val : BitVec n) : Prop := val &&& mask.smask = 0 ∨ val &&& mask.smask = mask.smask theorem lower_bound_is_lower_bound {n} (mask : SMask n) (bv : BitVec n) (h : mask.Contains bv) : BitVec.sle mask.lower_bound bv := by sorry; theorem contains_zero {n} (mask : SMask n) : mask.Contains 0 := by simp [Contains] theorem contains_mask {n} (mask : SMask n) : mask.Contains (mask.smask) := by simp [Contains] theorem TopContainsEverything {n} (hn : n > 0) (val : BitVec n) : (top n hn).Contains val := by simp [top, Contains]; grind; /- and transfer function. and works by doing a bitwise and of the two masks, which is basically returning the smaller of the two implied ranges. -/ @[grind, simp] def and (mask₁ mask₂ : SMask n) : SMask n where smask := mask₁.smask &&& mask₂.smask wf := by have h₁ := mask₁.wf; have h₂ := mask₂.wf; grind; @[grind, simp] def IsAndOf (and a b : SMask n) := ∀ {bv₁ bv₂}, a.Contains bv₁ ∧ b.Contains bv₂ → and.Contains (bv₁ &&& bv₂) theorem and_sound {n} (sm₁ sm₂ : SMask n) : (sm₁.and sm₂).IsAndOf sm₁ sm₂ := by grind theorem allOnes_shl_inj {n} {x₁ x₂ : Nat} (hx₁ : x₁ < n) (hx₂ : x₂ < n) (h : (~~~(0 : BitVec n) <<< x₁) = (~~~(0 : BitVec n) <<< x₂)) : x₁ = x₂ := by have hne (h_ne : x₁ ≠ x₂) : False := by have hlt : x₁ < x₂ ∨ x₂ < x₁ := Nat.lt_or_gt_of_ne h_ne cases hlt with | inl hlt => have h1 : (~~~(0 : BitVec n) <<< x₁).getLsbD x₁ = true := by simp [BitVec.getLsbD_shiftLeft] omega have h2 : (~~~(0 : BitVec n) <<< x₂).getLsbD x₁ = false := by simp [BitVec.getLsbD_shiftLeft] omega rw [h] at h1 rw [h1] at h2 contradiction | inr hlt => have h1 : (~~~(0 : BitVec n) <<< x₂).getLsbD x₂ = true := by simp [BitVec.getLsbD_shiftLeft] omega have h2 : (~~~(0 : BitVec n) <<< x₁).getLsbD x₂ = false := by simp [BitVec.getLsbD_shiftLeft] omega rw [← h] at h1 rw [h1] at h2 contradiction apply Classical.byContradiction intro h_ne exact hne h_ne theorem all_ones_shift_and_all_ones_shift_is_all_ones_shift {n} (x₁ x₂ : Nat) : (~~~(0 : BitVec n)) <<< x₁ &&& (~~~(0 : BitVec n)) <<< x₂ = (~~~(0 : BitVec n)) <<< max x₁ x₂ := by grind /- add transfer function add is defined by doing a bitwise and of the two masks, and then shifting left by one. the intuition is that if we have two ranges of values, the sum of those two ranges is a range that is at least as big as the bitwise and of the two ranges, and overflow can occur by one bit. the well-formedness condition is that the resulting mask is still a valid smask, which is a bit tricky to prove. -/ @[grind, simp] def add (mask₁ mask₂ : SMask n) (h1 : ¬ mask₁.isTop) (h2 : ¬ mask₂.isTop) : SMask n where smask := (mask₁.smask &&& mask₂.smask) <<< 1 wf := by have ⟨x₁, h1a, h1b⟩ := mask₁.wf have ⟨x₂, h2a, h2b⟩ := mask₂.wf rw [h1a, h2a, all_ones_shift_and_all_ones_shift_is_all_ones_shift] apply Exists.intro (max x₁ x₂ + 1) apply And.intro . simp [BitVec.shiftLeft_add] . have hn : n > 0 := by apply Classical.byContradiction intro hn' simp at hn' subst hn' omega have hx₁_lt : x₁ < n - 1 := by apply Classical.byContradiction intro h_ge have : x₁ = n - 1 := by omega have : mask₁ = top n hn := by apply SMask.ext simp [top, h1a, this] exact h1 ⟨hn, this⟩ have hx₂_lt : x₂ < n - 1 := by apply Classical.byContradiction intro h_ge have : x₂ = n - 1 := by omega have : mask₂ = top n hn := by apply SMask.ext simp [top, h2a, this] exact h2 ⟨hn, this⟩ omega def IsAddOf (sum a b : SMask n) := ∀ {bv₁ bv₂}, a.Contains bv₁ ∧ b.Contains bv₂ → sum.Contains (bv₁ + bv₂) theorem add_top_sound {n} (sm₁ sm₂ : SMask n) (hn : n > 0) : (top n hn).IsAddOf sm₁ sm₂ := by simp only [top, IsAddOf, Contains] at * intro bv₁ bv₂ h; grind theorem add_sound {n} (sm₁ sm₂ : SMask n) (h1 : ¬ sm₁.isTop) (h2 : ¬ sm₂.isTop) : (sm₁.add sm₂ h1 h2).IsAddOf sm₁ sm₂ := by sorry