DDeepin Developerfeat: Init commit
9df0a864创建于 2022年12月20日历史提交
/*++
Copyright (c) 2011 Microsoft Corporation

Module Name:

    symmetry_reduce.cpp

Abstract:

    Add symmetry breaking predicates to goals.

Author:

    Nikolaj (nbjorner) 2011-05-31

Notes:

    This is a straight-forward and literal
    adaption of the algorithms proposed for veriT.

--*/
#include "tactic/tactical.h"
#include "ast/for_each_expr.h"
#include "util/map.h"
#include "ast/rewriter/expr_replacer.h"
#include "ast/rewriter/rewriter_def.h"
#include "ast/ast_pp.h"

class symmetry_reduce_tactic : public tactic {
    class imp;
    imp *  m_imp;
public:
    symmetry_reduce_tactic(ast_manager & m);

    tactic * translate(ast_manager & m) override {
        return alloc(symmetry_reduce_tactic, m);
    }
    
    ~symmetry_reduce_tactic() override;
    
    void operator()(goal_ref const & g, 
                    goal_ref_buffer & result) override;
    void cleanup() override;
};

class ac_rewriter {
    ast_manager& m_manager;
public:
    ac_rewriter(ast_manager& m): m_manager(m) {}

    br_status mk_app_core(func_decl * f, unsigned num_args, expr * const * args, expr_ref & result) {
        if ((f->is_associative() && f->is_commutative()) ||
            m_manager.is_distinct(f)) {
            ptr_buffer<expr> buffer;
            buffer.append(num_args, args);
            std::sort(buffer.begin(), buffer.end(), ast_lt_proc());
            bool change = false;
            for (unsigned i = 0; !change && i < num_args; ++i) {
                change = (args[i] != buffer[i]);
            }
            if (change) {
                result = m().mk_app(f, num_args, buffer.begin());
                return BR_DONE;
            }
        }
        else if (f->is_commutative() && num_args == 2 && args[0]->get_id() > args[1]->get_id()) {
            expr* args2[2] = { args[1], args[0] };
            result = m().mk_app(f, num_args, args2);
            return BR_DONE;
        }
        return BR_FAILED;
    }

    void mk_app(func_decl * f, unsigned num_args, expr * const * args, expr_ref & result) {
        if (mk_app_core(f, num_args, args, result) == BR_FAILED)
            result = m().mk_app(f, num_args, args);
    }
private:
    ast_manager& m() const { return m_manager; }
};


struct ac_rewriter_cfg : public default_rewriter_cfg {
    ac_rewriter m_r;
    bool rewrite_patterns() const { return false; }
    bool flat_assoc(func_decl * f) const { return false; }
    br_status reduce_app(func_decl * f, unsigned num, expr * const * args, expr_ref & result, proof_ref & result_pr) {
        result_pr = nullptr;
        return m_r.mk_app_core(f, num, args, result);
    }
    ac_rewriter_cfg(ast_manager & m):m_r(m) {}
};

class ac_rewriter_star : public rewriter_tpl<ac_rewriter_cfg> {
    ac_rewriter_cfg m_cfg;
public:
    ac_rewriter_star(ast_manager & m):
        rewriter_tpl<ac_rewriter_cfg>(m, false, m_cfg),
        m_cfg(m) {}
};

template class rewriter_tpl<ac_rewriter_cfg>;

class symmetry_reduce_tactic::imp {
    typedef ptr_vector<app>     permutation;
    typedef vector<permutation> permutations;
    typedef ptr_vector<app>     term_set;
    typedef obj_map<app, unsigned> app_map;
    typedef u_map<ptr_vector<app> > inv_app_map;
    ast_manager&                m_manager;
    ac_rewriter_star            m_rewriter;
    scoped_ptr<expr_replacer>   m_replace;
    
    ast_manager& m() const { return m_manager; }
public:
    imp(ast_manager& m) : m_manager(m), m_rewriter(m) {
        m_replace = mk_default_expr_replacer(m, false);
    }

    ~imp() {}

