Fixing element/set types for the tuner ...
[Benchmarks_CSolver.git] / nqueens / nqueens.cc
1 #include <vector>
2 #include <stack>
3 #include <sstream>
4 #include <iostream>
5 #include <fstream>
6 #include <cstdio>
7 #include <cstdlib>
8 #include "inc_solver.h"
9 #include "solver_interface.h"
10 #include "csolver.h"
11 #include "common.h"
12 #include <algorithm>
13 #include <ctime>
14
15 #define NANOSEC 1000000000.0
16
17 using namespace std;
18
19 void EqualOneToCNF(vector<int> literals, vector< vector<int> > & cnf){
20         int N = literals.size();
21         cnf.push_back(literals);
22
23         vector<int> dnf; 
24         // this one is a->~b, a->~c, ... 
25         for (int j=0; j<N-1; j++){
26                 for (int k=j+1; k<N; k++){
27                         dnf.push_back(-literals[j]);
28                         dnf.push_back(-literals[k]);
29                         cnf.push_back(dnf);
30                         dnf.clear();
31                 }
32         }
33 }
34
35 void verticalSymmetryBreaking(vector<int> col1, vector<int> colN, vector<vector<int> > &cnf){
36         vector<int> dnf;
37         for(int i=0; i<col1.size(); i++){
38                 dnf.push_back(-col1[i]);
39                 for(int j=i+1; j<colN.size(); j++){
40                         dnf.push_back(colN[j]);
41                 }
42                 cnf.push_back(dnf);
43                 dnf.clear();
44         }
45                 
46
47 }
48
49 void Or( vector<int> literals, vector<vector<int> > & cnf){
50         int N= literals.size();
51         vector<int> dnf;
52         for( int i=0; i<N; i++){
53                 dnf.push_back(literals[i]);
54         }
55         cnf.push_back(dnf);
56 }
57
58 void LessEqualOneToCNF(vector<int> literals, vector< vector<int> > & cnf){
59         int N = literals.size();
60         vector<int> dnf; 
61         // this one is a->~b, a->~c, ... 
62         for (int j=0; j<N-1; j++){
63                 for (int k=j+1; k<N; k++){
64                         dnf.push_back(-literals[j]);
65                         dnf.push_back(-literals[k]);
66                         cnf.push_back(dnf);
67                         dnf.clear();
68                 }
69         }
70 }
71
72 bool validateSolution(int N, int *table, int size){
73         for(int k=0; k<size; k++){
74                 if(table[k]>0){
75                         int row = k/N;
76                         int col = k%N;
77                         for (int j= row*N; j<(row+1)*N; j++)
78                                 if(j!=k && table[j] >0){
79                                         return false;
80                                 }       
81                         for(int j=0; j<N; j++){
82                                 int indx = j*N+col;
83                                 if(indx !=k && table[indx]>0){
84                                         return false;
85                                 }
86                         }
87                         
88                         int i=row;
89                         int j = col;
90                         while( i>0 && j>0){
91                                 int indx = i--*N+j--;
92                                 if(k!=indx && table[indx]>0){
93                                         return false;
94                                 }
95                         }
96                         i=row;
97                         j=col;
98                         while(i>0 && j<N){
99                                 int indx=i--*N+j++;
100                                 if(k!=indx && table[indx]>0){
101                                         return false;
102                                 }
103                         }
104                         i=row;
105                         j=col;
106                         while(i<N && j>0){
107                                 int indx = i++*N+j--;
108                                 if(k!=indx && table[indx]>0){
109                                         return false;
110                                 }
111                         }
112                         i=row;
113                         j=col;
114                         while(i<N && j<N){
115                                 int indx = i++*N+j++;
116                                 if(k!=indx && table[indx]>0){
117                                         return false;
118                                 }
119                         }
120                         
121                 }
122         }
123         return true;
124 }
125
126 void printValidationStatus(int N, int *table, int size){
127         if(validateSolution(N, table, size)){
128                 printf("***CORRECT****\n");
129         }else{
130                 printf("***WRONG******\n");
131         }
132
133 }
134
135 void printSolution(int N, int *table, int size){
136         for(int i=0; i<size; i++){
137                 if(table[i] > 0){
138                         printf("Q");
139                 }else{
140                         printf("*");
141                 }
142                 if((i+1)%N==0){
143                         printf("\n");
144                 }
145         }
146         printf("\n");
147 }
148
149
150 void originalNqueensEncoding(int N){
151         int numVars = N*N;
152         int kk=1;
153         int ** VarName = new int * [N];
154         for (int i=0; i<N; i++){
155                 VarName[i] = new int [N];
156                 for (int j=0; j<N; j++){
157                         VarName[i][j] = kk++;
158                 }
159         }
160
161         int numFormula = 0;
162         vector< vector<int> > cnf;
163
164         vector<int> vars;
165         
166         // generator formula per row
167         // v1 + v2 + v3 + v4 + ... = 1r 
168         for (int i=0; i<N; i++){
169                 for (int j=0; j<N; j++){
170                         vars.push_back(VarName[i][j]);
171                 }
172                 EqualOneToCNF(vars, cnf);
173                 vars.clear();
174         }
175
176         // generator formula per col
177         // v1 + v2 + v3 + v4 + ... = 1r 
178         for (int i=0; i<N; i++){
179                 for (int j=0; j<N; j++){
180                         vars.push_back(VarName[j][i]);
181                 }
182                 EqualOneToCNF(vars, cnf);
183                 vars.clear();
184         }
185
186         // diagonal
187         for (int i=0; i<N-1; i++){
188                 for (int j=0; j<N-i; j++){
189                         vars.push_back(VarName[j][i+j]);
190                 }
191                 LessEqualOneToCNF(vars, cnf);
192                 vars.clear();
193         }
194         for (int i=1; i<N-1; i++){
195                 for (int j=0; j<N-i; j++){
196                         vars.push_back(VarName[j+i][j]);
197                 }
198                 LessEqualOneToCNF(vars, cnf);
199                 vars.clear();
200         }
201         for (int i=0; i<N-1; i++){
202                 for (int j=0; j<N-i; j++){
203                         vars.push_back(VarName[j][N-1 - (i+j)]);
204                 }
205                 LessEqualOneToCNF(vars, cnf);
206                 vars.clear();
207         }
208         for (int i=1; i<N-1; i++){
209                 for (int j=0; j<N-i; j++){
210                         vars.push_back(VarName[j+i][N-1-j]);
211                 }
212                 LessEqualOneToCNF(vars, cnf);
213                 vars.clear();
214         }
215
216         //Symmetry breaking constraint
217         int mid = 1+ ((N-1)/2);
218         for (int i=0; i<mid; i++){
219                 vars.push_back(VarName[0][i]);
220         }
221         Or(vars, cnf);
222         vector<int> lastCol;
223         for(int i=0; i<N; i++){
224                 lastCol.push_back(VarName[N-1][i]);
225         }
226         verticalSymmetryBreaking(vars, lastCol, cnf);
227         vars.clear();
228         //That's it ... Let's solve the problem ...
229         IncrementalSolver *solver =allocIncrementalSolver();
230         
231         for (int i=0; i<cnf.size(); i++){
232                 addArrayClauseLiteral(solver, cnf[i].size(), cnf[i].data());
233         }
234         finishedClauses(solver);
235         long long start = getTimeNano();
236         int result = solve(solver);
237         long long stop = getTimeNano();
238         cout << "SAT Solving time: " << (stop-start)/NANOSEC << endl;
239         switch(result){
240                 case IS_UNSAT:
241                         printf("Problem is unsat\n");
242                         break;
243                 case IS_SAT:{
244                         printSolution(N, &solver->solution[1], solver->solutionsize);
245                         printValidationStatus(N, &solver->solution[1], solver->solutionsize);
246                         break;
247                 }
248                 default:
249                         printf("Unknown results from SAT Solver...\n");
250                         
251         }
252         deleteIncrementalSolver(solver);
253 }
254
255
256 void csolverNQueensSub(int N, bool serialize=false){
257         CSolver *solver = new CSolver();
258         uint64_t domain[N];
259         for(int i=0; i<N; i++){
260                 domain[i] = i;
261         }
262         uint64_t range[2*N-1];
263         for(int i=0; i<2*N-1; i++){
264                 range[i] = i-N+1;
265         }
266         Set *domainSet = solver->createSet(1, domain, N);
267         Set *rangeSet = solver->createSet(1, range, 2*N-1);
268         vector<Element *> Xs;
269         vector<Element *> Ys;
270         for(int i=0; i<N; i++){
271                 Xs.push_back(solver->getElementVar(domainSet));
272                 Ys.push_back(solver->getElementVar(domainSet));
273         }
274         Set *d1[] = {domainSet, domainSet};
275         Function *sub = solver->createFunctionOperator(SATC_SUB, rangeSet, SATC_NOOVERFLOW);
276         //X shouldn't be equal
277         for(int i=0; i<N-1; i++){
278                 for(int j=i+1; j<N; j++ ){
279                         Element *e1x = Xs[i];
280                         Element *e2x = Xs[j];
281                         Predicate *eq = solver->createPredicateOperator(SATC_EQUALS);
282                         Element *inputs2 [] = {e1x, e2x};
283                         BooleanEdge equals = solver->applyPredicate(eq, inputs2, 2);
284                         solver->addConstraint(solver->applyLogicalOperation(SATC_NOT, equals));
285                 }
286         }
287         //Y shouldn't be equal
288         for(int i=0; i<N-1; i++){
289                 for(int j=i+1; j<N; j++ ){
290                         Element *e1y = Ys[i];
291                         Element *e2y = Ys[j];
292                         Predicate *eq = solver->createPredicateOperator(SATC_EQUALS);
293                         Element *inputs2 [] = {e1y, e2y};
294                         BooleanEdge equals = solver->applyPredicate(eq, inputs2, 2);
295                         solver->addConstraint(solver->applyLogicalOperation(SATC_NOT, equals));
296                 }
297         }
298         //vertical difference and horizontal difference shouldn't be equal  shouldn't be equal
299         BooleanEdge overflow = solver->getBooleanVar(2);
300         Set *d2[] = {rangeSet, rangeSet};
301         for(int i=0; i<N-1; i++){
302                 for(int j=i+1; j<N; j++ ){
303                         Element *e1y = Ys[i];
304                         Element *e2y = Ys[j];
305                         Element *e1x = Xs[i];
306                         Element *e2x = Xs[j];           
307                         Function *f1 = solver->createFunctionOperator(SATC_SUB, rangeSet, SATC_IGNORE);
308                         Element *in1[] = {e1x, e2x};
309                         Element *subx = solver->applyFunction(f1, in1, 2, overflow);
310                         Element *in2[] = {e1y, e2y};
311                         Element *suby = solver->applyFunction(f1, in2, 2, overflow);
312                         Predicate *eq = solver->createPredicateOperator(SATC_EQUALS);
313                         Element *inputs2 [] = {subx, suby};
314                         BooleanEdge equals = solver->applyPredicate(eq, inputs2, 2);
315                         solver->addConstraint(solver->applyLogicalOperation(SATC_NOT, equals));
316                 }
317         }
318         if (serialize){
319                 solver->serialize();
320         }
321         if (solver->solve() != 1){
322                 printf("Problem is Unsolvable ...\n");
323         }else {
324                 int table[N*N];
325                 memset( table, 0, N*N*sizeof(int) );
326                 for(int i=0; i<N; i++){
327                         uint x = solver->getElementValue(Xs[i]);
328                         uint y = solver->getElementValue(Ys[i]);
329 //                      printf("X=%d, Y=%d\n", x, y);
330                         ASSERT(N*x+y < N*N);
331                         table[N*x+y] = 1;
332                 }
333                 printSolution(N, table, N*N);
334                 printValidationStatus(N, table, N*N);
335         }
336         delete solver;
337 }
338
339 void atmostOneConstraint(CSolver *solver, vector<BooleanEdge> &constraints){
340         int size = constraints.size();
341         if(size <1){
342                 return;
343         } else if(size ==1){
344                 solver->addConstraint(constraints[0]);
345         }else{
346 //              solver->addConstraint(solver->applyLogicalOperation(SATC_OR, &constraints[0], size))
347                 for(int i=0; i<size-1; i++){
348                         for(int j=i+1; j<size; j++){
349                                 BooleanEdge const1 = solver->applyLogicalOperation(SATC_NOT, constraints[i]);
350                                 BooleanEdge const2 = solver->applyLogicalOperation(SATC_NOT, constraints[j]);
351                                 BooleanEdge array[] = {const1, const2};
352                                 solver->addConstraint( solver->applyLogicalOperation(SATC_OR, (BooleanEdge *)array, 2));
353                         }
354                 }
355         
356         }
357 }
358
359 void mustHaveValueConstraint(CSolver* solver, vector<Element*> &elems){
360         for(int i=0; i<elems.size(); i++){
361                 solver->mustHaveValue(elems[i]);
362         }
363 }
364
365 void differentInEachRow(CSolver* solver, int N, vector<Element*> &elems){
366         Predicate *eq = solver->createPredicateOperator(SATC_EQUALS);
367         for(int i=0; i<N-1; i++){
368                 for(int j=i+1; j<N; j++ ){
369                         Element *e1x = elems[i];
370                         Element *e2x = elems[j];
371                         Element *inputs2 [] = {e1x, e2x};
372                         BooleanEdge equals = solver->applyPredicate(eq, inputs2, 2);
373                         solver->addConstraint(solver->applyLogicalOperation(SATC_NOT, equals));
374                 }
375         }
376
377
378 }
379
380 void oneQueenInEachRow(CSolver* solver, vector<Element*> &elems){
381         Predicate *eq = solver->createPredicateOperator(SATC_EQUALS);
382         int N = elems.size();
383         for(int i=0; i<N; i++){
384                 vector<BooleanEdge> rowConstr;
385                 for(int j=0; j<N; j++){
386                         Element* e1 = elems[j];
387                         Element* e2 = solver->getElementConst(1, (uint64_t) i);
388                         Element* in[] = {e1, e2};
389                         BooleanEdge equals = solver->applyPredicate(eq, in, 2);
390                         rowConstr.push_back(equals);
391                 }
392                 if(rowConstr.size()>0){
393                         solver->addConstraint(solver->applyLogicalOperation(SATC_OR, &rowConstr[0], rowConstr.size()) );
394                 }
395         }
396 }
397
398 void generateRowConstraints(CSolver* solver, int N, vector<Element*> &elems){
399         oneQueenInEachRow(solver, elems);
400         differentInEachRow(solver, N, elems);
401 }
402
403 void diagonallyDifferentConstraint(CSolver *solver, int N, vector<Element*> &elems){
404         Predicate *eq = solver->createPredicateOperator(SATC_EQUALS);
405         for(int i=N-1; i>0; i--){
406 //              cout << "i:" << i << "\t";
407                 vector<BooleanEdge> diagonals;
408                 for(int j=i; j>=0; j--){
409                         int index = i-j;
410                         Element* e1 = elems[index];
411 //                      cout << "e" << e1 <<"=" << j << ", ";
412                         Element* e2 = solver->getElementConst(1, (uint64_t) j);
413                         Element* in[] = {e1, e2};
414                         BooleanEdge equals = solver->applyPredicate(eq, in, 2);
415                         diagonals.push_back(equals);
416                         
417                 }
418 //              cout << endl;
419                 atmostOneConstraint(solver, diagonals);
420         }
421         for(int i=1; i< N-1; i++){
422 //              cout << "i:" << i << "\t";
423                 vector<BooleanEdge> diagonals;
424                 for(int j=i; j<N; j++){
425                         int index =N-1- (j-i);
426                         Element* e1 = elems[index];
427 //                      cout << "e" << e1 <<"=" << j << ", ";
428                         Element* e2 = solver->getElementConst(1, (uint64_t) j);
429                         Element* in[] = {e1, e2};
430                         BooleanEdge equals = solver->applyPredicate(eq, in, 2);
431                         diagonals.push_back(equals);
432                         
433                 }
434 //              cout << endl;
435                 atmostOneConstraint(solver, diagonals);
436
437         }
438         
439 }
440
441 void diagonallyDifferentConstraintBothDir(CSolver *solver, int N, vector<Element*> &elems){
442         diagonallyDifferentConstraint(solver, N, elems);
443         reverse(elems.begin(), elems.end());
444 //      cout << "Other Diagonal:" << endl;
445         diagonallyDifferentConstraint(solver, N, elems);
446
447
448 void symmetryBreakingConstraint(CSolver *solver, int N, vector<Element*>& elems){
449         Predicate *lt = solver->createPredicateOperator(SATC_LT);
450         int mid = N/2 + N%2;
451         Element *e1x = elems[0];
452         Element *e2x = solver->getElementConst(1, mid);
453         Element *inputs [] = {e1x, e2x};
454         solver->addConstraint(solver->applyPredicate(lt, inputs, 2));
455         
456         e2x = elems[N-1];
457         Element *inputs2 [] = {e1x, e2x};
458         BooleanEdge equals = solver->applyPredicate(lt, inputs2, 2);
459         solver->addConstraint(equals);
460
461 }
462
463 void csolverNQueens(int N, bool serialize=false){
464         if(N <=1){
465                 cout<<"Q" << endl;
466                 return;
467         }
468         CSolver *solver = new CSolver();
469         uint64_t domain[N];
470         for(int i=0; i<N; i++){
471                 domain[i] = i;
472         }
473         Set *domainSet = solver->createSet(1, domain, N);
474         vector<Element *> elems;
475         for(int i=0; i<N; i++){
476                 elems.push_back(solver->getElementVar(domainSet));
477         }
478         mustHaveValueConstraint(solver, elems);
479         generateRowConstraints(solver, N, elems);
480         diagonallyDifferentConstraintBothDir(solver, N, elems);
481         symmetryBreakingConstraint(solver, N, elems);
482 //      solver->printConstraints();
483         if(serialize){
484                 solver->serialize();
485         }
486         if (solver->solve() != 1){
487                 printf("Problem is Unsolvable ...\n");
488         }else {
489                 int table[N*N];
490                 memset( table, 0, N*N*sizeof(int) );
491                 for(int i=0; i<N; i++){
492                         uint x = solver->getElementValue(elems[i]);
493 //                      printf("X=%d, Y=%d\n", x, i);
494                         ASSERT(N*x+i < N*N);
495                         table[N*x+i] = 1;
496                 }
497                 printSolution(N, table, N*N);
498                 printValidationStatus(N, table, N*N);
499         }
500         delete solver;
501 }
502
503
504
505 int main(int argc, char * argv[]){
506         if(argc < 2){
507                 printf("Two arguments are needed\n./nqueen <size> [--csolver]\n");
508                 exit(-1);
509         }
510         int N = atoi(argv[1]);
511         if(argc <3){
512                 printf("Running the original encoding ...\n");
513                 originalNqueensEncoding(N);
514         }else if( strcmp( argv[2], "--csolver") == 0){
515                 printf("Running the CSolver encoding ...\n");
516                 csolverNQueens(N);
517         }else if (strcmp( argv[2], "--dump") == 0){
518                 printf("Running the CSolver encoding ...\n");
519                 csolverNQueens(N, true);
520         }else {
521                 printf("Unknown argument %s", argv[2]);
522         }
523         return 0;
524 }