Changing the API for Sudoku
[Benchmarks_CSolver.git] / sudoku-csolver / Sudoku.py
1 import pycosat
2 import sys, getopt 
3 import time
4 import numpy as np
5 import re
6 import random
7 import csolversudoku as cs
8
9 def main(argv): 
10         argument = '' 
11         try:
12                 opts, args = getopt.getopt(argv,"emhvb",["easy","medium","hard","evil","blank","help", "file", "gen", "csolver"])
13         except getopt.GetoptError:
14                 print('Argument error, check -h | --help')
15                 sys.exit(2)
16         for indx,(opt, arg) in enumerate(opts): 
17                 if opt in ("--help"):
18                         help()
19                 elif opt in ("-e", "--easy"): 
20                         solve_problem(easy) 
21                 elif opt in ("-m", "--medium"): 
22                         solve_problem(medium) 
23                 elif opt in ("-h", "--hard"):
24                         solve_problem(hard) 
25                 elif opt in ("-v", "--evil"):
26                         solve_problem(evil) 
27                 elif opt in ("-b", "--blank"):
28                         solve_problem(blank) 
29                 elif opt in ( "--file"):
30                         print opts, args
31                         if "--csolver" in args:
32                                 print "Solving the problem using csolver ..."
33                                 read_problem_from_file(args[indx], True)
34                         else:
35                                 read_problem_from_file(args[indx])
36                 elif opt in ( "--gen"):
37                         N, K = extractProblemSpecs(args)
38                         if "--csolver" in args:
39                                 print "Generating problem using csolver ..."
40                                 generate_problem_csolver(N,K)
41                         else:
42                                 generate_problem(N, K)
43                 else:
44                         help()
45
46             
47 def help():
48         print('Usage:')
49         print('Sudoku.py -e [or] --easy')
50         print('Sudoku.py -m [or] --medium')
51         print('Sudoku.py -h [or] --hard')
52         print('Sudoku.py -v [or] --evil')
53         print('Sudoku.py -b [or] --blank')
54         print('Sudoku.py --file file.problem [--csolver]')
55         print('Sudoku.py --gen 9 20 [--csolver]')
56         print('All problems generated by websudoku.com')
57         sys.exit()
58
59
60 def removeKDigits(mat, N, K):
61         count = K;
62         while (count != 0):
63                 cellId = random.randint(1,N*N)
64                 i = cellId//N
65                 j = cellId%N
66                 if j != 0:
67                         j = j - 1
68                 if i == N:
69                         i = i-1
70                 #print 'cellId=' + str(cellId) + ' i='+ str(i) + ' j=' + str(j) + ' count='+ str(count)
71                 if mat[i][j] != 0:
72                         count = count -1
73                         mat[i][j] = 0
74
75 def extractProblemSpecs(args):
76         assert len(args) >= 2
77         global N
78         N = int(args[0])
79         K = int(args[1])
80         print N
81         return N, K
82
83 def generate_problem_csolver(N,K):
84         problem = cs.generateProblem(N)
85         pprint(problem)
86         np.savetxt('solved/'+str(N) + 'x' + str(N) + '.sol',problem)
87         removeKDigits(problem, N, K)
88 #       np.savetxt('problems/'+str(N) + 'x' + str(N) + '-' + str(K) + '.problem',problem)
89
90 def generate_problem(N, K):
91         problem = [[0 for i in range(N)] for i in range(N)]
92         solve(problem)
93         np.savetxt('solved/'+str(N) + 'x' + str(N) + '.sol',problem)    
94         removeKDigits(problem, N, K)
95         pprint(problem)
96         np.savetxt('problems/'+str(N) + 'x' + str(N) + '-' + str(K) + '.problem',problem)
97
98 def read_problem_from_file(filename, useCsolver=False):
99         problem = np.loadtxt(filename)
100         global N
101         N=int(re.findall('\d+', filename)[0])
102         problem = problem.astype(int)
103         solve_problem(problem, useCsolver)
104
105 def solve_problem(problemset, useCsolver):
106         print('Problem:') 
107         pprint(problemset)
108         if useCsolver:
109                 problemset=cs.solveProblem(N, problemset)
110                 np.savetxt('solved/'+str(N) + 'x' + str(N) + '.problem',problemset)
111         else: 
112                 solve(problemset) 
113         print('Answer:')
114         pprint(problemset)  
115     
116 def v(i, j, d): 
117         return N**2 * (i - 1) + N * (j - 1) + d
118
119 #Reduces Sudoku problem to a SAT clauses 
120 def sudoku_clauses(): 
121         res = []
122         # for all cells, ensure that the each cell:
123         for i in range(1, N+1):
124                 for j in range(1, N+1):
125                         # denotes (at least) one of the 9 digits (1 clause)
126                         res.append([v(i, j, d) for d in range(1, N+1)])
127                         # does not denote two different digits at once (36 clauses)
128                         for d in range(1, N+1):
129                                 for dp in range(d + 1, N+1):
130                                         res.append([-v(i, j, d), -v(i, j, dp)])
131         print "First one :" + str( len(res))
132         
133         def valid(cells): 
134                 for i, xi in enumerate(cells):
135                         for j, xj in enumerate(cells):
136                                 if i < j:
137                                         for d in range(1, N+1):
138                                                 res.append([-v(xi[0], xi[1], d), -v(xj[0], xj[1], d)])
139
140         # ensure rows and columns have distinct values
141         for i in range(1, N+1):
142                 valid([(i, j) for j in range(1, N+1)])
143                 valid([(j, i) for j in range(1, N+1)])
144         print "Second one :" + str(len(res))
145         # ensure rootxroot (e.g. 3*3) sub-grids "regions" have distinct values
146         root = int(N**(0.5))
147         collections = [ root*i+1 for i in range(root)]
148         for i in collections:
149                 for j in collections:
150                         valid([(i + k % root, j + k // root) for k in range(N)])
151         print "Third one :" + str( len(res))
152 #       assert len(res) == 81 * (1 + 36) + 27 * 324
153         return res
154
155 def solve(grid):
156         #solve a Sudoku problem
157         clauses = sudoku_clauses()
158         for i in range(1, N+1):
159                 for j in range(1, N+1):
160                         d = grid[i - 1][j - 1]
161                         # For each digit already known, a clause (with one literal). 
162                         if d:
163                                 clauses.append([v(i, j, d)])
164
165         # Print number SAT clause 
166         numclause = len(clauses)
167         print "P CNF " + str(numclause) +"(number of clauses)"
168
169         # solve the SAT problem
170         start = time.time()
171         sol = set(pycosat.solve(clauses))
172         end = time.time()
173         print("Time: "+str(end - start))
174     
175         def read_cell(i, j):
176                 # return the digit of cell i, j according to the solution
177                 for d in range(1, N+1):
178                         if v(i, j, d) in sol:
179                                 return d
180
181         for i in range(1, N+1):
182                 for j in range(1, N+1):
183                         grid[i - 1][j - 1] = read_cell(i, j)
184
185
186 if __name__ == '__main__':
187         from pprint import pprint
188         N = 9
189         # Sudoku problem generated by websudoku.com
190         easy = [[0, 0, 0, 1, 0, 9, 4, 2, 7],
191                 [1, 0, 9, 8, 0, 0, 0, 0, 6],
192                 [0, 0, 7, 0, 5, 0, 1, 0, 8],
193                 [0, 5, 6, 0, 0, 0, 0, 8, 2],
194                 [0, 0, 0, 0, 2, 0, 0, 0, 0],
195                 [9, 4, 0, 0, 0, 0, 6, 1, 0],
196                 [7, 0, 4, 0, 6, 0, 9, 0, 0],
197                 [6, 0, 0, 0, 0, 8, 2, 0, 5],
198                 [2, 9, 5, 3, 0, 1, 0, 0, 0]]
199
200         medium = [[5, 8, 0, 0, 0, 1, 0, 0, 0],
201                 [0, 3, 0, 0, 6, 0, 0, 7, 0],
202                 [9, 0, 0, 3, 2, 0, 1, 0, 6],
203                 [0, 0, 0, 0, 0, 0, 0, 5, 0],
204                 [3, 0, 9, 0, 0, 0, 2, 0, 1],
205                 [0, 5, 0, 0, 0, 0, 0, 0, 0],
206                 [6, 0, 2, 0, 5, 7, 0, 0, 8],
207                 [0, 4, 0, 0, 8, 0, 0, 1, 0],
208                 [0, 0, 0, 1, 0, 0, 0, 6, 5]]
209
210         evil = [[0, 2, 0, 0, 0, 0, 0, 0, 0],
211                 [0, 0, 0, 6, 0, 0, 0, 0, 3],
212                 [0, 7, 4, 0, 8, 0, 0, 0, 0],
213                 [0, 0, 0, 0, 0, 3, 0, 0, 2],
214                 [0, 8, 0, 0, 4, 0, 0, 1, 0],
215                 [6, 0, 0, 5, 0, 0, 0, 0, 0],
216                 [0, 0, 0, 0, 1, 0, 7, 8, 0],
217                 [5, 0, 0, 0, 0, 9, 0, 0, 0],
218                 [0, 0, 0, 0, 0, 0, 0, 4, 0]]
219
220         hard = [[0, 2, 0, 0, 0, 0, 0, 3, 0],
221                 [0, 0, 0, 6, 0, 1, 0, 0, 0],
222                 [0, 6, 8, 2, 0, 0, 0, 0, 5],
223                 [0, 0, 9, 0, 0, 8, 3, 0, 0],
224                 [0, 4, 6, 0, 0, 0, 7, 5, 0],
225                 [0, 0, 1, 3, 0, 0, 4, 0, 0],
226                 [9, 0, 0, 0, 0, 7, 5, 1, 0],
227                 [0, 0, 0, 1, 0, 4, 0, 0, 0],
228                 [0, 1, 0, 0, 0, 0, 0, 9, 0]]
229     
230         blank = [[0, 0, 0, 0, 0, 0, 0, 0, 0],
231                 [0, 0, 0, 0, 0, 0, 0, 0, 0],
232                 [0, 0, 0, 0, 0, 0, 0, 0, 0],
233                 [0, 0, 0, 0, 0, 0, 0, 0, 0],
234                 [0, 0, 0, 0, 0, 0, 0, 0, 0],
235                 [0, 0, 0, 0, 0, 0, 0, 0, 0],
236                 [0, 0, 0, 0, 0, 0, 0, 0, 0],
237                 [0, 0, 0, 0, 0, 0, 0, 0, 0],
238                 [0, 0, 0, 0, 0, 0, 0, 0, 0]]
239     
240         if(len(sys.argv[1:]) == 0):
241                 print('Argument error, check --help')
242         else:
243                 main(sys.argv[1:])