From 2bd081ba374c614f506def918363bbc363c4f962 Mon Sep 17 00:00:00 2001 From: Joshua Moerman Date: Tue, 11 Jun 2024 09:59:13 +0200 Subject: [PATCH] reorganised python script, slightly more efficient now --- other/decompose_fsm_optimise.py | 487 +++++++++++++++++--------------- 1 file changed, 256 insertions(+), 231 deletions(-) diff --git a/other/decompose_fsm_optimise.py b/other/decompose_fsm_optimise.py index 8829552..6f82e85 100644 --- a/other/decompose_fsm_optimise.py +++ b/other/decompose_fsm_optimise.py @@ -11,22 +11,44 @@ import argparse # Stap 1: pip3 install python-sat # Stap 2: python3 decompose_fsm.py -h -timeout = False -timeout_seconds = 3*60 -record_sizes = False -record_file = './results/log.txt' +keep_log = True -parser = argparse.ArgumentParser(description="Decomposes a FSM into smaller components by remapping its outputs. Uses a SAT solver.") -parser.add_argument('-c', '--components', type=int, default=2, help='number of components') -parser.add_argument('-w', '--weak', default=False, action="store_true", help='look for weak decomposition') -parser.add_argument('--add-state-trans', default=False, action="store_true", help='adds state transitivity constraints') -parser.add_argument('-v', '--verbose', default=False, action="store_true", help='prints more info') -parser.add_argument('filename', help='path to .dot file') -args = parser.parse_args() +def main(): + record_file = './results/log.txt' if keep_log else None -# Aantal componenten. c = 1 is zinloos, maar zou moeten werken -c = args.components -assert c >= 1 + parser = argparse.ArgumentParser(description="Decomposes a FSM into smaller components by remapping its outputs. Uses a SAT solver.") + parser.add_argument('-c', '--components', type=int, default=2, help='number of components') + parser.add_argument('-w', '--weak', default=False, action="store_true", help='look for weak decomposition') + parser.add_argument('--add-state-trans', default=False, action="store_true", help='adds state transitivity constraints') + parser.add_argument('-v', '--verbose', default=False, action="store_true", help='prints more info') + parser.add_argument('-t', '--timeout', type=int, default=None, help='timeout (in seconds)') + parser.add_argument('filename', help='path to .dot file') + args = parser.parse_args() + + # Aantal componenten. c = 1 is zinloos, maar zou moeten werken + c = args.components + assert c >= 1 + + with open(args.filename) as file: + machine = parse_dot_file(file) + if args.verbose: + print(machine) + + print(f'Input FSM: {len(machine.states)} states, {len(machine.inputs)} inputs, and {len(machine.outputs)} outputs') + + if args.timeout != None: + def timeout_handler(signum, frame): + with open(record_file, 'a') as file: + last_two_comps = '/'.join(args.filename.split('/')[-2:]) + file.write(f'{last_two_comps}\t{len(machine.states)}\t{len(machine.inputs)}\t{len(machine.outputs)}\t{args.weak}\t{c}\tTIMEOUT\n') + print('TIMEOUT') + exit() + + signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(args.timeout) + + encoder = Encoder(machine, args, record_file=record_file) + encoder.solve() ################################### @@ -86,28 +108,9 @@ def parse_dot_file(lines): return FSM(initial_state, states, inputs, outputs, transition_map, output_map) -with open(args.filename) as file: - machine = parse_dot_file(file) - if args.verbose: - print(machine) - -N = len(machine.states) -print(f'Input FSM: {N} states, {len(machine.inputs)} inputs, and {len(machine.outputs)} outputs') - -if timeout: - def timeout_handler(signum, frame): - with open(record_file, 'a') as file: - last_two_comps = '/'.join(args.filename.split('/')[-2:]) - file.write(f'{last_two_comps}\t{N}\t{len(machine.inputs)}\t{len(machine.outputs)}\t{args.weak}\t{c}\tTIMEOUT\n') - print('TIMEOUT') - exit() - - signal.signal(signal.SIGALRM, timeout_handler) - signal.alarm(timeout_seconds) ################################### # Utility functions - def print_table(cell, rs, cs): first_col_size = max([len(str(r)) for r in rs]) col_size = 1 + max([len(str(c)) for c in cs] + [len(cell(r, c)) for c in cs for r in rs]) @@ -126,6 +129,7 @@ def print_table(cell, rs, cs): def print_eqrel(rel, xs): print_table(lambda r, c: 'Y' if rel(r, c) else 'ยท', xs, xs) + class Progress: def __init__(self, name, guess): self.reset(name, guess, show=False) @@ -148,228 +152,249 @@ class Progress: print(f'{self.percentage}%', end='', flush=True) print('\r', end='') -progress = Progress('', 1) -######################## -# Encodering naar logica -print('Start encoding') -os = list(machine.outputs) # outputs -rids = [i for i in range(c)] # components -vpool = IDPool() -cnf = CNF() +################################### +# Main logic +class Encoder: + def __init__(self, machine, args, record_file=None, progress=None): + self.machine = machine + self.args = args + self.record_file = record_file + self.progress = progress if progress else Progress('', 1) -# Een hulp variabele voor False en True, maakt de andere variabelen eenvoudiger -def var_const(b): - return(vpool.id(('const', b))) + self.c = self.args.components + self.N = len(self.machine.states) + self.os = list(machine.outputs) # outputs + self.rids = [i for i in range(self.c)] # components + self.vpool = IDPool() + self.solver = Solver() -cnf.append([var_const(True)]) -cnf.append([-var_const(False)]) + self.encode() -# Voor elke relatie en elke twee elementen o1 en o2, is er een variabele die -# aangeeft of o1 en o2 gerelateerd zijn. Er is 1 variabele voor xRy en yRx, dus -# symmetrie is al ingebouwd. Reflexiviteit is ook ingebouwd. -def var_rel(rid, o1, o2): - if o1 == o2: - return var_const(True) + def add_clause(self, cls, no_return=True): + self.solver.add_clause(cls, no_return) - [so1, so2] = sorted([o1, o2]) - return(vpool.id(('rel', rid, so1, so2))) + def add_clauses(self, clss, no_return=True): + self.solver.append_formula(clss, no_return) -# De relatie op outputs geeft een relaties op states. Deze relatie moet ook een -# bisimulatie zijn. -def var_state_rel(rid, s1, s2): - if s1 == s2: - return var_const(True) + # Een hulp variabele voor False en True, maakt de andere variabelen eenvoudiger + def var_const(self, b): + return(self.vpool.id(('const', b))) - [ss1, ss2] = sorted([s1, s2]) - return(vpool.id(('state_rel', rid, ss1, ss2))) + # Voor elke relatie en elke twee elementen o1 en o2, is er een variabele die + # aangeeft of o1 en o2 gerelateerd zijn. Er is 1 variabele voor xRy en yRx, dus + # symmetrie is al ingebouwd. Reflexiviteit is ook ingebouwd. + def var_rel(self, rid, o1, o2): + if o1 == o2: + return self.var_const(True) -# Voor elke relatie, en elke equivalentie-klasse, kiezen we precies 1 state -# als representant. Deze variabele geeft aan welk element. -def var_state_rep(rid, s): - return(vpool.id(('state_rep', rid, s))) + [so1, so2] = sorted([o1, o2]) + return self.vpool.id(('rel', rid, so1, so2)) -# Contraints zodat de relatie een equivalentie relatie is. We hoeven alleen -# maar transitiviteit te encoderen, want refl en symm zijn ingebouwd in de var. -progress.reset('transitivity (o)', guess=len(rids) * len(os) ** 3) -for rid in rids: - for xo in os: - for yo in os: - for zo in os: - # als xo R yo en yo R zo dan xo R zo - cnf.append([-var_rel(rid, xo, yo), -var_rel(rid, yo, zo), var_rel(rid, xo, zo)]) - progress.add() + # De relatie op outputs geeft een relaties op states. Deze relatie moet ook een + # bisimulatie zijn. + def var_state_rel(self, rid, s1, s2): + if s1 == s2: + return self.var_const(True) -if args.add_state_trans: - progress.reset('transitivity (s)', guess=len(rids) * len(machine.states) ** 3) - for rid in rids: - for sx in machine.states: - for sy in machine.states: - for sz in machine.states: - # als sx R sy en sy R sz dan sx R sz - cnf.append([-var_state_rel(rid, sx, sy), -var_state_rel(rid, sy, sz), var_state_rel(rid, sx, sz)]) - progress.add() + [ss1, ss2] = sorted([s1, s2]) + return self.vpool.id(('state_rel', rid, ss1, ss2)) -# Constraint zodat de relaties samen alle elementen kunnen onderscheiden. -# (Aka: the bijbehorende quotienten zijn joint-injective.) -progress.reset('injectivity', guess=len(os) * (len(os) - 1) / 2) -for xi, xo in enumerate(os): - for yo in os[xi+1:]: - # Tenminste een rid moet een verschil maken - cnf.append([-var_rel(rid, xo, yo) for rid in rids]) - progress.add() + # Voor elke relatie, en elke equivalentie-klasse, kiezen we precies 1 state + # als representant. Deze variabele geeft aan welk element. + def var_state_rep(self, rid, s): + return self.vpool.id(('state_rep', rid, s)) -# sx ~ sy => for each input: (1) outputs equivalent AND (2) successors related -# Momenteel hebben we niet de inverse implicatie, is misschien ook niet nodig? -progress.reset('bisimulation modulo rel', guess=len(rids) * len(machine.states) * len(machine.states) * len(machine.inputs)) -for rid in rids: - for sx in machine.states: - for sy in machine.states: - for i in machine.inputs: - # sx ~ sy => output(sx, i) ~ output(sy, i) - ox = machine.output(sx, i) - oy = machine.output(sy, i) - cnf.append([-var_state_rel(rid, sx, sy), var_rel(rid, ox, oy)]) + def encode(self): + print('===============') + print('Start encoding') + self.add_clause([self.var_const(True)]) + self.add_clause([-self.var_const(False)]) - # sx ~ sy => delta(sx, i) ~ delta(sy, i) - tx = machine.transition(sx, i) - ty = machine.transition(sy, i) - cnf.append([-var_state_rel(rid, sx, sy), var_state_rel(rid, tx, ty)]) + # Contraints zodat de relatie een equivalentie relatie is. We hoeven alleen + # maar transitiviteit te encoderen, want refl en symm zijn ingebouwd in de var. + self.progress.reset('transitivity (o)', guess=len(self.rids) * len(self.os) ** 3) + for rid in self.rids: + for xo in self.os: + for yo in self.os: + for zo in self.os: + # als xo R yo en yo R zo dan xo R zo + self.add_clause([-self.var_rel(rid, xo, yo), -self.var_rel(rid, yo, zo), self.var_rel(rid, xo, zo)]) + self.progress.add() - progress.add() + if self.args.add_state_trans: + self.progress.reset('transitivity (s)', guess=len(self.rids) * len(self.machine.states) ** 3) + for rid in self.rids: + for sx in self.machine.states: + for sy in self.machine.states: + for sz in self.machine.states: + # als sx R sy en sy R sz dan sx R sz + self.add_clause([-self.var_state_rel(rid, sx, sy), -self.var_state_rel(rid, sy, sz), self.var_state_rel(rid, sx, sz)]) + self.progress.add() -# De constraints die zorgen dat representanten ook echt representanten zijn. -states = list(machine.states) -progress.reset('representatives', guess=len(rids) * len(states)) -for rid in rids: - for ix, sx in enumerate(states): - # Belangrijkste: een element is een representant, of equivalent met een - # eerder element. We forceren hiermee dat de solver representanten moet - # kiezen (voor aan de lijst). - cnf.append([var_state_rep(rid, sx)] + [var_state_rel(rid, sx, sy) for sy in states[:ix]] ) + # Constraint zodat de relaties samen alle elementen kunnen onderscheiden. + # (Aka: the bijbehorende quotienten zijn joint-injective.) + self.progress.reset('injectivity', guess=len(self.os) * (len(self.os) - 1) / 2) + for xi, xo in enumerate(self.os): + for yo in self.os[xi+1:]: + # Tenminste een rid moet een verschil maken + self.add_clause([-self.var_rel(rid, xo, yo) for rid in self.rids]) + self.progress.add() - for sy in states[:ix]: - # rx en ry kunnen niet beide een representant zijn, tenzij ze - # niet gerelateerd zijn. - cnf.append([-var_state_rep(rid, sx), -var_state_rep(rid, sy), -var_state_rel(rid, sx, sy)]) + # sx ~ sy => for each input: (1) outputs equivalent AND (2) successors related + # Momenteel hebben we niet de inverse implicatie, is misschien ook niet nodig? + self.progress.reset('bisimulation modulo rel', guess=len(self.rids) * len(self.machine.states) * len(self.machine.states) * len(self.machine.inputs)) + for rid in self.rids: + for sx in self.machine.states: + for sy in self.machine.states: + for i in self.machine.inputs: + # sx ~ sy => output(sx, i) ~ output(sy, i) + ox = self.machine.output(sx, i) + oy = self.machine.output(sy, i) + self.add_clause([-self.var_state_rel(rid, sx, sy), self.var_rel(rid, ox, oy)]) - progress.add() + # sx ~ sy => delta(sx, i) ~ delta(sy, i) + tx = self.machine.transition(sx, i) + ty = self.machine.transition(sy, i) + self.add_clause([-self.var_state_rel(rid, sx, sy), self.var_state_rel(rid, tx, ty)]) -# Tot slot willen we weinig representanten. Dit doen we met een "atmost" -# formule. We gaan een binaire zoek doen met incremental sat solving. -rhs = None -if args.weak: - lower_bound = int(math.floor((N-1)**(1/c))) - upper_bound = int(N) - print(f'weak size constraints {lower_bound} {upper_bound}') - rhs = [] - for rid in rids: - with ITotalizer([var_state_rep(rid, sx) for sx in machine.states], ubound=upper_bound, top_id=vpool.top) as cnf_optim: - vpool.occupy(vpool.top + 1, cnf_optim.top_id) - vpool.top = cnf_optim.top_id - cnf.extend(cnf_optim.cnf.clauses) - rhs.append(cnf_optim.rhs) -else: - lower_bound = int(math.floor(c * (N-1)**(1/c))) - upper_bound = int(N + c - 1) - print(f'size constraints {lower_bound} {upper_bound}') - with ITotalizer([var_state_rep(rid, sx) for rid in rids for sx in machine.states], ubound=upper_bound, top_id=vpool.top) as cnf_optim: - cnf.extend(cnf_optim.cnf.clauses) - rhs = cnf_optim.rhs + self.progress.add() -################################## -# Probleem oplossen met solver :-) -print('Start solving') -print('- copying formula') -with Solver(bootstrap_with=cnf) as solver: - print('===============') - sat = None - while upper_bound - lower_bound >= 2: - mid_size = int((lower_bound + upper_bound) / 2) - print(f'Trying {lower_bound} < {mid_size} < {upper_bound}', end='', flush=True) - if args.weak: - assumptions = [-rhs[rid][mid_size] for rid in rids] - sat = solver.solve(assumptions=assumptions) + # De constraints die zorgen dat representanten ook echt representanten zijn. + states = list(self.machine.states) + self.progress.reset('representatives', guess=len(self.rids) * len(states)) + for rid in self.rids: + for ix, sx in enumerate(states): + # Belangrijkste: een element is een representant, of equivalent met een + # eerder element. We forceren hiermee dat de solver representanten moet + # kiezen (voor aan de lijst). + self.add_clause([self.var_state_rep(rid, sx)] + [self.var_state_rel(rid, sx, sy) for sy in states[:ix]] ) + + for sy in states[:ix]: + # rx en ry kunnen niet beide een representant zijn, tenzij ze + # niet gerelateerd zijn. + self.add_clause([-self.var_state_rep(rid, sx), -self.var_state_rep(rid, sy), -self.var_state_rel(rid, sx, sy)]) + + self.progress.add() + + # Tot slot willen we weinig representanten. Dit doen we met een "atmost" + # formule. We gaan een binaire zoek doen met incremental sat solving. + self.rhs = None + if self.args.weak: + self.lower_bound = int(math.floor((self.N-1)**(1/self.c))) + self.upper_bound = int(self.N) + print(f'weak size constraints {self.lower_bound} {self.upper_bound}') + self.rhs = [] + for rid in self.rids: + with ITotalizer([self.var_state_rep(rid, sx) for sx in self.machine.states], ubound=self.upper_bound, top_id=self.vpool.top) as cnf_optim: + self.vpool.occupy(self.vpool.top + 1, cnf_optim.top_id) + self.vpool.top = cnf_optim.top_id + self.add_clauses(cnf_optim.cnf.clauses) + self.rhs.append(cnf_optim.rhs) else: - sat = solver.solve(assumptions=[-rhs[mid_size]]) - if sat: - print('\tdown') - upper_bound = mid_size - continue - else: - print('\tup') - lower_bound = mid_size - continue + self.lower_bound = int(math.floor(self.c * (self.N-1)**(1/self.c))) + self.upper_bound = int(self.N + self.c - 1) + print(f'size constraints {self.lower_bound} {self.upper_bound}') + with ITotalizer([self.var_state_rep(rid, sx) for rid in self.rids for sx in self.machine.states], ubound=self.upper_bound, top_id=self.vpool.top) as cnf_optim: + self.add_clauses(cnf_optim.cnf.clauses) + self.rhs = cnf_optim.rhs - bound = upper_bound - print(f'done searching, found bound = {bound}') - - if args.weak: - assumptions = [-rhs[rid][bound] for rid in rids] - sat = solver.solve(assumptions=assumptions) - else: - sat = solver.solve(assumptions=[-rhs[bound]]) - assert sat - - # Even omzetten in een makkelijkere data structuur - m = solver.get_model() - model = {} - for l in m: - if l < 0: model[-l] = False - else: model[l] = True - - if args.verbose: - for rid in rids: - print(f'Relation {rid}:') - print_eqrel(lambda x, y: model[var_rel(rid, x, y)], os) - - for rid in rids: - print(f'State relation {rid}:') - print_eqrel(lambda x, y: model[var_state_rel(rid, x, y)], machine.states) - - # Precieze groottes van elk component tellen - counts = [] - for rid in rids: - count = 0 - # Eerst verzamelen we de representanten - for s in machine.states: - if model[var_state_rep(rid, s)]: - count += 1 - if args.verbose: - print(f'comp {rid} -> representative state {s}') - counts.append(count) - - print(f'Reduced sizes = {counts} = {sum(counts)}') - if record_sizes: - with open(record_file, 'a') as file: - last_two_comps = '/'.join(args.filename.split('/')[-2:]) - file.write(f'{last_two_comps}\t{N}\t{len(machine.inputs)}\t{len(machine.outputs)}\t{args.weak}\t{c}\t{sum(counts)}\t{sorted(counts, reverse=True)}\n') - - projections = {} - for rid in rids: - local_outputs = machine.outputs.copy() - projections[rid] = {} - count = 0 - - while local_outputs: - repr = local_outputs.pop() - if repr in projections[rid]: + def solve(self): + print('===============') + print('Start solving') + sat = None + while self.upper_bound - self.lower_bound >= 2: + self.mid_size = int((self.lower_bound + self.upper_bound) / 2) + print(f'Trying {self.lower_bound} < {self.mid_size} < {self.upper_bound}', end='', flush=True) + if self.args.weak: + assumptions = [-self.rhs[rid][self.mid_size] for rid in self.rids] + sat = self.solver.solve(assumptions=assumptions) + else: + sat = self.solver.solve(assumptions=[-self.rhs[self.mid_size]]) + if sat: + print('\tdown') + self.upper_bound = self.mid_size + continue + else: + print('\tup') + self.lower_bound = self.mid_size continue - projections[rid][repr] = f'cls_{rid}_{count}' - others = False + self.bound = self.upper_bound + print(f'done searching, found bound = {self.bound}') - for o in local_outputs: - if model[var_rel(rid, o, repr)]: - others = True - projections[rid][o] = f'cls_{rid}_{count}' + if self.args.weak: + assumptions = [-self.rhs[rid][self.bound] for rid in self.rids] + sat = self.solver.solve(assumptions=assumptions) + else: + sat = self.solver.solve(assumptions=[-self.rhs[self.bound]]) + assert sat - if not others: - # Aangeven dat het een unieke output is - projections[rid][repr] = f'cls_{rid}_{count}_u' + # Even omzetten in een makkelijkere data structuur + m = self.solver.get_model() + model = {} + for l in m: + if l < 0: model[-l] = False + else: model[l] = True - count += 1 + if self.args.verbose: + for rid in self.rids: + print(f'Relation {rid}:') + print_eqrel(lambda x, y: model[self.var_rel(rid, x, y)], self.os) - print('===============') - print('Output mapping:') - print_table(lambda o, rid: projections[rid][o], machine.outputs, rids) + for rid in self.rids: + print(f'State relation {rid}:') + print_eqrel(lambda x, y: model[self.var_state_rel(rid, x, y)], self.machine.states) + + # Precieze groottes van elk component tellen + counts = [] + for rid in self.rids: + count = 0 + # Eerst verzamelen we de representanten + for s in self.machine.states: + if model[self.var_state_rep(rid, s)]: + count += 1 + if self.args.verbose: + print(f'comp {rid} -> representative state {s}') + counts.append(count) + + print(f'Reduced sizes = {counts} = {sum(counts)}') + if self.record_file != None: + with open(self.record_file, 'a') as file: + last_two_comps = '/'.join(self.args.filename.split('/')[-2:]) + file.write(f'{last_two_comps}\t{self.N}\t{len(self.machine.inputs)}\t{len(self.machine.outputs)}\t{self.args.weak}\t{self.c}\t{sum(counts)}\t{sorted(counts, reverse=True)}\n') + + projections = {} + for rid in self.rids: + local_outputs = self.machine.outputs.copy() + projections[rid] = {} + count = 0 + + while local_outputs: + repr = local_outputs.pop() + if repr in projections[rid]: + continue + + projections[rid][repr] = f'cls_{rid}_{count}' + others = False + + for o in local_outputs: + if model[self.var_rel(rid, o, repr)]: + others = True + projections[rid][o] = f'cls_{rid}_{count}' + + if not others: + # Aangeven dat het een unieke output is + projections[rid][repr] = f'cls_{rid}_{count}_u' + + count += 1 + + print('===============') + print('Output mapping:') + print_table(lambda o, rid: projections[rid][o], self.machine.outputs, self.rids) + + +################################### +# Run script +if __name__ == "__main__": + main()