Making benchmarks as similar as possible to the original encoding
[Benchmarks_CSolver.git] / killerSudoku / csolversudoku.py
1 import pycsolver as ps
2 from ctypes import *
3 import numpy as np
4 import sys
5 from itertools import combinations, ifilter, chain
6
7 def solveProblem(N, killerRules, serialize = False):
8         
9         return generateKillerSudokuConstraints(N, killerRules, serialize)
10
11 def getDomain(allPossible):
12         assert len(allPossible) > 0
13         domainNum = len(allPossible[0])
14         domain = [x[i] for x in allPossible for i in range(domainNum)]
15         return list(set(domain))
16
17 def validateElements(elems):
18         for row in elems:
19                 for elem in row:
20                         assert elem != None
21
22 def generateSumRange(domain1, domain2, totalSum):
23         range = set()
24         for d1 in domain1:
25                 for d2 in domain2:
26                         if (d1+d2) <= totalSum:
27                                 range.add(d1+d2)
28         return range
29
30 def generateKillerSudokuConstraints(N, killerRules, serialize):
31         csolverlb = ps.loadCSolver()
32         solver = csolverlb.createCCSolver()
33         problem = np.array([[None for i in range(N)] for i in range(N)])
34         elemConsts = [csolverlb.getElementConst(solver, c_uint(1), i) for i in range(1, N+1)]
35         
36         
37         def valid(cells):
38                 for i, ei in enumerate(cells):
39                         for j, ej in enumerate(cells):
40                                 if i < j:
41                                         si = csolverlb.getElementRange(solver, ei)
42                                         sj = csolverlb.getElementRange(solver,ej)
43                                         d = [si,sj]
44                                         domain = (c_void_p *len(d))(*d)
45                                         equals = csolverlb.createPredicateOperator(solver, c_uint(ps.CompOp.SATC_EQUALS))
46                                         inp = [ei,ej]
47                                         inputs = (c_void_p*len(inp))(*inp)
48                                         b = csolverlb.applyPredicate(solver,equals, inputs, c_uint(2))
49                                         b = csolverlb.applyLogicalOperationOne(solver, ps.LogicOps.SATC_NOT, b)
50                                         csolverlb.addConstraint(solver,b);
51                 
52         def getElement(cage):
53                 cageSum = cage[0]
54                 cageSize = len(cage[1])
55                 cageCells = cage[1]
56                 cb = combinations([ii for ii in range(1, N+1)], cageSize)
57                 f = lambda x : sum(x) == cageSum
58                 comb = ifilter(f, cb) # all valid combinations
59                 allPossible = list(chain(comb))
60                 d1 = getDomain(allPossible)
61                 set1 = (c_long* len(d1))(*d1)
62                 s1 = csolverlb.createSet(solver, c_uint(1), set1, c_uint(len(d1)))
63                 for i in range(len(cage[1])):
64                         problem[cage[1][i][0]][cage[1][i][1]] = csolverlb.getElementVar(solver, s1);
65                 elems = [ problem[cage[1][i][0]][cage[1][i][1]] for i in range(len(cage[1])) ]
66                 #Elements in each cage shouldn't be identical 
67                 valid(elems)
68                 return elems, d1
69         
70         def generateSumConstraint(sumCage, elements, domain):
71                 assert len(elements) >1
72                 parDomain = domain
73                 parElem = elements[0]
74                 overflow = csolverlb.getBooleanVar(solver, c_uint(2));
75                 for i in range(len(elements)):
76                         if i< len(elements) -1: 
77                                 elem = elements[i+1]
78                                 set1 = (c_long* len(domain))(*domain)
79                                 s1 = csolverlb.createSet(solver, c_uint(1), set1, c_uint(len(domain)))
80                                 pS = (c_long* len(parDomain))(*parDomain)
81                                 parSet = csolverlb.createSet(solver, c_uint(1), pS, c_uint(len(parDomain)))
82                                 d = [s1, parSet]
83                                 domains = (c_void_p *len(d))(*d)
84                                 parDomain = generateSumRange(domain, parDomain, sumCage)
85                                 r = (c_long* len(parDomain))(*parDomain)
86                                 sumRange = csolverlb.createSet(solver, c_uint(1), r, c_uint(len(parDomain)))
87                                 f1 = csolverlb.createFunctionOperator(solver, ps.ArithOp.SATC_ADD, sumRange, ps.OverFlowBehavior.SATC_OVERFLOWSETSFLAG);
88                                 inp = [elem, parElem]
89                                 inputs = (c_void_p*len(inp))(*inp)
90                                 parElem = csolverlb.applyFunction(solver, f1, inputs, len(inp), overflow);
91                 esum = csolverlb.getElementConst(solver, c_uint(3), c_long(sumCage))
92                 setSum = csolverlb.getElementRange(solver, esum)
93                 equals = csolverlb.createPredicateOperator(solver, c_uint(ps.CompOp.SATC_EQUALS))
94                 inp = [parElem,esum]
95                 inputs = (c_void_p*len(inp))(*inp)
96                 b = csolverlb.applyPredicate(solver,equals, inputs, c_uint(len(inp)))
97                 csolverlb.addConstraint(solver,b);
98                 csolverlb.addConstraint(solver, csolverlb.applyLogicalOperationOne(solver, ps.LogicOps.SATC_NOT, overflow));
99         
100         # Generating constraints for each cage
101         for cage in killerRules:
102                 sumCage = cage[0]
103                 if len(cage[1])==1:
104                         problem[cage[1][0][0]][cage[1][0][1]] = csolverlb.getElementConst(solver, c_uint(7), c_long(sumCage))
105                         continue
106                 elements, domain = getElement(cage)
107                 generateSumConstraint(sumCage, elements, domain)        
108         # Ensure there's no undefined element (for each cell we have a rule)
109         validateElements(problem)
110         # ensure each cell at least has one value!
111         for i,row in enumerate(problem):
112                 for j, elem in enumerate(row):
113                         constr = []
114                         for econst in elemConsts:
115                                 s1 = csolverlb.getElementRange(solver, elem)
116                                 sconst = csolverlb.getElementRange(solver,econst)
117                                 d = [s1,sconst]
118                                 domain = (c_void_p *len(d))(*d)
119                                 equals = csolverlb.createPredicateOperator(solver, c_uint(ps.CompOp.SATC_EQUALS))
120                                 inp = [elem,econst]
121                                 inputs = (c_void_p*len(inp))(*inp)
122                                 constr.append( csolverlb.applyPredicate(solver,equals, inputs, c_uint(2)))
123                         b = (c_void_p*len(constr))(*constr)
124                         b = csolverlb.applyLogicalOperation(solver, ps.LogicOps.SATC_OR, b, len(constr))
125                         csolverlb.addConstraint(solver,b);
126         
127         
128         #ensure each cell at least has one value
129 #       for i,row in enumerate(problem):
130 #               for j, elem in enumerate(row):
131 #                       csolverlb.mustHaveValue(solver, elem)
132
133         # ensure rows have distinct values
134         for i in range( N):
135                 valid(problem[:,i])
136                 
137         # ensure columns have distinct values
138         for i in range( N):
139                 valid(problem[i,:])
140         
141         # ensure each block has distinct values
142         root = int(N**(0.5))
143         collections = [ root*i for i in range(root)]
144         for i in collections:
145                 for j in collections:
146                         valid([problem[i + k % root, j + k // root] for k in range(N)])
147         
148         #Serializing the problem before solving it ....
149         if serialize:
150                 csolverlb.serialize(solver)
151         if csolverlb.solve(solver) != 1:
152                 return None
153         result = [[0 for i in range(N)] for i in range(N)]
154         for i,row in enumerate(problem):
155                 for j, elem in enumerate(row):
156                         result[i][j] = csolverlb.getElementValue(solver, elem)
157         return result
158