56fa2f6e45b2f299ba104f4458c8402eaaa873c2
[Benchmarks_CSolver.git] / nqueens / nqueens.cc
1 #include <vector>
2 #include <stack>
3 #include <sstream>
4 #include <iostream>
5 #include <fstream>
6 #include <cstdio>
7 #include <cstdlib>
8 #include "inc_solver.h"
9 #include "solver_interface.h"
10 #include "csolver.h"
11 #include "common.h"
12 #include <algorithm>
13 #include <ctime>
14
15 using namespace std;
16
17 void EqualOneToCNF(vector<int> literals, vector< vector<int> > & cnf){
18         int N = literals.size();
19         cnf.push_back(literals);
20
21         vector<int> dnf; 
22         // this one is a->~b, a->~c, ... 
23         for (int j=0; j<N-1; j++){
24                 for (int k=j+1; k<N; k++){
25                         dnf.push_back(-literals[j]);
26                         dnf.push_back(-literals[k]);
27                         cnf.push_back(dnf);
28                         dnf.clear();
29                 }
30         }
31 }
32
33 void LessEqualOneToCNF(vector<int> literals, vector< vector<int> > & cnf){
34         int N = literals.size();
35         vector<int> dnf; 
36         // this one is a->~b, a->~c, ... 
37         for (int j=0; j<N-1; j++){
38                 for (int k=j+1; k<N; k++){
39                         dnf.push_back(-literals[j]);
40                         dnf.push_back(-literals[k]);
41                         cnf.push_back(dnf);
42                         dnf.clear();
43                 }
44         }
45 }
46
47 bool validateSolution(int N, int *table, int size){
48         for(int k=0; k<size; k++){
49                 if(table[k]>0){
50                         int row = k/N;
51                         int col = k%N;
52                         for (int j= row*N; j<(row+1)*N; j++)
53                                 if(j!=k && table[j] >0){
54                                         return false;
55                                 }       
56                         for(int j=0; j<N; j++){
57                                 int indx = j*N+col;
58                                 if(indx !=k && table[indx]>0){
59                                         return false;
60                                 }
61                         }
62                         
63                         int i=row;
64                         int j = col;
65                         while( i>0 && j>0){
66                                 int indx = i--*N+j--;
67                                 if(k!=indx && table[indx]>0){
68                                         return false;
69                                 }
70                         }
71                         i=row;
72                         j=col;
73                         while(i>0 && j<N){
74                                 int indx=i--*N+j++;
75                                 if(k!=indx && table[indx]>0){
76                                         return false;
77                                 }
78                         }
79                         i=row;
80                         j=col;
81                         while(i<N && j>0){
82                                 int indx = i++*N+j--;
83                                 if(k!=indx && table[indx]>0){
84                                         return false;
85                                 }
86                         }
87                         i=row;
88                         j=col;
89                         while(i<N && j<N){
90                                 int indx = i++*N+j++;
91                                 if(k!=indx && table[indx]>0){
92                                         return false;
93                                 }
94                         }
95                         
96                 }
97         }
98         return true;
99 }
100
101 void printValidationStatus(int N, int *table, int size){
102         if(validateSolution(N, table, size)){
103                 printf("***CORRECT****\n");
104         }else{
105                 printf("***WRONG******\n");
106         }
107
108 }
109
110 void printSolution(int N, int *table, int size){
111         for(int i=0; i<size; i++){
112                 if(table[i] > 0){
113                         printf("Q");
114                 }else{
115                         printf("*");
116                 }
117                 if((i+1)%N==0){
118                         printf("\n");
119                 }
120         }
121         printf("\n");
122 }
123
124 void originalNqueensEncoding(int N){
125         int numVars = N*N;
126         int kk=1;
127         int ** VarName = new int * [N];
128         for (int i=0; i<N; i++){
129                 VarName[i] = new int [N];
130                 for (int j=0; j<N; j++){
131                         VarName[i][j] = kk++;
132                 }
133         }
134
135         int numFormula = 0;
136         vector< vector<int> > cnf;
137
138         vector<int> vars;
139         
140         // generator formula per row
141         // v1 + v2 + v3 + v4 + ... = 1r 
142         for (int i=0; i<N; i++){
143                 for (int j=0; j<N; j++){
144                         vars.push_back(VarName[i][j]);
145                 }
146                 EqualOneToCNF(vars, cnf);
147                 vars.clear();
148         }
149
150         // generator formula per col
151         // v1 + v2 + v3 + v4 + ... = 1r 
152         for (int i=0; i<N; i++){
153                 for (int j=0; j<N; j++){
154                         vars.push_back(VarName[j][i]);
155                 }
156                 EqualOneToCNF(vars, cnf);
157                 vars.clear();
158         }
159
160         // diagonal
161         for (int i=0; i<N-1; i++){
162                 for (int j=0; j<N-i; j++){
163                         vars.push_back(VarName[j][i+j]);
164                 }
165                 LessEqualOneToCNF(vars, cnf);
166                 vars.clear();
167         }
168         for (int i=1; i<N-1; i++){
169                 for (int j=0; j<N-i; j++){
170                         vars.push_back(VarName[j+i][j]);
171                 }
172                 LessEqualOneToCNF(vars, cnf);
173                 vars.clear();
174         }
175         for (int i=0; i<N-1; i++){
176                 for (int j=0; j<N-i; j++){
177                         vars.push_back(VarName[j][N-1 - (i+j)]);
178                 }
179                 LessEqualOneToCNF(vars, cnf);
180                 vars.clear();
181         }
182         for (int i=1; i<N-1; i++){
183                 for (int j=0; j<N-i; j++){
184                         vars.push_back(VarName[j+i][N-1-j]);
185                 }
186                 LessEqualOneToCNF(vars, cnf);
187                 vars.clear();
188         }
189         
190         IncrementalSolver *solver =allocIncrementalSolver();
191         
192         for (int i=0; i<cnf.size(); i++){
193                 addArrayClauseLiteral(solver, cnf[i].size(), cnf[i].data());
194         }
195         finishedClauses(solver);
196         int start_s=clock();
197         int result = solve(solver);
198         int stop_s=clock();
199         cout << "SAT Solving time: " << (stop_s-start_s)/double(CLOCKS_PER_SEC)*1000 << " ms" << endl;
200         switch(result){
201                 case IS_UNSAT:
202                         printf("Problem is unsat\n");
203                         break;
204                 case IS_SAT:{
205                         printSolution(N, &solver->solution[1], solver->solutionsize);
206                         printValidationStatus(N, &solver->solution[1], solver->solutionsize);
207                         break;
208                 }
209                 default:
210                         printf("Unknown results from SAT Solver...\n");
211                         
212         }
213         deleteIncrementalSolver(solver);
214 }
215
216
217 void csolverNQueensSub(int N){
218         CSolver *solver = new CSolver();
219         uint64_t domain[N];
220         for(int i=0; i<N; i++){
221                 domain[i] = i;
222         }
223         uint64_t range[2*N-1];
224         for(int i=0; i<2*N-1; i++){
225                 range[i] = i-N+1;
226         }
227         Set *domainSet = solver->createSet(1, domain, N);
228         Set *rangeSet = solver->createSet(1, range, 2*N-1);
229         vector<Element *> Xs;
230         vector<Element *> Ys;
231         for(int i=0; i<N; i++){
232                 Xs.push_back(solver->getElementVar(domainSet));
233                 Ys.push_back(solver->getElementVar(domainSet));
234         }
235         Set *d1[] = {domainSet, domainSet};
236         Function *sub = solver->createFunctionOperator(SATC_SUB, rangeSet, SATC_NOOVERFLOW);
237         //X shouldn't be equal
238         for(int i=0; i<N-1; i++){
239                 for(int j=i+1; j<N; j++ ){
240                         Element *e1x = Xs[i];
241                         Element *e2x = Xs[j];
242                         Predicate *eq = solver->createPredicateOperator(SATC_EQUALS);
243                         Element *inputs2 [] = {e1x, e2x};
244                         BooleanEdge equals = solver->applyPredicate(eq, inputs2, 2);
245                         solver->addConstraint(solver->applyLogicalOperation(SATC_NOT, equals));
246                 }
247         }
248         //Y shouldn't be equal
249         for(int i=0; i<N-1; i++){
250                 for(int j=i+1; j<N; j++ ){
251                         Element *e1y = Ys[i];
252                         Element *e2y = Ys[j];
253                         Predicate *eq = solver->createPredicateOperator(SATC_EQUALS);
254                         Element *inputs2 [] = {e1y, e2y};
255                         BooleanEdge equals = solver->applyPredicate(eq, inputs2, 2);
256                         solver->addConstraint(solver->applyLogicalOperation(SATC_NOT, equals));
257                 }
258         }
259         //vertical difference and horizontal difference shouldn't be equal  shouldn't be equal
260         BooleanEdge overflow = solver->getBooleanVar(2);
261         Set *d2[] = {rangeSet, rangeSet};
262         for(int i=0; i<N-1; i++){
263                 for(int j=i+1; j<N; j++ ){
264                         Element *e1y = Ys[i];
265                         Element *e2y = Ys[j];
266                         Element *e1x = Xs[i];
267                         Element *e2x = Xs[j];           
268                         Function *f1 = solver->createFunctionOperator(SATC_SUB, rangeSet, SATC_IGNORE);
269                         Element *in1[] = {e1x, e2x};
270                         Element *subx = solver->applyFunction(f1, in1, 2, overflow);
271                         Element *in2[] = {e1y, e2y};
272                         Element *suby = solver->applyFunction(f1, in2, 2, overflow);
273                         Predicate *eq = solver->createPredicateOperator(SATC_EQUALS);
274                         Element *inputs2 [] = {subx, suby};
275                         BooleanEdge equals = solver->applyPredicate(eq, inputs2, 2);
276                         solver->addConstraint(solver->applyLogicalOperation(SATC_NOT, equals));
277                 }
278         }
279
280 //      solver->serialize();
281         if (solver->solve() != 1){
282                 printf("Problem is Unsolvable ...\n");
283         }else {
284                 int table[N*N];
285                 memset( table, 0, N*N*sizeof(int) );
286                 for(int i=0; i<N; i++){
287                         uint x = solver->getElementValue(Xs[i]);
288                         uint y = solver->getElementValue(Ys[i]);
289 //                      printf("X=%d, Y=%d\n", x, y);
290                         ASSERT(N*x+y < N*N);
291                         table[N*x+y] = 1;
292                 }
293                 printSolution(N, table, N*N);
294                 printValidationStatus(N, table, N*N);
295         }
296         delete solver;
297 }
298
299 void atmostOneConstraint(CSolver *solver, vector<BooleanEdge> &constraints){
300         int size = constraints.size();
301         if(size <1){
302                 return;
303         } else if(size ==1){
304                 solver->addConstraint(constraints[0]);
305         }else{
306 //              solver->addConstraint(solver->applyLogicalOperation(SATC_OR, &constraints[0], size))
307                 for(int i=0; i<size-1; i++){
308                         for(int j=i+1; j<size; j++){
309                                 BooleanEdge const1 = solver->applyLogicalOperation(SATC_NOT, constraints[i]);
310                                 BooleanEdge const2 = solver->applyLogicalOperation(SATC_NOT, constraints[j]);
311                                 BooleanEdge array[] = {const1, const2};
312                                 solver->addConstraint( solver->applyLogicalOperation(SATC_OR, (BooleanEdge *)array, 2));
313                         }
314                 }
315         
316         }
317 }
318
319 void mustHaveValueConstraint(CSolver* solver, vector<Element*> &elems){
320         for(int i=0; i<elems.size(); i++){
321                 solver->mustHaveValue(elems[i]);
322         }
323 }
324
325 void differentInEachRow(CSolver* solver, int N, vector<Element*> &elems){
326         Predicate *eq = solver->createPredicateOperator(SATC_EQUALS);
327         for(int i=0; i<N-1; i++){
328                 for(int j=i+1; j<N; j++ ){
329                         Element *e1x = elems[i];
330                         Element *e2x = elems[j];
331                         Element *inputs2 [] = {e1x, e2x};
332                         BooleanEdge equals = solver->applyPredicate(eq, inputs2, 2);
333                         solver->addConstraint(solver->applyLogicalOperation(SATC_NOT, equals));
334                 }
335         }
336
337
338 }
339
340 void diagonallyDifferentConstraint(CSolver *solver, int N, vector<Element*> &elems){
341         Predicate *eq = solver->createPredicateOperator(SATC_EQUALS);
342         for(int i=N-1; i>0; i--){
343 //              cout << "i:" << i << "\t";
344                 vector<BooleanEdge> diagonals;
345                 for(int j=i; j>=0; j--){
346                         int index = i-j;
347                         Element* e1 = elems[index];
348 //                      cout << "e" << e1 <<"=" << j << ", ";
349                         Element* e2 = solver->getElementConst(2, (uint64_t) j);
350                         Element* in[] = {e1, e2};
351                         BooleanEdge equals = solver->applyPredicate(eq, in, 2);
352                         diagonals.push_back(equals);
353                         
354                 }
355 //              cout << endl;
356                 atmostOneConstraint(solver, diagonals);
357         }
358         for(int i=1; i< N-1; i++){
359 //              cout << "i:" << i << "\t";
360                 vector<BooleanEdge> diagonals;
361                 for(int j=i; j<N; j++){
362                         int index =N-1- (j-i);
363                         Element* e1 = elems[index];
364 //                      cout << "e" << e1 <<"=" << j << ", ";
365                         Element* e2 = solver->getElementConst(2, (uint64_t) j);
366                         Element* in[] = {e1, e2};
367                         BooleanEdge equals = solver->applyPredicate(eq, in, 2);
368                         diagonals.push_back(equals);
369                         
370                 }
371 //              cout << endl;
372                 atmostOneConstraint(solver, diagonals);
373
374         }
375         
376 }
377
378 void diagonallyDifferentConstraintBothDir(CSolver *solver, int N, vector<Element*> &elems){
379         diagonallyDifferentConstraint(solver, N, elems);
380         reverse(elems.begin(), elems.end());
381 //      cout << "Other Diagonal:" << endl;
382         diagonallyDifferentConstraint(solver, N, elems);
383
384
385
386 void csolverNQueens(int N){
387         if(N <=1){
388                 cout<<"Q" << endl;
389                 return;
390         }
391         CSolver *solver = new CSolver();
392         uint64_t domain[N];
393         for(int i=0; i<N; i++){
394                 domain[i] = i;
395         }
396         Set *domainSet = solver->createSet(1, domain, N);
397         vector<Element *> elems;
398         for(int i=0; i<N; i++){
399                 elems.push_back(solver->getElementVar(domainSet));
400         }
401         mustHaveValueConstraint(solver, elems);
402         differentInEachRow(solver, N, elems);
403         diagonallyDifferentConstraintBothDir(solver, N, elems);
404 //      solver->printConstraints();
405 //      solver->serialize();
406         if (solver->solve() != 1){
407                 printf("Problem is Unsolvable ...\n");
408         }else {
409                 int table[N*N];
410                 memset( table, 0, N*N*sizeof(int) );
411                 for(int i=0; i<N; i++){
412                         uint x = solver->getElementValue(elems[i]);
413 //                      printf("X=%d, Y=%d\n", x, i);
414                         ASSERT(N*x+i < N*N);
415                         table[N*x+i] = 1;
416                 }
417                 printSolution(N, table, N*N);
418                 printValidationStatus(N, table, N*N);
419         }
420         delete solver;
421 }
422
423
424
425 int main(int argc, char * argv[]){
426         if(argc < 2){
427                 printf("Two arguments are needed\n./nqueen <size> [--csolver]\n");
428                 exit(-1);
429         }
430         int N = atoi(argv[1]);
431         if(argc <3){
432                 printf("Running the original encoding ...\n");
433                 originalNqueensEncoding(N);
434         }else if( strcmp( argv[2], "--csolver") == 0 ){
435                 printf("Running the CSolver encoding ...\n");
436                 csolverNQueens(N);
437         }
438         return 0;
439 }