Constraint Solving

TODO

Explain all this. Rationale: We use z3 (and its Python bindings) for constraint solving if it works. However, z3 is not good with nonlinear integer arithmetic and mixed sequence/integer constraints, which is why we implement a bridge to the KeY solver: Translation form z3py expressions to KeY problem file definitions and a interface class to KeY.

import collections.abc
import random
import typing
from typing import Union
import z3
def check_z3(formula: Union[z3.BoolRef, typing.Iterable[z3.BoolRef]], timeout_ms=600, tries=50) -> z3.CheckSatResult:
    if not tries:
        return z3.unknown

    if isinstance(formula, collections.abc.Iterable):
        formula = z3_and(*formula)

    solver = z3.Solver()
    solver.set("timeout", random.randint(150, timeout_ms))
    solver.add(formula)
    result = solver.check()

    if result == z3.unknown:
        return check_z3(formula, timeout_ms, tries - 1)

    return result


def z3_and(*formulas: z3.BoolRef) -> z3.BoolRef:
    if len(formulas) == 0:
        return z3.BoolVal(True)
    elif len(formulas) == 1:
        return formulas[0]
    else:
        return z3.And(*formulas)
x, y = z3.Ints("x y")
check_z3(z3.And(x > y, y > z3.IntVal(0)))
sat
x, y = z3.Ints("x y")
check_z3(z3.And(x > y, y > z3.IntVal(0), x < z3.IntVal(0)))
unsat
from typing import Optional
def z3_sequence(name: str, ctx=None):
    ctx = z3.get_ctx(ctx)
    elem_sort = z3.IntSort(ctx)
    return z3.SeqRef(
        z3.Z3_mk_const(ctx.ref(),
                       z3.to_symbol(name, ctx),
                       z3.SeqSortRef(z3.Z3_mk_seq_sort(elem_sort.ctx_ref(), elem_sort.ast)).ast),
        ctx)
from typing import Union, List
def z3_sequences(names: Union[str, List[str]], ctx=None):
    ctx = z3.get_ctx(ctx)
    if isinstance(names, str):
        names = names.split(" ")
    return [z3_sequence(name, ctx) for name in names]
n, idx = z3.Ints("n idx")
seq, other_seq = z3_sequences("seq other_seq")

seq_expansion_preserves_elements = z3.ForAll(
    [seq, n, other_seq],
    z3.Implies(
        z3.Exists(
            [idx],
            z3.And(
                seq[idx] == n,
                idx >= z3.IntVal(0),
                idx < z3.Length(seq),
            )
        ),
        z3.Exists([idx], (z3.Concat(seq, other_seq))[idx] == n)
    )
)

seq_expansion_preserves_elements
∀seq, n, other_seq : (∃idx : Nth(seq, idx) = n ∧ 0 ≤ idx ∧ idx < Length(seq)) ⇒ (∃idx : Nth(Concat(seq, other_seq), idx) = n)
check_z3(seq_expansion_preserves_elements)
unknown
check_z3(z3.Not(seq_expansion_preserves_elements))
unsat
package de.uka.ilkd.key.core;

import de.uka.ilkd.key.proof.Proof;
import de.uka.ilkd.key.proof.io.ProblemLoader;
import de.uka.ilkd.key.prover.ProverTaskListener;
import de.uka.ilkd.key.prover.TaskFinishedInfo;
import de.uka.ilkd.key.prover.TaskStartedInfo;
import de.uka.ilkd.key.strategy.StrategyProperties;
import de.uka.ilkd.key.ui.ConsoleUserInterfaceControl;
import de.uka.ilkd.key.ui.Verbosity;
import py4j.GatewayServer;

import java.io.File;

public class KeYPythonGateway {
    public String proveProblem(String fileName) {
        try {
            final ConsoleUserInterfaceControl ui = new ConsoleUserInterfaceControl(Verbosity.HIGH, false);
            final SuccessListener successListener = new SuccessListener();
            ui.addProverTaskListener(successListener);
            final File file = new File(fileName);
            ui.loadProblem(file);

            return ui.allProofsSuccessful ? "Success" : successListener.numOpenGoals + " open goals";
        } catch (Exception e) {
            return "Exception: " + e.getMessage();
        }
    }

    public static void main(String[] args) {
        final KeYPythonGateway gateway = new KeYPythonGateway();
        final GatewayServer server = new GatewayServer(gateway);
        server.start();
        System.out.println("KeY-Python gateway is running...");
    }

    private static class SuccessListener implements ProverTaskListener {
        private int numOpenGoals = 0;

        @Override
        public void taskStarted(TaskStartedInfo info) {
        }

        @Override
        public void taskProgress(int position) {
        }

