Bug fix for onehot encoding
[Benchmarks_CSolver.git] / nqueens / nqueens.cc
1 #include <vector>
2 #include <stack>
3 #include <sstream>
4 #include <iostream>
5 #include <fstream>
6 #include <cstdio>
7 #include <cstdlib>
8 #include "inc_solver.h"
9 #include "solver_interface.h"
10 #include "csolver.h"
11 #include "common.h"
12 #include <algorithm>
13
14 using namespace std;
15
16 void EqualOneToCNF(vector<int> literals, vector< vector<int> > & cnf){
17         int N = literals.size();
18         cnf.push_back(literals);
19
20         vector<int> dnf; 
21         // this one is a->~b, a->~c, ... 
22         for (int j=0; j<N-1; j++){
23                 for (int k=j+1; k<N; k++){
24                         dnf.push_back(-literals[j]);
25                         dnf.push_back(-literals[k]);
26                         cnf.push_back(dnf);
27                         dnf.clear();
28                 }
29         }
30 }
31
32 void LessEqualOneToCNF(vector<int> literals, vector< vector<int> > & cnf){
33         int N = literals.size();
34         vector<int> dnf; 
35         // this one is a->~b, a->~c, ... 
36         for (int j=0; j<N-1; j++){
37                 for (int k=j+1; k<N; k++){
38                         dnf.push_back(-literals[j]);
39                         dnf.push_back(-literals[k]);
40                         cnf.push_back(dnf);
41                         dnf.clear();
42                 }
43         }
44 }
45
46 bool validateSolution(int N, int *table, int size){
47         for(int k=0; k<size; k++){
48                 if(table[k]>0){
49                         int row = k/N;
50                         int col = k%N;
51                         for (int j= row*N; j<(row+1)*N; j++)
52                                 if(j!=k && table[j] >0){
53                                         return false;
54                                 }       
55                         for(int j=0; j<N; j++){
56                                 int indx = j*N+col;
57                                 if(indx !=k && table[indx]>0){
58                                         return false;
59                                 }
60                         }
61                         
62                         int i=row;
63                         int j = col;
64                         while( i>0 && j>0){
65                                 int indx = i--*N+j--;
66                                 if(k!=indx && table[indx]>0){
67                                         return false;
68                                 }
69                         }
70                         i=row;
71                         j=col;
72                         while(i>0 && j<N){
73                                 int indx=i--*N+j++;
74                                 if(k!=indx && table[indx]>0){
75                                         return false;
76                                 }
77                         }
78                         i=row;
79                         j=col;
80                         while(i<N && j>0){
81                                 int indx = i++*N+j--;
82                                 if(k!=indx && table[indx]>0){
83                                         return false;
84                                 }
85                         }
86                         i=row;
87                         j=col;
88                         while(i<N && j<N){
89                                 int indx = i++*N+j++;
90                                 if(k!=indx && table[indx]>0){
91                                         return false;
92                                 }
93                         }
94                         
95                 }
96         }
97         return true;
98 }
99
100 void printValidationStatus(int N, int *table, int size){
101         if(validateSolution(N, table, size)){
102                 printf("***CORRECT****\n");
103         }else{
104                 printf("***WRONG******\n");
105         }
106
107 }
108
109 void printSolution(int N, int *table, int size){
110         for(int i=0; i<size; i++){
111                 if(table[i] > 0){
112                         printf("Q");
113                 }else{
114                         printf("*");
115                 }
116                 if((i+1)%N==0){
117                         printf("\n");
118                 }
119         }
120         printf("\n");
121 }
122
123 void originalNqueensEncoding(int N){
124         int numVars = N*N;
125         int kk=1;
126         int ** VarName = new int * [N];
127         for (int i=0; i<N; i++){
128                 VarName[i] = new int [N];
129                 for (int j=0; j<N; j++){
130                         VarName[i][j] = kk++;
131                 }
132         }
133
134         int numFormula = 0;
135         vector< vector<int> > cnf;
136
137         vector<int> vars;
138         
139         // generator formula per row
140         // v1 + v2 + v3 + v4 + ... = 1r 
141         for (int i=0; i<N; i++){
142                 for (int j=0; j<N; j++){
143                         vars.push_back(VarName[i][j]);
144                 }
145                 EqualOneToCNF(vars, cnf);
146                 vars.clear();
147         }
148
149         // generator formula per col
150         // v1 + v2 + v3 + v4 + ... = 1r 
151         for (int i=0; i<N; i++){
152                 for (int j=0; j<N; j++){
153                         vars.push_back(VarName[j][i]);
154                 }
155                 EqualOneToCNF(vars, cnf);
156                 vars.clear();
157         }
158
159         // diagonal
160         for (int i=0; i<N-1; i++){
161                 for (int j=0; j<N-i; j++){
162                         vars.push_back(VarName[j][i+j]);
163                 }
164                 LessEqualOneToCNF(vars, cnf);
165                 vars.clear();
166         }
167         for (int i=1; i<N-1; i++){
168                 for (int j=0; j<N-i; j++){
169                         vars.push_back(VarName[j+i][j]);
170                 }
171                 LessEqualOneToCNF(vars, cnf);
172                 vars.clear();
173         }
174         for (int i=0; i<N-1; i++){
175                 for (int j=0; j<N-i; j++){
176                         vars.push_back(VarName[j][N-1 - (i+j)]);
177                 }
178                 LessEqualOneToCNF(vars, cnf);
179                 vars.clear();
180         }
181         for (int i=1; i<N-1; i++){
182                 for (int j=0; j<N-i; j++){
183                         vars.push_back(VarName[j+i][N-1-j]);
184                 }
185                 LessEqualOneToCNF(vars, cnf);
186                 vars.clear();
187         }
188         
189         IncrementalSolver *solver =allocIncrementalSolver();
190         
191         for (int i=0; i<cnf.size(); i++){
192                 addArrayClauseLiteral(solver, cnf[i].size(), cnf[i].data());
193         }
194         finishedClauses(solver);
195         int result = solve(solver);
196         switch(result){
197                 case IS_UNSAT:
198                         printf("Problem is unsat\n");
199                         break;
200                 case IS_SAT:{
201                         printSolution(N, &solver->solution[1], solver->solutionsize);
202                         printValidationStatus(N, &solver->solution[1], solver->solutionsize);
203                         break;
204                 }
205                 default:
206                         printf("Unknown results from SAT Solver...\n");
207                         
208         }
209         deleteIncrementalSolver(solver);
210 }
211
212
213 void csolverNQueensSub(int N){
214         CSolver *solver = new CSolver();
215         uint64_t domain[N];
216         for(int i=0; i<N; i++){
217                 domain[i] = i;
218         }
219         uint64_t range[2*N-1];
220         for(int i=0; i<2*N-1; i++){
221                 range[i] = i-N+1;
222         }
223         Set *domainSet = solver->createSet(1, domain, N);
224         Set *rangeSet = solver->createSet(1, range, 2*N-1);
225         vector<Element *> Xs;
226         vector<Element *> Ys;
227         for(int i=0; i<N; i++){
228                 Xs.push_back(solver->getElementVar(domainSet));
229                 Ys.push_back(solver->getElementVar(domainSet));
230         }
231         Set *d1[] = {domainSet, domainSet};
232         Function *sub = solver->createFunctionOperator(SATC_SUB, rangeSet, SATC_NOOVERFLOW);
233         //X shouldn't be equal
234         for(int i=0; i<N-1; i++){
235                 for(int j=i+1; j<N; j++ ){
236                         Element *e1x = Xs[i];
237                         Element *e2x = Xs[j];
238                         Predicate *eq = solver->createPredicateOperator(SATC_EQUALS);
239                         Element *inputs2 [] = {e1x, e2x};
240                         BooleanEdge equals = solver->applyPredicate(eq, inputs2, 2);
241                         solver->addConstraint(solver->applyLogicalOperation(SATC_NOT, equals));
242                 }
243         }
244         //Y shouldn't be equal
245         for(int i=0; i<N-1; i++){
246                 for(int j=i+1; j<N; j++ ){
247                         Element *e1y = Ys[i];
248                         Element *e2y = Ys[j];
249                         Predicate *eq = solver->createPredicateOperator(SATC_EQUALS);
250                         Element *inputs2 [] = {e1y, e2y};
251                         BooleanEdge equals = solver->applyPredicate(eq, inputs2, 2);
252                         solver->addConstraint(solver->applyLogicalOperation(SATC_NOT, equals));
253                 }
254         }
255         //vertical difference and horizontal difference shouldn't be equal  shouldn't be equal
256         BooleanEdge overflow = solver->getBooleanVar(2);
257         Set *d2[] = {rangeSet, rangeSet};
258         for(int i=0; i<N-1; i++){
259                 for(int j=i+1; j<N; j++ ){
260                         Element *e1y = Ys[i];
261                         Element *e2y = Ys[j];
262                         Element *e1x = Xs[i];
263                         Element *e2x = Xs[j];           
264                         Function *f1 = solver->createFunctionOperator(SATC_SUB, rangeSet, SATC_IGNORE);
265                         Element *in1[] = {e1x, e2x};
266                         Element *subx = solver->applyFunction(f1, in1, 2, overflow);
267                         Element *in2[] = {e1y, e2y};
268                         Element *suby = solver->applyFunction(f1, in2, 2, overflow);
269                         Predicate *eq = solver->createPredicateOperator(SATC_EQUALS);
270                         Element *inputs2 [] = {subx, suby};
271                         BooleanEdge equals = solver->applyPredicate(eq, inputs2, 2);
272                         solver->addConstraint(solver->applyLogicalOperation(SATC_NOT, equals));
273                 }
274         }
275
276 //      solver->serialize();
277         if (solver->solve() != 1){
278                 printf("Problem is Unsolvable ...\n");
279         }else {
280                 int table[N*N];
281                 memset( table, 0, N*N*sizeof(int) );
282                 for(int i=0; i<N; i++){
283                         uint x = solver->getElementValue(Xs[i]);
284                         uint y = solver->getElementValue(Ys[i]);
285 //                      printf("X=%d, Y=%d\n", x, y);
286                         ASSERT(N*x+y < N*N);
287                         table[N*x+y] = 1;
288                 }
289                 printSolution(N, table, N*N);
290                 printValidationStatus(N, table, N*N);
291         }
292         delete solver;
293 }
294
295 void atmostOneConstraint(CSolver *solver, vector<BooleanEdge> &constraints){
296         int size = constraints.size();
297         if(size <1){
298                 return;
299         } else if(size ==1){
300                 solver->addConstraint(constraints[0]);
301         }else{
302 //              solver->addConstraint(solver->applyLogicalOperation(SATC_OR, &constraints[0], size))
303                 for(int i=0; i<size-1; i++){
304                         for(int j=i+1; j<size; j++){
305                                 BooleanEdge const1 = solver->applyLogicalOperation(SATC_NOT, constraints[i]);
306                                 BooleanEdge const2 = solver->applyLogicalOperation(SATC_NOT, constraints[j]);
307                                 BooleanEdge array[] = {const1, const2};
308                                 solver->addConstraint( solver->applyLogicalOperation(SATC_OR, (BooleanEdge *)array, 2));
309                         }
310                 }
311         
312         }
313 }
314
315 void mustHaveValueConstraint(CSolver* solver, vector<Element*> &elems){
316         for(int i=0; i<elems.size(); i++){
317                 solver->mustHaveValue(elems[i]);
318         }
319 }
320
321 void differentInEachRow(CSolver* solver, int N, vector<Element*> &elems){
322         Predicate *eq = solver->createPredicateOperator(SATC_EQUALS);
323         for(int i=0; i<N-1; i++){
324                 for(int j=i+1; j<N; j++ ){
325                         Element *e1x = elems[i];
326                         Element *e2x = elems[j];
327                         Element *inputs2 [] = {e1x, e2x};
328                         BooleanEdge equals = solver->applyPredicate(eq, inputs2, 2);
329                         solver->addConstraint(solver->applyLogicalOperation(SATC_NOT, equals));
330                 }
331         }
332
333
334 }
335
336 void diagonallyDifferentConstraint(CSolver *solver, int N, vector<Element*> &elems){
337         Predicate *eq = solver->createPredicateOperator(SATC_EQUALS);
338         for(int i=N-1; i>0; i--){
339 //              cout << "i:" << i << "\t";
340                 vector<BooleanEdge> diagonals;
341                 for(int j=i; j>=0; j--){
342                         int index = i-j;
343                         Element* e1 = elems[index];
344 //                      cout << "e" << e1 <<"=" << j << ", ";
345                         Element* e2 = solver->getElementConst(2, (uint64_t) j);
346                         Element* in[] = {e1, e2};
347                         BooleanEdge equals = solver->applyPredicate(eq, in, 2);
348                         diagonals.push_back(equals);
349                         
350                 }
351 //              cout << endl;
352                 atmostOneConstraint(solver, diagonals);
353         }
354         for(int i=1; i< N-1; i++){
355 //              cout << "i:" << i << "\t";
356                 vector<BooleanEdge> diagonals;
357                 for(int j=i; j<N; j++){
358                         int index =N-1- (j-i);
359                         Element* e1 = elems[index];
360 //                      cout << "e" << e1 <<"=" << j << ", ";
361                         Element* e2 = solver->getElementConst(2, (uint64_t) j);
362                         Element* in[] = {e1, e2};
363                         BooleanEdge equals = solver->applyPredicate(eq, in, 2);
364                         diagonals.push_back(equals);
365                         
366                 }
367 //              cout << endl;
368                 atmostOneConstraint(solver, diagonals);
369
370         }
371         
372 }
373
374 void diagonallyDifferentConstraintBothDir(CSolver *solver, int N, vector<Element*> &elems){
375         diagonallyDifferentConstraint(solver, N, elems);
376         reverse(elems.begin(), elems.end());
377 //      cout << "Other Diagonal:" << endl;
378         diagonallyDifferentConstraint(solver, N, elems);
379
380
381
382 void csolverNQueens(int N){
383         if(N <=1){
384                 cout<<"Q" << endl;
385                 return;
386         }
387         CSolver *solver = new CSolver();
388         uint64_t domain[N];
389         for(int i=0; i<N; i++){
390                 domain[i] = i;
391         }
392         Set *domainSet = solver->createSet(1, domain, N);
393         vector<Element *> elems;
394         for(int i=0; i<N; i++){
395                 elems.push_back(solver->getElementVar(domainSet));
396         }
397         mustHaveValueConstraint(solver, elems);
398         differentInEachRow(solver, N, elems);
399         diagonallyDifferentConstraintBothDir(solver, N, elems);
400 //      solver->printConstraints();
401 //      solver->serialize();
402         if (solver->solve() != 1){
403                 printf("Problem is Unsolvable ...\n");
404         }else {
405                 int table[N*N];
406                 memset( table, 0, N*N*sizeof(int) );
407                 for(int i=0; i<N; i++){
408                         uint x = solver->getElementValue(elems[i]);
409 //                      printf("X=%d, Y=%d\n", x, i);
410                         ASSERT(N*x+i < N*N);
411                         table[N*x+i] = 1;
412                 }
413                 printSolution(N, table, N*N);
414                 printValidationStatus(N, table, N*N);
415         }
416         delete solver;
417 }
418
419
420
421 int main(int argc, char * argv[]){
422         if(argc < 2){
423                 printf("Two arguments are needed\n./nqueen <size> [--csolver]\n");
424                 exit(-1);
425         }
426         int N = atoi(argv[1]);
427         if(argc <3){
428                 printf("Running the original encoding ...\n");
429                 originalNqueensEncoding(N);
430         }else if( strcmp( argv[2], "--csolver") == 0 ){
431                 printf("Running the CSolver encoding ...\n");
432                 csolverNQueens(N);
433         }
434         return 0;
435 }