Select copytoreg and copyfromreg nodes that have flag operands correctly.
[oota-llvm.git] / utils / TableGen / DAGISelEmitter.cpp
index 07d35d4b7d3042121f1b1a88c13bcc93e3493b43..b8077b361a50fff9b20a0030a9947f26b609fd26 100644 (file)
@@ -296,6 +296,7 @@ bool TreePatternNode::UpdateNodeType(unsigned char VT, TreePattern &TP) {
 
   if (isLeaf()) {
     dump();
+    std::cerr << " ";
     TP.error("Type inference contradiction found in node!");
   } else {
     TP.error("Type inference contradiction found in node " + 
@@ -951,14 +952,17 @@ void DAGISelEmitter::ParsePatternFragments(std::ostream &OS) {
 /// HandleUse - Given "Pat" a leaf in the pattern, check to see if it is an
 /// instruction input.  Return true if this is a real use.
 static bool HandleUse(TreePattern *I, TreePatternNode *Pat,
-                      std::map<std::string, TreePatternNode*> &InstInputs) {
+                      std::map<std::string, TreePatternNode*> &InstInputs,
+                      std::vector<Record*> &InstImpInputs) {
   // No name -> not interesting.
   if (Pat->getName().empty()) {
     if (Pat->isLeaf()) {
       DefInit *DI = dynamic_cast<DefInit*>(Pat->getLeafValue());
       if (DI && DI->getDef()->isSubClassOf("RegisterClass"))
         I->error("Input " + DI->getDef()->getName() + " must be named!");
-
+      else if (DI && DI->getDef()->isSubClassOf("Register")) {
+        InstImpInputs.push_back(DI->getDef());
+      }
     }
     return false;
   }
@@ -1004,9 +1008,11 @@ static bool HandleUse(TreePattern *I, TreePatternNode *Pat,
 void DAGISelEmitter::
 FindPatternInputsAndOutputs(TreePattern *I, TreePatternNode *Pat,
                             std::map<std::string, TreePatternNode*> &InstInputs,
-                            std::map<std::string, Record*> &InstResults) {
+                            std::map<std::string, Record*> &InstResults,
+                            std::vector<Record*> &InstImpInputs,
+                            std::vector<Record*> &InstImpResults) {
   if (Pat->isLeaf()) {
-    bool isUse = HandleUse(I, Pat, InstInputs);
+    bool isUse = HandleUse(I, Pat, InstInputs, InstImpInputs);
     if (!isUse && Pat->getTransformFn())
       I->error("Cannot specify a transform function for a non-input value!");
     return;
@@ -1016,14 +1022,15 @@ FindPatternInputsAndOutputs(TreePattern *I, TreePatternNode *Pat,
     for (unsigned i = 0, e = Pat->getNumChildren(); i != e; ++i) {
       if (Pat->getChild(i)->getExtType() == MVT::isVoid)
         I->error("Cannot have void nodes inside of patterns!");
-      FindPatternInputsAndOutputs(I, Pat->getChild(i), InstInputs, InstResults);
+      FindPatternInputsAndOutputs(I, Pat->getChild(i), InstInputs, InstResults,
+                                  InstImpInputs, InstImpResults);
     }
     
     // If this is a non-leaf node with no children, treat it basically as if
     // it were a leaf.  This handles nodes like (imm).
     bool isUse = false;
     if (Pat->getNumChildren() == 0)
-      isUse = HandleUse(I, Pat, InstInputs);
+      isUse = HandleUse(I, Pat, InstInputs, InstImpInputs);
     
     if (!isUse && Pat->getTransformFn())
       I->error("Cannot specify a transform function for a non-input value!");
@@ -1049,19 +1056,22 @@ FindPatternInputsAndOutputs(TreePattern *I, TreePatternNode *Pat,
     DefInit *Val = dynamic_cast<DefInit*>(Dest->getLeafValue());
     if (!Val)
       I->error("set destination should be a register!");
-    
-    if (!Val->getDef()->isSubClassOf("RegisterClass") &&
-        !Val->getDef()->isSubClassOf("Register"))
-      I->error("set destination should be a register!");
-    if (Dest->getName().empty())
-      I->error("set destination must have a name!");
-    if (InstResults.count(Dest->getName()))
-      I->error("cannot set '" + Dest->getName() +"' multiple times");
-    InstResults[Dest->getName()] = Val->getDef();
 
+    if (Val->getDef()->isSubClassOf("RegisterClass")) {
+      if (Dest->getName().empty())
+        I->error("set destination must have a name!");
+      if (InstResults.count(Dest->getName()))
+        I->error("cannot set '" + Dest->getName() +"' multiple times");
+      InstResults[Dest->getName()] = Val->getDef();
+    } else if (Val->getDef()->isSubClassOf("Register")) {
+      InstImpResults.push_back(Val->getDef());
+    } else {
+      I->error("set destination should be a register!");
+    }
+    
     // Verify and collect info from the computation.
     FindPatternInputsAndOutputs(I, Pat->getChild(i+NumValues),
-                                InstInputs, InstResults);
+                                InstInputs, InstResults, InstImpInputs, InstImpResults);
   }
 }
 
@@ -1135,8 +1145,11 @@ void DAGISelEmitter::ParseInstructions() {
       }
       
       // Create and insert the instruction.
+      std::vector<Record*> ImpResults;
+      std::vector<Record*> ImpOperands;
       Instructions.insert(std::make_pair(Instrs[i], 
-                            DAGInstruction(0, Results, Operands)));
+                          DAGInstruction(0, Results, Operands,
+                                         ImpResults, ImpOperands)));
       continue;  // no pattern.
     }
     
@@ -1157,6 +1170,9 @@ void DAGISelEmitter::ParseInstructions() {
     // InstResults - Keep track of all the virtual registers that are 'set'
     // in the instruction, including what reg class they are.
     std::map<std::string, Record*> InstResults;
+
+    std::vector<Record*> InstImpInputs;
+    std::vector<Record*> InstImpResults;
     
     // Verify that the top-level forms in the instruction are of void type, and
     // fill in the InstResults map.
@@ -1167,7 +1183,8 @@ void DAGISelEmitter::ParseInstructions() {
                  " void types");
 
       // Find inputs and outputs, and verify the structure of the uses/defs.
-      FindPatternInputsAndOutputs(I, Pat, InstInputs, InstResults);
+      FindPatternInputsAndOutputs(I, Pat, InstInputs, InstResults,
+                                  InstImpInputs, InstImpResults);
     }
 
     // Now that we have inputs and outputs of the pattern, inspect the operands
@@ -1256,7 +1273,7 @@ void DAGISelEmitter::ParseInstructions() {
       new TreePatternNode(I->getRecord(), ResultNodeOperands);
 
     // Create and insert the instruction.
-    DAGInstruction TheInst(I, Results, Operands);
+    DAGInstruction TheInst(I, Results, Operands, InstImpResults, InstImpInputs);
     Instructions.insert(std::make_pair(I->getRecord(), TheInst));
 
     // Use a temporary tree pattern to infer all types and make sure that the
@@ -1284,16 +1301,14 @@ void DAGISelEmitter::ParseInstructions() {
     }
     TreePatternNode *Pattern = I->getTree(0);
     TreePatternNode *SrcPattern;
-    if (TheInst.getNumResults() == 0) {
-      SrcPattern = Pattern;
-    } else {
-      if (Pattern->getOperator()->getName() != "set")
-        continue;  // Not a set (store or something?)
-    
+    if (Pattern->getOperator()->getName() == "set") {
       if (Pattern->getNumChildren() != 2)
         continue;  // Not a set of a single value (not handled so far)
 
       SrcPattern = Pattern->getChild(1)->clone();    
+    } else{
+      // Not a set (store or something?)
+      SrcPattern = Pattern;
     }
     
     std::string Reason;
@@ -1332,8 +1347,11 @@ void DAGISelEmitter::ParsePatterns() {
     {
       std::map<std::string, TreePatternNode*> InstInputs;
       std::map<std::string, Record*> InstResults;
+      std::vector<Record*> InstImpInputs;
+      std::vector<Record*> InstImpResults;
       FindPatternInputsAndOutputs(Pattern, Pattern->getOnlyTree(),
-                                  InstInputs, InstResults);
+                                  InstInputs, InstResults,
+                                  InstImpInputs, InstImpResults);
     }
     
     ListInit *LI = Patterns[i]->getValueAsListInit("ResultInstrs");
@@ -1643,7 +1661,8 @@ static const ComplexPattern *NodeGetComplexPattern(TreePatternNode *N,
 static unsigned getPatternSize(TreePatternNode *P, DAGISelEmitter &ISE) {
   assert(isExtIntegerVT(P->getExtType()) || 
          isExtFloatingPointVT(P->getExtType()) ||
-         P->getExtType() == MVT::isVoid && "Not a valid pattern node to size!");
+         P->getExtType() == MVT::isVoid ||
+         P->getExtType() == MVT::Flag && "Not a valid pattern node to size!");
   unsigned Size = 1;  // The node itself.
 
   // FIXME: This is a hack to statically increase the priority of patterns
@@ -1955,7 +1974,7 @@ public:
         OS << ";\n";
         OS << "      if (!" << Fn << "(" << Val;
         for (unsigned i = 0; i < NumRes; i++)
-          OS << " , Tmp" << i + ResNo;
+          OS << ", Tmp" << i + ResNo;
         OS << ")) goto P" << PatternNo << "Fail;\n";
         TmpNo = ResNo + NumRes;
       } else {
@@ -2025,8 +2044,8 @@ public:
           Ops.push_back(NumTemps[i].second + j);
       }
 
-      CodeGenInstruction &II =
-        ISE.getTargetInfo().getInstruction(Op->getName());
+      const CodeGenTarget &CGT = ISE.getTargetInfo();
+      CodeGenInstruction &II = CGT.getInstruction(Op->getName());
 
       // Emit all the chain and CopyToReg stuff.
       if (II.hasCtrlDep)
@@ -2034,12 +2053,20 @@ public:
       EmitCopyToRegs(Pattern, "N", II.hasCtrlDep);
 
       const DAGInstruction &Inst = ISE.getInstruction(Op);
+      unsigned NumImpResults =  Inst.getNumImpResults();
       unsigned NumResults = Inst.getNumResults();    
       unsigned ResNo = TmpNo++;
       if (!isRoot) {
         OS << "      SDOperand Tmp" << ResNo << " = CurDAG->getTargetNode("
-           << II.Namespace << "::" << II.TheDef->getName() << ", MVT::"
-           << getEnumName(N->getType());
+           << II.Namespace << "::" << II.TheDef->getName();
+        if (N->getType() != MVT::isVoid)
+          OS << ", MVT::" << getEnumName(N->getType());
+        for (unsigned i = 0; i < NumImpResults; i++) {
+          Record *ImpResult = Inst.getImpResult(i);
+          MVT::ValueType RVT = getRegisterValueType(ImpResult, CGT);
+          OS << ", MVT::" << getEnumName(RVT);
+        }
+
         unsigned LastOp = 0;
         for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
           LastOp = Ops[i];
@@ -2055,9 +2082,17 @@ public:
         OS << "      SDOperand Result = ";
         OS << "CurDAG->getTargetNode("
            << II.Namespace << "::" << II.TheDef->getName();
-        if (NumResults > 0) 
-          OS << ", MVT::" << getEnumName(N->getType()); // TODO: multiple results?
+        if (NumResults > 0) { 
+          // TODO: multiple results?
+          if (N->getType() != MVT::isVoid)
+            OS << ", MVT::" << getEnumName(N->getType());
+        }
         OS << ", MVT::Other";
+        for (unsigned i = 0; i < NumImpResults; i++) {
+          Record *ImpResult = Inst.getImpResult(i);
+          MVT::ValueType RVT = getRegisterValueType(ImpResult, CGT);
+          OS << ", MVT::" << getEnumName(RVT);
+        }
         for (unsigned i = 0, e = Ops.size(); i != e; ++i)
           OS << ", Tmp" << Ops[i];
         OS << ", Chain";
@@ -2074,7 +2109,7 @@ public:
           OS << "= CodeGenMap[" << FoldedChains[j] << ".getValue("
              << NumResults << ")] ";
         OS << "= Result.getValue(" << NumResults << ");\n";
-        if (NumResults == 0)
+        if (NumResults == 0 && NumImpResults == 0)
           OS << "      return Chain;\n";
         else
           OS << "      return (N.ResNo) ? Chain : Result.getValue(0);\n";
@@ -2083,8 +2118,14 @@ public:
         // use SelectNodeTo instead of getTargetNode to avoid an allocation.
         OS << "      if (N.Val->hasOneUse()) {\n";
         OS << "        return CurDAG->SelectNodeTo(N.Val, "
-           << II.Namespace << "::" << II.TheDef->getName() << ", MVT::"
-           << getEnumName(N->getType());
+           << II.Namespace << "::" << II.TheDef->getName();
+        if (N->getType() != MVT::isVoid)
+          OS << ", MVT::" << getEnumName(N->getType());
+        for (unsigned i = 0; i < NumImpResults; i++) {
+          Record *ImpResult = Inst.getImpResult(i);
+          MVT::ValueType RVT = getRegisterValueType(ImpResult, CGT);
+          OS << ", MVT::" << getEnumName(RVT);
+        }
         for (unsigned i = 0, e = Ops.size(); i != e; ++i)
           OS << ", Tmp" << Ops[i];
         if (InFlag)
@@ -2092,8 +2133,14 @@ public:
         OS << ");\n";
         OS << "      } else {\n";
         OS << "        return CodeGenMap[N] = CurDAG->getTargetNode("
-           << II.Namespace << "::" << II.TheDef->getName() << ", MVT::"
-           << getEnumName(N->getType());
+           << II.Namespace << "::" << II.TheDef->getName();
+        if (N->getType() != MVT::isVoid)
+          OS << ", MVT::" << getEnumName(N->getType());
+        for (unsigned i = 0; i < NumImpResults; i++) {
+          Record *ImpResult = Inst.getImpResult(i);
+          MVT::ValueType RVT = getRegisterValueType(ImpResult, CGT);
+          OS << ", MVT::" << getEnumName(RVT);
+        }
         for (unsigned i = 0, e = Ops.size(); i != e; ++i)
           OS << ", Tmp" << Ops[i];
         if (InFlag)
@@ -2159,7 +2206,9 @@ private:
           Record *RR = DI->getDef();
           if (RR->isSubClassOf("Register")) {
             MVT::ValueType RVT = getRegisterValueType(RR, T);
-            if (HasCtrlDep) {
+            if (RVT == MVT::Flag) {
+              OS << "      InFlag = Select(" << RootName << OpNo << ");\n";
+            } else if (HasCtrlDep) {
               OS << "      SDOperand " << RootName << "CR" << i << ";\n";
               OS << "      " << RootName << "CR" << i
                  << "  = CurDAG->getCopyToReg(Chain, CurDAG->getRegister("
@@ -2275,7 +2324,7 @@ void DAGISelEmitter::EmitInstructionSelector(std::ostream &OS) {
      << "      N.getOpcode() < (ISD::BUILTIN_OP_END+" << InstNS
      << "INSTRUCTION_LIST_END))\n"
      << "    return N;   // Already selected.\n\n"
-  << "  std::map<SDOperand, SDOperand>::iterator CGMI = CodeGenMap.find(N);\n"
+    << "  std::map<SDOperand, SDOperand>::iterator CGMI = CodeGenMap.find(N);\n"
      << "  if (CGMI != CodeGenMap.end()) return CGMI->second;\n"
      << "  switch (N.getOpcode()) {\n"
      << "  default: break;\n"
@@ -2302,19 +2351,47 @@ void DAGISelEmitter::EmitInstructionSelector(std::ostream &OS) {
      << "    }\n"
      << "  case ISD::CopyFromReg: {\n"
      << "    SDOperand Chain = Select(N.getOperand(0));\n"
-     << "    if (Chain == N.getOperand(0)) return N; // No change\n"
-     << "    SDOperand New = CurDAG->getCopyFromReg(Chain,\n"
-     << "                    cast<RegisterSDNode>(N.getOperand(1))->getReg(),\n"
-     << "                                         N.Val->getValueType(0));\n"
-     << "    return New.getValue(N.ResNo);\n"
+     << "    unsigned Reg = cast<RegisterSDNode>(N.getOperand(1))->getReg();\n"
+     << "    MVT::ValueType VT = N.Val->getValueType(0);\n"
+     << "    if (N.getNumOperands() == 2) {\n"
+     << "      if (Chain == N.getOperand(0)) return N; // No change\n"
+     << "      SDOperand New = CurDAG->getCopyFromReg(Chain, Reg, VT);\n"
+     << "      CodeGenMap[N.getValue(0)] = New;\n"
+     << "      CodeGenMap[N.getValue(1)] = New.getValue(1);\n"
+     << "      return New.getValue(N.ResNo);\n"
+     << "    } else {\n"
+     << "      SDOperand Flag;\n"
+     << "      if (N.getOperand(2).Val) Flag = Select(N.getOperand(2));\n"
+     << "      if (Chain == N.getOperand(0) && Flag == N.getOperand(2))\n"
+     << "        return N; // No change\n"
+     << "      SDOperand New = CurDAG->getCopyFromReg(Chain, Reg, VT, Flag);\n"
+     << "      CodeGenMap[N.getValue(0)] = New;\n"
+     << "      CodeGenMap[N.getValue(1)] = New.getValue(1);\n"
+     << "      CodeGenMap[N.getValue(2)] = New.getValue(2);\n"
+     << "      return New.getValue(N.ResNo);\n"
+     << "    }\n"
      << "  }\n"
      << "  case ISD::CopyToReg: {\n"
      << "    SDOperand Chain = Select(N.getOperand(0));\n"
      << "    SDOperand Reg = N.getOperand(1);\n"
      << "    SDOperand Val = Select(N.getOperand(2));\n"
-     << "    return CodeGenMap[N] = \n"
-     << "                   CurDAG->getNode(ISD::CopyToReg, MVT::Other,\n"
-     << "                                   Chain, Reg, Val);\n"
+     << "    SDOperand Result = N;\n"
+     << "    if (N.getNumOperands() == 3) {\n"
+     << "      if (Chain != N.getOperand(0) || Val != N.getOperand(2))\n"
+     << "        Result = CurDAG->getNode(ISD::CopyToReg, MVT::Other,\n"
+     << "                                 Chain, Reg, Val);\n"
+     << "      return CodeGenMap[N] = Result;\n"
+     << "    } else {\n"
+     << "      SDOperand Flag;\n"
+     << "      if (N.getOperand(3).Val) Flag = Select(N.getOperand(3));\n"
+     << "      if (Chain != N.getOperand(0) || Val != N.getOperand(2) ||\n"
+     << "          Flag != N.getOperand(3))\n"
+     << "        Result = CurDAG->getNode(ISD::CopyToReg, MVT::Other,\n"
+     << "                                 Chain, Reg, Val, Flag);\n"
+     << "      CodeGenMap[N.getValue(0)] = Result;\n"
+     << "      CodeGenMap[N.getValue(1)] = Result.getValue(1);\n"
+     << "      return Result.getValue(N.ResNo);\n"
+     << "    }\n"
      << "  }\n";
     
   // Group the patterns by their top-level opcodes.