    void operator()(goal & g) {
        if (g.inconsistent())
            return;
        tactic_report report("symmetry-reduce", g);
        vector<ptr_vector<app> > P;    
        expr_ref fml(m());
        to_formula(g, fml);
        app_map occs;
        compute_occurrences(fml, occs);
        find_candidate_permutations(fml, occs, P);
        if (P.empty()) {
            return;
        }
        term_set T, cts;        
        unsigned num_sym_break_preds = 0;
        for (unsigned i = 0; i < P.size(); ++i) {
            term_set& consts = P[i];
            if (invariant_by_permutations(fml, consts)) {
                cts.reset();
                select_terms(fml, consts, T);
                while (!T.empty() && cts.size() < consts.size()) {
                    app* t = select_most_promising_term(fml, T, cts, consts, occs);
                    T.erase(t);                    
                    compute_used_in(t, cts, consts);
                    app* c = select_const(consts, cts);
                    if (!c) break;
                    cts.push_back(c);
                    expr* mem = mk_member(t, cts);
                    g.assert_expr(mem); 
                    num_sym_break_preds++;
                    TRACE("symmetry_reduce", tout << "member predicate: " << mk_pp(mem, m()) << "\n";);
                    fml = m().mk_and(fml.get(), mem);
                    normalize(fml);
                }
            }
        }
        report_tactic_progress(":num-symmetry-breaking ", num_sym_break_preds);
    }

private:
    void to_formula(goal const & g, expr_ref& fml) {
        ptr_vector<expr> conjs;
        for (unsigned i = 0; i < g.size(); ++i) {
            conjs.push_back(g.form(i));
        }
        fml = m().mk_and(conjs.size(), conjs.data());
        normalize(fml);
    }

    // find candidate permutations
    void find_candidate_permutations(expr* fml, app_map const& occs, permutations& P) {
        app_map coloring;
        app_map depth;
        inv_app_map inv_color;
        unsigned num_occs = 0;
        compute_sort_colors(fml, coloring);
        compute_max_depth(fml, depth);
        merge_colors(occs, coloring);        
        merge_colors(depth, coloring);     
        // compute_siblings(fml, coloring);
        compute_inv_app(coloring, inv_color);
        
        inv_app_map::iterator it = inv_color.begin(), end = inv_color.end();
        for (; it != end; ++it) {
            if (it->m_value.size() < 2) {
                continue;
            }
            VERIFY(occs.find(it->m_value[0], num_occs));
            if (num_occs < 2) {
                continue;
            }
            bool is_const = true;
            for (unsigned j = 0; is_const && j < it->m_value.size(); ++j) {
                is_const = it->m_value[j]->get_num_args() == 0;
            }
            if (!is_const) {
                continue;
            }
            P.push_back(it->m_value);
            TRACE("symmetry_reduce",
                for (unsigned i = 0; i < it->m_value.size(); ++i) {
                    tout << mk_pp(it->m_value[i], m()) << " ";
                }
            tout << "\n";);
        }
    }

    //
    // refine coloring by taking most specific generalization.
    // a |-> c1, b |-> c2 <c1,c2> |-> c
    // 
    struct u_pair {
        unsigned m_first;
        unsigned m_second;
        u_pair(unsigned f, unsigned s) : m_first(f), m_second(s) {}
        u_pair(): m_first(0), m_second(0) {}

        struct hash {
            unsigned operator()(u_pair const& p) const {
                return mk_mix(p.m_first, p.m_second, 23);
            }
        };
        struct eq {
            bool operator()(u_pair const& p, u_pair const& q) const {
                return p.m_first == q.m_first && p.m_second == q.m_second;
            }
        };
    };
    typedef map<u_pair, unsigned, u_pair::hash, u_pair::eq> pair_map;
    bool merge_colors(app_map const& colors1, app_map& colors2) {
        pair_map recolor;
        unsigned num_colors = 0, v1 = 0, v2 = 0, w = 0, old_max = 0;
        for (auto & kv : colors2) {
            app* a = kv.m_key;
            v1 = kv.m_value;
            VERIFY(colors1.find(a, v2));
            if (recolor.find(u_pair(v1, v2), w)) {
                kv.m_value = w;
            }
            else {
                kv.m_value = num_colors;
                recolor.insert(u_pair(v1, v2), num_colors++);
            }
            if (v1 > old_max) old_max = v1;
        }
        return num_colors > old_max + 1;
    }

