teach the type inference code how to infer types for instructions and node
authorChris Lattner <sabre@nondot.org>
Thu, 15 Sep 2005 22:23:50 +0000 (22:23 +0000)
committerChris Lattner <sabre@nondot.org>
Thu, 15 Sep 2005 22:23:50 +0000 (22:23 +0000)
xforms.  Run type inference on result patterns, so we always have fully typed
results (and to catch errors in .td files).

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@23369 91177308-0d34-0410-b5e6-96231b3b80d8

utils/TableGen/DAGISelEmitter.cpp
utils/TableGen/DAGISelEmitter.h

index e5a436bb133de64cfc7954c9f035478aa9455ba6..79ef2311c3d7711f9cc8096236c4264f58cbee0b 100644 (file)
@@ -324,14 +324,34 @@ bool TreePatternNode::ApplyTypeConstraints(TreePattern &TP) {
     for (unsigned i = 0, e = getNumChildren(); i != e; ++i)
       MadeChange |= getChild(i)->ApplyTypeConstraints(TP);
     return MadeChange;  
-  } else {
-    assert(getOperator()->isSubClassOf("Instruction") && "Unknown node type!");
-    
+  } else if (getOperator()->isSubClassOf("Instruction")) {
     const DAGInstruction &Inst =
       TP.getDAGISelEmitter().getInstruction(getOperator());
     
-    // TODO: type inference for instructions.
-    return false;
+    assert(Inst.getNumResults() == 1 && "Only supports one result instrs!");
+    // Apply the result type to the node
+    bool MadeChange = UpdateNodeType(Inst.getResultType(0), TP);
+
+    if (getNumChildren() != Inst.getNumOperands())
+      TP.error("Instruction '" + getOperator()->getName() + " expects " +
+               utostr(Inst.getNumOperands()) + " operands, not " +
+               utostr(getNumChildren()) + " operands!");
+    for (unsigned i = 0, e = getNumChildren(); i != e; ++i) {
+      MadeChange |= getChild(i)->UpdateNodeType(Inst.getOperandType(i), TP);
+      MadeChange |= getChild(i)->ApplyTypeConstraints(TP);
+    }
+    return MadeChange;
+  } else {
+    assert(getOperator()->isSubClassOf("SDNodeXForm") && "Unknown node type!");
+    
+    // Node transforms always take one operand, and take and return the same
+    // type.
+    if (getNumChildren() != 1)
+      TP.error("Node transform '" + getOperator()->getName() +
+               "' requires one operand!");
+    bool MadeChange = UpdateNodeType(getChild(0)->getType(), TP);
+    MadeChange |= getChild(0)->UpdateNodeType(getType(), TP);
+    return MadeChange;
   }
 }
 
