Alloy Support for Killer Sudoku
[Benchmarks_CSolver.git] / killerSudoku / killerSolver.py
1 #Robert White 
2 #Supervisor Dr. Konstantin Korovin
3 #http://pythonsudoku.sourceforge.net/
4 #Oct 2013
5 import os, sys
6 import pyparsing as pp
7 import numpy
8 import time
9 import glucose
10 import re
11 from os.path import basename
12 from itertools import combinations, ifilter, chain
13 from argprocessor import KSudokuArgParser
14 import csolversudoku as cs
15
16 argparser = KSudokuArgParser()
17 # some global values:
18 N = 9
19 seed = N**3+1
20 indexBoard = []
21
22 countCNF_ari = 0
23 countCNF_all = 0
24
25
26 def getNewIndex():
27     global seed
28     seed = seed + 1
29     return seed
30
31
32 # note that the k is a number from 1 to 9
33 def getIndex(i,j,k):
34     global N
35     return (i * N**2 + j * N + k)
36
37 # note that the k is a number from 1 to 9
38 def deIndex (index):
39     global N
40     i = (index-1) / N**2
41     j = ((index -1)/ N) % N
42     k = (index-1) % N +1 
43     return (i, j, k)  # k is 1 - 9
44
45
46 def decode_to_matrix(result_list):
47     global N
48     out_matrix = numpy.zeros(shape = (N,N))
49     for l in result_list:
50         if l > 0 and l < N**3+1:
51             #print l ,'means', (l-1)/81,((l-1)/9)%9,' is ', (l-1)%9 +1 
52             (i,j,k) = deIndex(l)
53             out_matrix[i][j] = k
54         #elif l > 10000:
55         #    print 'WTF:'+str(l)
56
57     return out_matrix
58             
59 def print_matrix(matrix):
60     line = 1
61     for row in matrix:
62         for ele in row:
63             print int(ele), 
64         line = line + 1
65         print '\n'
66        
67
68             
69 def write_to_cnf_file(cnf, name): # out is the writting channel
70     out = open(name, 'w+')
71     
72     for clause in cnf:
73         for literal in clause:
74             out.write(str(literal)+' ')
75     out.write('\n')
76     out.close() 
77
78
79 def remove_cell_values(i, j, v):
80     if v in cellPossibleValues[i][j]:
81         cellPossibleValues[i][j].remove (v)
82
83 def add_cell_values(i, j, v):
84     if not (v in cellPossibleValues[i][j]):
85             cellPossibleValues[i][j].append(v)
86
87 def exactly_one(literals):
88     # print 'exactly one of ', literals
89     new = []
90     cnf = []
91     previous= None
92     if len(literals) == 1:
93         cnf.append([literals[0]])
94     elif len(literals) == 2:
95         # cnf.append([literals[0], literals[1]])
96         cnf.append([literals[0] * -1, literals[1] * -1])
97     else:
98         for x in literals[:-2]:
99             # print 'x is ', x
100             y = getNewIndex()
101             # print 'y is ', y
102             # cnf.append([x, y])
103             cnf.append([-1*x, -1*y])
104             # print 'x only one y'
105             if len(new) != 0:
106                 # print 'link to the previous x and y'
107                 cnf.append([x * -1, new[-1]])
108                 # print [x * -1, new[-1]]
109                 cnf.append([y * -1, new[-1]])
110                 # print [y * -1, new[-1]]
111             new.append(y)
112             previous = x
113         # cnf.append([literals[-1], literals[-2]])
114         cnf.append([-1*literals[-1], -1*literals[-2]])
115         cnf.append([-1*literals[-1], new[-1]])
116         cnf.append([-1*literals[-2], new[-1]])
117     cnf.append(literals)
118     return cnf
119
120
121 def encode_to_cnf(killerRules): #encode a problem (stored in matrix) as cnf
122     global N
123     global indexBoard
124     global countCNF_all
125     global countCNF_ari
126     for i in range (0, N):
127         tmp = []
128         for j in range (0, N):
129             tmp2 = []
130             for num in range (1, N+1): # here, it represent num of 1 to 9
131                 tmp2.append(getIndex(i, j, num))
132             tmp.append(tmp2)
133         indexBoard.append(tmp)
134
135     cnf = []
136     
137     # find all possible combination of each cell
138     for k in killerRules:
139         cageSum = k[0]
140         cageSize = len(k[1])
141         cageCells = k[1]
142         cb = combinations([ii for ii in range(1, N+1)], cageSize)
143         f = lambda x : sum(x) == cageSum
144         comb = ifilter(f, cb) # all valid combinations
145         allPossible = list(chain(comb))
146 #         print '\nall possible: ', allPossible
147         common = []
148         for i in range (1,N+1):
149             flag = True # means it is a common one
150             for j in allPossible:
151 #                 print 'test on', list(j)
152                 if not(i in list(j)):
153 #                     print i, ' is not in ', list(j) 
154                     flag = False
155             if flag == True:
156 #                 print '************this is a common one: ', i
157                 common.append(i)
158
159         different = []
160         for p in allPossible:
161             pl = list(p)
162             for r in common:
163                 if r in pl:
164                     pl.remove(r)
165             if (pl != []):
166                 different.append(pl)
167         
168 #         print 'In this iteration, we have the sum of cage: ', cageSum, '; the size of cage', cageSize
169 #         print 'these cells are: ', cageCells
170 #         print 'possible combinations are: ', allPossible
171 #         print 'common values, ' , common
172 #         print 'different values', different
173
174         # next we start to encode. 
175
176         # encode the common values first. 
177         # for every common value, [1,2] for example, introduce a new index representing the existence of the value
178         # among the cells of the cage. 
179
180         dic = {}
181         for num in range(1, N+1): 
182             dic[num] = getNewIndex()
183             tmp =[]
184             for cc in cageCells:
185                 cnf.append([indexBoard[cc[0]][cc[1]][num-1] * -1,  dic[num]]) # right arrow
186                 tmp.append(indexBoard[cc[0]][cc[1]][num-1])                # left arrow
187             tmp.append(dic[num] * -1)
188             cnf.append(tmp)
189         for num in common:
190 #             print num, '  is a common number'
191             cnf.append([dic[num]])
192
193         # next, we encode the differnt ones 
194         # we need to introduce new values as above
195         # we need to obtain all the numbers possibly in the different cases
196         # lst = reduce ((lambda x, y: x + y), different)
197         # lst = list(set(lst)) # remove duplicated elements
198
199         x_list = []
200         for dif in different: # for example, [[3,4,8], [3,5,7], [4,5,6]
201             # for cc in cageCells:
202             # we need to convert from DNF to CNF
203             # first, we need to convert that x1 = 3 /\ 4 /\ 8
204             # again, we need to introduce our x1, for 348, x2 for 357 , etc
205             x = getNewIndex()
206             # x -> 3 4 8
207             for d in dif:
208                 cnf.append([-1* x , dic[d]])
209                 # print ' for ', d, ' -- ', dic[d]
210             # ~(3, 4, 8) \/ x
211             # i.e. -3 \/ -4 \/ -8 \/ x
212             tmp = map ((lambda x : -1 * dic[x]), dif)
213             tmp.append(x)  
214             # print ' == ', tmp 
215             cnf.append(tmp)
216             x_list.append(x)
217         if x_list != []:
218             # print '***************', x_list
219             cnf.append(x_list)
220     # END of killer sudoku ------------------------
221     countCNF_ari = len(cnf)
222
223     # Exactly one in each cell
224     for i in range(N): #column
225         for j in range(N): #row
226             # print 'for cell ', i ,' and ', j, '\n'
227             #at least one of k should be true
228             temp =[]
229             for k in range(1,N+1):
230                 temp.append(getIndex(i,j,k))
231             cnf = cnf + exactly_one(temp)
232             
233     #exactly once in each row     
234     for k in range(1,N+1): #each number
235         # appear exactly once in each row
236         for j in range(N):
237             #appear at least once
238             #print 'In row ', j, ' \n'
239             temp = []
240             for i in range(N):
241                 temp.append(getIndex(i,j,k))
242             cnf = cnf + exactly_one(temp)
243             
244         #exactly once in each coloumn
245         for i in range(N):
246             temp = []
247             for j in range(N):
248                 temp.append(getIndex(i,j,k))
249             cnf = cnf + exactly_one(temp)
250             
251         #exactly once in each block
252         sqrRootN = int(N**(0.5))
253         for block_i in range(sqrRootN):
254             for block_j in range(sqrRootN):
255                 #print 'for block', block_i, ' and ', block_j, '::::\n'
256                 #at least once
257                 temp  = []
258                 for i in range(block_i*sqrRootN, block_i*sqrRootN + sqrRootN):
259                     for j in range(block_j*sqrRootN, block_j*sqrRootN + sqrRootN):
260                         temp.append(getIndex(i,j,k))
261                 cnf = cnf + exactly_one(temp)
262     countCNF_all = len(cnf)
263     return cnf
264
265                 
266 def readSudoku(filename):
267     global N
268     global seed
269     # print 'the constraints from file ', filename, ' are:'
270     name = re.findall("\d+", basename(filename[0]))
271     if len(name)>1:
272         N = int(name[1])
273         seed = N**3+1
274     else:
275         N = 9
276     file_reader = open(filename[0], 'r')
277     lines = file_reader.readlines()
278     killerRules = []
279     
280     f = lambda x: [int(re.findall("(\d+)", x)[0]), int(re.findall("(\d+)", x)[1])] 
281
282     for l in lines:
283         # print l
284         (s, t) = l.split('=')
285         tl = t.split('+')
286         lst = map(f, tl)
287         killerRules.append((int(s), lst))
288     return killerRules
289
290 def verify_killer_sudoku(killerRules, result_matrix):
291     # print 'start checking the answer!'
292     for r in killerRules:
293         ans = r[0]
294         cage = r[1]
295         cageSum = 0
296         for c in cage:
297             cageSum = cageSum + result_matrix[c[0]][c[1]]
298         if cageSum != ans:
299             # print 'the rule is not validated: ', r
300             return False
301     return True
302
303 def solveOriginalEncoding(killerRules):
304     global seed
305     cnf =  encode_to_cnf(killerRules)       
306     # #solve the encoded CNF     
307     start = time.time()
308     result_list = glucose.solve(cnf, seed)
309     end = time.time()
310     #output the result
311     # print result_list
312     if result_list == 'UNSAT':
313         return None
314     elif result_list !=[]:
315         print '*************\nTime in Sat Solver:\t%04.5f\ncountCNF_ari:\t%d\ncountCNF_ALL:\t%d\n'% ((end-start), countCNF_ari, countCNF_all)
316         return decode_to_matrix(result_list)
317     else:
318         print 'SYSTEM ERROR'
319         sys.exit(1)
320
321 def solveKillerSudoku(killerRules):
322     global N
323     global argparser
324     result_matrix = None
325     
326     if argparser.getCSolverOption() > 0:
327         result_matrix = cs.solveProblem(N, killerRules, argparser.getCSolverOption())
328     else:
329         result_matrix = solveOriginalEncoding(killerRules)
330     
331     if result_matrix is None:
332         print 'UNSAT'
333     else:
334         if (verify_killer_sudoku(killerRules, result_matrix)):
335             print 'CORRECT'
336         else:
337             print 'ERROR'
338 def main ():
339     global argparser
340     
341     killerRules = readSudoku(argparser.getProblemName())
342     solveKillerSudoku(killerRules)
343     
344               
345 if __name__ == '__main__':
346     main()