        @Override
        public void taskFinished(TaskFinishedInfo info) {
            final Proof proof = info.getProof();
            if (info.getSource() instanceof ProblemLoader) {
                if (info.getResult() instanceof RuntimeException) {
                    throw (RuntimeException) info.getResult();
                }

                final StrategyProperties sp = proof.getSettings().getStrategySettings().getActiveStrategyProperties();
                sp.setProperty(StrategyProperties.NON_LIN_ARITH_OPTIONS_KEY, StrategyProperties.NON_LIN_ARITH_DEF_OPS);
            } else {
                numOpenGoals = proof.openGoals().size();
            }
        }
    }
}
from py4j.java_gateway import JavaObject, JavaGateway
import tempfile
#% EXPORT
key_prover: Optional[JavaObject] = None
def is_unsat_key(formula: Union[z3.BoolRef, typing.Iterable[z3.BoolRef]]) -> bool:
    if isinstance(formula, collections.abc.Iterable):
        formula = z3_and(*formula)

    global key_prover
    if key_prover is None:
        key_prover = JavaGateway().entry_point

    constants = [
        sub for sub in visit_z3_expr(formula)
        if z3.is_const(sub) and sub.decl().kind() == z3.Z3_OP_UNINTERPRETED
    ]

    key_file_content = ""

    if constants:
        key_file_content += "\\functions {\n"
        key_file_content += "\n".join([
            "    " + ("int" if isinstance(constant, z3.ArithRef) else "Seq") + f" {str(constant).replace('!', '_')};"
            for constant in constants])
        key_file_content += "\n}\n\n"

    key_file_content += "\\problem{\n    "
    key_file_content += z3_to_key(z3.Not(formula))
    key_file_content += "\n}"

    tmp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".key", delete=False)
    tmp_file_name = tmp_file.name

    with open(tmp_file_name, mode="w") as f:
        f.write(key_file_content)

    key_result: str = key_prover.proveProblem(tmp_file_name)

    return key_result == "Success"
from typing import Dict, Generator
def visit_z3_expr(e: Union[z3.ExprRef, z3.QuantifierRef],
                  seen: Optional[Dict[z3.ExprRef, bool]] = None) -> \
        Generator[z3.ExprRef, None, None]:
    if seen is None:
        seen = {}
    elif e in seen:
        return

    seen[e] = True
    yield e

    if z3.is_app(e):
        for ch in e.children():
            for e in visit_z3_expr(ch, seen):
                yield e
        return

    if z3.is_quantifier(e):
        for e in visit_z3_expr(e.body(), seen):
            yield e
        return
