Library compcert.common.Switch
Multi-way branches (``switch'' statements) and their compilation
to comparison trees.
Require Import EqNat.
Require Import Coqlib.
Require Import Maps.
Require Import Integers.
Require Import Values.
A multi-way branch is composed of a list of (key, action) pairs,
plus a default action.
Definition table : Type := list (Z × nat).
Fixpoint switch_target (n: Z) (dfl: nat) (cases: table)
{struct cases} : nat :=
match cases with
| nil ⇒ dfl
| (key, action) :: rem ⇒
if zeq n key then action else switch_target n dfl rem
Inductive switch_argument: bool → val → Z → Prop :=
| switch_argument_32: ∀ i,
switch_argument false (Vint i) (Int.unsigned i)
| switch_argument_64: ∀ i,
switch_argument true (Vlong i) (Int64.unsigned i).
Multi-way branches are translated to comparison trees.
Each node of the tree performs either
- an equality against one of the keys;
- or a "less than" test against one of the keys;
- or a computed branch (jump table) against a range of key values.
Inductive comptree : Type :=
| CTaction (act: nat)
| CTifeq (key: Z) (act: nat) (cne: comptree)
| CTiflt (key: Z) (clt: comptree) (cge: comptree)
| CTjumptable (ofs: Z) (sz: Z) (acts: list nat) (cother: comptree).
Variable modulus: Z.
Hypothesis modulus_pos: modulus > 0.
Fixpoint comptree_match (n: Z) (t: comptree) {struct t}: option nat :=
match t with
| CTaction act ⇒ Some act
| CTifeq key act t' ⇒
if zeq n key then Some act else comptree_match n t'
| CTiflt key t1 t2 ⇒
if zlt n key then comptree_match n t1 else comptree_match n t2
| CTjumptable ofs sz tbl t' ⇒
let delta := (n - ofs) mod modulus in
if zlt delta sz
then list_nth_z tbl (delta mod Int.modulus)
else comptree_match n t'
The translation from a table to a comparison tree is performed
by untrusted Caml code (function compile_switch in
file In Coq, we validate a posteriori the
result of this function. In other terms, we now develop
and prove correct Coq functions that take a table and a comparison
tree, and check that their semantics are equivalent.
Fixpoint split_lt (pivot: Z) (cases: table)
{struct cases} : table × table :=
match cases with
| nil ⇒ (nil, nil)
| (key, act) :: rem ⇒
let (l, r) := split_lt pivot rem in
if zlt key pivot
then ((key, act) :: l, r)
else (l, (key, act) :: r)
Fixpoint split_eq (pivot: Z) (cases: table)
{struct cases} : option nat × table :=
match cases with
| nil ⇒ (None, nil)
| (key, act) :: rem ⇒
let (same, others) := split_eq pivot rem in
if zeq key pivot
then (Some act, others)
else (same, (key, act) :: others)
Fixpoint split_between (dfl: nat) (ofs sz: Z) (cases: table)
{struct cases} : ZMap.t nat × table :=
match cases with
| nil ⇒ (ZMap.init dfl, nil)
| (key, act) :: rem ⇒
let (inside, outside) := split_between dfl ofs sz rem in
if zlt ((key - ofs) mod modulus) sz
then (ZMap.set key act inside, outside)
else (inside, (key, act) :: outside)
Definition refine_low_bound (v lo: Z) :=
if zeq v lo then lo + 1 else lo.
Definition refine_high_bound (v hi: Z) :=
if zeq v hi then hi - 1 else hi.
Fixpoint validate_jumptable (cases: ZMap.t nat)
(tbl: list nat) (n: Z) {struct tbl} : bool :=
match tbl with
| nil ⇒ true
| act :: rem ⇒
beq_nat act (ZMap.get n cases)
&& validate_jumptable cases rem (Zsucc n)
Fixpoint validate (default: nat) (cases: table) (t: comptree)
(lo hi: Z) {struct t} : bool :=
match t with
| CTaction act ⇒
match cases with
| nil ⇒
beq_nat act default
| (key1, act1) :: _ ⇒
zeq key1 lo && zeq lo hi && beq_nat act act1
| CTifeq pivot act t' ⇒
zle 0 pivot && zlt pivot modulus &&
match split_eq pivot cases with
| (None, _) ⇒
| (Some act', others) ⇒
beq_nat act act'
&& validate default others t'
(refine_low_bound pivot lo)
(refine_high_bound pivot hi)
| CTiflt pivot t1 t2 ⇒
zle 0 pivot && zlt pivot modulus &&
match split_lt pivot cases with
| (lcases, rcases) ⇒
validate default lcases t1 lo (pivot - 1)
&& validate default rcases t2 pivot hi
| CTjumptable ofs sz tbl t' ⇒
let tbl_len := list_length_z tbl in
zle 0 ofs && zlt ofs modulus &&
zle 0 sz && zlt sz modulus &&
zle (ofs + sz) modulus && zle sz tbl_len && zlt sz Int.modulus &&
match split_between default ofs sz cases with
| (inside, outside) ⇒
validate_jumptable inside tbl ofs
&& validate default outside t' lo hi
Definition validate_switch (default: nat) (cases: table) (t: comptree) :=
validate default cases t 0 (modulus - 1).
Structural properties checked by validation
Inductive wf_comptree: comptree → Prop :=
| wf_action: ∀ act,
wf_comptree (CTaction act)
| wf_ifeq: ∀ key act cne,
0 ≤ key < modulus → wf_comptree cne → wf_comptree (CTifeq key act cne)
| wf_iflt: ∀ key clt cge,
0 ≤ key < modulus → wf_comptree clt → wf_comptree cge → wf_comptree (CTiflt key clt cge)
| wf_jumptable: ∀ ofs sz acts cother,
0 ≤ ofs < modulus → 0 ≤ sz < modulus →
wf_comptree cother →
wf_comptree (CTjumptable ofs sz acts cother).
Lemma validate_wf:
∀ default t cases lo hi,
validate default cases t lo hi = true →
wf_comptree t.
induction t; simpl; intros; InvBooleans.
- constructor.
- destruct (split_eq key cases) as [[act'|] others]; try discriminate; InvBooleans.
constructor; eauto.
- destruct (split_lt key cases) as [lc rc]; InvBooleans.
constructor; eauto.
- destruct (split_between default ofs sz cases) as [ins out]; InvBooleans.
constructor; eauto.
Semantic correctness proof for validation.
Lemma split_eq_prop:
∀ v default n cases optact cases',
split_eq n cases = (optact, cases') →
switch_target v default cases =
(if zeq v n
then match optact with Some act ⇒ act | None ⇒ default end
else switch_target v default cases').
induction cases; simpl; intros until cases'.
- intros. inv H. simpl. destruct (zeq v n); auto.
- destruct a as [key act].
destruct (split_eq n cases) as [same other] eqn:SEQ.
rewrite (IHcases same other) by auto.
destruct (zeq key n); intros EQ; inv EQ.
+ destruct (zeq v n); auto.
+ simpl. destruct (zeq v key).
× subst v. rewrite zeq_false by auto. auto.
× auto.
Lemma split_lt_prop:
∀ v default n cases lcases rcases,
split_lt n cases = (lcases, rcases) →
switch_target v default cases =
(if zlt v n
then switch_target v default lcases
else switch_target v default rcases).
induction cases; intros until rcases; simpl.
- intros. inv H. simpl. destruct (zlt v n); auto.
- destruct a as [key act].
destruct (split_lt n cases) as [lc rc] eqn:SEQ.
rewrite (IHcases lc rc) by auto.
destruct (zlt key n); intros EQ; inv EQ; simpl.
+ destruct (zeq v key). rewrite zlt_true by omega. auto. auto.
+ destruct (zeq v key). rewrite zlt_false by omega. auto. auto.
Lemma split_between_prop:
∀ v default ofs sz cases inside outside,
split_between default ofs sz cases = (inside, outside) →
switch_target v default cases =
(if zlt ((v - ofs) mod modulus) sz
then ZMap.get v inside
else switch_target v default outside).
induction cases; intros until outside; simpl; intros SEQ.
- inv SEQ. rewrite simpl. destruct (zlt ((v - ofs) mod modulus) sz); auto.
- destruct a as [key act].
destruct (split_between default ofs sz cases) as [ins outs].
erewrite IHcases; eauto.
destruct (zlt ((key - ofs) mod modulus) sz); inv SEQ.
+ rewrite ZMap.gsspec. unfold ZIndexed.eq.
destruct (zeq v key).
subst v. rewrite zlt_true by auto. auto.
+ simpl. destruct (zeq v key).
subst v. rewrite zlt_false by auto. auto.
Lemma validate_jumptable_correct_rec:
∀ cases tbl base v,
validate_jumptable cases tbl base = true →
0 ≤ v < list_length_z tbl →
list_nth_z tbl v = Some(ZMap.get (base + v) cases).
induction tbl; simpl; intros.
- unfold list_length_z in H0. simpl in H0. omegaContradiction.
- InvBooleans. rewrite list_length_z_cons in H0. apply beq_nat_true in H1.
destruct (zeq v 0).
+ replace (base + v) with base by omega. congruence.
+ replace (base + v) with (Z.succ base + Z.pred v) by omega.
apply IHtbl. auto. omega.
Lemma validate_jumptable_correct:
∀ cases tbl ofs v sz,
validate_jumptable cases tbl ofs = true →
(v - ofs) mod modulus < sz →
0 ≤ sz → 0 ≤ ofs → ofs + sz ≤ modulus →
0 ≤ v < modulus →
sz ≤ list_length_z tbl →
list_nth_z tbl ((v - ofs) mod modulus) = Some(ZMap.get v cases).
rewrite (validate_jumptable_correct_rec cases tbl ofs); auto.
- f_equal. f_equal. rewrite Zmod_small. omega.
destruct (zle ofs v). omega.
assert (M: ((v - ofs) + 1 × modulus) mod modulus = (v - ofs) + modulus).
{ rewrite Zmod_small. omega. omega. }
rewrite Z_mod_plus in M by auto. rewrite M in H0. omega.
- generalize (Z_mod_lt (v - ofs) modulus modulus_pos). omega.
Lemma validate_correct_rec:
∀ default v,
0 ≤ v < modulus →
∀ t cases lo hi,
validate default cases t lo hi = true →
lo ≤ v ≤ hi →
comptree_match v t = Some (switch_target v default cases).
intros default v VRANGE. induction t; simpl; intros until hi.
destruct cases as [ | [key1 act1] cases1]; intros.
+ apply beq_nat_true in H. subst act. reflexivity.
+ InvBooleans. apply beq_nat_true in H2. subst. simpl.
destruct (zeq v hi). auto. omegaContradiction.
destruct (split_eq key cases) as [optact others] eqn:EQ. intros.
destruct optact as [act1|]; InvBooleans; try discriminate.
apply beq_nat_true in H.
rewrite (split_eq_prop v default _ _ _ _ EQ).
destruct (zeq v key).
+ congruence.
+ eapply IHt; eauto.
unfold refine_low_bound, refine_high_bound. split.
destruct (zeq key lo); omega.
destruct (zeq key hi); omega.
destruct (split_lt key cases) as [lcases rcases] eqn:EQ; intros; InvBooleans.
rewrite (split_lt_prop v default _ _ _ _ EQ). destruct (zlt v key).
eapply IHt1. eauto. omega.
eapply IHt2. eauto. omega.
destruct (split_between default ofs sz cases) as [ins outs] eqn:EQ; intros; InvBooleans.
rewrite (split_between_prop v _ _ _ _ _ _ EQ).
assert (0 ≤ (v - ofs) mod modulus < modulus) by (apply Z_mod_lt; omega).
destruct (zlt ((v - ofs) mod modulus) sz).
rewrite Zmod_small by omega. eapply validate_jumptable_correct; eauto.
eapply IHt; eauto.
Definition table_tree_agree
(default: nat) (cases: table) (t: comptree) : Prop :=
∀ v, 0 ≤ v < modulus → comptree_match v t = Some(switch_target v default cases).
Theorem validate_switch_correct:
∀ default t cases,
validate_switch default cases t = true →
wf_comptree t ∧ table_tree_agree default cases t.
unfold validate_switch, table_tree_agree; split.
eapply validate_wf; eauto.
intros; eapply validate_correct_rec; eauto. omega.