Use SCEVAddRecExpr::isAffine.
[oota-llvm.git] / utils / TableGen / DAGISelEmitter.cpp
index 72bd5bdb2fbabca00dc7636e6c7a0b662ee2c589..8d89eeeb7ae22f68afeedda79bc9b36784da54e0 100644 (file)
@@ -51,8 +51,8 @@ static const ComplexPattern *NodeGetComplexPattern(TreePatternNode *N,
 /// patterns before small ones.  This is used to determine the size of a
 /// pattern.
 static unsigned getPatternSize(TreePatternNode *P, CodeGenDAGPatterns &CGP) {
-  assert((MVT::isExtIntegerInVTs(P->getExtTypes()) || 
-          MVT::isExtFloatingPointInVTs(P->getExtTypes()) ||
+  assert((EMVT::isExtIntegerInVTs(P->getExtTypes()) ||
+          EMVT::isExtFloatingPointInVTs(P->getExtTypes()) ||
           P->getExtTypeNum(0) == MVT::isVoid ||
           P->getExtTypeNum(0) == MVT::Flag ||
           P->getExtTypeNum(0) == MVT::iPTR) && 
@@ -160,7 +160,7 @@ struct PatternSortingPredicate {
 
 /// getRegisterValueType - Look up and return the first ValueType of specified 
 /// RegisterClass record
-static MVT::ValueType getRegisterValueType(Record *R, const CodeGenTarget &T) {
+static MVT::SimpleValueType getRegisterValueType(Record *R, const CodeGenTarget &T) {
   if (const CodeGenRegisterClass *RC = T.getRegisterClassForRegister(R))
     return RC->getValueTypeNum(0);
   return MVT::Other;
@@ -331,6 +331,15 @@ private:
   /// instructions.
   std::vector<std::string> &TargetOpcodes;
   std::vector<std::string> &TargetVTs;
+  /// OutputIsVariadic - Records whether the instruction output pattern uses
+  /// variable_ops.  This requires that the Emit function be passed an
+  /// additional argument to indicate where the input varargs operands
+  /// begin.
+  bool &OutputIsVariadic;
+  /// NumInputRootOps - Records the number of operands the root node of the
+  /// input pattern has.  This information is used in the generated code to
+  /// pass to Emit functions when variable_ops processing is needed.
+  unsigned &NumInputRootOps;
 
   std::string ChainName;
   unsigned TmpNo;
@@ -367,10 +376,13 @@ public:
                      std::vector<std::pair<unsigned, std::string> > &gc,
                      std::set<std::string> &gd,
                      std::vector<std::string> &to,
-                     std::vector<std::string> &tv)
+                     std::vector<std::string> &tv,
+                     bool &oiv,
+                     unsigned &niro)
   : CGP(cgp), Predicates(preds), Pattern(pattern), Instruction(instr),
     GeneratedCode(gc), GeneratedDecl(gd),
     TargetOpcodes(to), TargetVTs(tv),
+    OutputIsVariadic(oiv), NumInputRootOps(niro),
     TmpNo(0), OpcNo(0), VTNo(0) {}
 
   /// EmitMatchCode - Emit a matcher for N, going to the label for PatternNo
@@ -392,6 +404,9 @@ public:
     bool isRoot = (P == NULL);
     // Emit instruction predicates. Each predicate is just a string for now.
     if (isRoot) {
+      // Record input varargs info.
+      NumInputRootOps = N->getNumChildren();
+
       std::string PredicateCheck;
       for (unsigned i = 0, e = Predicates->getSize(); i != e; ++i) {
         if (DefInit *Pred = dynamic_cast<DefInit*>(Predicates->getElement(i))) {
@@ -887,7 +902,7 @@ public:
       if (InstPatNode && InstPatNode->getOperator()->getName() == "set") {
         InstPatNode = InstPatNode->getChild(InstPatNode->getNumChildren()-1);
       }
-      bool HasVarOps     = isRoot && II.isVariadic;
+      bool IsVariadic = isRoot && II.isVariadic;
       // FIXME: fix how we deal with physical register operands.
       bool HasImpInputs  = isRoot && Inst.getNumImpOperands() > 0;
       bool HasImpResults = isRoot && DstRegs.size() > 0;
@@ -904,17 +919,20 @@ public:
       unsigned NumResults = Inst.getNumResults();    
       unsigned NumDstRegs = HasImpResults ? DstRegs.size() : 0;
 
+      // Record output varargs info.
+      OutputIsVariadic = IsVariadic;
+
       if (NodeHasOptInFlag) {
         emitCode("bool HasInFlag = "
            "(N.getOperand(N.getNumOperands()-1).getValueType() == MVT::Flag);");
       }
-      if (HasVarOps)
+      if (IsVariadic)
         emitCode("SmallVector<SDOperand, 8> Ops" + utostr(OpcNo) + ";");
 
       // How many results is this pattern expected to produce?
       unsigned NumPatResults = 0;
       for (unsigned i = 0, e = Pattern->getExtTypes().size(); i != e; i++) {
-        MVT::ValueType VT = Pattern->getTypeNum(i);
+        MVT::SimpleValueType VT = Pattern->getTypeNum(i);
         if (VT != MVT::isVoid && VT != MVT::Flag)
           NumPatResults++;
       }
@@ -946,23 +964,16 @@ public:
       // in the 'execute always' values.  Match up the node operands to the
       // instruction operands to do this.
       std::vector<std::string> AllOps;
-      unsigned NumEAInputs = 0; // # of synthesized 'execute always' inputs.
       for (unsigned ChildNo = 0, InstOpNo = NumResults;
            InstOpNo != II.OperandList.size(); ++InstOpNo) {
         std::vector<std::string> Ops;
         
-        // If this is a normal operand or a predicate operand without
-        // 'execute always', emit it.
+        // Determine what to emit for this operand.
         Record *OperandNode = II.OperandList[InstOpNo].Rec;
-        if ((!OperandNode->isSubClassOf("PredicateOperand") &&
-             !OperandNode->isSubClassOf("OptionalDefOperand")) ||
-            CGP.getDefaultOperand(OperandNode).DefaultOps.empty()) {
-          Ops = EmitResultCode(N->getChild(ChildNo), DstRegs,
-                               InFlagDecled, ResNodeDecled);
-          AllOps.insert(AllOps.end(), Ops.begin(), Ops.end());
-          ++ChildNo;
-        } else {
-          // Otherwise, this is a predicate or optional def operand, emit the
+        if ((OperandNode->isSubClassOf("PredicateOperand") ||
+             OperandNode->isSubClassOf("OptionalDefOperand")) &&
+            !CGP.getDefaultOperand(OperandNode).DefaultOps.empty()) {
+          // This is a predicate or optional def operand; emit the
           // 'default ops' operands.
           const DAGDefaultOperand &DefaultOp =
             CGP.getDefaultOperand(II.OperandList[InstOpNo].Rec);
@@ -970,20 +981,14 @@ public:
             Ops = EmitResultCode(DefaultOp.DefaultOps[i], DstRegs,
                                  InFlagDecled, ResNodeDecled);
             AllOps.insert(AllOps.end(), Ops.begin(), Ops.end());
-            NumEAInputs += Ops.size();
           }
-        }
-      }
-
-      // Generate MemOperandSDNodes nodes for each memory accesses covered by 
-      // this pattern.
-      if (II.isSimpleLoad | II.mayLoad | II.mayStore) {
-        std::vector<std::string>::const_iterator mi, mie;
-        for (mi = LSI.begin(), mie = LSI.end(); mi != mie; ++mi) {
-          emitCode("SDOperand LSI_" + *mi + " = "
-                   "CurDAG->getMemOperand(cast<LSBaseSDNode>(" +
-                   *mi + ")->getMemOperand());");
-          AllOps.push_back("LSI_" + *mi);
+        } else {
+          // Otherwise this is a normal operand or a predicate operand without
+          // 'execute always'; emit it.
+          Ops = EmitResultCode(N->getChild(ChildNo), DstRegs,
+                               InFlagDecled, ResNodeDecled);
+          AllOps.insert(AllOps.end(), Ops.begin(), Ops.end());
+          ++ChildNo;
         }
       }
 
@@ -1040,7 +1045,7 @@ public:
         for (unsigned i = 0; i < NumDstRegs; i++) {
           Record *RR = DstRegs[i];
           if (RR->isSubClassOf("Register")) {
-            MVT::ValueType RVT = getRegisterValueType(RR, CGT);
+            MVT::SimpleValueType RVT = getRegisterValueType(RR, CGT);
             Code += ", " + getEnumName(RVT);
           }
         }
@@ -1049,19 +1054,12 @@ public:
         if (NodeHasOutFlag)
           Code += ", MVT::Flag";
 
-        // Figure out how many fixed inputs the node has.  This is important to
-        // know which inputs are the variable ones if present.
-        unsigned NumInputs = AllOps.size();
-        NumInputs += NodeHasChain;
-        
         // Inputs.
-        if (HasVarOps) {
+        if (IsVariadic) {
           for (unsigned i = 0, e = AllOps.size(); i != e; ++i)
             emitCode("Ops" + utostr(OpsNo) + ".push_back(" + AllOps[i] + ");");
           AllOps.clear();
-        }
 
-        if (HasVarOps) {
           // Figure out whether any operands at the end of the op list are not
           // part of the variable section.
           std::string EndAdjust;
@@ -1070,7 +1068,7 @@ public:
           else if (NodeHasOptInFlag)
             EndAdjust = "-(HasInFlag?1:0)"; // May have a flag.
 
-          emitCode("for (unsigned i = " + utostr(NumInputs - NumEAInputs) +
+          emitCode("for (unsigned i = NumInputRootOps + " + utostr(NodeHasChain) +
                    ", e = N.getNumOperands()" + EndAdjust + "; i != e; ++i) {");
 
           emitCode("  AddToISelQueue(N.getOperand(i));");
@@ -1078,14 +1076,29 @@ public:
           emitCode("}");
         }
 
+        // Generate MemOperandSDNodes nodes for each memory accesses covered by 
+        // this pattern.
+        if (II.isSimpleLoad | II.mayLoad | II.mayStore) {
+          std::vector<std::string>::const_iterator mi, mie;
+          for (mi = LSI.begin(), mie = LSI.end(); mi != mie; ++mi) {
+            emitCode("SDOperand LSI_" + *mi + " = "
+                     "CurDAG->getMemOperand(cast<LSBaseSDNode>(" +
+                     *mi + ")->getMemOperand());");
+            if (IsVariadic)
+              emitCode("Ops" + utostr(OpsNo) + ".push_back(LSI_" + *mi + ");");
+            else
+              AllOps.push_back("LSI_" + *mi);
+          }
+        }
+
         if (NodeHasChain) {
-          if (HasVarOps)
+          if (IsVariadic)
             emitCode("Ops" + utostr(OpsNo) + ".push_back(" + ChainName + ");");
           else
             AllOps.push_back(ChainName);
         }
 
-        if (HasVarOps) {
+        if (IsVariadic) {
           if (NodeHasInFlag || HasImpInputs)
             emitCode("Ops" + utostr(OpsNo) + ".push_back(InFlag);");
           else if (NodeHasOptInFlag) {
@@ -1298,7 +1311,7 @@ private:
 
           Record *RR = DI->getDef();
           if (RR->isSubClassOf("Register")) {
-            MVT::ValueType RVT = getRegisterValueType(RR, T);
+            MVT::SimpleValueType RVT = getRegisterValueType(RR, T);
             if (RVT == MVT::Flag) {
               if (!InFlagDecled) {
                 emitCode("SDOperand InFlag = " + RootName + utostr(OpNo) + ";");
@@ -1350,11 +1363,17 @@ void DAGISelEmitter::GenerateCodeForPattern(const PatternToMatch &Pattern,
                   std::vector<std::pair<unsigned, std::string> > &GeneratedCode,
                                            std::set<std::string> &GeneratedDecl,
                                         std::vector<std::string> &TargetOpcodes,
-                                          std::vector<std::string> &TargetVTs) {
+                                            std::vector<std::string> &TargetVTs,
+                                            bool &OutputIsVariadic,
+                                            unsigned &NumInputRootOps) {
+  OutputIsVariadic = false;
+  NumInputRootOps = 0;
+
   PatternCodeEmitter Emitter(CGP, Pattern.getPredicates(),
                              Pattern.getSrcPattern(), Pattern.getDstPattern(),
                              GeneratedCode, GeneratedDecl,
-                             TargetOpcodes, TargetVTs);
+                             TargetOpcodes, TargetVTs,
+                             OutputIsVariadic, NumInputRootOps);
 
   // Emit the matcher, capturing named arguments in VariableMap.
   bool FoundChain = false;
@@ -1615,12 +1634,13 @@ void DAGISelEmitter::EmitInstructionSelector(std::ostream &OS) {
                      PatternSortingPredicate(CGP));
 
     // Split them into groups by type.
-    std::map<MVT::ValueType, std::vector<const PatternToMatch*> >PatternsByType;
+    std::map<MVT::SimpleValueType,
+             std::vector<const PatternToMatch*> > PatternsByType;
     for (unsigned i = 0, e = PatternsOfOp.size(); i != e; ++i) {
       const PatternToMatch *Pat = PatternsOfOp[i];
       TreePatternNode *SrcPat = Pat->getSrcPattern();
-      MVT::ValueType VT = SrcPat->getTypeNum(0);
-      std::map<MVT::ValueType, 
+      MVT::SimpleValueType VT = SrcPat->getTypeNum(0);
+      std::map<MVT::SimpleValueType,
                std::vector<const PatternToMatch*> >::iterator TI = 
         PatternsByType.find(VT);
       if (TI != PatternsByType.end())
@@ -1632,10 +1652,11 @@ void DAGISelEmitter::EmitInstructionSelector(std::ostream &OS) {
       }
     }
 
-    for (std::map<MVT::ValueType, std::vector<const PatternToMatch*> >::iterator
+    for (std::map<MVT::SimpleValueType,
+                  std::vector<const PatternToMatch*> >::iterator
            II = PatternsByType.begin(), EE = PatternsByType.end(); II != EE;
          ++II) {
-      MVT::ValueType OpVT = II->first;
+      MVT::SimpleValueType OpVT = II->first;
       std::vector<const PatternToMatch*> &Patterns = II->second;
       typedef std::vector<std::pair<unsigned,std::string> > CodeList;
       typedef std::vector<std::pair<unsigned,std::string> >::iterator CodeListI;
@@ -1644,17 +1665,24 @@ void DAGISelEmitter::EmitInstructionSelector(std::ostream &OS) {
       std::vector<std::vector<std::string> > PatternOpcodes;
       std::vector<std::vector<std::string> > PatternVTs;
       std::vector<std::set<std::string> > PatternDecls;
+      std::vector<bool> OutputIsVariadicFlags;
+      std::vector<unsigned> NumInputRootOpsCounts;
       for (unsigned i = 0, e = Patterns.size(); i != e; ++i) {
         CodeList GeneratedCode;
         std::set<std::string> GeneratedDecl;
         std::vector<std::string> TargetOpcodes;
         std::vector<std::string> TargetVTs;
+        bool OutputIsVariadic;
+        unsigned NumInputRootOps;
         GenerateCodeForPattern(*Patterns[i], GeneratedCode, GeneratedDecl,
-                               TargetOpcodes, TargetVTs);
+                               TargetOpcodes, TargetVTs,
+                               OutputIsVariadic, NumInputRootOps);
         CodeForPatterns.push_back(std::make_pair(Patterns[i], GeneratedCode));
         PatternDecls.push_back(GeneratedDecl);
         PatternOpcodes.push_back(TargetOpcodes);
         PatternVTs.push_back(TargetVTs);
+        OutputIsVariadicFlags.push_back(OutputIsVariadic);
+        NumInputRootOpsCounts.push_back(NumInputRootOps);
       }
     
       // Scan the code to see if all of the patterns are reachable and if it is
@@ -1689,6 +1717,8 @@ void DAGISelEmitter::EmitInstructionSelector(std::ostream &OS) {
         std::vector<std::string> &TargetOpcodes = PatternOpcodes[i];
         std::vector<std::string> &TargetVTs = PatternVTs[i];
         std::set<std::string> Decls = PatternDecls[i];
+        bool OutputIsVariadic = OutputIsVariadicFlags[i];
+        unsigned NumInputRootOps = NumInputRootOpsCounts[i];
         std::vector<std::string> AddedInits;
         int CodeSize = (int)GeneratedCode.size();
         int LastPred = -1;
@@ -1706,7 +1736,7 @@ void DAGISelEmitter::EmitInstructionSelector(std::ostream &OS) {
           CallerCode += ", " + TargetOpcodes[j];
         }
         for (unsigned j = 0, e = TargetVTs.size(); j != e; ++j) {
-          CalleeCode += ", MVT::ValueType VT" + utostr(j);
+          CalleeCode += ", MVT VT" + utostr(j);
           CallerCode += ", " + TargetVTs[j];
         }
         for (std::set<std::string>::iterator
@@ -1715,6 +1745,12 @@ void DAGISelEmitter::EmitInstructionSelector(std::ostream &OS) {
           CalleeCode += ", SDOperand &" + Name;
           CallerCode += ", " + Name;
         }
+
+        if (OutputIsVariadic) {
+          CalleeCode += ", unsigned NumInputRootOps";
+          CallerCode += ", " + utostr(NumInputRootOps);
+        }
+
         CallerCode += ");";
         CalleeCode += ") ";
         // Prevent emission routines from being inlined to reduce selection
@@ -1818,7 +1854,7 @@ void DAGISelEmitter::EmitInstructionSelector(std::ostream &OS) {
      << "  for (unsigned j = 0, e = Ops.size(); j != e; ++j)\n"
      << "    AddToISelQueue(Ops[j]);\n\n"
     
-     << "  std::vector<MVT::ValueType> VTs;\n"
+     << "  std::vector<MVT> VTs;\n"
      << "  VTs.push_back(MVT::Other);\n"
      << "  VTs.push_back(MVT::Flag);\n"
      << "  SDOperand New = CurDAG->getNode(ISD::INLINEASM, VTs, &Ops[0], "
@@ -1897,7 +1933,7 @@ void DAGISelEmitter::EmitInstructionSelector(std::ostream &OS) {
      << "INSTRUCTION_LIST_END)) {\n"
      << "    return NULL;   // Already selected.\n"
      << "  }\n\n"
-     << "  MVT::ValueType NVT = N.Val->getValueType(0);\n"
+     << "  MVT::SimpleValueType NVT = N.Val->getValueType(0).getSimpleVT();\n"
      << "  switch (N.getOpcode()) {\n"
      << "  default: break;\n"
      << "  case ISD::EntryToken:       // These leaves remain the same.\n"
@@ -1974,7 +2010,7 @@ void DAGISelEmitter::EmitInstructionSelector(std::ostream &OS) {
       
     // If there is an iPTR result version of this pattern, emit it here.
     if (HasPtrPattern) {
-      OS << "      if (NVT == TLI.getPointerTy())\n";
+      OS << "      if (TLI.getPointerTy() == NVT)\n";
       OS << "        return Select_" << getLegalCName(OpName) <<"_iPTR(N);\n";
     }
     if (HasDefaultPattern) {