    class sort_colors {
        app_map& m_app2sortid;
        obj_map<sort,unsigned>  m_sort2id;
        unsigned m_max_id;

    public:
        sort_colors(app_map& app2sort) : m_app2sortid(app2sort), m_max_id(0) {}

        void operator()(app* n) {
            sort* s = n->get_sort();
            unsigned id;
            if (!m_sort2id.find(s, id)) {
                id = m_max_id++;
                m_sort2id.insert(s, id);
            }
            m_app2sortid.insert(n, id);
        }
        void operator()(quantifier * n) {
        }
        void operator()(var * n) {}
    };

    void compute_sort_colors(expr* fml, app_map& app2sortId) {
        app2sortId.reset();
        sort_colors sc(app2sortId);
        for_each_expr(sc, fml);
    }

    void compute_inv_app(app_map const& map, inv_app_map& inv_map) {
        for (auto & kv : map) {
            app* t = kv.m_key;
            unsigned n = kv.m_value;
            if (is_uninterpreted(t)) {
                inv_map.insert_if_not_there(n, ptr_vector<app>()).push_back(t);
            }
        }
    }
    bool is_uninterpreted(app* t) const {
        return t->get_family_id() == null_family_id;
    }

    // compute maximal depth of terms.
    void compute_max_depth(expr* e, app_map& depth) {
        ptr_vector<expr> todo;
        unsigned_vector  depths;
        unsigned d, d1;
        todo.push_back(e);
        depths.push_back(0);
        while (!todo.empty()) {
            e = todo.back();
            d = depths.back();
            todo.pop_back();
            depths.pop_back();
            if (is_var(e)) {
                // nothing
            }
            else if (is_quantifier(e)) {
                todo.push_back(to_quantifier(e)->get_expr());
                depths.push_back(d+1);
            }
            else if (is_app(e)) {
                app* a = to_app(e);
                if (depth.find(a, d1) && d <= d1) {
                    continue;
                }
                depth.insert(a, d);                
                ++d;
                for (unsigned i = 0; i < a->get_num_args(); ++i) {
                    todo.push_back(a->get_arg(i));
                    depths.push_back(d);
                }
            }
            else {
                UNREACHABLE();
            }

        }
    }

    // color nodes according to the function symbols they appear in
    typedef obj_hashtable<func_decl> fun_set;
    typedef obj_map<app, fun_set*> app_parents;

    class parents {
        app_parents m_use_funs;        
    public:
        parents() {}

        app_parents const& get_parents() { return m_use_funs; }

        void operator()(app* n) {
            func_decl* f = n->get_decl();
            for (expr* e : *n) {
                if (is_app(e)) {
                    auto& value = m_use_funs.insert_if_not_there(to_app(e), 0);
                    if (!value) value = alloc(fun_set);
                    value->insert(f); 
                }
            }
        }
        void operator()(quantifier *n) {}
        void operator()(var* n) {}
    };
    void compute_parents(expr* e, app_map& parents) {
    }

    typedef hashtable<unsigned, u_hash, u_eq> uint_set;
    typedef obj_map<app, uint_set*> app_siblings;;

    class siblings {
        app_map const& m_colors;
        app_siblings m_sibs;
    public:
        siblings(app_map const& colors): m_colors(colors) {}

        app_siblings const& get() { return m_sibs; }
        void operator()(app* n) {
            unsigned sz = n->get_num_args();
            for (unsigned i = 0; i < sz; ++i) {
                expr* e = n->get_arg(i);
                if (!is_app(e)) continue;
                auto& value = m_sibs.insert_if_not_there(to_app(e), 0);
                if (!value) value = alloc(uint_set);
                for (unsigned j = 0; j < sz; ++j) {
                    expr* f = n->get_arg(j);
                    if (is_app(f) && i != j) {
                        unsigned c1 = 0;
                        m_colors.find(to_app(f), c1);
                        value->insert(c1);
                    }
                }
            }
        }
        void operator()(quantifier *n) {}
        void operator()(var* n) {}        
    };
    // refine coloring by taking colors of siblings into account.
    bool compute_siblings_rec(expr* e, app_map& colors) {
        siblings sibs(colors);
        app_map colors1;
        for_each_expr(sibs, e);
        app_siblings const& s = sibs.get();
        app_siblings::iterator it = s.begin(), end = s.end();
        for (; it != end; ++it) {
            app* a = it->m_key;
            uint_set* set = it->m_value;
            uint_set::iterator it2 = set->begin(), end2 = set->end();
            unsigned c = 0;
            for(; it2 != end2; ++it2) {
                c += 1 + *it2;
            }
            colors1.insert(a, c);
            dealloc(set);
        }
        if (is_app(e)) {
            colors1.insert(to_app(e), 0);
        }
        return merge_colors(colors1, colors);
    }
    void compute_siblings(expr* fml, app_map& colors) {
        while(compute_siblings_rec(fml, colors));
    }

