ac1fef6359463ce7adc1cf1d78edc313df140231
[Benchmarks_CSolver.git] / sudoku-csolver / csolversudoku.py
1 import pycsolver as ps
2 from ctypes import *
3 import numpy as np
4 import sys
5
6 def generateProblem(N):
7         return generateSudokuConstraints(N)     
8
9 def solveProblem(N, problem, serialize):
10         return generateSudokuConstraints(N, problem, serialize)
11
12 def replaceWithElemConstOptimization(elemConsts, problem, sudoku):
13         for i,row in enumerate(sudoku):
14                 for j, cell in enumerate(row):
15                         if cell != 0:
16                                 problem[i][j] = elemConsts[cell-1]
17                                 
18 def constantCellConstraint(csolverlb, solver, elemConsts, problem, sudoku):
19         for i,row in enumerate(sudoku):
20                 for j, cell in enumerate(row):
21                         if cell != 0:
22                                 csolverlb.addConstraint(solver, generateEqualityConstraint(csolverlb, solver, problem[i][j], elemConsts[cell-1]))
23
24 def generateEqualityConstraint(csolverlb, solver, e1, e2):
25         equals = csolverlb.createPredicateOperator(solver, c_uint(ps.CompOp.SATC_EQUALS))
26         inp = [e1,e2]
27         inputs = (c_void_p*len(inp))(*inp)
28         b = csolverlb.applyPredicate(solver,equals, inputs, c_uint(2))
29         return b
30         
31 def extractItemInSetOptimization(csolverlb, solver, sudoku, N):
32         sets =[ [[i for i in range(1, N+1)] for i in range(N)] for i in range (N)]
33         root = int(N**(0.5))
34         for i, row in enumerate(sudoku):
35                 for j, item in enumerate(row):
36                         if item != 0:
37                                 for k in range(N):
38                                         if item in sets[i][k]:
39                                                 sets[i][k].remove(item)
40                                 for k in range(N):
41                                         if item in sets[k][j]:
42                                                 sets[k][j].remove(item)
43                                 ii = (i/root)*root
44                                 jj = (j/root)*root
45                                 for k in range(N):
46                                         if item in sets[ii +k% root][ jj + k//root]:
47                                                 sets[ii +k% root][ jj + k//root].remove(item)                   
48         for i in range(N):
49                 for j in range(N):
50                         setSize = len(sets[i][j])
51                         setp = (c_long*setSize)(*sets[i][j])
52                         sets[i][j] = csolverlb.createSet(solver, c_uint(1), setp, c_uint(setSize))
53         
54         return np.array([[csolverlb.getElementVar(solver,sets[i][j]) for j in range(N)] for i in range(N)])
55
56 def generateSudokuConstraints(N, sudoku = None, serialize=False):
57         csolverlb = ps.loadCSolver()
58         solver = csolverlb.createCCSolver()
59         s1 = [ i for i in range(1, N+1)]
60         set1 = (c_long* len(s1))(*s1)
61         s1 = csolverlb.createSet(solver, c_uint(1), set1, c_uint(N))
62         problem = np.array([[csolverlb.getElementVar(solver,s1) for i in range(N)] for i in range(N)])# if sudoku is None else extractItemInSetOptimization(csolverlb, solver, sudoku, N)
63         elemConsts = [csolverlb.getElementConst(solver, c_uint(1), i) for i in range(1, N+1)]
64         
65                                 
66         def valid(cells):
67                 for i, ei in enumerate(cells):
68                         for j, ej in enumerate(cells):
69                                 if i < j:
70                                         si = csolverlb.getElementRange(solver, ei)
71                                         sj = csolverlb.getElementRange(solver,ej)
72                                         d = [si,sj]
73                                         domain = (c_void_p *len(d))(*d)
74                                         equals = csolverlb.createPredicateOperator(solver, c_uint(ps.CompOp.SATC_EQUALS))
75                                         inp = [ei,ej]
76                                         inputs = (c_void_p*len(inp))(*inp)
77                                         b = csolverlb.applyPredicate(solver,equals, inputs, c_uint(2))
78                                         b = csolverlb.applyLogicalOperationOne(solver, ps.LogicOps.SATC_NOT, b)
79                                         csolverlb.addConstraint(solver,b);
80
81
82         # ensure each cell at least has one value!
83 #       for i,row in enumerate(problem):
84 #               for j, elem in enumerate(row):
85 #                       constr = []
86 #                       for econst in elemConsts:
87 #                               s1 = csolverlb.getElementRange(solver, elem)
88 #                               sconst = csolverlb.getElementRange(solver,econst)
89 #                               d = [s1,sconst]
90 #                               domain = (c_void_p *len(d))(*d)
91 #                               equals = csolverlb.createPredicateOperator(solver, c_uint(ps.CompOp.SATC_EQUALS))
92 #                               inp = [elem,econst]
93 #                               inputs = (c_void_p*len(inp))(*inp)
94 #                               constr.append( csolverlb.applyPredicate(solver,equals, inputs, c_uint(2)))
95 #                       b = (c_void_p*len(constr))(*constr)
96 #                       b = csolverlb.applyLogicalOperation(solver, ps.LogicOps.SATC_OR, b, len(constr))
97 #                       csolverlb.addConstraint(solver,b);
98         
99         
100         #ensure each cell at least has one value
101         for i,row in enumerate(problem):
102                 for j, elem in enumerate(row):
103                         csolverlb.mustHaveValue(solver, elem)
104
105         # ensure rows and columns have distinct values
106         for i in range( N):
107                 valid(problem[:,i])
108                 valid(problem[i,:])
109         
110         # ensure each block has distinct values
111         root = int(N**(0.5))
112         collections = [ root*i for i in range(root)]
113         for i in collections:
114                 for j in collections:
115                         valid([problem[i + k % root, j + k // root] for k in range(N)])
116
117         
118         # Is it a sudoku to solve?
119         if sudoku is not None:
120 #               replaceWithElemConstOptimization(elemConsts, problem, sudoku)
121                 constantCellConstraint(csolverlb, solver, elemConsts, problem, sudoku)                                          
122
123 #       csolverlb.printConstraints(solver);     
124         #Serializing the problem before solving it ....
125         if serialize:
126                 csolverlb.serialize(solver)
127         if csolverlb.solve(solver) != 1:
128                 print "Problem is unsolvable!"
129                 sys.exit(1)
130         result = [[0 for i in range(N)] for i in range(N)]
131         for i,row in enumerate(problem):
132                 for j, elem in enumerate(row):
133                         result[i][j] = csolverlb.getElementValue(solver, elem)
134         return result
135