Alloy Support for Killer Sudoku
[Benchmarks_CSolver.git] / killerSudoku / csolversudoku.py
1 import pycsolver as ps
2 from ctypes import *
3 import numpy as np
4 import sys
5 from itertools import combinations, ifilter, chain
6
7 class Solver:
8         CSOLVER=1
9         SERIALISE=2
10         ALLOY=3
11
12 def solveProblem(N, killerRules, solverOption):
13         
14         return generateKillerSudokuConstraints(N, killerRules, solverOption)
15
16 def getDomain(allPossible):
17         assert len(allPossible) > 0
18         domainNum = len(allPossible[0])
19         domain = [x[i] for x in allPossible for i in range(domainNum)]
20         return list(set(domain))
21
22 def validateElements(elems):
23         for row in elems:
24                 for elem in row:
25                         assert elem != None
26
27 def generateSumRange(domain1, domain2, totalSum):
28         range = set()
29         for d1 in domain1:
30                 for d2 in domain2:
31                         if (d1+d2) <= totalSum:
32                                 range.add(d1+d2)
33         return range
34
35 def generateEqualityConstraint(csolverlb, solver, e1, e2):
36         equals = csolverlb.createPredicateOperator(solver, c_uint(ps.CompOp.SATC_EQUALS))
37         inp = [e1,e2]
38         inputs = (c_void_p*len(inp))(*inp)
39         b = csolverlb.applyPredicate(solver,equals, inputs, c_uint(2))
40         return b
41
42 def generateKillerSudokuConstraints(N, killerRules, solverOption):
43         csolverlb = ps.loadCSolver()
44         solver = csolverlb.createCCSolver()
45         s1 = [ i for i in range(1, N+1)]
46         set1 = (c_long* len(s1))(*s1)
47         s1 = csolverlb.createSet(solver, c_uint(1), set1, c_uint(N))
48         problem = np.array([[csolverlb.getElementVar(solver,s1) for i in range(N)] for i in range(N)])
49         
50         #problem = np.array([[None for i in range(N)] for i in range(N)])
51         elemConsts = [csolverlb.getElementConst(solver, c_uint(1), i) for i in range(1, N+1)]
52         
53         
54         def valid(cells):
55                 for i, ei in enumerate(cells):
56                         for j, ej in enumerate(cells):
57                                 if i < j:
58                                         si = csolverlb.getElementRange(solver, ei)
59                                         sj = csolverlb.getElementRange(solver,ej)
60                                         d = [si,sj]
61                                         domain = (c_void_p *len(d))(*d)
62                                         equals = csolverlb.createPredicateOperator(solver, c_uint(ps.CompOp.SATC_EQUALS))
63                                         inp = [ei,ej]
64                                         inputs = (c_void_p*len(inp))(*inp)
65                                         b = csolverlb.applyPredicate(solver,equals, inputs, c_uint(2))
66                                         b = csolverlb.applyLogicalOperationOne(solver, ps.LogicOps.SATC_NOT, b)
67                                         csolverlb.addConstraint(solver,b);
68
69         def getElement(cage):
70                 cageSum = cage[0]
71                 cageSize = len(cage[1])
72                 cageCells = cage[1]
73                 cb = combinations([ii for ii in range(1, N+1)], cageSize)
74                 f = lambda x : sum(x) == cageSum
75                 comb = ifilter(f, cb) # all valid combinations
76                 allPossible = list(chain(comb))
77                 d1 = getDomain(allPossible)
78                 set1 = (c_long* len(d1))(*d1)
79                 s1 = csolverlb.createSet(solver, c_uint(1), set1, c_uint(len(d1)))
80                 for i in range(len(cage[1])):
81                         problem[cage[1][i][0]][cage[1][i][1]] = csolverlb.getElementVar(solver, s1);
82                 elems = [ problem[cage[1][i][0]][cage[1][i][1]] for i in range(len(cage[1])) ]
83                 #Elements in each cage shouldn't be identical 
84                 valid(elems)
85                 return elems, d1
86         
87         def generateSumConstraint(sumCage, elements, domain):
88                 assert len(elements) >1
89                 parDomain = domain
90                 parElem = elements[0]
91                 overflow = csolverlb.getBooleanVar(solver, c_uint(1));
92                 for i in range(len(elements)):
93                         if i< len(elements) -1: 
94                                 elem = elements[i+1]
95                                 set1 = (c_long* len(domain))(*domain)
96                                 s1 = csolverlb.createSet(solver, c_uint(1), set1, c_uint(len(domain)))
97                                 pS = (c_long* len(parDomain))(*parDomain)
98                                 parSet = csolverlb.createSet(solver, c_uint(1), pS, c_uint(len(parDomain)))
99                                 d = [s1, parSet]
100                                 domains = (c_void_p *len(d))(*d)
101                                 parDomain = generateSumRange(domain, parDomain, sumCage)
102                                 r = (c_long* len(parDomain))(*parDomain)
103                                 sumRange = csolverlb.createSet(solver, c_uint(1), r, c_uint(len(parDomain)))
104                                 f1 = csolverlb.createFunctionOperator(solver, ps.ArithOp.SATC_ADD, sumRange, ps.OverFlowBehavior.SATC_OVERFLOWSETSFLAG);
105                                 inp = [elem, parElem]
106                                 inputs = (c_void_p*len(inp))(*inp)
107                                 parElem = csolverlb.applyFunction(solver, f1, inputs, len(inp), overflow);
108                 esum = csolverlb.getElementConst(solver, c_uint(1), c_long(sumCage))
109                 setSum = csolverlb.getElementRange(solver, esum)
110                 equals = csolverlb.createPredicateOperator(solver, c_uint(ps.CompOp.SATC_EQUALS))
111                 inp = [parElem,esum]
112                 inputs = (c_void_p*len(inp))(*inp)
113                 b = csolverlb.applyPredicate(solver,equals, inputs, c_uint(len(inp)))
114                 csolverlb.addConstraint(solver,b);
115                 csolverlb.addConstraint(solver, csolverlb.applyLogicalOperationOne(solver, ps.LogicOps.SATC_NOT, overflow));
116         
117         # Generating constraints for each cage
118         #for cage in killerRules:
119         #       sumCage = cage[0]
120         #       if len(cage[1])==1:
121         #               problem[cage[1][0][0]][cage[1][0][1]] = csolverlb.getElementConst(solver, c_uint(1), c_long(sumCage))
122         #               continue
123         #       elements, domain = getElement(cage)
124         #       generateSumConstraint(sumCage, elements, domain)        
125         
126         # find all possible combination of each cell
127         for k in killerRules:
128                 cageSum = k[0]
129                 cageSize = len(k[1])
130                 cageCells = k[1]
131                 cb = combinations([ii for ii in range(1, N+1)], cageSize)
132                 f = lambda x : sum(x) == cageSum
133                 comb = ifilter(f, cb) # all valid combinations
134                 allPossible = list(chain(comb))
135                 #         print '\nall possible: ', allPossible
136                 common = []
137                 for i in range (1,N+1):
138                         flag = True # means it is a common one
139                         for j in allPossible:
140                 #                 print 'test on', list(j)
141                                 if not(i in list(j)):
142                 #                     print i, ' is not in ', list(j)
143                                         flag = False
144                         if flag == True:
145                 #                 print '************this is a common one: ', i
146                                 common.append(i)
147
148                 different = []
149                 for p in allPossible:
150                         pl = list(p)
151                         for r in common:
152                                 if r in pl:
153                                         pl.remove(r)
154                         if (pl != []):
155                                 different.append(pl)
156                 dic = {}
157                 for num in range(1, N+1):
158                         dic[num] = csolverlb.getBooleanVar(solver, c_uint(1))
159                         tmp =[]
160                         for cc in cageCells:
161                                 equal = generateEqualityConstraint(csolverlb, solver, problem[cc[0]][cc[1]], elemConsts[num-1])
162                                 proxy = csolverlb.applyLogicalOperationTwo(solver, ps.LogicOps.SATC_IMPLIES, equal, dic[num]) 
163                                 csolverlb.addConstraint(solver, proxy)
164                                 tmp.append(equal)                # left arrow
165                         tmp.append(csolverlb.applyLogicalOperationOne(solver, ps.LogicOps.SATC_NOT,dic[num]))
166                         clause = (c_void_p*len(tmp))(*tmp)
167                         csolverlb.addConstraint(solver, csolverlb.applyLogicalOperation(solver, ps.LogicOps.SATC_OR,clause, c_uint(len(tmp))))
168                 for num in common:
169                 #             print num, '  is a common number'
170                         csolverlb.addConstraint(solver, dic[num])
171                 x_list = []
172                 for dif in different: # for example, [[3,4,8], [3,5,7], [4,5,6]
173                         # for cc in cageCells:
174                         # we need to convert from DNF to CNF
175                         # first, we need to convert that x1 = 3 /\ 4 /\ 8
176                         # again, we need to introduce our x1, for 348, x2 for 357 , etc
177                         x = csolverlb.getBooleanVar(solver, c_uint(1))
178                         # x -> 3 4 8
179                         for d in dif:
180                                 csolverlb.addConstraint(solver, csolverlb.applyLogicalOperationTwo(solver, ps.LogicOps.SATC_IMPLIES, x, dic[d]))
181                         # print ' for ', d, ' -- ', dic[d]
182                         # ~(3, 4, 8) \/ x
183                         # i.e. -3 \/ -4 \/ -8 \/ x
184                         tmp = map ((lambda x : csolverlb.applyLogicalOperationOne(solver, ps.LogicOps.SATC_NOT, dic[x])), dif)
185                         tmp.append(x)
186                         # print ' == ', tmp
187                         clause = (c_void_p*len(tmp))(*tmp)
188                         csolverlb.addConstraint(solver, csolverlb.applyLogicalOperation(solver, ps.LogicOps.SATC_OR,clause, c_uint(len(tmp))))
189                         x_list.append(x)
190                 if x_list != []:
191                         # print '***************', x_list
192                         clause = (c_void_p*len(x_list))(*x_list)
193                         csolverlb.addConstraint(solver, csolverlb.applyLogicalOperation(solver, ps.LogicOps.SATC_OR,clause, c_uint(len(x_list))))
194         # Ensure there's no undefined element (for each cell we have a rule)
195         validateElements(problem)
196         # ensure each cell at least has one value!
197 #       for i,row in enumerate(problem):
198 #               for j, elem in enumerate(row):
199 #                       constr = []
200 #                       for econst in elemConsts:
201 #                               s1 = csolverlb.getElementRange(solver, elem)
202 #                               sconst = csolverlb.getElementRange(solver,econst)
203 #                               d = [s1,sconst]
204 #                               domain = (c_void_p *len(d))(*d)
205 #                               equals = csolverlb.createPredicateOperator(solver, c_uint(ps.CompOp.SATC_EQUALS))
206 #                               inp = [elem,econst]
207 #                               inputs = (c_void_p*len(inp))(*inp)
208 #                               constr.append( csolverlb.applyPredicate(solver,equals, inputs, c_uint(2)))
209 #                       b = (c_void_p*len(constr))(*constr)
210 #                       b = csolverlb.applyLogicalOperation(solver, ps.LogicOps.SATC_OR, b, len(constr))
211 #                       csolverlb.addConstraint(solver,b);
212         
213         
214         #ensure each cell at least has one value
215         for i,row in enumerate(problem):
216                 for j, elem in enumerate(row):
217                         csolverlb.mustHaveValue(solver, elem)
218
219         # ensure rows have distinct values
220         for i in range( N):
221                 valid(problem[:,i])
222                 
223         # ensure columns have distinct values
224         for i in range( N):
225                 valid(problem[i,:])
226         
227         # ensure each block has distinct values
228         root = int(N**(0.5))
229         collections = [ root*i for i in range(root)]
230         for i in collections:
231                 for j in collections:
232                         valid([problem[i + k % root, j + k // root] for k in range(N)])
233         
234         #Serializing the problem before solving it ....
235         if solverOption == Solver.SERIALISE:
236                 csolverlb.serialize(solver)
237         if solverOption == Solver.ALLOY:
238                 csolverlb.setAlloyEncoder(solver)
239         if csolverlb.solve(solver) != 1:
240                 return None
241         result = [[0 for i in range(N)] for i in range(N)]
242         for i,row in enumerate(problem):
243                 for j, elem in enumerate(row):
244                         result[i][j] = csolverlb.getElementValue(solver, elem)
245         return result
246