From 4c8b8d9f2f905af7ff2f3d8e4ea2358e9310918c Mon Sep 17 00:00:00 2001 From: Joshua Moerman Date: Wed, 29 May 2024 16:40:17 +0200 Subject: [PATCH] Incremental sat solving to find optimal solution (with fixed c) --- other/decompose_fsm_optimise.py | 330 ++++++++++++++++++++++++++++++++ 1 file changed, 330 insertions(+) create mode 100644 other/decompose_fsm_optimise.py diff --git a/other/decompose_fsm_optimise.py b/other/decompose_fsm_optimise.py new file mode 100644 index 0000000..6d12c94 --- /dev/null +++ b/other/decompose_fsm_optimise.py @@ -0,0 +1,330 @@ +from pysat.solvers import Solver +from pysat.card import ITotalizer +from pysat.formula import IDPool +from pysat.formula import CNF + +import math +import argparse + +### Gebruik: +# Stap 1: pip3 install python-sat +# Stap 2: python3 decompose_fsm.py -h + +parser = argparse.ArgumentParser(description="Decomposes a FSM into smaller components by remapping its outputs. Uses a SAT solver.") +parser.add_argument('-c', '--components', type=int, default=2, help='number of components') +parser.add_argument('--add-state-trans', default=False, action="store_true", help='adds state transitivity constraints') +parser.add_argument('-v', '--verbose', default=False, action="store_true", help='prints more info') +parser.add_argument('filename', help='path to .dot file') +args = parser.parse_args() + +# als de de total_size te laag is => UNSAT => duurt lang +c = args.components + +assert c >= 1 # c = 1 is zinloos, maar zou moeten werken + + +################################### +# .dot file parser (heuristic) +class FSM: + def __init__(self, initial_state, states, inputs, outputs, transition_map, output_map): + self.initial_state = initial_state + self.states = states + self.inputs = inputs + self.outputs = outputs + self.transition_map = transition_map + self.output_map = output_map + + def __str__(self): + return f'FSM({self.initial_state}, {self.states}, {self.inputs}, {self.outputs}, {self.transition_map}, {self.output_map})' + + def transition(self, s, a): + return self.transition_map[(s, a)] + + def output(self, s, a): + return self.output_map[(s, a)] + +def parse_dot_file(lines): + def parse_transition(line): + (l, _, r) = line.partition('->') + s = l.strip() + + (l, _, r) = r.partition('[label="') + t = l.strip() + + (l, _, _) = r.partition('"]') + (i, _, o) = l.partition('/') + + return (s, i, o, t) + + initial_state = None + states, inputs, outputs = set(), set(), set() + transition_map, output_map = {}, {} + + for line in lines: + (s, i, o, t) = parse_transition(line) + if s and i and o and t: + states.add(s) + inputs.add(i) + outputs.add(o) + states.add(t) + + transition_map[(s, i)] = t + output_map[(s, i)] = o + + if not initial_state: + initial_state = s + + assert initial_state in states + assert len(transition_map) == len(states) * len(inputs) + assert len(output_map) == len(states) * len(inputs) + + return FSM(initial_state, states, inputs, outputs, transition_map, output_map) + +with open(args.filename) as file: + machine = parse_dot_file(file) + if args.verbose: + print(machine) + +N = len(machine.states) +print(f'Initial size: {N}') + + +################################### +# Utility functions + +def print_table(cell, rs, cs): + first_col_size = max([len(str(r)) for r in rs]) + col_size = 1 + max([len(str(c)) for c in cs] + [len(cell(r, c)) for c in cs for r in rs]) + + print(''.rjust(first_col_size), end='') + for c in cs: + print(str(c).rjust(col_size), end='') + print('') + + for r in rs: + print(str(r).rjust(first_col_size), end='') + for c in cs: + print(cell(r, c).rjust(col_size), end='') + print('') + +def print_eqrel(rel, xs): + print_table(lambda r, c: 'Y' if rel(r, c) else 'ยท', xs, xs) + +class Progress: + def __init__(self, name, guess): + self.reset(name, guess, show=False) + + def reset(self, name, guess, show=True): + self.name = name + self.guess = math.ceil(guess) + self.count = 0 + self.percentage = None + + if show: + print(name) + + def add(self, n=1): + self.count += n + percentage = math.floor(100 * self.count / self.guess) + + if percentage != self.percentage: + self.percentage = percentage + print(f'{self.percentage}%', end='', flush=True) + print('\r', end='') + +progress = Progress('', 1) + +######################## +# Encodering naar logica +print('Start encoding') +os = list(machine.outputs) # outputs +rids = [i for i in range(c)] # components +vpool = IDPool() +cnf = CNF() + +# Een hulp variabele voor False en True, maakt de andere variabelen eenvoudiger +def var_const(b): + return(vpool.id(('const', b))) + +cnf.append([var_const(True)]) +cnf.append([-var_const(False)]) + +# Voor elke relatie en elke twee elementen o1 en o2, is er een variabele die +# aangeeft of o1 en o2 gerelateerd zijn. Er is 1 variabele voor xRy en yRx, dus +# symmetrie is al ingebouwd. Reflexiviteit is ook ingebouwd. +def var_rel(rid, o1, o2): + if o1 == o2: + return var_const(True) + + [so1, so2] = sorted([o1, o2]) + return(vpool.id(('rel', rid, so1, so2))) + +# De relatie op outputs geeft een relaties op states. Deze relatie moet ook een +# bisimulatie zijn. +def var_state_rel(rid, s1, s2): + if s1 == s2: + return var_const(True) + + [ss1, ss2] = sorted([s1, s2]) + return(vpool.id(('state_rel', rid, ss1, ss2))) + +# Voor elke relatie, en elke equivalentie-klasse, kiezen we precies 1 state +# als representant. Deze variabele geeft aan welk element. +def var_state_rep(rid, s): + return(vpool.id(('state_rep', rid, s))) + +# Contraints zodat de relatie een equivalentie relatie is. We hoeven alleen +# maar transitiviteit te encoderen, want refl en symm zijn ingebouwd in de var. +progress.reset('transitivity (o)', guess=len(rids) * len(os) ** 3) +for rid in rids: + for xo in os: + for yo in os: + for zo in os: + # als xo R yo en yo R zo dan xo R zo + cnf.append([-var_rel(rid, xo, yo), -var_rel(rid, yo, zo), var_rel(rid, xo, zo)]) + progress.add() + +if args.add_state_trans: + progress.reset('transitivity (s)', guess=len(rids) * len(machine.states) ** 3) + for rid in rids: + for sx in machine.states: + for sy in machine.states: + for sz in machine.states: + # als sx R sy en sy R sz dan sx R sz + cnf.append([-var_state_rel(rid, sx, sy), -var_state_rel(rid, sy, sz), var_state_rel(rid, sx, sz)]) + progress.add() + +# Constraint zodat de relaties samen alle elementen kunnen onderscheiden. +# (Aka: the bijbehorende quotienten zijn joint-injective.) +progress.reset('injectivity', guess=len(os) * (len(os) - 1) / 2) +for xi, xo in enumerate(os): + for yo in os[xi+1:]: + # Tenminste een rid moet een verschil maken + cnf.append([-var_rel(rid, xo, yo) for rid in rids]) + progress.add() + +# sx ~ sy => for each input: (1) outputs equivalent AND (2) successors related +# Momenteel hebben we niet de inverse implicatie, is misschien ook niet nodig? +progress.reset('bisimulation modulo rel', guess=len(rids) * len(machine.states) * len(machine.states) * len(machine.inputs)) +for rid in rids: + for sx in machine.states: + for sy in machine.states: + for i in machine.inputs: + # sx ~ sy => output(sx, i) ~ output(sy, i) + ox = machine.output(sx, i) + oy = machine.output(sy, i) + cnf.append([-var_state_rel(rid, sx, sy), var_rel(rid, ox, oy)]) + + # sx ~ sy => delta(sx, i) ~ delta(sy, i) + tx = machine.transition(sx, i) + ty = machine.transition(sy, i) + cnf.append([-var_state_rel(rid, sx, sy), var_state_rel(rid, tx, ty)]) + + progress.add() + +# De constraints die zorgen dat representanten ook echt representanten zijn. +states = list(machine.states) +progress.reset('representatives', guess=len(rids) * len(states)) +for rid in rids: + for ix, sx in enumerate(states): + # Belangrijkste: een element is een representant, of equivalent met een + # eerder element. We forceren hiermee dat de solver representanten moet + # kiezen (voor aan de lijst). + cnf.append([var_state_rep(rid, sx)] + [var_state_rel(rid, sx, sy) for sy in states[:ix]] ) + + for sy in states[:ix]: + # rx en ry kunnen niet beide een representant zijn, tenzij ze + # niet gerelateerd zijn. + cnf.append([-var_state_rep(rid, sx), -var_state_rep(rid, sy), -var_state_rel(rid, sx, sy)]) + + progress.add() + +# Tot slot willen we weinig representanten. Dit doen we met een "atmost" +# formule. Idealiter zoeken we naar de total_size, maar die staat nu vast. +lower_bound = int(math.floor(c * (N-1)**(1/c))) +upper_bound = int(N + c - 1) +print(f'size constraints {lower_bound} {upper_bound}') +cnf_optim = ITotalizer([var_state_rep(rid, sx) for rid in rids for sx in machine.states], ubound=upper_bound, top_id=vpool.top) +cnf.extend(cnf_optim.cnf.clauses) + +################################## +# Probleem oplossen met solver :-) +print('Start solving') +print('- copying formula') +with Solver(bootstrap_with=cnf) as solver: + print('===============') + while upper_bound - lower_bound >= 2: + mid_size = int((lower_bound + upper_bound) / 2) + print(f'Trying {lower_bound} < {mid_size} < {upper_bound}') + sat = solver.solve(assumptions=[-cnf_optim.rhs[mid_size]]) + if sat: + upper_bound = mid_size + continue + else: + lower_bound = mid_size + continue + + total_size = upper_bound + print(f'done searching, found size = {total_size}') + sat = solver.solve(assumptions=[-cnf_optim.rhs[total_size]]) + assert sat + + # Even omzetten in een makkelijkere data structuur + print('- get model') + m = solver.get_model() + model = {} + for l in m: + if l < 0: model[-l] = False + else: model[l] = True + + if args.verbose: + for rid in rids: + print(f'Relation {rid}:') + print_eqrel(lambda x, y: model[var_rel(rid, x, y)], os) + + for rid in rids: + print(f'State relation {rid}:') + print_eqrel(lambda x, y: model[var_state_rel(rid, x, y)], machine.states) + + # print equivalence classes + count = 0 + for rid in rids: + if args.verbose: + print(f'component {rid}') + # Eerst verzamelen we de representanten + for s in machine.states: + if model[var_state_rep(rid, s)]: + count += 1 + if args.verbose: + print(f'- representative state {s}') + + # count moet gelijk zijn aan cost (of kleiner) + print(f'Reduced size = {count}') + + projections = {} + for rid in rids: + local_outputs = machine.outputs.copy() + projections[rid] = {} + count = 0 + + while local_outputs: + repr = local_outputs.pop() + if repr in projections[rid]: + continue + + projections[rid][repr] = f'cls_{rid}_{count}' + others = False + + for o in local_outputs: + if model[var_rel(rid, o, repr)]: + others = True + projections[rid][o] = f'cls_{rid}_{count}' + + if not others: + projections[rid][repr] = f'{repr}' + + count += 1 + + print('===============') + print('Output mapping:') + print_table(lambda o, rid: projections[rid][o], machine.outputs, rids)