from typing import Tuple
def z3_to_key(expr: z3.ExprRef, variables: Tuple[str, ...] = ()) -> str:
    if z3.is_var(expr):
        return variables[len(variables) - z3.get_var_index(expr) - 1]

    if z3.is_const(expr):
        if expr.decl().kind() == z3.Z3_OP_UNINTERPRETED or isinstance(expr, z3.ArithRef):
            return str(expr).replace("!", "_")
        elif isinstance(expr, z3.BoolRef):
            return "true" if z3.is_true(expr) else "false"

        raise NotImplementedError(f"Translation for constant {expr} not implemented.")

    if z3.is_quantifier(expr):
        expr: z3.QuantifierRef

        result = ""

        for i in range(expr.num_vars()):
            result += "(" + ("\\forall" if expr.is_forall() else "\\exists") + " "
            result += "int" if isinstance(expr.var_sort(i), z3.ArithSortRef) else "Seq"
            result += " "
            result += expr.var_name(i)
            result += "; "

            variables += (expr.var_name(i),)

        assert len(expr.children()) == 1
        result += z3_to_key(expr.children()[0], variables)

        for _ in range(expr.num_vars()):
            result += ")"

        return result

    expr.decl().kind()

    assert z3.is_app(expr)

    if expr.decl().kind() == z3.Z3_OP_AND:
        return "(" + " & ".join(map(lambda c: z3_to_key(c, variables), expr.children())) + ")"
    elif expr.decl().kind() == z3.Z3_OP_OR:
        return "(" + " | ".join(map(lambda c: z3_to_key(c, variables), expr.children())) + ")"
    elif expr.decl().kind() == z3.Z3_OP_IMPLIES:
        return f"({z3_to_key(expr.children()[0], variables)} -> {z3_to_key(expr.children()[1], variables)})"
    elif expr.decl().kind() == z3.Z3_OP_NOT:
        return "!(" + z3_to_key(expr.children()[0], variables) + ")"
    elif expr.decl().kind() == z3.Z3_OP_DIV or expr.decl().kind() == z3.Z3_OP_IDIV:
        return f"jdiv({z3_to_key(expr.children()[0], variables)}, {z3_to_key(expr.children()[1], variables)})"
    elif expr.decl().kind() == z3.Z3_OP_MOD:
        return f"jmod({z3_to_key(expr.children()[0], variables)}, {z3_to_key(expr.children()[1], variables)})"
    elif expr.decl().kind() in [
        z3.Z3_OP_LE, z3.Z3_OP_LT, z3.Z3_OP_GE, z3.Z3_OP_GT,
        z3.Z3_OP_ADD, z3.Z3_OP_UMINUS, z3.Z3_OP_MUL]:
        return (z3_to_key(expr.children()[0], variables) +
                f" {str(expr.decl())} " +
                z3_to_key(expr.children()[1], variables))
    elif expr.decl().kind() == z3.Z3_OP_EQ:
        operator = "<->" if isinstance(expr.children()[0], z3.BoolRef) else "="
        return f"({z3_to_key(expr.children()[0], variables)} {operator} {z3_to_key(expr.children()[1], variables)})"
    elif expr.decl().kind() == z3.Z3_OP_SEQ_LENGTH:
        return f"seqLen({z3_to_key(expr.children()[0], variables)})"
    elif expr.decl().kind() == z3.Z3_OP_ITE:
        return f"\\if ({z3_to_key(expr.children()[0], variables)}) " \
               f"\\then ({z3_to_key(expr.children()[1], variables)}) " \
               f"\\else ({z3_to_key(expr.children()[2], variables)})"
    elif str(expr.decl()) == "seq.nth_i" or str(expr.decl()) == "seq.nth_u" or expr.decl().kind() == z3.Z3_OP_SEQ_NTH:
        return f"int::seqGet({z3_to_key(expr.children()[0], variables)}, {z3_to_key(expr.children()[1], variables)})"
    elif expr.decl().kind() == z3.Z3_OP_SEQ_CONCAT:
        assert len(expr.children()) == 2
        return f"seqConcat({z3_to_key(expr.children()[0], variables)}, {z3_to_key(expr.children()[1], variables)})"
    elif expr.decl().kind() == z3.Z3_OP_SEQ_EXTRACT:
        return f"seqSub({z3_to_key(expr.children()[0], variables)}, " \
               f"{z3_to_key(expr.children()[1], variables)}, " \
               f"({z3_to_key(expr.children()[1], variables)} + {z3_to_key(expr.children()[2], variables)}))"
    elif expr.decl().kind() == z3.Z3_OP_SEQ_UNIT:
        return f"seqSingleton({z3_to_key(expr.children()[0], variables)})"
    elif expr.decl().kind() == z3.Z3_OP_TO_REAL:
        # (See Z3_OP_TO_INT)
        return z3_to_key(expr.children()[0], variables)
    elif expr.decl().kind() == z3.Z3_OP_TO_INT:
        # Hack: The SE engine transforms "x // y" into "ToInt(ToReal(x) / ToReal(y))" to model
        #       Python's floor division semantics. For KeY, we have to handle this differently:
        #       It gets "\if (x / y >= 0 | x % y == 0) \then (x / y) \else (x / y - 1)".
        assert expr.children()[0].decl().kind() == z3.Z3_OP_DIV

        x = z3_to_key(expr.children()[0].children()[0].children()[0], variables)
        y = z3_to_key(expr.children()[0].children()[1].children()[0], variables)

        x_div_y = f"jdiv({x}, {y})"
        x_mod_y = f"jmod({x}, {y})"

        return f"(\\if ({x} >= 0 & {y} >= 0 | {x} < 0 & {y} < 0 | " \
               f"{x_mod_y} = 0) \\then ({x_div_y}) \\else ({x_div_y} - 1))"

    raise NotImplementedError(f"Translation for application {expr} not implemented.")
z3_to_key(seq_expansion_preserves_elements)
'(\\forall Seq seq; (\\forall int n; (\\forall Seq other_seq; ((\\exists int idx; ((int::seqGet(seq, idx) = n) & 0 <= idx & idx < seqLen(seq))) -> (\\exists int idx; (int::seqGet(seqConcat(seq, other_seq), idx) = n))))))'
is_unsat_key(z3.Not(seq_expansion_preserves_elements))
True
is_unsat_key(seq_expansion_preserves_elements)
False
def is_unsat(formula: z3.BoolRef, timeout_ms=500) -> bool:
    if z3.is_true(formula):
        return False
    if z3.is_false(formula):
        return True

    z3_result = check_z3(formula, timeout_ms=timeout_ms)

    if z3_result != z3.unknown:
        return z3_result == z3.unsat

    return is_unsat_key(formula)
is_unsat(z3.Not(seq_expansion_preserves_elements))
True
is_unsat(z3.And(x > y, y > z3.IntVal(0), x < z3.IntVal(0)))
True