1
Fork 0
mirror of https://git.cs.ou.nl/joshua.moerman/mealy-decompose.git synced 2025-04-30 02:07:44 +02:00

reorganised python script, slightly more efficient now

This commit is contained in:
Joshua Moerman 2024-06-11 09:59:13 +02:00
parent a067d158f4
commit 2bd081ba37

View file

@ -11,22 +11,44 @@ import argparse
# Stap 1: pip3 install python-sat # Stap 1: pip3 install python-sat
# Stap 2: python3 decompose_fsm.py -h # Stap 2: python3 decompose_fsm.py -h
timeout = False keep_log = True
timeout_seconds = 3*60
record_sizes = False
record_file = './results/log.txt'
parser = argparse.ArgumentParser(description="Decomposes a FSM into smaller components by remapping its outputs. Uses a SAT solver.") def main():
parser.add_argument('-c', '--components', type=int, default=2, help='number of components') record_file = './results/log.txt' if keep_log else None
parser.add_argument('-w', '--weak', default=False, action="store_true", help='look for weak decomposition')
parser.add_argument('--add-state-trans', default=False, action="store_true", help='adds state transitivity constraints')
parser.add_argument('-v', '--verbose', default=False, action="store_true", help='prints more info')
parser.add_argument('filename', help='path to .dot file')
args = parser.parse_args()
# Aantal componenten. c = 1 is zinloos, maar zou moeten werken parser = argparse.ArgumentParser(description="Decomposes a FSM into smaller components by remapping its outputs. Uses a SAT solver.")
c = args.components parser.add_argument('-c', '--components', type=int, default=2, help='number of components')
assert c >= 1 parser.add_argument('-w', '--weak', default=False, action="store_true", help='look for weak decomposition')
parser.add_argument('--add-state-trans', default=False, action="store_true", help='adds state transitivity constraints')
parser.add_argument('-v', '--verbose', default=False, action="store_true", help='prints more info')
parser.add_argument('-t', '--timeout', type=int, default=None, help='timeout (in seconds)')
parser.add_argument('filename', help='path to .dot file')
args = parser.parse_args()
# Aantal componenten. c = 1 is zinloos, maar zou moeten werken
c = args.components
assert c >= 1
with open(args.filename) as file:
machine = parse_dot_file(file)
if args.verbose:
print(machine)
print(f'Input FSM: {len(machine.states)} states, {len(machine.inputs)} inputs, and {len(machine.outputs)} outputs')
if args.timeout != None:
def timeout_handler(signum, frame):
with open(record_file, 'a') as file:
last_two_comps = '/'.join(args.filename.split('/')[-2:])
file.write(f'{last_two_comps}\t{len(machine.states)}\t{len(machine.inputs)}\t{len(machine.outputs)}\t{args.weak}\t{c}\tTIMEOUT\n')
print('TIMEOUT')
exit()
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(args.timeout)
encoder = Encoder(machine, args, record_file=record_file)
encoder.solve()
################################### ###################################
@ -86,28 +108,9 @@ def parse_dot_file(lines):
return FSM(initial_state, states, inputs, outputs, transition_map, output_map) return FSM(initial_state, states, inputs, outputs, transition_map, output_map)
with open(args.filename) as file:
machine = parse_dot_file(file)
if args.verbose:
print(machine)
N = len(machine.states)
print(f'Input FSM: {N} states, {len(machine.inputs)} inputs, and {len(machine.outputs)} outputs')
if timeout:
def timeout_handler(signum, frame):
with open(record_file, 'a') as file:
last_two_comps = '/'.join(args.filename.split('/')[-2:])
file.write(f'{last_two_comps}\t{N}\t{len(machine.inputs)}\t{len(machine.outputs)}\t{args.weak}\t{c}\tTIMEOUT\n')
print('TIMEOUT')
exit()
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(timeout_seconds)
################################### ###################################
# Utility functions # Utility functions
def print_table(cell, rs, cs): def print_table(cell, rs, cs):
first_col_size = max([len(str(r)) for r in rs]) first_col_size = max([len(str(r)) for r in rs])
col_size = 1 + max([len(str(c)) for c in cs] + [len(cell(r, c)) for c in cs for r in rs]) col_size = 1 + max([len(str(c)) for c in cs] + [len(cell(r, c)) for c in cs for r in rs])
@ -126,6 +129,7 @@ def print_table(cell, rs, cs):
def print_eqrel(rel, xs): def print_eqrel(rel, xs):
print_table(lambda r, c: 'Y' if rel(r, c) else '·', xs, xs) print_table(lambda r, c: 'Y' if rel(r, c) else '·', xs, xs)
class Progress: class Progress:
def __init__(self, name, guess): def __init__(self, name, guess):
self.reset(name, guess, show=False) self.reset(name, guess, show=False)
@ -148,206 +152,221 @@ class Progress:
print(f'{self.percentage}%', end='', flush=True) print(f'{self.percentage}%', end='', flush=True)
print('\r', end='') print('\r', end='')
progress = Progress('', 1)
######################## ###################################
# Encodering naar logica # Main logic
print('Start encoding') class Encoder:
os = list(machine.outputs) # outputs def __init__(self, machine, args, record_file=None, progress=None):
rids = [i for i in range(c)] # components self.machine = machine
vpool = IDPool() self.args = args
cnf = CNF() self.record_file = record_file
self.progress = progress if progress else Progress('', 1)
# Een hulp variabele voor False en True, maakt de andere variabelen eenvoudiger self.c = self.args.components
def var_const(b): self.N = len(self.machine.states)
return(vpool.id(('const', b))) self.os = list(machine.outputs) # outputs
self.rids = [i for i in range(self.c)] # components
self.vpool = IDPool()
self.solver = Solver()
cnf.append([var_const(True)]) self.encode()
cnf.append([-var_const(False)])
# Voor elke relatie en elke twee elementen o1 en o2, is er een variabele die def add_clause(self, cls, no_return=True):
# aangeeft of o1 en o2 gerelateerd zijn. Er is 1 variabele voor xRy en yRx, dus self.solver.add_clause(cls, no_return)
# symmetrie is al ingebouwd. Reflexiviteit is ook ingebouwd.
def var_rel(rid, o1, o2): def add_clauses(self, clss, no_return=True):
self.solver.append_formula(clss, no_return)
# Een hulp variabele voor False en True, maakt de andere variabelen eenvoudiger
def var_const(self, b):
return(self.vpool.id(('const', b)))
# Voor elke relatie en elke twee elementen o1 en o2, is er een variabele die
# aangeeft of o1 en o2 gerelateerd zijn. Er is 1 variabele voor xRy en yRx, dus
# symmetrie is al ingebouwd. Reflexiviteit is ook ingebouwd.
def var_rel(self, rid, o1, o2):
if o1 == o2: if o1 == o2:
return var_const(True) return self.var_const(True)
[so1, so2] = sorted([o1, o2]) [so1, so2] = sorted([o1, o2])
return(vpool.id(('rel', rid, so1, so2))) return self.vpool.id(('rel', rid, so1, so2))
# De relatie op outputs geeft een relaties op states. Deze relatie moet ook een # De relatie op outputs geeft een relaties op states. Deze relatie moet ook een
# bisimulatie zijn. # bisimulatie zijn.
def var_state_rel(rid, s1, s2): def var_state_rel(self, rid, s1, s2):
if s1 == s2: if s1 == s2:
return var_const(True) return self.var_const(True)
[ss1, ss2] = sorted([s1, s2]) [ss1, ss2] = sorted([s1, s2])
return(vpool.id(('state_rel', rid, ss1, ss2))) return self.vpool.id(('state_rel', rid, ss1, ss2))
# Voor elke relatie, en elke equivalentie-klasse, kiezen we precies 1 state # Voor elke relatie, en elke equivalentie-klasse, kiezen we precies 1 state
# als representant. Deze variabele geeft aan welk element. # als representant. Deze variabele geeft aan welk element.
def var_state_rep(rid, s): def var_state_rep(self, rid, s):
return(vpool.id(('state_rep', rid, s))) return self.vpool.id(('state_rep', rid, s))
# Contraints zodat de relatie een equivalentie relatie is. We hoeven alleen def encode(self):
# maar transitiviteit te encoderen, want refl en symm zijn ingebouwd in de var. print('===============')
progress.reset('transitivity (o)', guess=len(rids) * len(os) ** 3) print('Start encoding')
for rid in rids: self.add_clause([self.var_const(True)])
for xo in os: self.add_clause([-self.var_const(False)])
for yo in os:
for zo in os: # Contraints zodat de relatie een equivalentie relatie is. We hoeven alleen
# maar transitiviteit te encoderen, want refl en symm zijn ingebouwd in de var.
self.progress.reset('transitivity (o)', guess=len(self.rids) * len(self.os) ** 3)
for rid in self.rids:
for xo in self.os:
for yo in self.os:
for zo in self.os:
# als xo R yo en yo R zo dan xo R zo # als xo R yo en yo R zo dan xo R zo
cnf.append([-var_rel(rid, xo, yo), -var_rel(rid, yo, zo), var_rel(rid, xo, zo)]) self.add_clause([-self.var_rel(rid, xo, yo), -self.var_rel(rid, yo, zo), self.var_rel(rid, xo, zo)])
progress.add() self.progress.add()
if args.add_state_trans: if self.args.add_state_trans:
progress.reset('transitivity (s)', guess=len(rids) * len(machine.states) ** 3) self.progress.reset('transitivity (s)', guess=len(self.rids) * len(self.machine.states) ** 3)
for rid in rids: for rid in self.rids:
for sx in machine.states: for sx in self.machine.states:
for sy in machine.states: for sy in self.machine.states:
for sz in machine.states: for sz in self.machine.states:
# als sx R sy en sy R sz dan sx R sz # als sx R sy en sy R sz dan sx R sz
cnf.append([-var_state_rel(rid, sx, sy), -var_state_rel(rid, sy, sz), var_state_rel(rid, sx, sz)]) self.add_clause([-self.var_state_rel(rid, sx, sy), -self.var_state_rel(rid, sy, sz), self.var_state_rel(rid, sx, sz)])
progress.add() self.progress.add()
# Constraint zodat de relaties samen alle elementen kunnen onderscheiden. # Constraint zodat de relaties samen alle elementen kunnen onderscheiden.
# (Aka: the bijbehorende quotienten zijn joint-injective.) # (Aka: the bijbehorende quotienten zijn joint-injective.)
progress.reset('injectivity', guess=len(os) * (len(os) - 1) / 2) self.progress.reset('injectivity', guess=len(self.os) * (len(self.os) - 1) / 2)
for xi, xo in enumerate(os): for xi, xo in enumerate(self.os):
for yo in os[xi+1:]: for yo in self.os[xi+1:]:
# Tenminste een rid moet een verschil maken # Tenminste een rid moet een verschil maken
cnf.append([-var_rel(rid, xo, yo) for rid in rids]) self.add_clause([-self.var_rel(rid, xo, yo) for rid in self.rids])
progress.add() self.progress.add()
# sx ~ sy => for each input: (1) outputs equivalent AND (2) successors related # sx ~ sy => for each input: (1) outputs equivalent AND (2) successors related
# Momenteel hebben we niet de inverse implicatie, is misschien ook niet nodig? # Momenteel hebben we niet de inverse implicatie, is misschien ook niet nodig?
progress.reset('bisimulation modulo rel', guess=len(rids) * len(machine.states) * len(machine.states) * len(machine.inputs)) self.progress.reset('bisimulation modulo rel', guess=len(self.rids) * len(self.machine.states) * len(self.machine.states) * len(self.machine.inputs))
for rid in rids: for rid in self.rids:
for sx in machine.states: for sx in self.machine.states:
for sy in machine.states: for sy in self.machine.states:
for i in machine.inputs: for i in self.machine.inputs:
# sx ~ sy => output(sx, i) ~ output(sy, i) # sx ~ sy => output(sx, i) ~ output(sy, i)
ox = machine.output(sx, i) ox = self.machine.output(sx, i)
oy = machine.output(sy, i) oy = self.machine.output(sy, i)
cnf.append([-var_state_rel(rid, sx, sy), var_rel(rid, ox, oy)]) self.add_clause([-self.var_state_rel(rid, sx, sy), self.var_rel(rid, ox, oy)])
# sx ~ sy => delta(sx, i) ~ delta(sy, i) # sx ~ sy => delta(sx, i) ~ delta(sy, i)
tx = machine.transition(sx, i) tx = self.machine.transition(sx, i)
ty = machine.transition(sy, i) ty = self.machine.transition(sy, i)
cnf.append([-var_state_rel(rid, sx, sy), var_state_rel(rid, tx, ty)]) self.add_clause([-self.var_state_rel(rid, sx, sy), self.var_state_rel(rid, tx, ty)])
progress.add() self.progress.add()
# De constraints die zorgen dat representanten ook echt representanten zijn. # De constraints die zorgen dat representanten ook echt representanten zijn.
states = list(machine.states) states = list(self.machine.states)
progress.reset('representatives', guess=len(rids) * len(states)) self.progress.reset('representatives', guess=len(self.rids) * len(states))
for rid in rids: for rid in self.rids:
for ix, sx in enumerate(states): for ix, sx in enumerate(states):
# Belangrijkste: een element is een representant, of equivalent met een # Belangrijkste: een element is een representant, of equivalent met een
# eerder element. We forceren hiermee dat de solver representanten moet # eerder element. We forceren hiermee dat de solver representanten moet
# kiezen (voor aan de lijst). # kiezen (voor aan de lijst).
cnf.append([var_state_rep(rid, sx)] + [var_state_rel(rid, sx, sy) for sy in states[:ix]] ) self.add_clause([self.var_state_rep(rid, sx)] + [self.var_state_rel(rid, sx, sy) for sy in states[:ix]] )
for sy in states[:ix]: for sy in states[:ix]:
# rx en ry kunnen niet beide een representant zijn, tenzij ze # rx en ry kunnen niet beide een representant zijn, tenzij ze
# niet gerelateerd zijn. # niet gerelateerd zijn.
cnf.append([-var_state_rep(rid, sx), -var_state_rep(rid, sy), -var_state_rel(rid, sx, sy)]) self.add_clause([-self.var_state_rep(rid, sx), -self.var_state_rep(rid, sy), -self.var_state_rel(rid, sx, sy)])
progress.add() self.progress.add()
# Tot slot willen we weinig representanten. Dit doen we met een "atmost" # Tot slot willen we weinig representanten. Dit doen we met een "atmost"
# formule. We gaan een binaire zoek doen met incremental sat solving. # formule. We gaan een binaire zoek doen met incremental sat solving.
rhs = None self.rhs = None
if args.weak: if self.args.weak:
lower_bound = int(math.floor((N-1)**(1/c))) self.lower_bound = int(math.floor((self.N-1)**(1/self.c)))
upper_bound = int(N) self.upper_bound = int(self.N)
print(f'weak size constraints {lower_bound} {upper_bound}') print(f'weak size constraints {self.lower_bound} {self.upper_bound}')
rhs = [] self.rhs = []
for rid in rids: for rid in self.rids:
with ITotalizer([var_state_rep(rid, sx) for sx in machine.states], ubound=upper_bound, top_id=vpool.top) as cnf_optim: with ITotalizer([self.var_state_rep(rid, sx) for sx in self.machine.states], ubound=self.upper_bound, top_id=self.vpool.top) as cnf_optim:
vpool.occupy(vpool.top + 1, cnf_optim.top_id) self.vpool.occupy(self.vpool.top + 1, cnf_optim.top_id)
vpool.top = cnf_optim.top_id self.vpool.top = cnf_optim.top_id
cnf.extend(cnf_optim.cnf.clauses) self.add_clauses(cnf_optim.cnf.clauses)
rhs.append(cnf_optim.rhs) self.rhs.append(cnf_optim.rhs)
else:
lower_bound = int(math.floor(c * (N-1)**(1/c)))
upper_bound = int(N + c - 1)
print(f'size constraints {lower_bound} {upper_bound}')
with ITotalizer([var_state_rep(rid, sx) for rid in rids for sx in machine.states], ubound=upper_bound, top_id=vpool.top) as cnf_optim:
cnf.extend(cnf_optim.cnf.clauses)
rhs = cnf_optim.rhs
##################################
# Probleem oplossen met solver :-)
print('Start solving')
print('- copying formula')
with Solver(bootstrap_with=cnf) as solver:
print('===============')
sat = None
while upper_bound - lower_bound >= 2:
mid_size = int((lower_bound + upper_bound) / 2)
print(f'Trying {lower_bound} < {mid_size} < {upper_bound}', end='', flush=True)
if args.weak:
assumptions = [-rhs[rid][mid_size] for rid in rids]
sat = solver.solve(assumptions=assumptions)
else: else:
sat = solver.solve(assumptions=[-rhs[mid_size]]) self.lower_bound = int(math.floor(self.c * (self.N-1)**(1/self.c)))
self.upper_bound = int(self.N + self.c - 1)
print(f'size constraints {self.lower_bound} {self.upper_bound}')
with ITotalizer([self.var_state_rep(rid, sx) for rid in self.rids for sx in self.machine.states], ubound=self.upper_bound, top_id=self.vpool.top) as cnf_optim:
self.add_clauses(cnf_optim.cnf.clauses)
self.rhs = cnf_optim.rhs
def solve(self):
print('===============')
print('Start solving')
sat = None
while self.upper_bound - self.lower_bound >= 2:
self.mid_size = int((self.lower_bound + self.upper_bound) / 2)
print(f'Trying {self.lower_bound} < {self.mid_size} < {self.upper_bound}', end='', flush=True)
if self.args.weak:
assumptions = [-self.rhs[rid][self.mid_size] for rid in self.rids]
sat = self.solver.solve(assumptions=assumptions)
else:
sat = self.solver.solve(assumptions=[-self.rhs[self.mid_size]])
if sat: if sat:
print('\tdown') print('\tdown')
upper_bound = mid_size self.upper_bound = self.mid_size
continue continue
else: else:
print('\tup') print('\tup')
lower_bound = mid_size self.lower_bound = self.mid_size
continue continue
bound = upper_bound self.bound = self.upper_bound
print(f'done searching, found bound = {bound}') print(f'done searching, found bound = {self.bound}')
if args.weak: if self.args.weak:
assumptions = [-rhs[rid][bound] for rid in rids] assumptions = [-self.rhs[rid][self.bound] for rid in self.rids]
sat = solver.solve(assumptions=assumptions) sat = self.solver.solve(assumptions=assumptions)
else: else:
sat = solver.solve(assumptions=[-rhs[bound]]) sat = self.solver.solve(assumptions=[-self.rhs[self.bound]])
assert sat assert sat
# Even omzetten in een makkelijkere data structuur # Even omzetten in een makkelijkere data structuur
m = solver.get_model() m = self.solver.get_model()
model = {} model = {}
for l in m: for l in m:
if l < 0: model[-l] = False if l < 0: model[-l] = False
else: model[l] = True else: model[l] = True
if args.verbose: if self.args.verbose:
for rid in rids: for rid in self.rids:
print(f'Relation {rid}:') print(f'Relation {rid}:')
print_eqrel(lambda x, y: model[var_rel(rid, x, y)], os) print_eqrel(lambda x, y: model[self.var_rel(rid, x, y)], self.os)
for rid in rids: for rid in self.rids:
print(f'State relation {rid}:') print(f'State relation {rid}:')
print_eqrel(lambda x, y: model[var_state_rel(rid, x, y)], machine.states) print_eqrel(lambda x, y: model[self.var_state_rel(rid, x, y)], self.machine.states)
# Precieze groottes van elk component tellen # Precieze groottes van elk component tellen
counts = [] counts = []
for rid in rids: for rid in self.rids:
count = 0 count = 0
# Eerst verzamelen we de representanten # Eerst verzamelen we de representanten
for s in machine.states: for s in self.machine.states:
if model[var_state_rep(rid, s)]: if model[self.var_state_rep(rid, s)]:
count += 1 count += 1
if args.verbose: if self.args.verbose:
print(f'comp {rid} -> representative state {s}') print(f'comp {rid} -> representative state {s}')
counts.append(count) counts.append(count)
print(f'Reduced sizes = {counts} = {sum(counts)}') print(f'Reduced sizes = {counts} = {sum(counts)}')
if record_sizes: if self.record_file != None:
with open(record_file, 'a') as file: with open(self.record_file, 'a') as file:
last_two_comps = '/'.join(args.filename.split('/')[-2:]) last_two_comps = '/'.join(self.args.filename.split('/')[-2:])
file.write(f'{last_two_comps}\t{N}\t{len(machine.inputs)}\t{len(machine.outputs)}\t{args.weak}\t{c}\t{sum(counts)}\t{sorted(counts, reverse=True)}\n') file.write(f'{last_two_comps}\t{self.N}\t{len(self.machine.inputs)}\t{len(self.machine.outputs)}\t{self.args.weak}\t{self.c}\t{sum(counts)}\t{sorted(counts, reverse=True)}\n')
projections = {} projections = {}
for rid in rids: for rid in self.rids:
local_outputs = machine.outputs.copy() local_outputs = self.machine.outputs.copy()
projections[rid] = {} projections[rid] = {}
count = 0 count = 0
@ -360,7 +379,7 @@ with Solver(bootstrap_with=cnf) as solver:
others = False others = False
for o in local_outputs: for o in local_outputs:
if model[var_rel(rid, o, repr)]: if model[self.var_rel(rid, o, repr)]:
others = True others = True
projections[rid][o] = f'cls_{rid}_{count}' projections[rid][o] = f'cls_{rid}_{count}'
@ -372,4 +391,10 @@ with Solver(bootstrap_with=cnf) as solver:
print('===============') print('===============')
print('Output mapping:') print('Output mapping:')
print_table(lambda o, rid: projections[rid][o], machine.outputs, rids) print_table(lambda o, rid: projections[rid][o], self.machine.outputs, self.rids)
###################################
# Run script
if __name__ == "__main__":
main()