Symmetry breaking constraint for CSolver encoding
[Benchmarks_CSolver.git] / nqueens / nqueens.cc
1 #include <vector>
2 #include <stack>
3 #include <sstream>
4 #include <iostream>
5 #include <fstream>
6 #include <cstdio>
7 #include <cstdlib>
8 #include "inc_solver.h"
9 #include "solver_interface.h"
10 #include "csolver.h"
11 #include "common.h"
12 #include <algorithm>
13 #include <ctime>
14
15 using namespace std;
16
17 void EqualOneToCNF(vector<int> literals, vector< vector<int> > & cnf){
18         int N = literals.size();
19         cnf.push_back(literals);
20
21         vector<int> dnf; 
22         // this one is a->~b, a->~c, ... 
23         for (int j=0; j<N-1; j++){
24                 for (int k=j+1; k<N; k++){
25                         dnf.push_back(-literals[j]);
26                         dnf.push_back(-literals[k]);
27                         cnf.push_back(dnf);
28                         dnf.clear();
29                 }
30         }
31 }
32
33 void verticalSymmetryBreaking(vector<int> col1, vector<int> colN, vector<vector<int> > &cnf){
34         vector<int> dnf;
35         for(int i=0; i<col1.size(); i++){
36                 dnf.push_back(-col1[i]);
37                 for(int j=i+1; j<colN.size(); j++){
38                         dnf.push_back(colN[j]);
39                 }
40                 cnf.push_back(dnf);
41                 dnf.clear();
42         }
43                 
44
45 }
46
47 void Or( vector<int> literals, vector<vector<int> > & cnf){
48         int N= literals.size();
49         vector<int> dnf;
50         for( int i=0; i<N; i++){
51                 dnf.push_back(literals[i]);
52         }
53         cnf.push_back(dnf);
54 }
55
56 void LessEqualOneToCNF(vector<int> literals, vector< vector<int> > & cnf){
57         int N = literals.size();
58         vector<int> dnf; 
59         // this one is a->~b, a->~c, ... 
60         for (int j=0; j<N-1; j++){
61                 for (int k=j+1; k<N; k++){
62                         dnf.push_back(-literals[j]);
63                         dnf.push_back(-literals[k]);
64                         cnf.push_back(dnf);
65                         dnf.clear();
66                 }
67         }
68 }
69
70 bool validateSolution(int N, int *table, int size){
71         for(int k=0; k<size; k++){
72                 if(table[k]>0){
73                         int row = k/N;
74                         int col = k%N;
75                         for (int j= row*N; j<(row+1)*N; j++)
76                                 if(j!=k && table[j] >0){
77                                         return false;
78                                 }       
79                         for(int j=0; j<N; j++){
80                                 int indx = j*N+col;
81                                 if(indx !=k && table[indx]>0){
82                                         return false;
83                                 }
84                         }
85                         
86                         int i=row;
87                         int j = col;
88                         while( i>0 && j>0){
89                                 int indx = i--*N+j--;
90                                 if(k!=indx && table[indx]>0){
91                                         return false;
92                                 }
93                         }
94                         i=row;
95                         j=col;
96                         while(i>0 && j<N){
97                                 int indx=i--*N+j++;
98                                 if(k!=indx && table[indx]>0){
99                                         return false;
100                                 }
101                         }
102                         i=row;
103                         j=col;
104                         while(i<N && j>0){
105                                 int indx = i++*N+j--;
106                                 if(k!=indx && table[indx]>0){
107                                         return false;
108                                 }
109                         }
110                         i=row;
111                         j=col;
112                         while(i<N && j<N){
113                                 int indx = i++*N+j++;
114                                 if(k!=indx && table[indx]>0){
115                                         return false;
116                                 }
117                         }
118                         
119                 }
120         }
121         return true;
122 }
123
124 void printValidationStatus(int N, int *table, int size){
125         if(validateSolution(N, table, size)){
126                 printf("***CORRECT****\n");
127         }else{
128                 printf("***WRONG******\n");
129         }
130
131 }
132
133 void printSolution(int N, int *table, int size){
134         for(int i=0; i<size; i++){
135                 if(table[i] > 0){
136                         printf("Q");
137                 }else{
138                         printf("*");
139                 }
140                 if((i+1)%N==0){
141                         printf("\n");
142                 }
143         }
144         printf("\n");
145 }
146
147 void originalNqueensEncoding(int N){
148         int numVars = N*N;
149         int kk=1;
150         int ** VarName = new int * [N];
151         for (int i=0; i<N; i++){
152                 VarName[i] = new int [N];
153                 for (int j=0; j<N; j++){
154                         VarName[i][j] = kk++;
155                 }
156         }
157
158         int numFormula = 0;
159         vector< vector<int> > cnf;
160
161         vector<int> vars;
162         
163         // generator formula per row
164         // v1 + v2 + v3 + v4 + ... = 1r 
165         for (int i=0; i<N; i++){
166                 for (int j=0; j<N; j++){
167                         vars.push_back(VarName[i][j]);
168                 }
169                 EqualOneToCNF(vars, cnf);
170                 vars.clear();
171         }
172
173         // generator formula per col
174         // v1 + v2 + v3 + v4 + ... = 1r 
175         for (int i=0; i<N; i++){
176                 for (int j=0; j<N; j++){
177                         vars.push_back(VarName[j][i]);
178                 }
179                 EqualOneToCNF(vars, cnf);
180                 vars.clear();
181         }
182
183         // diagonal
184         for (int i=0; i<N-1; i++){
185                 for (int j=0; j<N-i; j++){
186                         vars.push_back(VarName[j][i+j]);
187                 }
188                 LessEqualOneToCNF(vars, cnf);
189                 vars.clear();
190         }
191         for (int i=1; i<N-1; i++){
192                 for (int j=0; j<N-i; j++){
193                         vars.push_back(VarName[j+i][j]);
194                 }
195                 LessEqualOneToCNF(vars, cnf);
196                 vars.clear();
197         }
198         for (int i=0; i<N-1; i++){
199                 for (int j=0; j<N-i; j++){
200                         vars.push_back(VarName[j][N-1 - (i+j)]);
201                 }
202                 LessEqualOneToCNF(vars, cnf);
203                 vars.clear();
204         }
205         for (int i=1; i<N-1; i++){
206                 for (int j=0; j<N-i; j++){
207                         vars.push_back(VarName[j+i][N-1-j]);
208                 }
209                 LessEqualOneToCNF(vars, cnf);
210                 vars.clear();
211         }
212
213         //Symmetry breaking constraint
214         for (int i=0; i<N/2; i++){
215                 vars.push_back(VarName[0][i]);
216         }
217         Or(vars, cnf);
218         vector<int> lastCol;
219         for(int i=0; i<N; i++){
220                 lastCol.push_back(VarName[N-1][i]);
221         }
222         verticalSymmetryBreaking(vars, lastCol, cnf);
223         vars.clear();
224         //That's it ... Let's solve the problem ...
225         IncrementalSolver *solver =allocIncrementalSolver();
226         
227         for (int i=0; i<cnf.size(); i++){
228                 addArrayClauseLiteral(solver, cnf[i].size(), cnf[i].data());
229         }
230         finishedClauses(solver);
231         int start_s=clock();
232         int result = solve(solver);
233         int stop_s=clock();
234         cout << "SAT Solving time: " << (stop_s-start_s)/double(CLOCKS_PER_SEC)*1000 << " ms" << endl;
235         switch(result){
236                 case IS_UNSAT:
237                         printf("Problem is unsat\n");
238                         break;
239                 case IS_SAT:{
240                         printSolution(N, &solver->solution[1], solver->solutionsize);
241                         printValidationStatus(N, &solver->solution[1], solver->solutionsize);
242                         break;
243                 }
244                 default:
245                         printf("Unknown results from SAT Solver...\n");
246                         
247         }
248         deleteIncrementalSolver(solver);
249 }
250
251
252 void csolverNQueensSub(int N, bool serialize=false){
253         CSolver *solver = new CSolver();
254         uint64_t domain[N];
255         for(int i=0; i<N; i++){
256                 domain[i] = i;
257         }
258         uint64_t range[2*N-1];
259         for(int i=0; i<2*N-1; i++){
260                 range[i] = i-N+1;
261         }
262         Set *domainSet = solver->createSet(1, domain, N);
263         Set *rangeSet = solver->createSet(1, range, 2*N-1);
264         vector<Element *> Xs;
265         vector<Element *> Ys;
266         for(int i=0; i<N; i++){
267                 Xs.push_back(solver->getElementVar(domainSet));
268                 Ys.push_back(solver->getElementVar(domainSet));
269         }
270         Set *d1[] = {domainSet, domainSet};
271         Function *sub = solver->createFunctionOperator(SATC_SUB, rangeSet, SATC_NOOVERFLOW);
272         //X shouldn't be equal
273         for(int i=0; i<N-1; i++){
274                 for(int j=i+1; j<N; j++ ){
275                         Element *e1x = Xs[i];
276                         Element *e2x = Xs[j];
277                         Predicate *eq = solver->createPredicateOperator(SATC_EQUALS);
278                         Element *inputs2 [] = {e1x, e2x};
279                         BooleanEdge equals = solver->applyPredicate(eq, inputs2, 2);
280                         solver->addConstraint(solver->applyLogicalOperation(SATC_NOT, equals));
281                 }
282         }
283         //Y shouldn't be equal
284         for(int i=0; i<N-1; i++){
285                 for(int j=i+1; j<N; j++ ){
286                         Element *e1y = Ys[i];
287                         Element *e2y = Ys[j];
288                         Predicate *eq = solver->createPredicateOperator(SATC_EQUALS);
289                         Element *inputs2 [] = {e1y, e2y};
290                         BooleanEdge equals = solver->applyPredicate(eq, inputs2, 2);
291                         solver->addConstraint(solver->applyLogicalOperation(SATC_NOT, equals));
292                 }
293         }
294         //vertical difference and horizontal difference shouldn't be equal  shouldn't be equal
295         BooleanEdge overflow = solver->getBooleanVar(2);
296         Set *d2[] = {rangeSet, rangeSet};
297         for(int i=0; i<N-1; i++){
298                 for(int j=i+1; j<N; j++ ){
299                         Element *e1y = Ys[i];
300                         Element *e2y = Ys[j];
301                         Element *e1x = Xs[i];
302                         Element *e2x = Xs[j];           
303                         Function *f1 = solver->createFunctionOperator(SATC_SUB, rangeSet, SATC_IGNORE);
304                         Element *in1[] = {e1x, e2x};
305                         Element *subx = solver->applyFunction(f1, in1, 2, overflow);
306                         Element *in2[] = {e1y, e2y};
307                         Element *suby = solver->applyFunction(f1, in2, 2, overflow);
308                         Predicate *eq = solver->createPredicateOperator(SATC_EQUALS);
309                         Element *inputs2 [] = {subx, suby};
310                         BooleanEdge equals = solver->applyPredicate(eq, inputs2, 2);
311                         solver->addConstraint(solver->applyLogicalOperation(SATC_NOT, equals));
312                 }
313         }
314         if (serialize){
315                 solver->serialize();
316         }
317         if (solver->solve() != 1){
318                 printf("Problem is Unsolvable ...\n");
319         }else {
320                 int table[N*N];
321                 memset( table, 0, N*N*sizeof(int) );
322                 for(int i=0; i<N; i++){
323                         uint x = solver->getElementValue(Xs[i]);
324                         uint y = solver->getElementValue(Ys[i]);
325 //                      printf("X=%d, Y=%d\n", x, y);
326                         ASSERT(N*x+y < N*N);
327                         table[N*x+y] = 1;
328                 }
329                 printSolution(N, table, N*N);
330                 printValidationStatus(N, table, N*N);
331         }
332         delete solver;
333 }
334
335 void atmostOneConstraint(CSolver *solver, vector<BooleanEdge> &constraints){
336         int size = constraints.size();
337         if(size <1){
338                 return;
339         } else if(size ==1){
340                 solver->addConstraint(constraints[0]);
341         }else{
342 //              solver->addConstraint(solver->applyLogicalOperation(SATC_OR, &constraints[0], size))
343                 for(int i=0; i<size-1; i++){
344                         for(int j=i+1; j<size; j++){
345                                 BooleanEdge const1 = solver->applyLogicalOperation(SATC_NOT, constraints[i]);
346                                 BooleanEdge const2 = solver->applyLogicalOperation(SATC_NOT, constraints[j]);
347                                 BooleanEdge array[] = {const1, const2};
348                                 solver->addConstraint( solver->applyLogicalOperation(SATC_OR, (BooleanEdge *)array, 2));
349                         }
350                 }
351         
352         }
353 }
354
355 void mustHaveValueConstraint(CSolver* solver, vector<Element*> &elems){
356         for(int i=0; i<elems.size(); i++){
357                 solver->mustHaveValue(elems[i]);
358         }
359 }
360
361 void differentInEachRow(CSolver* solver, int N, vector<Element*> &elems){
362         Predicate *eq = solver->createPredicateOperator(SATC_EQUALS);
363         for(int i=0; i<N-1; i++){
364                 for(int j=i+1; j<N; j++ ){
365                         Element *e1x = elems[i];
366                         Element *e2x = elems[j];
367                         Element *inputs2 [] = {e1x, e2x};
368                         BooleanEdge equals = solver->applyPredicate(eq, inputs2, 2);
369                         solver->addConstraint(solver->applyLogicalOperation(SATC_NOT, equals));
370                 }
371         }
372
373
374 }
375
376 void oneQueenInEachRow(CSolver* solver, vector<Element*> &elems){
377         Predicate *eq = solver->createPredicateOperator(SATC_EQUALS);
378         int N = elems.size();
379         for(int i=0; i<N; i++){
380                 vector<BooleanEdge> rowConstr;
381                 for(int j=0; j<N; j++){
382                         Element* e1 = elems[j];
383                         Element* e2 = solver->getElementConst(3, (uint64_t) i);
384                         Element* in[] = {e1, e2};
385                         BooleanEdge equals = solver->applyPredicate(eq, in, 2);
386                         rowConstr.push_back(equals);
387                 }
388                 if(rowConstr.size()>0){
389                         solver->addConstraint(solver->applyLogicalOperation(SATC_OR, &rowConstr[0], rowConstr.size()) );
390                 }
391         }
392 }
393
394 void generateRowConstraints(CSolver* solver, int N, vector<Element*> &elems){
395         oneQueenInEachRow(solver, elems);
396         differentInEachRow(solver, N, elems);
397 }
398
399 void diagonallyDifferentConstraint(CSolver *solver, int N, vector<Element*> &elems){
400         Predicate *eq = solver->createPredicateOperator(SATC_EQUALS);
401         for(int i=N-1; i>0; i--){
402 //              cout << "i:" << i << "\t";
403                 vector<BooleanEdge> diagonals;
404                 for(int j=i; j>=0; j--){
405                         int index = i-j;
406                         Element* e1 = elems[index];
407 //                      cout << "e" << e1 <<"=" << j << ", ";
408                         Element* e2 = solver->getElementConst(2, (uint64_t) j);
409                         Element* in[] = {e1, e2};
410                         BooleanEdge equals = solver->applyPredicate(eq, in, 2);
411                         diagonals.push_back(equals);
412                         
413                 }
414 //              cout << endl;
415                 atmostOneConstraint(solver, diagonals);
416         }
417         for(int i=1; i< N-1; i++){
418 //              cout << "i:" << i << "\t";
419                 vector<BooleanEdge> diagonals;
420                 for(int j=i; j<N; j++){
421                         int index =N-1- (j-i);
422                         Element* e1 = elems[index];
423 //                      cout << "e" << e1 <<"=" << j << ", ";
424                         Element* e2 = solver->getElementConst(2, (uint64_t) j);
425                         Element* in[] = {e1, e2};
426                         BooleanEdge equals = solver->applyPredicate(eq, in, 2);
427                         diagonals.push_back(equals);
428                         
429                 }
430 //              cout << endl;
431                 atmostOneConstraint(solver, diagonals);
432
433         }
434         
435 }
436
437 void diagonallyDifferentConstraintBothDir(CSolver *solver, int N, vector<Element*> &elems){
438         diagonallyDifferentConstraint(solver, N, elems);
439         reverse(elems.begin(), elems.end());
440 //      cout << "Other Diagonal:" << endl;
441         diagonallyDifferentConstraint(solver, N, elems);
442
443
444 void symmetryBreakingConstraint(CSolver *solver, int N, vector<Element*>& elems){
445         Predicate *eq = solver->createPredicateOperator(SATC_EQUALS);
446         vector<BooleanEdge> constr;
447         for(int i=0; i<N/2; i++){
448                 Element *e1x = elems[0];
449                 Element *e2x = solver->getElementConst(2, (uint64_t)i);
450                 Element *inputs2 [] = {e1x, e2x};
451                 BooleanEdge equals = solver->applyPredicate(eq, inputs2, 2);
452                 constr.push_back( equals);
453         }
454         solver->addConstraint(solver->applyLogicalOperation(SATC_OR, &constr[0], constr.size()) );
455         
456         Predicate *lt = solver->createPredicateOperator(SATC_LT);
457         Element *e1x = elems[0];
458         Element *e2x = elems[N-1];
459         Element *inputs2 [] = {e1x, e2x};
460         BooleanEdge equals = solver->applyPredicate(lt, inputs2, 2);
461         solver->addConstraint(equals);
462
463 }
464
465 void csolverNQueens(int N, bool serialize=false){
466         if(N <=1){
467                 cout<<"Q" << endl;
468                 return;
469         }
470         CSolver *solver = new CSolver();
471         uint64_t domain[N];
472         for(int i=0; i<N; i++){
473                 domain[i] = i;
474         }
475         Set *domainSet = solver->createSet(1, domain, N);
476         vector<Element *> elems;
477         for(int i=0; i<N; i++){
478                 elems.push_back(solver->getElementVar(domainSet));
479         }
480         mustHaveValueConstraint(solver, elems);
481         generateRowConstraints(solver, N, elems);
482         diagonallyDifferentConstraintBothDir(solver, N, elems);
483         symmetryBreakingConstraint(solver, N, elems);
484 //      solver->printConstraints();
485         if(serialize){
486                 solver->serialize();
487         }
488         if (solver->solve() != 1){
489                 printf("Problem is Unsolvable ...\n");
490         }else {
491                 int table[N*N];
492                 memset( table, 0, N*N*sizeof(int) );
493                 for(int i=0; i<N; i++){
494                         uint x = solver->getElementValue(elems[i]);
495 //                      printf("X=%d, Y=%d\n", x, i);
496                         ASSERT(N*x+i < N*N);
497                         table[N*x+i] = 1;
498                 }
499                 printSolution(N, table, N*N);
500                 printValidationStatus(N, table, N*N);
501         }
502         delete solver;
503 }
504
505
506
507 int main(int argc, char * argv[]){
508         if(argc < 2){
509                 printf("Two arguments are needed\n./nqueen <size> [--csolver]\n");
510                 exit(-1);
511         }
512         int N = atoi(argv[1]);
513         if(argc <3){
514                 printf("Running the original encoding ...\n");
515                 originalNqueensEncoding(N);
516         }else if( strcmp( argv[2], "--csolver") == 0){
517                 printf("Running the CSolver encoding ...\n");
518                 csolverNQueens(N);
519         }else if (strcmp( argv[2], "--dump") == 0){
520                 printf("Running the CSolver encoding ...\n");
521                 csolverNQueens(N, true);
522         }else {
523                 printf("Unknown argument %s", argv[2]);
524         }
525         return 0;
526 }