    // check if assertion set is invariant under the current permutation
    bool invariant_by_permutations(expr* fml, permutation& p) {

        SASSERT(p.size() >= 2);
        bool result = check_swap(fml, p[0], p[1]) && check_cycle(fml, p);
        TRACE("symmetry_reduce", 
              if (result) {
                  tout << "Symmetric: ";
              }
              else {
                  tout << "Not symmetric: ";
              }
              for (unsigned i = 0; i < p.size(); ++i) {
                  tout << mk_pp(p[i], m()) << " ";
              }
              tout << "\n";);
        return result;
    }

    bool check_swap(expr* fml, app* t1, app* t2) {
        expr_substitution sub(m());
        sub.insert(t1, t2);
        sub.insert(t2, t1);
        m_replace->set_substitution(&sub);
        return check_substitution(fml);
    }

    bool check_cycle(expr* fml, permutation& p) {
        expr_substitution sub(m());
        for (unsigned i = 0; i + 1 < p.size(); ++i) {
            sub.insert(p[i], p[i+1]);
        }
        sub.insert(p[p.size()-1], p[0]);
        m_replace->set_substitution(&sub);
        return check_substitution(fml);
    }

    bool check_substitution(expr* t) {
        expr_ref r(m());
        (*m_replace)(t, r);
        normalize(r);
        return t == r.get();
    }

    void normalize(expr_ref& r) {
        proof_ref pr(m());
        expr_ref  result(m());
        m_rewriter(r.get(), result, pr);
        r = result;
    }

    // select terms that are range restricted by set p.
    void select_terms(expr* fml, term_set const& p, term_set& T) {
        T.reset();
        ptr_vector<expr> todo;
        todo.push_back(fml);
        app* t = nullptr;
        while (!todo.empty()) {
            fml = todo.back();
            todo.pop_back();
            if (m().is_and(fml)) {
                todo.append(to_app(fml)->get_num_args(), to_app(fml)->get_args());               
            }
            else if (is_range_restriction(fml, p, t)) {
                T.push_back(t);
            }
        }
    }
    bool is_range_restriction(expr* form, term_set const& C, app*& t) {
        if (!m().is_or(form)) return false;
        unsigned sz = to_app(form)->get_num_args();
        t = nullptr;
        for (unsigned i = 0; i < sz; ++i) {
            expr* e = to_app(form)->get_arg(i);
            expr* e1, *e2;
            if (!m().is_eq(e, e1, e2)) return false;
            if (!is_app(e1) || !is_app(e2)) return false;
            app* a1 = to_app(e1), *a2 = to_app(e2);
            if (C.contains(a1) && (t == nullptr || t == a2)) {
                t = a2;
            }
            else if (C.contains(a2) && (t == nullptr || t == a1)) {
                t = a1;
            }
            else {
                return false;
            }
        }
        return t != nullptr;
    }


    // select the most promising term among T.
    // terms with the largest number of occurrences have higher weight.
    // terms that have fewest terms among C as subterms are preferred.
    
    class num_occurrences {
        app_map& m_occs;
    public:
        num_occurrences(app_map& occs): m_occs(occs) {}
        void operator()(app* n) {
            m_occs.insert_if_not_there(n, 0);
            unsigned sz = n->get_num_args();
            for (unsigned i = 0; i < sz; ++i) {
                expr* arg = n->get_arg(i);
                if (is_app(arg)) {
                    m_occs.insert_if_not_there(to_app(arg), 0)++;
                }
            }
        }
        void operator()(quantifier * n) {}
        void operator()(var * n) {}
    };
    void compute_occurrences(expr* fml, app_map& occs) {
        occs.reset();
        num_occurrences num_occ(occs);
        for_each_expr(num_occ, fml);
    }

