bug fix for nqueens
[Benchmarks_CSolver.git] / nqueens / nqueens.cc
1 #include <vector>
2 #include <stack>
3 #include <sstream>
4 #include <iostream>
5 #include <fstream>
6 #include <cstdio>
7 #include <cstdlib>
8 #include "inc_solver.h"
9 #include "solver_interface.h"
10 #include "csolver.h"
11 #include "common.h"
12 #include <algorithm>
13
14 using namespace std;
15
16 void EqualOneToCNF(vector<int> literals, vector< vector<int> > & cnf){
17         int N = literals.size();
18         cnf.push_back(literals);
19
20         vector<int> dnf; 
21         // this one is a->~b, a->~c, ... 
22         for (int j=0; j<N-1; j++){
23                 for (int k=j+1; k<N; k++){
24                         dnf.push_back(-literals[j]);
25                         dnf.push_back(-literals[k]);
26                         cnf.push_back(dnf);
27                         dnf.clear();
28                 }
29         }
30 }
31
32 void LessEqualOneToCNF(vector<int> literals, vector< vector<int> > & cnf){
33         int N = literals.size();
34         vector<int> dnf; 
35         // this one is a->~b, a->~c, ... 
36         for (int j=0; j<N-1; j++){
37                 for (int k=j+1; k<N; k++){
38                         dnf.push_back(-literals[j]);
39                         dnf.push_back(-literals[k]);
40                         cnf.push_back(dnf);
41                         dnf.clear();
42                 }
43         }
44 }
45
46 bool validateSolution(int N, int *table, int size){
47         for(int k=0; k<size; k++){
48                 if(table[k]>0){
49                         int row = k/N;
50                         int col = k%N;
51                         for (int j= row*N; j<(row+1)*N; j++)
52                                 if(j!=k && table[j] >0){
53                                         return false;
54                                 }       
55                         for(int j=0; j<N; j++){
56                                 int indx = j*N+col;
57                                 if(indx !=k && table[indx]>0){
58                                         return false;
59                                 }
60                         }
61                         
62                         int i=row;
63                         int j = col;
64                         while( i>0 && j>0){
65                                 int indx = i--*N+j--;
66                                 if(k!=indx && table[indx]>0){
67                                         return false;
68                                 }
69                         }
70                         i=row;
71                         j=col;
72                         while(i>0 && j<N){
73                                 int indx=i--*N+j++;
74                                 if(k!=indx && table[indx]>0){
75                                         return false;
76                                 }
77                         }
78                         i=row;
79                         j=col;
80                         while(i<N && j>0){
81                                 int indx = i++*N+j--;
82                                 if(k!=indx && table[indx]>0){
83                                         return false;
84                                 }
85                         }
86                         i=row;
87                         j=col;
88                         while(i<N && j<N){
89                                 int indx = i++*N+j++;
90                                 if(k!=indx && table[indx]>0){
91                                         return false;
92                                 }
93                         }
94                         
95                 }
96         }
97         return true;
98 }
99
100 void printValidationStatus(int N, int *table, int size){
101         if(validateSolution(N, table, size)){
102                 printf("***CORRECT****\n");
103         }else{
104                 printf("***WRONG******\n");
105         }
106
107 }
108
109 void printSolution(int N, int *table, int size){
110         for(int i=0; i<size; i++){
111                 if(table[i] > 0){
112                         printf("Q");
113                 }else{
114                         printf("*");
115                 }
116                 if((i+1)%N==0){
117                         printf("\n");
118                 }
119         }
120         printf("\n");
121 }
122
123 void originalNqueensEncoding(int N){
124         int numVars = N*N;
125         int kk=1;
126         int ** VarName = new int * [N];
127         for (int i=0; i<N; i++){
128                 VarName[i] = new int [N];
129                 for (int j=0; j<N; j++){
130                         VarName[i][j] = kk++;
131                 }
132         }
133
134         int numFormula = 0;
135         vector< vector<int> > cnf;
136
137         vector<int> vars;
138         
139         // generator formula per row
140         // v1 + v2 + v3 + v4 + ... = 1r 
141         for (int i=0; i<N; i++){
142                 for (int j=0; j<N; j++){
143                         vars.push_back(VarName[i][j]);
144                 }
145                 EqualOneToCNF(vars, cnf);
146                 vars.clear();
147         }
148
149         // generator formula per col
150         // v1 + v2 + v3 + v4 + ... = 1r 
151         for (int i=0; i<N; i++){
152                 for (int j=0; j<N; j++){
153                         vars.push_back(VarName[j][i]);
154                 }
155                 EqualOneToCNF(vars, cnf);
156                 vars.clear();
157         }
158
159         // diagonal
160         for (int i=0; i<N-1; i++){
161                 for (int j=0; j<N-i; j++){
162                         vars.push_back(VarName[j][i+j]);
163                 }
164                 LessEqualOneToCNF(vars, cnf);
165                 vars.clear();
166         }
167         for (int i=1; i<N-1; i++){
168                 for (int j=0; j<N-i; j++){
169                         vars.push_back(VarName[j+i][j]);
170                 }
171                 LessEqualOneToCNF(vars, cnf);
172                 vars.clear();
173         }
174         for (int i=0; i<N-1; i++){
175                 for (int j=0; j<N-i; j++){
176                         vars.push_back(VarName[j][N-1 - (i+j)]);
177                 }
178                 LessEqualOneToCNF(vars, cnf);
179                 vars.clear();
180         }
181         for (int i=1; i<N-1; i++){
182                 for (int j=0; j<N-i; j++){
183                         vars.push_back(VarName[j+i][N-1-j]);
184                 }
185                 LessEqualOneToCNF(vars, cnf);
186                 vars.clear();
187         }
188         
189         IncrementalSolver *solver =allocIncrementalSolver();
190         
191         for (int i=0; i<cnf.size(); i++){
192                 addArrayClauseLiteral(solver, cnf[i].size(), cnf[i].data());
193         }
194         finishedClauses(solver);
195         int result = solve(solver);
196         switch(result){
197                 case IS_UNSAT:
198                         printf("Problem is unsat\n");
199                         break;
200                 case IS_SAT:{
201                         printSolution(N, &solver->solution[1], solver->solutionsize);
202                         printValidationStatus(N, &solver->solution[1], solver->solutionsize);
203                         break;
204                 }
205                 default:
206                         printf("Unknown results from SAT Solver...\n");
207                         
208         }
209         deleteIncrementalSolver(solver);
210 }
211
212
213 void csolverNQueensSub(int N){
214         CSolver *solver = new CSolver();
215         uint64_t domain[N];
216         for(int i=0; i<N; i++){
217                 domain[i] = i;
218         }
219         uint64_t range[2*N-1];
220         for(int i=0; i<2*N-1; i++){
221                 range[i] = i-N+1;
222         }
223         Set *domainSet = solver->createSet(1, domain, N);
224         Set *rangeSet = solver->createSet(1, range, 2*N-1);
225         vector<Element *> Xs;
226         vector<Element *> Ys;
227         for(int i=0; i<N; i++){
228                 Xs.push_back(solver->getElementVar(domainSet));
229                 Ys.push_back(solver->getElementVar(domainSet));
230         }
231         Set *d1[] = {domainSet, domainSet};
232         Function *sub = solver->createFunctionOperator(SATC_SUB, rangeSet, SATC_NOOVERFLOW);
233         //X shouldn't be equal
234         for(int i=0; i<N-1; i++){
235                 for(int j=i+1; j<N; j++ ){
236                         Element *e1x = Xs[i];
237                         Element *e2x = Xs[j];
238                         Predicate *eq = solver->createPredicateOperator(SATC_EQUALS);
239                         Element *inputs2 [] = {e1x, e2x};
240                         BooleanEdge equals = solver->applyPredicate(eq, inputs2, 2);
241                         solver->addConstraint(solver->applyLogicalOperation(SATC_NOT, equals));
242                 }
243         }
244         //Y shouldn't be equal
245         for(int i=0; i<N-1; i++){
246                 for(int j=i+1; j<N; j++ ){
247                         Element *e1y = Ys[i];
248                         Element *e2y = Ys[j];
249                         Predicate *eq = solver->createPredicateOperator(SATC_EQUALS);
250                         Element *inputs2 [] = {e1y, e2y};
251                         BooleanEdge equals = solver->applyPredicate(eq, inputs2, 2);
252                         solver->addConstraint(solver->applyLogicalOperation(SATC_NOT, equals));
253                 }
254         }
255         //vertical difference and horizontal difference shouldn't be equal  shouldn't be equal
256         BooleanEdge overflow = solver->getBooleanVar(2);
257         Set *d2[] = {rangeSet, rangeSet};
258         for(int i=0; i<N-1; i++){
259                 for(int j=i+1; j<N; j++ ){
260                         Element *e1y = Ys[i];
261                         Element *e2y = Ys[j];
262                         Element *e1x = Xs[i];
263                         Element *e2x = Xs[j];           
264                         Function *f1 = solver->createFunctionOperator(SATC_SUB, rangeSet, SATC_IGNORE);
265                         Element *in1[] = {e1x, e2x};
266                         Element *subx = solver->applyFunction(f1, in1, 2, overflow);
267                         Element *in2[] = {e1y, e2y};
268                         Element *suby = solver->applyFunction(f1, in2, 2, overflow);
269                         Predicate *eq = solver->createPredicateOperator(SATC_EQUALS);
270                         Element *inputs2 [] = {subx, suby};
271                         BooleanEdge equals = solver->applyPredicate(eq, inputs2, 2);
272                         solver->addConstraint(solver->applyLogicalOperation(SATC_NOT, equals));
273                 }
274         }
275
276 //      solver->serialize();
277         if (solver->solve() != 1){
278                 printf("Problem is Unsolvable ...\n");
279         }else {
280                 int table[N*N];
281                 memset( table, 0, N*N*sizeof(int) );
282                 for(int i=0; i<N; i++){
283                         uint x = solver->getElementValue(Xs[i]);
284                         uint y = solver->getElementValue(Ys[i]);
285 //                      printf("X=%d, Y=%d\n", x, y);
286                         ASSERT(N*x+y < N*N);
287                         table[N*x+y] = 1;
288                 }
289                 printSolution(N, table, N*N);
290                 printValidationStatus(N, table, N*N);
291         }
292         delete solver;
293 }
294
295 void atmostOneConstraint(CSolver *solver, vector<BooleanEdge> &constraints){
296         int size = constraints.size();
297         if(size <1){
298                 return;
299         } else if(size ==1){
300                 solver->addConstraint(constraints[0]);
301         }else{
302 //              solver->addConstraint(solver->applyLogicalOperation(SATC_OR, &constraints[0], size))
303                 for(int i=0; i<size-1; i++){
304                         for(int j=i+1; j<size; j++){
305                                 BooleanEdge const1 = solver->applyLogicalOperation(SATC_NOT, constraints[i]);
306                                 BooleanEdge const2 = solver->applyLogicalOperation(SATC_NOT, constraints[j]);
307                                 BooleanEdge array[] = {const1, const2};
308                                 solver->addConstraint( solver->applyLogicalOperation(SATC_OR, (BooleanEdge *)array, 2));
309                         }
310                 }
311         
312         }
313 }
314
315 void differentInEachRow(CSolver* solver, int N, vector<Element*> &elems){
316         Predicate *eq = solver->createPredicateOperator(SATC_EQUALS);
317         for(int i=0; i<N-1; i++){
318                 for(int j=i+1; j<N; j++ ){
319                         Element *e1x = elems[i];
320                         Element *e2x = elems[j];
321                         Element *inputs2 [] = {e1x, e2x};
322                         BooleanEdge equals = solver->applyPredicate(eq, inputs2, 2);
323                         solver->addConstraint(solver->applyLogicalOperation(SATC_NOT, equals));
324                 }
325         }
326
327
328 }
329
330 void diagonallyDifferentConstraint(CSolver *solver, int N, vector<Element*> &elems){
331         Predicate *eq = solver->createPredicateOperator(SATC_EQUALS);
332         for(int i=N-1; i>0; i--){
333 //              cout << "i:" << i << "\t";
334                 vector<BooleanEdge> diagonals;
335                 for(int j=i; j>=0; j--){
336                         int index = i-j;
337                         Element* e1 = elems[index];
338 //                      cout << "e" << e1 <<"=" << j << ", ";
339                         Element* e2 = solver->getElementConst(2, (uint64_t) j);
340                         Element* in[] = {e1, e2};
341                         BooleanEdge equals = solver->applyPredicate(eq, in, 2);
342                         diagonals.push_back(equals);
343                         
344                 }
345 //              cout << endl;
346                 atmostOneConstraint(solver, diagonals);
347         }
348         for(int i=1; i< N-1; i++){
349 //              cout << "i:" << i << "\t";
350                 vector<BooleanEdge> diagonals;
351                 for(int j=i; j<N; j++){
352                         int index =N-1- (j-i);
353                         Element* e1 = elems[index];
354 //                      cout << "e" << e1 <<"=" << j << ", ";
355                         Element* e2 = solver->getElementConst(2, (uint64_t) j);
356                         Element* in[] = {e1, e2};
357                         BooleanEdge equals = solver->applyPredicate(eq, in, 2);
358                         diagonals.push_back(equals);
359                         
360                 }
361 //              cout << endl;
362                 atmostOneConstraint(solver, diagonals);
363
364         }
365         
366 }
367
368 void diagonallyDifferentConstraintBothDir(CSolver *solver, int N, vector<Element*> &elems){
369         diagonallyDifferentConstraint(solver, N, elems);
370         reverse(elems.begin(), elems.end());
371 //      cout << "Other Diagonal:" << endl;
372         diagonallyDifferentConstraint(solver, N, elems);
373
374
375
376 void csolverNQueens(int N){
377         if(N <=1){
378                 cout<<"Q" << endl;
379                 return;
380         }
381         CSolver *solver = new CSolver();
382         uint64_t domain[N];
383         for(int i=0; i<N; i++){
384                 domain[i] = i;
385         }
386         Set *domainSet = solver->createSet(1, domain, N);
387         vector<Element *> elems;
388         for(int i=0; i<N; i++){
389                 elems.push_back(solver->getElementVar(domainSet));
390         }
391         
392         differentInEachRow(solver, N, elems);
393         diagonallyDifferentConstraintBothDir(solver, N, elems);
394 //      solver->printConstraints();
395 //      solver->serialize();
396         if (solver->solve() != 1){
397                 printf("Problem is Unsolvable ...\n");
398         }else {
399                 int table[N*N];
400                 memset( table, 0, N*N*sizeof(int) );
401                 for(int i=0; i<N; i++){
402                         uint x = solver->getElementValue(elems[i]);
403 //                      printf("X=%d, Y=%d\n", x, i);
404                         ASSERT(N*x+i < N*N);
405                         table[N*x+i] = 1;
406                 }
407                 printSolution(N, table, N*N);
408                 printValidationStatus(N, table, N*N);
409         }
410         delete solver;
411 }
412
413
414
415 int main(int argc, char * argv[]){
416         if(argc < 2){
417                 printf("Two arguments are needed\n./nqueen <size> [--csolver]\n");
418                 exit(-1);
419         }
420         int N = atoi(argv[1]);
421         if(argc <3){
422                 printf("Running the original encoding ...\n");
423                 originalNqueensEncoding(N);
424         }else if( strcmp( argv[2], "--csolver") == 0 ){
425                 printf("Running the CSolver encoding ...\n");
426                 csolverNQueens(N);
427         }
428         return 0;
429 }