mirror of
https://git.cs.ou.nl/joshua.moerman/mealy-decompose.git
synced 2025-04-29 17:57:44 +02:00
Incremental sat solving to find optimal solution (with fixed c)
This commit is contained in:
parent
7cca04c3df
commit
4c8b8d9f2f
1 changed files with 330 additions and 0 deletions
330
other/decompose_fsm_optimise.py
Normal file
330
other/decompose_fsm_optimise.py
Normal file
|
@ -0,0 +1,330 @@
|
|||
from pysat.solvers import Solver
|
||||
from pysat.card import ITotalizer
|
||||
from pysat.formula import IDPool
|
||||
from pysat.formula import CNF
|
||||
|
||||
import math
|
||||
import argparse
|
||||
|
||||
### Gebruik:
|
||||
# Stap 1: pip3 install python-sat
|
||||
# Stap 2: python3 decompose_fsm.py -h
|
||||
|
||||
parser = argparse.ArgumentParser(description="Decomposes a FSM into smaller components by remapping its outputs. Uses a SAT solver.")
|
||||
parser.add_argument('-c', '--components', type=int, default=2, help='number of components')
|
||||
parser.add_argument('--add-state-trans', default=False, action="store_true", help='adds state transitivity constraints')
|
||||
parser.add_argument('-v', '--verbose', default=False, action="store_true", help='prints more info')
|
||||
parser.add_argument('filename', help='path to .dot file')
|
||||
args = parser.parse_args()
|
||||
|
||||
# als de de total_size te laag is => UNSAT => duurt lang
|
||||
c = args.components
|
||||
|
||||
assert c >= 1 # c = 1 is zinloos, maar zou moeten werken
|
||||
|
||||
|
||||
###################################
|
||||
# .dot file parser (heuristic)
|
||||
class FSM:
|
||||
def __init__(self, initial_state, states, inputs, outputs, transition_map, output_map):
|
||||
self.initial_state = initial_state
|
||||
self.states = states
|
||||
self.inputs = inputs
|
||||
self.outputs = outputs
|
||||
self.transition_map = transition_map
|
||||
self.output_map = output_map
|
||||
|
||||
def __str__(self):
|
||||
return f'FSM({self.initial_state}, {self.states}, {self.inputs}, {self.outputs}, {self.transition_map}, {self.output_map})'
|
||||
|
||||
def transition(self, s, a):
|
||||
return self.transition_map[(s, a)]
|
||||
|
||||
def output(self, s, a):
|
||||
return self.output_map[(s, a)]
|
||||
|
||||
def parse_dot_file(lines):
|
||||
def parse_transition(line):
|
||||
(l, _, r) = line.partition('->')
|
||||
s = l.strip()
|
||||
|
||||
(l, _, r) = r.partition('[label="')
|
||||
t = l.strip()
|
||||
|
||||
(l, _, _) = r.partition('"]')
|
||||
(i, _, o) = l.partition('/')
|
||||
|
||||
return (s, i, o, t)
|
||||
|
||||
initial_state = None
|
||||
states, inputs, outputs = set(), set(), set()
|
||||
transition_map, output_map = {}, {}
|
||||
|
||||
for line in lines:
|
||||
(s, i, o, t) = parse_transition(line)
|
||||
if s and i and o and t:
|
||||
states.add(s)
|
||||
inputs.add(i)
|
||||
outputs.add(o)
|
||||
states.add(t)
|
||||
|
||||
transition_map[(s, i)] = t
|
||||
output_map[(s, i)] = o
|
||||
|
||||
if not initial_state:
|
||||
initial_state = s
|
||||
|
||||
assert initial_state in states
|
||||
assert len(transition_map) == len(states) * len(inputs)
|
||||
assert len(output_map) == len(states) * len(inputs)
|
||||
|
||||
return FSM(initial_state, states, inputs, outputs, transition_map, output_map)
|
||||
|
||||
with open(args.filename) as file:
|
||||
machine = parse_dot_file(file)
|
||||
if args.verbose:
|
||||
print(machine)
|
||||
|
||||
N = len(machine.states)
|
||||
print(f'Initial size: {N}')
|
||||
|
||||
|
||||
###################################
|
||||
# Utility functions
|
||||
|
||||
def print_table(cell, rs, cs):
|
||||
first_col_size = max([len(str(r)) for r in rs])
|
||||
col_size = 1 + max([len(str(c)) for c in cs] + [len(cell(r, c)) for c in cs for r in rs])
|
||||
|
||||
print(''.rjust(first_col_size), end='')
|
||||
for c in cs:
|
||||
print(str(c).rjust(col_size), end='')
|
||||
print('')
|
||||
|
||||
for r in rs:
|
||||
print(str(r).rjust(first_col_size), end='')
|
||||
for c in cs:
|
||||
print(cell(r, c).rjust(col_size), end='')
|
||||
print('')
|
||||
|
||||
def print_eqrel(rel, xs):
|
||||
print_table(lambda r, c: 'Y' if rel(r, c) else '·', xs, xs)
|
||||
|
||||
class Progress:
|
||||
def __init__(self, name, guess):
|
||||
self.reset(name, guess, show=False)
|
||||
|
||||
def reset(self, name, guess, show=True):
|
||||
self.name = name
|
||||
self.guess = math.ceil(guess)
|
||||
self.count = 0
|
||||
self.percentage = None
|
||||
|
||||
if show:
|
||||
print(name)
|
||||
|
||||
def add(self, n=1):
|
||||
self.count += n
|
||||
percentage = math.floor(100 * self.count / self.guess)
|
||||
|
||||
if percentage != self.percentage:
|
||||
self.percentage = percentage
|
||||
print(f'{self.percentage}%', end='', flush=True)
|
||||
print('\r', end='')
|
||||
|
||||
progress = Progress('', 1)
|
||||
|
||||
########################
|
||||
# Encodering naar logica
|
||||
print('Start encoding')
|
||||
os = list(machine.outputs) # outputs
|
||||
rids = [i for i in range(c)] # components
|
||||
vpool = IDPool()
|
||||
cnf = CNF()
|
||||
|
||||
# Een hulp variabele voor False en True, maakt de andere variabelen eenvoudiger
|
||||
def var_const(b):
|
||||
return(vpool.id(('const', b)))
|
||||
|
||||
cnf.append([var_const(True)])
|
||||
cnf.append([-var_const(False)])
|
||||
|
||||
# Voor elke relatie en elke twee elementen o1 en o2, is er een variabele die
|
||||
# aangeeft of o1 en o2 gerelateerd zijn. Er is 1 variabele voor xRy en yRx, dus
|
||||
# symmetrie is al ingebouwd. Reflexiviteit is ook ingebouwd.
|
||||
def var_rel(rid, o1, o2):
|
||||
if o1 == o2:
|
||||
return var_const(True)
|
||||
|
||||
[so1, so2] = sorted([o1, o2])
|
||||
return(vpool.id(('rel', rid, so1, so2)))
|
||||
|
||||
# De relatie op outputs geeft een relaties op states. Deze relatie moet ook een
|
||||
# bisimulatie zijn.
|
||||
def var_state_rel(rid, s1, s2):
|
||||
if s1 == s2:
|
||||
return var_const(True)
|
||||
|
||||
[ss1, ss2] = sorted([s1, s2])
|
||||
return(vpool.id(('state_rel', rid, ss1, ss2)))
|
||||
|
||||
# Voor elke relatie, en elke equivalentie-klasse, kiezen we precies 1 state
|
||||
# als representant. Deze variabele geeft aan welk element.
|
||||
def var_state_rep(rid, s):
|
||||
return(vpool.id(('state_rep', rid, s)))
|
||||
|
||||
# Contraints zodat de relatie een equivalentie relatie is. We hoeven alleen
|
||||
# maar transitiviteit te encoderen, want refl en symm zijn ingebouwd in de var.
|
||||
progress.reset('transitivity (o)', guess=len(rids) * len(os) ** 3)
|
||||
for rid in rids:
|
||||
for xo in os:
|
||||
for yo in os:
|
||||
for zo in os:
|
||||
# als xo R yo en yo R zo dan xo R zo
|
||||
cnf.append([-var_rel(rid, xo, yo), -var_rel(rid, yo, zo), var_rel(rid, xo, zo)])
|
||||
progress.add()
|
||||
|
||||
if args.add_state_trans:
|
||||
progress.reset('transitivity (s)', guess=len(rids) * len(machine.states) ** 3)
|
||||
for rid in rids:
|
||||
for sx in machine.states:
|
||||
for sy in machine.states:
|
||||
for sz in machine.states:
|
||||
# als sx R sy en sy R sz dan sx R sz
|
||||
cnf.append([-var_state_rel(rid, sx, sy), -var_state_rel(rid, sy, sz), var_state_rel(rid, sx, sz)])
|
||||
progress.add()
|
||||
|
||||
# Constraint zodat de relaties samen alle elementen kunnen onderscheiden.
|
||||
# (Aka: the bijbehorende quotienten zijn joint-injective.)
|
||||
progress.reset('injectivity', guess=len(os) * (len(os) - 1) / 2)
|
||||
for xi, xo in enumerate(os):
|
||||
for yo in os[xi+1:]:
|
||||
# Tenminste een rid moet een verschil maken
|
||||
cnf.append([-var_rel(rid, xo, yo) for rid in rids])
|
||||
progress.add()
|
||||
|
||||
# sx ~ sy => for each input: (1) outputs equivalent AND (2) successors related
|
||||
# Momenteel hebben we niet de inverse implicatie, is misschien ook niet nodig?
|
||||
progress.reset('bisimulation modulo rel', guess=len(rids) * len(machine.states) * len(machine.states) * len(machine.inputs))
|
||||
for rid in rids:
|
||||
for sx in machine.states:
|
||||
for sy in machine.states:
|
||||
for i in machine.inputs:
|
||||
# sx ~ sy => output(sx, i) ~ output(sy, i)
|
||||
ox = machine.output(sx, i)
|
||||
oy = machine.output(sy, i)
|
||||
cnf.append([-var_state_rel(rid, sx, sy), var_rel(rid, ox, oy)])
|
||||
|
||||
# sx ~ sy => delta(sx, i) ~ delta(sy, i)
|
||||
tx = machine.transition(sx, i)
|
||||
ty = machine.transition(sy, i)
|
||||
cnf.append([-var_state_rel(rid, sx, sy), var_state_rel(rid, tx, ty)])
|
||||
|
||||
progress.add()
|
||||
|
||||
# De constraints die zorgen dat representanten ook echt representanten zijn.
|
||||
states = list(machine.states)
|
||||
progress.reset('representatives', guess=len(rids) * len(states))
|
||||
for rid in rids:
|
||||
for ix, sx in enumerate(states):
|
||||
# Belangrijkste: een element is een representant, of equivalent met een
|
||||
# eerder element. We forceren hiermee dat de solver representanten moet
|
||||
# kiezen (voor aan de lijst).
|
||||
cnf.append([var_state_rep(rid, sx)] + [var_state_rel(rid, sx, sy) for sy in states[:ix]] )
|
||||
|
||||
for sy in states[:ix]:
|
||||
# rx en ry kunnen niet beide een representant zijn, tenzij ze
|
||||
# niet gerelateerd zijn.
|
||||
cnf.append([-var_state_rep(rid, sx), -var_state_rep(rid, sy), -var_state_rel(rid, sx, sy)])
|
||||
|
||||
progress.add()
|
||||
|
||||
# Tot slot willen we weinig representanten. Dit doen we met een "atmost"
|
||||
# formule. Idealiter zoeken we naar de total_size, maar die staat nu vast.
|
||||
lower_bound = int(math.floor(c * (N-1)**(1/c)))
|
||||
upper_bound = int(N + c - 1)
|
||||
print(f'size constraints {lower_bound} {upper_bound}')
|
||||
cnf_optim = ITotalizer([var_state_rep(rid, sx) for rid in rids for sx in machine.states], ubound=upper_bound, top_id=vpool.top)
|
||||
cnf.extend(cnf_optim.cnf.clauses)
|
||||
|
||||
##################################
|
||||
# Probleem oplossen met solver :-)
|
||||
print('Start solving')
|
||||
print('- copying formula')
|
||||
with Solver(bootstrap_with=cnf) as solver:
|
||||
print('===============')
|
||||
while upper_bound - lower_bound >= 2:
|
||||
mid_size = int((lower_bound + upper_bound) / 2)
|
||||
print(f'Trying {lower_bound} < {mid_size} < {upper_bound}')
|
||||
sat = solver.solve(assumptions=[-cnf_optim.rhs[mid_size]])
|
||||
if sat:
|
||||
upper_bound = mid_size
|
||||
continue
|
||||
else:
|
||||
lower_bound = mid_size
|
||||
continue
|
||||
|
||||
total_size = upper_bound
|
||||
print(f'done searching, found size = {total_size}')
|
||||
sat = solver.solve(assumptions=[-cnf_optim.rhs[total_size]])
|
||||
assert sat
|
||||
|
||||
# Even omzetten in een makkelijkere data structuur
|
||||
print('- get model')
|
||||
m = solver.get_model()
|
||||
model = {}
|
||||
for l in m:
|
||||
if l < 0: model[-l] = False
|
||||
else: model[l] = True
|
||||
|
||||
if args.verbose:
|
||||
for rid in rids:
|
||||
print(f'Relation {rid}:')
|
||||
print_eqrel(lambda x, y: model[var_rel(rid, x, y)], os)
|
||||
|
||||
for rid in rids:
|
||||
print(f'State relation {rid}:')
|
||||
print_eqrel(lambda x, y: model[var_state_rel(rid, x, y)], machine.states)
|
||||
|
||||
# print equivalence classes
|
||||
count = 0
|
||||
for rid in rids:
|
||||
if args.verbose:
|
||||
print(f'component {rid}')
|
||||
# Eerst verzamelen we de representanten
|
||||
for s in machine.states:
|
||||
if model[var_state_rep(rid, s)]:
|
||||
count += 1
|
||||
if args.verbose:
|
||||
print(f'- representative state {s}')
|
||||
|
||||
# count moet gelijk zijn aan cost (of kleiner)
|
||||
print(f'Reduced size = {count}')
|
||||
|
||||
projections = {}
|
||||
for rid in rids:
|
||||
local_outputs = machine.outputs.copy()
|
||||
projections[rid] = {}
|
||||
count = 0
|
||||
|
||||
while local_outputs:
|
||||
repr = local_outputs.pop()
|
||||
if repr in projections[rid]:
|
||||
continue
|
||||
|
||||
projections[rid][repr] = f'cls_{rid}_{count}'
|
||||
others = False
|
||||
|
||||
for o in local_outputs:
|
||||
if model[var_rel(rid, o, repr)]:
|
||||
others = True
|
||||
projections[rid][o] = f'cls_{rid}_{count}'
|
||||
|
||||
if not others:
|
||||
projections[rid][repr] = f'{repr}'
|
||||
|
||||
count += 1
|
||||
|
||||
print('===============')
|
||||
print('Output mapping:')
|
||||
print_table(lambda o, rid: projections[rid][o], machine.outputs, rids)
|
Loading…
Add table
Reference in a new issue