    app* select_most_promising_term(
        expr* fml, term_set const& T, 
        term_set& cts, term_set const& consts, app_map const& occs) {
        SASSERT(!T.empty());
        app* t = T[0];
        unsigned weight = 0, weight1 = 0;
        VERIFY(occs.find(t, weight));
        unsigned cts_delta = compute_cts_delta(t, cts, consts);
        TRACE("symmetry_reduce", tout << mk_pp(t, m()) << " " << weight << " " << cts_delta << "\n";);
        for (unsigned i = 1; i < T.size(); ++i) {
            app* t1 = T[i];
            VERIFY(occs.find(t1, weight1));
            if (weight1 < weight && t->get_num_args() <= t1->get_num_args()) {
                continue;
            }
            unsigned cts_delta1 = compute_cts_delta(t1, cts, consts);
            TRACE("symmetry_reduce", tout << mk_pp(t1, m()) << " " << weight1 << " " << cts_delta1 << "\n";);
            if ((t->get_num_args() == t1->get_num_args() && (weight1 > weight || cts_delta1 < cts_delta)) || 
                t->get_num_args() > t1->get_num_args()) {
                cts_delta = cts_delta1;
                weight = weight1;
                t = t1;
            }
        }        
        return t;
    }

    // add to cts subterms of t that are members of consts.
    class member_of {
        term_set const& m_S;
        term_set&       m_r;
    public:
        member_of(term_set const& S, term_set& r) : m_S(S), m_r(r) {}
        void operator()(app* n) {
            if (m_S.contains(n) && !m_r.contains(n)) {
                m_r.push_back(n);
            }
        }
        void operator()(quantifier * n) {}
        void operator()(var * n) {}
    };
    void compute_used_in(app* t, term_set& cts, term_set const& consts) {
        member_of mem(consts, cts);
        for_each_expr(mem, t);
        TRACE("symmetry_reduce",
              tout << "Term: " << mk_pp(t, m()) << "\n";
              tout << "Support set: ";
              for (unsigned i = 0; i < consts.size(); ++i) {
                  tout << mk_pp(consts[i], m()) << " ";
              }
              tout << "\n";
              tout << "Constants: ";
              for (unsigned i = 0; i < cts.size(); ++i) {
                  tout << mk_pp(cts[i], m()) << " ";
              }
              tout << "\n";
              );
    }

    unsigned compute_cts_delta(app* t, term_set& cts, term_set const& consts) {
        unsigned cts_size = cts.size();
        if (cts_size == consts.size()) {
            return 0;
        }
        compute_used_in(t, cts, consts);
        unsigned cts_delta = cts.size() - cts_size;
        cts.resize(cts_size);
        return cts_delta;
    }

    // select element in A not in B
    app* select_const(term_set const& A, term_set const& B) {
        unsigned j;
        for (j = 0; j < A.size() && B.contains(A[j]); ++j);
        return (j == A.size())?0:A[j];
    }

    app* mk_member(app* t, term_set const& C) {
        expr_ref_vector eqs(m());
        for (unsigned i = 0; i < C.size(); ++i) {
            eqs.push_back(m().mk_eq(t, C[i]));
        }
        return m().mk_or(eqs.size(), eqs.data());
    }
};

symmetry_reduce_tactic::symmetry_reduce_tactic(ast_manager & m) {
    m_imp = alloc(imp, m);
}

symmetry_reduce_tactic::~symmetry_reduce_tactic() {
    dealloc(m_imp);
}
    
void symmetry_reduce_tactic::operator()(goal_ref const & g, 
                                        goal_ref_buffer & result) {
    fail_if_proof_generation("symmetry_reduce", g);
    fail_if_unsat_core_generation("symmetry_reduce", g);
    fail_if_has_quantifiers("symmetry_reduce", g);
    result.reset();
    (*m_imp)(*(g.get()));
    g->inc_depth();
    result.push_back(g.get());
}

void symmetry_reduce_tactic::cleanup() {
    // no-op.
}

tactic * mk_symmetry_reduce_tactic(ast_manager & m, params_ref const & p) {
    return alloc(symmetry_reduce_tactic, m);
}