From fb0adfbf46ca6a0be7e29436a75afebe9db9623c Mon Sep 17 00:00:00 2001 From: Joshua Moerman Date: Mon, 24 Jun 2024 08:28:07 +0200 Subject: [PATCH] Additional optimisation of components --- other/decompose_fsm_optimise.py | 268 ++++++++++++++++++++------------ 1 file changed, 167 insertions(+), 101 deletions(-) diff --git a/other/decompose_fsm_optimise.py b/other/decompose_fsm_optimise.py index 36e4f3b..357083f 100644 --- a/other/decompose_fsm_optimise.py +++ b/other/decompose_fsm_optimise.py @@ -11,28 +11,36 @@ import argparse # Stap 2: python3 decompose_fsm.py -h keep_log = True +record_file = './results/log.txt' if keep_log else None + +filename = None def main(): - record_file = './results/log.txt' if keep_log else None - parser = argparse.ArgumentParser(description='Decomposes a FSM into smaller components by remapping its outputs. Uses a SAT solver.') parser.add_argument('-c', '--components', type=int, default=2, help='number of components') parser.add_argument('-w', '--weak', default=False, action='store_true', help='look for weak decomposition') parser.add_argument('--add-state-trans', default=False, action='store_true', help='adds state transitivity constraints') - parser.add_argument('-v', '--verbose', default=False, action='store_true', help='prints more info') parser.add_argument('-t', '--timeout', type=int, default=None, help='timeout (in seconds)') parser.add_argument('filename', help='path to .dot file') args = parser.parse_args() + global filename + filename = args.filename + # Aantal componenten. c = 1 is zinloos, maar zou moeten werken c = args.components assert c >= 1 with open(args.filename) as file: machine = parse_dot_file(file) - if args.verbose: - print(machine) + + # Als er maar 1 state is, valt er niks te ontbinden, maar het script zou niet + # moeten crashen. Het aantal outputs moet minstens 3 zijn voor een zinvolle + # decompositie. + assert len(machine.states) >= 1 + assert len(machine.inputs) >= 1 + assert len(machine.outputs) >= 1 print(f'Input FSM: {len(machine.states)} states, {len(machine.inputs)} inputs, and {len(machine.outputs)} outputs') @@ -48,8 +56,8 @@ def main(): signal.signal(signal.SIGALRM, timeout_handler) signal.alarm(args.timeout) - encoder = Encoder(machine, args, record_file=record_file) - encoder.solve() + with Encoder(machine, c, args.weak, add_state_trans=args.add_state_trans, record_file=record_file) as encoder: + encoder.solve() ################################### @@ -134,10 +142,10 @@ def print_eqrel(rel, xs): class Progress: - def __init__(self, name, guess): + def __init__(self, name: str, guess: int): self.reset(name, guess, show=False) - def reset(self, name, guess, show=True): + def reset(self, name: str, guess: int, show: bool = True): self.name = name self.guess = math.ceil(guess) self.count = 0 @@ -146,7 +154,7 @@ class Progress: if show: print(name) - def add(self, n=1): + def add(self, n: int = 1): self.count += n percentage = math.floor(100 * self.count / self.guess) @@ -159,21 +167,40 @@ class Progress: ################################### # Main logic class Encoder: - def __init__(self, machine, args, record_file=None, progress=None): + def __init__(self, machine: FSM, components: int = 2, weak: bool = False, add_state_trans: bool = False, record_file: str = None, progress: Progress = None): self.machine = machine - self.args = args + self.c = components + self.weak = weak self.record_file = record_file self.progress = progress if progress else Progress('', 1) - self.c = self.args.components + # optionally add state transitivity constraints. This is not necessary for + # the decomposition and it is cubic in the number of states, so it's off + # by default. + self.add_state_trans = add_state_trans + self.N = len(self.machine.states) - self.os = list(machine.outputs) # outputs - self.rids = [i for i in range(self.c)] # components + self.rids = [i for i in range(self.c)] self.vpool = IDPool() self.solver = Solver() + def __enter__(self): + self.solver = Solver() + self.solver.__enter__() + self.encode() + if self.weak: + self.encode_weak_size_constraints() + else: + self.encode_strict_size_constraints() + assert hasattr(self, 'rhs') + + return self + + def __exit__(self, *args): + return self.solver.__exit__(*args) + def add_clause(self, cls, no_return=True): self.solver.add_clause(cls, no_return) @@ -181,34 +208,37 @@ class Encoder: self.solver.append_formula(clss, no_return) # Een hulp variabele voor False en True, maakt de andere variabelen eenvoudiger - def var_const(self, b): + def var_const(self, b) -> int: return self.vpool.id(('const', b)) - # Voor elke relatie en elke twee elementen o1 en o2, is er een variabele die - # aangeeft of o1 en o2 gerelateerd zijn. Er is 1 variabele voor xRy en yRx, dus - # symmetrie is al ingebouwd. Reflexiviteit is ook ingebouwd. - def var_rel(self, rid, o1, o2): - if o1 == o2: + # Een variabele die aangeeft of x en y gerelateerd zijn. Deze variabele is + # symmetrisch en reflexief. De id is een object die de relatie identificeert. + # Zo kunnen we meerdere relaties encoderen. + def var_rel_abs(self, id, x, y) -> int: + if x == y: return self.var_const(True) - [so1, so2] = sorted([o1, o2]) - return self.vpool.id(('rel', rid, so1, so2)) + [sx, sy] = sorted([x, y]) + return self.vpool.id(('r', id, sx, sy)) + + # Een relatie op de output-elementen. + def var_rel(self, rid, o1, o2) -> int: + return self.var_rel_abs(('output', rid), o1, o2) # De relatie op outputs geeft een relaties op states. Deze relatie moet ook een # bisimulatie zijn. - def var_state_rel(self, rid, s1, s2): - if s1 == s2: - return self.var_const(True) - - [ss1, ss2] = sorted([s1, s2]) - return self.vpool.id(('state_rel', rid, ss1, ss2)) + def var_state_rel(self, rid, s1, s2) -> int: + return self.var_rel_abs(('state', rid), s1, s2) # Voor elke relatie, en elke equivalentie-klasse, kiezen we precies 1 state # als representant. Deze variabele geeft aan welk element. - def var_state_rep(self, rid, s): + def var_state_rep(self, rid, s) -> int: return self.vpool.id(('state_rep', rid, s)) def encode(self): + # lokale variabelen om het wat leesbaarder te maken + rids, os, states, inputs = self.rids, list(self.machine.outputs), list(self.machine.states), list(self.machine.inputs) + print('===============') print('Start encoding') self.add_clause([self.var_const(True)]) @@ -216,41 +246,41 @@ class Encoder: # Contraints zodat de relatie een equivalentie relatie is. We hoeven alleen # maar transitiviteit te encoderen, want refl en symm zijn ingebouwd in de var. - self.progress.reset('transitivity (o)', guess=len(self.rids) * len(self.os) ** 3) - for rid in self.rids: - for xo in self.os: - for yo in self.os: - for zo in self.os: + self.progress.reset('transitivity (o)', guess=len(rids) * len(os) ** 3) + for rid in rids: + for xo in os: + for yo in os: + for zo in os: # als xo R yo en yo R zo dan xo R zo self.add_clause([-self.var_rel(rid, xo, yo), -self.var_rel(rid, yo, zo), self.var_rel(rid, xo, zo)]) self.progress.add() - if self.args.add_state_trans: - self.progress.reset('transitivity (s)', guess=len(self.rids) * len(self.machine.states) ** 3) - for rid in self.rids: - for sx in self.machine.states: - for sy in self.machine.states: - for sz in self.machine.states: + if self.add_state_trans: + self.progress.reset('transitivity (s)', guess=len(rids) * len(states) ** 3) + for rid in rids: + for sx in states: + for sy in states: + for sz in states: # als sx R sy en sy R sz dan sx R sz self.add_clause([-self.var_state_rel(rid, sx, sy), -self.var_state_rel(rid, sy, sz), self.var_state_rel(rid, sx, sz)]) self.progress.add() # Constraint zodat de relaties samen alle elementen kunnen onderscheiden. # (Aka: the bijbehorende quotienten zijn joint-injective.) - self.progress.reset('injectivity', guess=len(self.os) * (len(self.os) - 1) / 2) - for xi, xo in enumerate(self.os): - for yo in self.os[xi + 1 :]: + self.progress.reset('injectivity', guess=len(os) * (len(os) - 1) / 2) + for xi, xo in enumerate(os): + for yo in os[xi + 1 :]: # Tenminste een rid moet een verschil maken self.add_clause([-self.var_rel(rid, xo, yo) for rid in self.rids]) self.progress.add() # sx ~ sy => for each input: (1) outputs equivalent AND (2) successors related # Momenteel hebben we niet de inverse implicatie, is misschien ook niet nodig? - self.progress.reset('bisimulation modulo rel', guess=len(self.rids) * len(self.machine.states) * len(self.machine.states) * len(self.machine.inputs)) - for rid in self.rids: - for sx in self.machine.states: - for sy in self.machine.states: - for i in self.machine.inputs: + self.progress.reset('bisimulation modulo rel', guess=len(rids) * len(states) * len(states) * len(inputs)) + for rid in rids: + for sx in states: + for sy in states: + for i in inputs: # sx ~ sy => output(sx, i) ~ output(sy, i) ox = self.machine.output(sx, i) oy = self.machine.output(sy, i) @@ -264,9 +294,8 @@ class Encoder: self.progress.add() # De constraints die zorgen dat representanten ook echt representanten zijn. - states = list(self.machine.states) - self.progress.reset('representatives', guess=len(self.rids) * len(states)) - for rid in self.rids: + self.progress.reset('representatives', guess=len(rids) * len(states)) + for rid in rids: for ix, sx in enumerate(states): # Belangrijkste: een element is een representant, of equivalent met een # eerder element. We forceren hiermee dat de solver representanten moet @@ -280,40 +309,84 @@ class Encoder: self.progress.add() - # Tot slot willen we weinig representanten. Dit doen we met een "atmost" - # formule. We gaan een binaire zoek doen met incremental sat solving. - self.rhs = None - if self.args.weak: - self.lower_bound = int(math.floor((self.N - 1) ** (1 / self.c))) - self.upper_bound = int(self.N) - print(f'weak size constraints {self.lower_bound} {self.upper_bound}') - self.rhs = [] - for rid in self.rids: - with ITotalizer([self.var_state_rep(rid, sx) for sx in self.machine.states], ubound=self.upper_bound, top_id=self.vpool.top) as cnf_optim: - self.vpool.occupy(self.vpool.top + 1, cnf_optim.top_id) - self.vpool.top = cnf_optim.top_id - self.add_clauses(cnf_optim.cnf.clauses) - self.rhs.append(cnf_optim.rhs) - else: - self.lower_bound = int(math.floor(self.c * (self.N - 1) ** (1 / self.c))) - self.upper_bound = int(self.N + self.c - 1) - print(f'size constraints {self.lower_bound} {self.upper_bound}') - with ITotalizer([self.var_state_rep(rid, sx) for rid in self.rids for sx in self.machine.states], ubound=self.upper_bound, top_id=self.vpool.top) as cnf_optim: - self.add_clauses(cnf_optim.cnf.clauses) - self.rhs = cnf_optim.rhs + # Op dit punt is de encodering klaar, op de constraints voor de grootte na. + # Dit hangt af van de weak of strict decompositie. - def solve(self): - print('===============') - print('Start solving') - sat = None + def encode_weak_size_constraints(self): + self.rhs = [] + self.lower_bound = int(math.floor((self.N - 1) ** (1 / self.c))) + self.upper_bound = int(self.N) + print(f'weak size constraints {self.lower_bound} {self.upper_bound}') + # In de weak decompositie, minimaliseren we de grootte van elk component. + # Dus voor elk component voegen we een cardinality constraint toe. We + # gebruiken ITotalizer, omdat deze incrementeel is. + for rid in self.rids: + with ITotalizer([self.var_state_rep(rid, sx) for sx in self.machine.states], ubound=self.upper_bound, top_id=self.vpool.top) as cnf_optim: + self.vpool.occupy(self.vpool.top + 1, cnf_optim.top_id) + self.vpool.top = cnf_optim.top_id + self.add_clauses(cnf_optim.cnf.clauses) + self.rhs.append(cnf_optim.rhs) + + def encode_strict_size_constraints(self): + self.lower_bound = int(math.floor(self.c * (self.N - 1) ** (1 / self.c))) + self.upper_bound = int(self.N + self.c - 1) + print(f'strict size constraints {self.lower_bound} {self.upper_bound}') + # In de sterke decompositie, minimaliseren we de som van de componenten. + # Dit komt neer op het feit dat we k representanten kiezen in de lijst + # rids * states. We gebruiken ITotalizer, omdat deze incrementeel is. + with ITotalizer([self.var_state_rep(rid, sx) for rid in self.rids for sx in self.machine.states], ubound=self.upper_bound, top_id=self.vpool.top) as cnf_optim: + self.add_clauses(cnf_optim.cnf.clauses) + self.rhs = cnf_optim.rhs + + def optimise_constraints(self): + if self.weak: + self.optimise_weak_constraints() + else: + self.optimise_strict_constraints() + + print(f'done searching, found bound(s) = {self.bound}') + + def optimise_weak_constraints(self): + bounds = {} + smallest = None + todo = self.rids.copy() + + while todo: + lower_bound = self.N + for rid in bounds: + lower_bound /= bounds[rid] + + lower_bound = int(math.floor((math.ceil(lower_bound) - 1) ** (1.0 / len(todo)))) + upper_bound = smallest if smallest else self.N + + while upper_bound - lower_bound >= 2: + mid_size = int((lower_bound + upper_bound) / 2) + print(f'W Trying {lower_bound} < {mid_size} < {upper_bound}', end='', flush=True) + assumptions = [-self.rhs[rid][b] for (rid, b) in bounds.items() if b < self.N] + assumptions += [-self.rhs[rid][mid_size] for rid in todo] + sat = self.solver.solve(assumptions=assumptions) + if sat: + print('\tdown') + upper_bound = mid_size + continue + else: + print('\tup') + lower_bound = mid_size + continue + + print(f'Found bound {upper_bound} for {todo[0]}') + bounds[todo.pop(0)] = upper_bound + smallest = upper_bound + + self.bound = bounds + assert len(bounds) == self.c + + def optimise_strict_constraints(self): while self.upper_bound - self.lower_bound >= 2: self.mid_size = int((self.lower_bound + self.upper_bound) / 2) - print(f'Trying {self.lower_bound} < {self.mid_size} < {self.upper_bound}', end='', flush=True) - if self.args.weak: - assumptions = [-self.rhs[rid][self.mid_size] for rid in self.rids] - sat = self.solver.solve(assumptions=assumptions) - else: - sat = self.solver.solve(assumptions=[-self.rhs[self.mid_size]]) + print(f'S Trying {self.lower_bound} < {self.mid_size} < {self.upper_bound}', end='', flush=True) + assumptions = [-self.rhs[self.mid_size]] + sat = self.solver.solve(assumptions=assumptions) if sat: print('\tdown') self.upper_bound = self.mid_size @@ -322,30 +395,25 @@ class Encoder: print('\tup') self.lower_bound = self.mid_size continue - self.bound = self.upper_bound - print(f'done searching, found bound = {self.bound}') - if self.args.weak: - assumptions = [-self.rhs[rid][self.bound] for rid in self.rids] + def solve(self): + print('===============') + print('Start solving') + self.optimise_constraints() + + if self.weak: + assumptions = [-self.rhs[rid][self.bound[rid]] for rid in self.rids if self.bound[rid] < self.N] sat = self.solver.solve(assumptions=assumptions) else: - sat = self.solver.solve(assumptions=[-self.rhs[self.bound]]) + assumptions = [-self.rhs[self.bound]] if self.bound < self.N * self.c else [] + sat = self.solver.solve(assumptions=assumptions) assert sat # Even omzetten in een makkelijkere data structuur m = self.solver.get_model() model = {abs(l): l > 0 for l in m} - if self.args.verbose: - for rid in self.rids: - print(f'Relation {rid}:') - print_eqrel(lambda x, y: model[self.var_rel(rid, x, y)], self.os) - - for rid in self.rids: - print(f'State relation {rid}:') - print_eqrel(lambda x, y: model[self.var_state_rel(rid, x, y)], self.machine.states) - # Precieze groottes van elk component tellen counts = [] for rid in self.rids: @@ -354,15 +422,13 @@ class Encoder: for s in self.machine.states: if model[self.var_state_rep(rid, s)]: count += 1 - if self.args.verbose: - print(f'comp {rid} -> representative state {s}') counts.append(count) print(f'Reduced sizes = {counts} = {sum(counts)}') if self.record_file: with open(self.record_file, 'a') as file: - last_two_comps = '/'.join(self.args.filename.split('/')[-2:]) - file.write(f'{last_two_comps}\t{self.N}\t{len(self.machine.inputs)}\t{len(self.machine.outputs)}\t{self.args.weak}\t{self.c}\t{sum(counts)}\t{sorted(counts, reverse=True)}\n') + last_two_comps = '/'.join(filename.split('/')[-2:]) + file.write(f'{last_two_comps}\t{self.N}\t{len(self.machine.inputs)}\t{len(self.machine.outputs)}\t{self.weak}\t{self.c}\t{sum(counts)}\t{sorted(counts, reverse=True)}\n') projections = {} for rid in self.rids: