Supporting at-most-one constraint for One Hot encoding
[satune.git] / src / Backend / satelemencoder.cc
index ed04630770b73b494faa9bee0d297e3ff2691d49..9a698e0e1f56b7230ac5a4e860b33dc214a5a31a 100644 (file)
@@ -258,14 +258,112 @@ void SATEncoder::generateBinaryIndexEncodingVars(ElementEncoding *encoding) {
        }
 }
 
+void SATEncoder::generateOneHotAtMostOne(ElementEncoding *encoding) {
+       if(encoding->numVars <= 1){
+               return;
+       }
+       AMOOneHot enc = (AMOOneHot)solver->getTuner()->getVarTunable(encoding->element->getRange()->getType(), ONEHOTATMOSTONE, &OneHotAtMostOneDesc);
+       switch (enc)
+       {
+       case ONEHOT_BINOMIAL:
+               model_print("Encode using bionomial encoding\n");
+               model_print("size = %u\n", encoding->numVars);
+               generateOneHotBinomialAtMostOne(encoding->variables, encoding->numVars);
+               break;
+       case ONEHOT_COMMANDER:
+               generateOneHotCommanderEncodingVars(encoding);
+               break;
+       case ONEHOT_SEQ_COUNTER:
+               generateOneHotSequentialAtMostOne(encoding);
+               break;
+       default:
+               ASSERT(0);
+               break;
+       }
+}
+
+void SATEncoder::generateOneHotBinomialAtMostOne(Edge *array, uint size, uint offset) {
+       for (uint i = offset; i < offset + size; i++) {
+               for (uint j = i + 1; j < offset + size; j++) {
+                       addConstraintCNF(cnf, constraintNegate(constraintAND2(cnf, array[i], array[j])));
+               }
+       }
+}
+
+void SATEncoder::generateOneHotCommanderAtMostOneRecursive(Edge *array, uint size) {
+       ASSERT(size > 1);
+       if (size <= SEQ_COUNTER_GROUP_SIZE) {
+               //Using binomial encoding
+               generateOneHotBinomialAtMostOne(array, size);
+       } else {
+               Edge commanders[size/SEQ_COUNTER_GROUP_SIZE + 1];
+               uint commanderSize = 0;
+               for(uint index = 0; index < size; index += SEQ_COUNTER_GROUP_SIZE) {
+                       uint groupSize = 0;
+                       if( (index + SEQ_COUNTER_GROUP_SIZE) < size) {
+                               groupSize = SEQ_COUNTER_GROUP_SIZE;
+                       } else {// The last group
+                               groupSize = size - index;
+                       }
+
+                       if(groupSize == 1) {
+                               commanders[commanderSize++] = array[index];
+                       } else {
+                               // 1. binomial encoding for items in the group
+                               generateOneHotBinomialAtMostOne(array, groupSize, index);
+                               // 2. if commander is true at least one item in the group is true
+                               Edge c = getNewVarSATEncoder();
+                               Edge carray[groupSize + 1];
+                               uint carraySize = 0;
+                               carray[carraySize++] = constraintNegate(c);
+                               for(uint i =index; i < index + groupSize; i++ ){
+                                       carray[carraySize++] = array[i];
+                               }
+                               addConstraintCNF(cnf, constraintOR(cnf, carraySize, carray));
+                               // 3. if commander is false, non of the items in the group can be true
+                               for(uint i =index; i < index + groupSize; i++ ){
+                                       addConstraintCNF(cnf, constraintOR2(cnf, constraintNegate(array[i]), c));
+                               }
+                               commanders[commanderSize++] = c;
+                       } 
+               }
+               ASSERT(commanderSize <= (size/SEQ_COUNTER_GROUP_SIZE + 1));
+               if(commanderSize > 1) {
+                       generateOneHotCommanderAtMostOneRecursive(commanders, commanderSize);
+               }
+       }
+}
+
+void SATEncoder::generateOneHotSequentialAtMostOne(ElementEncoding *encoding) {
+       model_print("At-Most-One constraint using sequential counter\n");
+       model_print("size = %u\n", encoding->numVars);
+       // for more information, look at "Towards an Optimal CNF Encoding of Boolean Cardinality Constraints" paper
+       ASSERT(encoding->numVars > 1);
+       Edge *array = encoding->variables;
+       uint size = encoding->numVars;
+       Edge s [size -1 ];
+       getArrayNewVarsSATEncoder(size-1, s);
+       addConstraintCNF(cnf, constraintOR2(cnf, constraintNegate(array[0]), s[0]));
+       addConstraintCNF(cnf, constraintOR2(cnf, constraintNegate(s[size-2]), constraintNegate(array[size-1])));
+       for (uint i=1; i < size -1 ; i++){
+               addConstraintCNF(cnf, constraintOR2(cnf, constraintNegate(array[i]), s[i]));
+               addConstraintCNF(cnf, constraintOR2(cnf, constraintNegate(s[i-1]), s[i]));
+               addConstraintCNF(cnf, constraintOR2(cnf, constraintNegate(array[i]), constraintNegate(s[i-1])));
+       }
+}
+
+void SATEncoder::generateOneHotCommanderEncodingVars(ElementEncoding *encoding) {
+       //For more detail look at paper "Efficient CNF Encoding for Selecting 1 from N Objects"
+       model_print("At-Most-One constraint using commander\n");
+       model_print("size = %u\n", encoding->numVars);
+       ASSERT(encoding->numVars > 1);
+       generateOneHotCommanderAtMostOneRecursive(encoding->variables, encoding->numVars);
+}
+
 void SATEncoder::generateOneHotEncodingVars(ElementEncoding *encoding) {
        allocElementConstraintVariables(encoding, encoding->encArraySize);
        getArrayNewVarsSATEncoder(encoding->numVars, encoding->variables);
-       for (uint i = 0; i < encoding->numVars; i++) {
-               for (uint j = i + 1; j < encoding->numVars; j++) {
-                       addConstraintCNF(cnf, constraintNegate(constraintAND2(cnf, encoding->variables[i], encoding->variables[j])));
-               }
-       }
+       generateOneHotAtMostOne(encoding);
        if (encoding->element->anyValue)
                addConstraintCNF(cnf, constraintOR(cnf, encoding->numVars, encoding->variables));
 }