@@ -340,13 +360,24 @@ bool TreePatternNode::ApplyTypeConstraints(TreePattern &TP) {
 // TreePattern implementation
 //
 
-TreePattern::TreePattern(Record *TheRec, const std::vector<DagInit *> &RawPat,
+TreePattern::TreePattern(Record *TheRec, ListInit *RawPat,
                          DAGISelEmitter &ise) : TheRecord(TheRec), ISE(ise) {
+   for (unsigned i = 0, e = RawPat->getSize(); i != e; ++i)
+     Trees.push_back(ParseTreePattern((DagInit*)RawPat->getElement(i)));
+}
 
-  for (unsigned i = 0, e = RawPat.size(); i != e; ++i)
-    Trees.push_back(ParseTreePattern(RawPat[i]));
+TreePattern::TreePattern(Record *TheRec, DagInit *Pat,
+                         DAGISelEmitter &ise) : TheRecord(TheRec), ISE(ise) {
+  Trees.push_back(ParseTreePattern(Pat));
 }
 
+TreePattern::TreePattern(Record *TheRec, TreePatternNode *Pat, 
+                         DAGISelEmitter &ise) : TheRecord(TheRec), ISE(ise) {
+  Trees.push_back(Pat);
+}
+
+
+
 void TreePattern::error(const std::string &Msg) const {
   dump();
   throw "In " + TheRecord->getName() + ": " + Msg;
@@ -550,9 +581,8 @@ void DAGISelEmitter::ParsePatternFragments(std::ostream &OS) {
   // First step, parse all of the fragments and emit predicate functions.
   OS << "\n// Predicate functions.\n";
   for (unsigned i = 0, e = Fragments.size(); i != e; ++i) {
-    std::vector<DagInit*> Trees;
-    Trees.push_back(Fragments[i]->getValueAsDag("Fragment"));
-    TreePattern *P = new TreePattern(Fragments[i], Trees, *this);
+    DagInit *Tree = Fragments[i]->getValueAsDag("Fragment");
+    TreePattern *P = new TreePattern(Fragments[i], Tree, *this);
     PatternFragments[Fragments[i]] = P;
     
     // Validate the argument list, converting it to map, to discard duplicates.
@@ -762,12 +792,8 @@ void DAGISelEmitter::ParseInstructions() {
     ListInit *LI = Instrs[i]->getValueAsListInit("Pattern");
     if (LI->getSize() == 0) continue;  // no pattern.
     
-    std::vector<DagInit*> Trees;
-    for (unsigned j = 0, e = LI->getSize(); j != e; ++j)
-      Trees.push_back((DagInit*)LI->getElement(j));
-
     // Parse the instruction.
-    TreePattern *I = new TreePattern(Instrs[i], Trees, *this);
+    TreePattern *I = new TreePattern(Instrs[i], LI, *this);
     // Inline pattern fragments into it.
     I->InlinePatternFragments();
     
@@ -876,11 +902,21 @@ void DAGISelEmitter::ParseInstructions() {
 
     TreePatternNode *ResultPattern =
       new TreePatternNode(I->getRecord(), ResultNodeOperands);
+
+    // Create and insert the instruction.
+    DAGInstruction TheInst(I, ResultTypes, OperandTypes);
+    Instructions.insert(std::make_pair(I->getRecord(), TheInst));
+
+    // Use a temporary tree pattern to infer all types and make sure that the
+    // constructed result is correct.  This depends on the instruction already
+    // being inserted into the Instructions map.
+    TreePattern Temp(I->getRecord(), ResultPattern, *this);
+    Temp.InferAllTypes();
+
+    DAGInstruction &TheInsertedInst = Instructions.find(I->getRecord())->second;
+    TheInsertedInst.setResultPattern(Temp.getOnlyTree());
     
     DEBUG(I->dump());
-    Instructions.insert(std::make_pair(I->getRecord(),
-                                       DAGInstruction(I, ResultTypes,
-                                                OperandTypes, ResultPattern)));
   }
    
   // If we can, convert the instructions to be patterns that are matched!
@@ -909,10 +945,8 @@ void DAGISelEmitter::ParsePatterns() {
   std::vector<Record*> Patterns = Records.getAllDerivedDefinitions("Pattern");
 
   for (unsigned i = 0, e = Patterns.size(); i != e; ++i) {
-    std::vector<DagInit*> Trees;
-    Trees.push_back(Patterns[i]->getValueAsDag("PatternToMatch"));
-    TreePattern *Pattern = new TreePattern(Patterns[i], Trees, *this);
-    Trees.clear();
+    DagInit *Tree = Patterns[i]->getValueAsDag("PatternToMatch");
+    TreePattern *Pattern = new TreePattern(Patterns[i], Tree, *this);
 
     // Inline pattern fragments into it.
     Pattern->InlinePatternFragments();
@@ -924,21 +958,17 @@ void DAGISelEmitter::ParsePatterns() {
     
     ListInit *LI = Patterns[i]->getValueAsListInit("ResultInstrs");
     if (LI->getSize() == 0) continue;  // no pattern.
-    for (unsigned j = 0, e = LI->getSize(); j != e; ++j)
-      Trees.push_back((DagInit*)LI->getElement(j));
     
     // Parse the instruction.
-    TreePattern *Result = new TreePattern(Patterns[i], Trees, *this);
+    TreePattern *Result = new TreePattern(Patterns[i], LI, *this);
     
     // Inline pattern fragments into it.
     Result->InlinePatternFragments();
     
     // Infer as many types as possible.  If we cannot infer all of them, we can
     // never do anything with this pattern: report it to the user.
-#if 0  // FIXME: ENABLE when we can infer though instructions!
     if (!Result->InferAllTypes())
       Result->error("Could not infer all types in pattern result!");
-#endif
    
     if (Result->getNumTrees() != 1)
       Result->error("Cannot handle instructions producing instructions "
index bd1018b5312bdc12ed794eee103791bbe0e9fb1b..de223b1cb6b25d01836639eca5b02a52dfa7a407 100644 (file)
@@ -20,6 +20,7 @@
 namespace llvm {
   class Record;
   struct Init;
+  class ListInit;
   class DagInit;
   class SDNodeInfo;
   class TreePattern;
@@ -222,8 +223,9 @@ namespace llvm {
       
     /// TreePattern constructor - Parse the specified DagInits into the
     /// current record.
-    TreePattern(Record *TheRec,
-                const std::vector<DagInit *> &RawPat, DAGISelEmitter &ise);
+    TreePattern(Record *TheRec, ListInit *RawPat, DAGISelEmitter &ise);
+    TreePattern(Record *TheRec, DagInit *Pat, DAGISelEmitter &ise);
+    TreePattern(Record *TheRec, TreePatternNode *Pat, DAGISelEmitter &ise);
         
     /// getTrees - Return the tree patterns which corresponds to this pattern.
     ///
@@ -285,15 +287,16 @@ namespace llvm {
   public:
     DAGInstruction(TreePattern *TP,
                    const std::vector<MVT::ValueType> &resultTypes,
-                   const std::vector<MVT::ValueType> &operandTypes,
-                   TreePatternNode *resultPattern)
+                   const std::vector<MVT::ValueType> &operandTypes)
       : Pattern(TP), ResultTypes(resultTypes), OperandTypes(operandTypes), 
-        ResultPattern(resultPattern) {}
+        ResultPattern(0) {}
 
     TreePattern *getPattern() const { return Pattern; }
     unsigned getNumResults() const { return ResultTypes.size(); }
     unsigned getNumOperands() const { return OperandTypes.size(); }
     
+    void setResultPattern(TreePatternNode *R) { ResultPattern = R; }
+    
     MVT::ValueType getResultType(unsigned RN) const {
       assert(RN < ResultTypes.size());
       return ResultTypes[RN];