1
Fork 0
mirror of https://git.cs.ou.nl/joshua.moerman/mealy-decompose.git synced 2025-04-29 17:57:44 +02:00
mealy-decompose/other/decompose_fsm_optimise.py
2024-06-26 09:13:56 +02:00

462 lines
16 KiB
Python

from pysat.solvers import Solver
from pysat.card import ITotalizer
from pysat.formula import IDPool
import signal
import math
import argparse
### Gebruik:
# Stap 1: pip3 install python-sat
# Stap 2: python3 decompose_fsm.py -h
keep_log = True
record_file = './results/log.txt' if keep_log else None
filename = None
def main():
parser = argparse.ArgumentParser(description='Decomposes a FSM into smaller components by remapping its outputs. Uses a SAT solver.')
parser.add_argument('-c', '--components', type=int, default=2, help='number of components')
parser.add_argument('-w', '--weak', default=False, action='store_true', help='look for weak decomposition')
parser.add_argument('--add-state-trans', default=False, action='store_true', help='adds state transitivity constraints')
parser.add_argument('-t', '--timeout', type=int, default=None, help='timeout (in seconds)')
parser.add_argument('filename', help='path to .dot file')
args = parser.parse_args()
global filename
filename = args.filename
# Aantal componenten. c = 1 is zinloos, maar zou moeten werken
c = args.components
assert c >= 1
with open(args.filename) as file:
machine = parse_dot_file(file)
# Als er maar 1 state is, valt er niks te ontbinden, maar het script zou niet
# moeten crashen. Het aantal outputs moet minstens 3 zijn voor een zinvolle
# decompositie.
assert len(machine.states) >= 1
assert len(machine.inputs) >= 1
assert len(machine.outputs) >= 1
print(f'Input FSM: {len(machine.states)} states, {len(machine.inputs)} inputs, and {len(machine.outputs)} outputs')
if args.timeout:
def timeout_handler(*_):
with open(record_file, 'a') as file:
last_two_comps = '/'.join(args.filename.split('/')[-2:])
file.write(f'{last_two_comps}\t{len(machine.states)}\t{len(machine.inputs)}\t{len(machine.outputs)}\t{args.weak}\t{c}\tTIMEOUT\n')
print('TIMEOUT')
exit()
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(args.timeout)
with Encoder(machine, c, args.weak, add_state_trans=args.add_state_trans, record_file=record_file) as encoder:
encoder.solve()
###################################
# .dot file parser (heuristic)
class FSM:
def __init__(self, initial_state, states, inputs, outputs, transition_map, output_map):
self.initial_state = initial_state
self.states = states
self.inputs = inputs
self.outputs = outputs
self.transition_map = transition_map
self.output_map = output_map
def __str__(self):
return f'FSM({self.initial_state}, {self.states}, {self.inputs}, {self.outputs}, {self.transition_map}, {self.output_map})'
def transition(self, s, a):
return self.transition_map[(s, a)]
def output(self, s, a):
return self.output_map[(s, a)]
def parse_dot_file(lines):
def parse_transition(line):
(l, _, r) = line.partition('->')
s = l.strip()
(l, _, r) = r.partition('[label="')
t = l.strip()
(l, _, _) = r.partition('"]')
(i, _, o) = l.partition('/')
return (s, i, o, t)
initial_state = None
states, inputs, outputs = set(), set(), set()
transition_map, output_map = {}, {}
for line in lines:
(s, i, o, t) = parse_transition(line)
if s and i and o and t:
states.add(s)
inputs.add(i)
outputs.add(o)
states.add(t)
transition_map[(s, i)] = t
output_map[(s, i)] = o
if not initial_state:
initial_state = s
assert initial_state in states
assert len(transition_map) == len(states) * len(inputs)
assert len(output_map) == len(states) * len(inputs)
return FSM(initial_state, states, inputs, outputs, transition_map, output_map)
###################################
# Utility functions
def print_table(cell, rs, cs):
first_col_size = max([len(str(r)) for r in rs])
col_size = 1 + max([len(str(c)) for c in cs] + [len(cell(r, c)) for c in cs for r in rs])
print(''.rjust(first_col_size), end='')
for c in cs:
print(str(c).rjust(col_size), end='')
print('')
for r in rs:
print(str(r).rjust(first_col_size), end='')
for c in cs:
print(cell(r, c).rjust(col_size), end='')
print('')
class Progress:
def __init__(self, name: str, guess: int):
self.reset(name, guess, show=False)
def reset(self, name: str, guess: int, show: bool = True):
self.name = name
self.guess = math.ceil(guess)
self.count = 0
self.percentage = None
if show:
print(name)
def add(self, n: int = 1):
self.count += n
percentage = math.floor(100 * self.count / self.guess)
if percentage != self.percentage:
self.percentage = percentage
print(f'{self.percentage}%', end='', flush=True)
print('\r', end='')
###################################
# Main logic
class Encoder:
def __init__(self, machine: FSM, components: int = 2, weak: bool = False, add_state_trans: bool = False, record_file: str = None, progress: Progress = None):
self.machine = machine
self.c = components
self.weak = weak
self.record_file = record_file
self.progress = progress if progress else Progress('', 1)
# optionally add state transitivity constraints. This is not necessary for
# the decomposition and it is cubic in the number of states, so it's off
# by default.
self.add_state_trans = add_state_trans
self.N = len(self.machine.states)
self.rids = [i for i in range(self.c)]
self.vpool = IDPool()
self.solver = Solver()
def __enter__(self):
self.solver = Solver()
self.solver.__enter__()
self.encode()
if self.weak:
self.encode_weak_size_constraints()
else:
self.encode_strict_size_constraints()
assert hasattr(self, 'rhs')
return self
def __exit__(self, *args):
return self.solver.__exit__(*args)
def add_clause(self, cls, no_return=True):
self.solver.add_clause(cls, no_return)
def add_clauses(self, clss, no_return=True):
self.solver.append_formula(clss, no_return)
# Een hulp variabele voor False en True, maakt de andere variabelen eenvoudiger
def var_const(self, b) -> int:
return self.vpool.id(('const', b))
# Een variabele die aangeeft of x en y gerelateerd zijn. Deze variabele is
# symmetrisch en reflexief. De id is een object die de relatie identificeert.
# Zo kunnen we meerdere relaties encoderen.
def var_rel_abs(self, id, x, y) -> int:
if x == y:
return self.var_const(True)
[sx, sy] = sorted([x, y])
return self.vpool.id(('r', id, sx, sy))
# Een relatie op de output-elementen.
def var_rel(self, rid, o1, o2) -> int:
return self.var_rel_abs(('output', rid), o1, o2)
# De relatie op outputs geeft een relaties op states. Deze relatie moet ook een
# bisimulatie zijn.
def var_state_rel(self, rid, s1, s2) -> int:
return self.var_rel_abs(('state', rid), s1, s2)
# Voor elke relatie, en elke equivalentie-klasse, kiezen we precies 1 state
# als representant. Deze variabele geeft aan welk element.
def var_state_rep(self, rid, s) -> int:
return self.vpool.id(('state_rep', rid, s))
def encode(self):
# lokale variabelen om het wat leesbaarder te maken
rids, os, states, inputs = self.rids, list(self.machine.outputs), list(self.machine.states), list(self.machine.inputs)
print('===============')
print('Start encoding')
self.add_clause([self.var_const(True)])
self.add_clause([-self.var_const(False)])
# Contraints zodat de relatie een equivalentie relatie is. We hoeven alleen
# maar transitiviteit te encoderen, want refl en symm zijn ingebouwd in de var.
self.progress.reset('transitivity (o)', guess=len(rids) * len(os) ** 3)
for rid in rids:
for xo in os:
for yo in os:
for zo in os:
# als xo R yo en yo R zo dan xo R zo
self.add_clause([-self.var_rel(rid, xo, yo), -self.var_rel(rid, yo, zo), self.var_rel(rid, xo, zo)])
self.progress.add()
if self.add_state_trans:
self.progress.reset('transitivity (s)', guess=len(rids) * len(states) ** 3)
for rid in rids:
for sx in states:
for sy in states:
for sz in states:
# als sx R sy en sy R sz dan sx R sz
self.add_clause([-self.var_state_rel(rid, sx, sy), -self.var_state_rel(rid, sy, sz), self.var_state_rel(rid, sx, sz)])
self.progress.add()
# Constraint zodat de relaties samen alle elementen kunnen onderscheiden.
# (Aka: the bijbehorende quotienten zijn joint-injective.)
self.progress.reset('injectivity', guess=len(os) * (len(os) - 1) / 2)
for xi, xo in enumerate(os):
for yo in os[xi + 1 :]:
# Tenminste een rid moet een verschil maken
self.add_clause([-self.var_rel(rid, xo, yo) for rid in self.rids])
self.progress.add()
# sx ~ sy => for each input: (1) outputs equivalent AND (2) successors related
# Momenteel hebben we niet de inverse implicatie, is misschien ook niet nodig?
self.progress.reset('bisimulation modulo rel', guess=len(rids) * len(states) * len(states) * len(inputs))
for rid in rids:
for sx in states:
for sy in states:
for i in inputs:
# sx ~ sy => output(sx, i) ~ output(sy, i)
ox = self.machine.output(sx, i)
oy = self.machine.output(sy, i)
self.add_clause([-self.var_state_rel(rid, sx, sy), self.var_rel(rid, ox, oy)])
# sx ~ sy => delta(sx, i) ~ delta(sy, i)
tx = self.machine.transition(sx, i)
ty = self.machine.transition(sy, i)
self.add_clause([-self.var_state_rel(rid, sx, sy), self.var_state_rel(rid, tx, ty)])
self.progress.add()
# De constraints die zorgen dat representanten ook echt representanten zijn.
self.progress.reset('representatives', guess=len(rids) * len(states))
for rid in rids:
for ix, sx in enumerate(states):
# Belangrijkste: een element is een representant, of equivalent met een
# eerder element. We forceren hiermee dat de solver representanten moet
# kiezen (voor aan de lijst).
self.add_clause([self.var_state_rep(rid, sx)] + [self.var_state_rel(rid, sx, sy) for sy in states[:ix]])
for sy in states[:ix]:
# rx en ry kunnen niet beide een representant zijn, tenzij ze
# niet gerelateerd zijn.
self.add_clause([-self.var_state_rep(rid, sx), -self.var_state_rep(rid, sy), -self.var_state_rel(rid, sx, sy)])
self.progress.add()
# Op dit punt is de encodering klaar, op de constraints voor de grootte na.
# Dit hangt af van de weak of strict decompositie.
def encode_weak_size_constraints(self):
self.rhs = []
self.lower_bound = int(math.floor((self.N - 1) ** (1 / self.c)))
self.upper_bound = int(self.N)
print(f'weak size constraints {self.lower_bound} {self.upper_bound}')
# In de weak decompositie, minimaliseren we de grootte van elk component.
# Dus voor elk component voegen we een cardinality constraint toe. We
# gebruiken ITotalizer, omdat deze incrementeel is.
for rid in self.rids:
with ITotalizer([self.var_state_rep(rid, sx) for sx in self.machine.states], ubound=self.upper_bound, top_id=self.vpool.top) as cnf_optim:
self.vpool.occupy(self.vpool.top + 1, cnf_optim.top_id)
self.vpool.top = cnf_optim.top_id
self.add_clauses(cnf_optim.cnf.clauses)
self.rhs.append(cnf_optim.rhs)
def encode_strict_size_constraints(self):
self.lower_bound = int(math.floor(self.c * (self.N - 1) ** (1 / self.c)))
self.upper_bound = int(self.N + self.c - 1)
print(f'strict size constraints {self.lower_bound} {self.upper_bound}')
# In de sterke decompositie, minimaliseren we de som van de componenten.
# Dit komt neer op het feit dat we k representanten kiezen in de lijst
# rids * states. We gebruiken ITotalizer, omdat deze incrementeel is.
with ITotalizer([self.var_state_rep(rid, sx) for rid in self.rids for sx in self.machine.states], ubound=self.upper_bound, top_id=self.vpool.top) as cnf_optim:
self.add_clauses(cnf_optim.cnf.clauses)
self.rhs = cnf_optim.rhs
def optimise_constraints(self):
if self.weak:
self.optimise_weak_constraints()
else:
self.optimise_strict_constraints()
print(f'done searching, found bound(s) = {self.bound}')
def optimise_weak_constraints(self):
bounds = {}
smallest = None
todo = self.rids.copy()
while todo:
lower_bound = self.N
for rid in bounds:
lower_bound /= bounds[rid]
lower_bound = int(math.floor((math.ceil(lower_bound) - 1) ** (1.0 / len(todo))))
upper_bound = smallest if smallest else self.N
while upper_bound - lower_bound >= 2:
mid_size = int((lower_bound + upper_bound) / 2)
print(f'W Trying {lower_bound} < {mid_size} < {upper_bound}', end='', flush=True)
assumptions = [-self.rhs[rid][b] for (rid, b) in bounds.items() if b < self.N]
assumptions += [-self.rhs[rid][mid_size] for rid in todo]
sat = self.solver.solve(assumptions=assumptions)
if sat:
print('\tdown')
upper_bound = mid_size
continue
else:
print('\tup')
lower_bound = mid_size
continue
print(f'Found bound {upper_bound} for {todo[0]}')
bounds[todo.pop(0)] = upper_bound
smallest = upper_bound
self.bound = bounds
assert len(bounds) == self.c
def optimise_strict_constraints(self):
while self.upper_bound - self.lower_bound >= 2:
self.mid_size = int((self.lower_bound + self.upper_bound) / 2)
print(f'S Trying {self.lower_bound} < {self.mid_size} < {self.upper_bound}', end='', flush=True)
assumptions = [-self.rhs[self.mid_size]]
sat = self.solver.solve(assumptions=assumptions)
if sat:
print('\tdown')
self.upper_bound = self.mid_size
continue
else:
print('\tup')
self.lower_bound = self.mid_size
continue
self.bound = self.upper_bound
def solve(self):
print('===============')
print('Start solving')
self.optimise_constraints()
if self.weak:
assumptions = [-self.rhs[rid][self.bound[rid]] for rid in self.rids if self.bound[rid] < self.N]
sat = self.solver.solve(assumptions=assumptions)
else:
assumptions = [-self.rhs[self.bound]] if self.bound < self.N * self.c else []
sat = self.solver.solve(assumptions=assumptions)
assert sat
# Even omzetten in een makkelijkere data structuur
m = self.solver.get_model()
model = {abs(l): l > 0 for l in m}
# Precieze groottes van elk component tellen
counts = []
for rid in self.rids:
count = 0
# Eerst verzamelen we de representanten
for s in self.machine.states:
if model[self.var_state_rep(rid, s)]:
count += 1
counts.append(count)
print(f'Reduced sizes = {counts} = {sum(counts)}')
if self.record_file:
with open(self.record_file, 'a') as file:
last_two_comps = '/'.join(filename.split('/')[-2:])
file.write(f'{last_two_comps}\t{self.N}\t{len(self.machine.inputs)}\t{len(self.machine.outputs)}\t{self.weak}\t{self.c}\t{sum(counts)}\t{sorted(counts, reverse=True)}\n')
projections = {}
for rid in self.rids:
local_outputs = self.machine.outputs.copy()
projections[rid] = {}
count = 0
while local_outputs:
repr = local_outputs.pop()
if repr in projections[rid]:
continue
projections[rid][repr] = f'cls_{rid}_{count}'
others = False
for o in local_outputs:
if model[self.var_rel(rid, o, repr)]:
others = True
projections[rid][o] = f'cls_{rid}_{count}'
if not others:
# Aangeven dat het een unieke output is
projections[rid][repr] = f'cls_{rid}_{count}_u'
count += 1
print('===============')
print('Output mapping:')
print_table(lambda o, rid: projections[rid][o], self.machine.outputs, self.rids)
###################################
# Run script
if __name__ == '__main__':
main()