Fix a problem that nate reduced for me.
[oota-llvm.git] / lib / CodeGen / SelectionDAG / LegalizeDAG.cpp
index 5cc148a84f55d02bfd7f2f0d5febebd36c9c1dd4..793037930c0159f80c2cd75ff0f68cf86eec9dc2 100644 (file)
@@ -162,12 +162,13 @@ void SelectionDAGLegalize::LegalizeDAG() {
 SDOperand SelectionDAGLegalize::LegalizeOp(SDOperand Op) {
   assert(getTypeAction(Op.getValueType()) == Legal &&
          "Caller should expand or promote operands that are not legal!");
+  SDNode *Node = Op.Val;
 
   // If this operation defines any values that cannot be represented in a
   // register on this target, make sure to expand or promote them.
-  if (Op.Val->getNumValues() > 1) {
-    for (unsigned i = 0, e = Op.Val->getNumValues(); i != e; ++i)
-      switch (getTypeAction(Op.Val->getValueType(i))) {
+  if (Node->getNumValues() > 1) {
+    for (unsigned i = 0, e = Node->getNumValues(); i != e; ++i)
+      switch (getTypeAction(Node->getValueType(i))) {
       case Legal: break;  // Nothing to do.
       case Expand: {
         SDOperand T1, T2;
@@ -184,13 +185,14 @@ SDOperand SelectionDAGLegalize::LegalizeOp(SDOperand Op) {
       }
   }
 
+  // Note that LegalizeOp may be reentered even from single-use nodes, which
+  // means that we always must cache transformed nodes.
   std::map<SDOperand, SDOperand>::iterator I = LegalizedNodes.find(Op);
   if (I != LegalizedNodes.end()) return I->second;
 
   SDOperand Tmp1, Tmp2, Tmp3;
 
   SDOperand Result = Op;
-  SDNode *Node = Op.Val;
 
   switch (Node->getOpcode()) {
   default:
@@ -321,13 +323,23 @@ SDOperand SelectionDAGLegalize::LegalizeOp(SDOperand Op) {
     break;
   }
 
-  case ISD::ADJCALLSTACKDOWN:
-  case ISD::ADJCALLSTACKUP:
+  case ISD::CALLSEQ_START:
+  case ISD::CALLSEQ_END:
     Tmp1 = LegalizeOp(Node->getOperand(0));  // Legalize the chain.
-    // There is no need to legalize the size argument (Operand #1)
-    if (Tmp1 != Node->getOperand(0))
+    // Do not try to legalize the target-specific arguments (#1+)
+    Tmp2 = Node->getOperand(0);
+    if (Tmp1 != Tmp2) {
       Node->setAdjCallChain(Tmp1);
-    // Note that we do not create new ADJCALLSTACK DOWN/UP nodes here.  These
+
+      // If moving the operand from pointing to Tmp2 dropped its use count to 1,
+      // this will cause the maps used to memoize results to get confused.
+      // Create and add a dummy use, just to increase its use count.  This will
+      // be removed at the end of legalize when dead nodes are removed.
+      if (Tmp2.Val->hasOneUse())
+        DAG.getNode(ISD::PCMARKER, MVT::Other, Tmp2,
+                    DAG.getConstant(0, MVT::i32));
+    }
+    // Note that we do not create new CALLSEQ_DOWN/UP nodes here.  These
     // nodes are treated specially and are mutated in place.  This makes the dag
     // legalization process more efficient and also makes libcall insertion
     // easier.
@@ -914,9 +926,9 @@ SDOperand SelectionDAGLegalize::LegalizeOp(SDOperand Op) {
       } else {
         assert(0 && "Unknown op!");
       }
-      // FIXME: THESE SHOULD USE ExpandLibCall ??!?
+
       std::pair<SDOperand,SDOperand> CallResult =
-        TLI.LowerCallTo(Tmp1, Type::VoidTy, false,
+        TLI.LowerCallTo(Tmp1, Type::VoidTy, false, 0,
                         DAG.getExternalSymbol(FnName, IntPtr), Args, DAG);
       Result = LegalizeOp(CallResult.second);
       break;
@@ -1098,7 +1110,7 @@ SDOperand SelectionDAGLegalize::LegalizeOp(SDOperand Op) {
         break;
       case ISD::CTTZ:
         //if Tmp1 == sizeinbits(NVT) then Tmp1 = sizeinbits(Old VT)
-        Tmp2 = DAG.getSetCC(ISD::SETEQ, MVT::i1, Tmp1, 
+        Tmp2 = DAG.getSetCC(ISD::SETEQ, TLI.getSetCCResultTy(), Tmp1, 
                             DAG.getConstant(getSizeInBits(NVT), NVT));
         Result = DAG.getNode(ISD::SELECT, NVT, Tmp2, 
                            DAG.getConstant(getSizeInBits(OVT),NVT), Tmp1);
@@ -1241,7 +1253,7 @@ SDOperand SelectionDAGLegalize::LegalizeOp(SDOperand Op) {
         Args.push_back(std::make_pair(Tmp1, T));
         // FIXME: should use ExpandLibCall!
         std::pair<SDOperand,SDOperand> CallResult =
-          TLI.LowerCallTo(DAG.getEntryNode(), T, false,
+          TLI.LowerCallTo(DAG.getEntryNode(), T, false, 0,
                           DAG.getExternalSymbol(FnName, VT), Args, DAG);
         Result = LegalizeOp(CallResult.first);
         break;
@@ -1389,9 +1401,9 @@ SDOperand SelectionDAGLegalize::LegalizeOp(SDOperand Op) {
   }
   }
 
-  if (!Op.Val->hasOneUse())
-    AddLegalizedOperand(Op, Result);
-
+  // Note that LegalizeOp may be reentered even from single-use nodes, which
+  // means that we always must cache transformed nodes.
+  AddLegalizedOperand(Op, Result);
   return Result;
 }
 
@@ -1407,14 +1419,18 @@ SDOperand SelectionDAGLegalize::PromoteOp(SDOperand Op) {
   assert(NVT > VT && MVT::isInteger(NVT) == MVT::isInteger(VT) &&
          "Cannot promote to smaller type!");
 
-  std::map<SDOperand, SDOperand>::iterator I = PromotedNodes.find(Op);
-  if (I != PromotedNodes.end()) return I->second;
-
   SDOperand Tmp1, Tmp2, Tmp3;
 
   SDOperand Result;
   SDNode *Node = Op.Val;
 
+  if (!Node->hasOneUse()) {
+    std::map<SDOperand, SDOperand>::iterator I = PromotedNodes.find(Op);
+    if (I != PromotedNodes.end()) return I->second;
+  } else {
+    assert(!PromotedNodes.count(Op) && "Repromoted this node??");
+  }
+
   // Promotion needs an optimization step to clean up after it, and is not
   // careful to avoid operations the target does not support.  Make sure that
   // all generated operations are legalized in the next iteration.
@@ -1923,15 +1939,15 @@ bool SelectionDAGLegalize::ExpandShift(unsigned Opc, SDOperand Op,SDOperand Amt,
   return true;
 }
 
-/// FindLatestAdjCallStackDown - Scan up the dag to find the latest (highest
-/// NodeDepth) node that is an AdjCallStackDown operation and occurs later than
+/// FindLatestCallSeqStart - Scan up the dag to find the latest (highest
+/// NodeDepth) node that is an CallSeqStart operation and occurs later than
 /// Found.
-static void FindLatestAdjCallStackDown(SDNode *Node, SDNode *&Found) {
+static void FindLatestCallSeqStart(SDNode *Node, SDNode *&Found) {
   if (Node->getNodeDepth() <= Found->getNodeDepth()) return;
 
-  // If we found an ADJCALLSTACKDOWN, we already know this node occurs later
+  // If we found an CALLSEQ_START, we already know this node occurs later
   // than the Found node. Just remember this node and return.
-  if (Node->getOpcode() == ISD::ADJCALLSTACKDOWN) {
+  if (Node->getOpcode() == ISD::CALLSEQ_START) {
     Found = Node;
     return;
   }
@@ -1940,23 +1956,23 @@ static void FindLatestAdjCallStackDown(SDNode *Node, SDNode *&Found) {
   assert(Node->getNumOperands() != 0 &&
          "All leaves should have depth equal to the entry node!");
   for (unsigned i = 0, e = Node->getNumOperands()-1; i != e; ++i)
-    FindLatestAdjCallStackDown(Node->getOperand(i).Val, Found);
+    FindLatestCallSeqStart(Node->getOperand(i).Val, Found);
 
   // Tail recurse for the last iteration.
-  FindLatestAdjCallStackDown(Node->getOperand(Node->getNumOperands()-1).Val,
+  FindLatestCallSeqStart(Node->getOperand(Node->getNumOperands()-1).Val,
                              Found);
 }
 
 
-/// FindEarliestAdjCallStackUp - Scan down the dag to find the earliest (lowest
-/// NodeDepth) node that is an AdjCallStackUp operation and occurs more recent
+/// FindEarliestCallSeqEnd - Scan down the dag to find the earliest (lowest
+/// NodeDepth) node that is an CallSeqEnd operation and occurs more recent
 /// than Found.
-static void FindEarliestAdjCallStackUp(SDNode *Node, SDNode *&Found) {
+static void FindEarliestCallSeqEnd(SDNode *Node, SDNode *&Found) {
   if (Found && Node->getNodeDepth() >= Found->getNodeDepth()) return;
 
-  // If we found an ADJCALLSTACKUP, we already know this node occurs earlier
+  // If we found an CALLSEQ_END, we already know this node occurs earlier
   // than the Found node. Just remember this node and return.
-  if (Node->getOpcode() == ISD::ADJCALLSTACKUP) {
+  if (Node->getOpcode() == ISD::CALLSEQ_END) {
     Found = Node;
     return;
   }
@@ -1965,49 +1981,50 @@ static void FindEarliestAdjCallStackUp(SDNode *Node, SDNode *&Found) {
   SDNode::use_iterator UI = Node->use_begin(), E = Node->use_end();
   if (UI == E) return;
   for (--E; UI != E; ++UI)
-    FindEarliestAdjCallStackUp(*UI, Found);
+    FindEarliestCallSeqEnd(*UI, Found);
 
   // Tail recurse for the last iteration.
-  FindEarliestAdjCallStackUp(*UI, Found);
+  FindEarliestCallSeqEnd(*UI, Found);
 }
 
-/// FindAdjCallStackUp - Given a chained node that is part of a call sequence,
-/// find the ADJCALLSTACKUP node that terminates the call sequence.
-static SDNode *FindAdjCallStackUp(SDNode *Node) {
-  if (Node->getOpcode() == ISD::ADJCALLSTACKUP)
+/// FindCallSeqEnd - Given a chained node that is part of a call sequence,
+/// find the CALLSEQ_END node that terminates the call sequence.
+static SDNode *FindCallSeqEnd(SDNode *Node) {
+  if (Node->getOpcode() == ISD::CALLSEQ_END)
     return Node;
   if (Node->use_empty())
-    return 0;   // No adjcallstackup
+    return 0;   // No CallSeqEnd
 
   if (Node->hasOneUse())  // Simple case, only has one user to check.
-    return FindAdjCallStackUp(*Node->use_begin());
+    return FindCallSeqEnd(*Node->use_begin());
 
   SDOperand TheChain(Node, Node->getNumValues()-1);
   assert(TheChain.getValueType() == MVT::Other && "Is not a token chain!");
 
   for (SDNode::use_iterator UI = Node->use_begin(),
          E = Node->use_end(); ; ++UI) {
-    assert(UI != E && "Didn't find a user of the tokchain, no ADJCALLSTACKUP!");
+    assert(UI != E && "Didn't find a user of the tokchain, no CALLSEQ_END!");
 
     // Make sure to only follow users of our token chain.
     SDNode *User = *UI;
     for (unsigned i = 0, e = User->getNumOperands(); i != e; ++i)
       if (User->getOperand(i) == TheChain)
-        return FindAdjCallStackUp(User);
+        if (SDNode *Result = FindCallSeqEnd(User))
+          return Result;
   }
   assert(0 && "Unreachable");
   abort();
 }
 
-/// FindAdjCallStackDown - Given a chained node that is part of a call sequence,
-/// find the ADJCALLSTACKDOWN node that initiates the call sequence.
-static SDNode *FindAdjCallStackDown(SDNode *Node) {
-  assert(Node && "Didn't find adjcallstackdown for a call??");
-  if (Node->getOpcode() == ISD::ADJCALLSTACKDOWN) return Node;
+/// FindCallSeqStart - Given a chained node that is part of a call sequence,
+/// find the CALLSEQ_START node that initiates the call sequence.
+static SDNode *FindCallSeqStart(SDNode *Node) {
+  assert(Node && "Didn't find callseq_start for a call??");
+  if (Node->getOpcode() == ISD::CALLSEQ_START) return Node;
 
   assert(Node->getOperand(0).getValueType() == MVT::Other &&
          "Node doesn't have a token chain argument!");
-  return FindAdjCallStackDown(Node->getOperand(0).Val);
+  return FindCallSeqStart(Node->getOperand(0).Val);
 }
 
 
@@ -2019,31 +2036,31 @@ static SDNode *FindAdjCallStackDown(SDNode *Node) {
 /// end of the call chain.
 static SDOperand FindInputOutputChains(SDNode *OpNode, SDNode *&OutChain,
                                        SDOperand Entry) {
-  SDNode *LatestAdjCallStackDown = Entry.Val;
-  SDNode *LatestAdjCallStackUp = 0;
-  FindLatestAdjCallStackDown(OpNode, LatestAdjCallStackDown);
-  //std::cerr<<"Found node: "; LatestAdjCallStackDown->dump(); std::cerr <<"\n";
+  SDNode *LatestCallSeqStart = Entry.Val;
+  SDNode *LatestCallSeqEnd = 0;
+  FindLatestCallSeqStart(OpNode, LatestCallSeqStart);
+  //std::cerr<<"Found node: "; LatestCallSeqStart->dump(); std::cerr <<"\n";
 
-  // It is possible that no ISD::ADJCALLSTACKDOWN was found because there is no
+  // It is possible that no ISD::CALLSEQ_START was found because there is no
   // previous call in the function.  LatestCallStackDown may in that case be
-  // the entry node itself.  Do not attempt to find a matching ADJCALLSTACKUP
-  // unless LatestCallStackDown is an ADJCALLSTACKDOWN.
-  if (LatestAdjCallStackDown->getOpcode() == ISD::ADJCALLSTACKDOWN)
-    LatestAdjCallStackUp = FindAdjCallStackUp(LatestAdjCallStackDown);
+  // the entry node itself.  Do not attempt to find a matching CALLSEQ_END
+  // unless LatestCallStackDown is an CALLSEQ_START.
+  if (LatestCallSeqStart->getOpcode() == ISD::CALLSEQ_START)
+    LatestCallSeqEnd = FindCallSeqEnd(LatestCallSeqStart);
   else
-    LatestAdjCallStackUp = Entry.Val;
-  assert(LatestAdjCallStackUp && "NULL return from FindAdjCallStackUp");
+    LatestCallSeqEnd = Entry.Val;
+  assert(LatestCallSeqEnd && "NULL return from FindCallSeqEnd");
 
   // Finally, find the first call that this must come before, first we find the
-  // adjcallstackup that ends the call.
+  // CallSeqEnd that ends the call.
   OutChain = 0;
-  FindEarliestAdjCallStackUp(OpNode, OutChain);
+  FindEarliestCallSeqEnd(OpNode, OutChain);
 
-  // If we found one, translate from the adj up to the adjdown.
+  // If we found one, translate from the adj up to the callseq_start.
   if (OutChain)
-    OutChain = FindAdjCallStackDown(OutChain);
+    OutChain = FindCallSeqStart(OutChain);
 
-  return SDOperand(LatestAdjCallStackUp, 0);
+  return SDOperand(LatestCallSeqEnd, 0);
 }
 
 /// SpliceCallInto - Given the result chain of a libcall (CallResult), and a 
@@ -2086,7 +2103,7 @@ SDOperand SelectionDAGLegalize::ExpandLibCall(const char *Name, SDNode *Node,
   // Splice the libcall in wherever FindInputOutputChains tells us to.
   const Type *RetTy = MVT::getTypeForValueType(Node->getValueType(0));
   std::pair<SDOperand,SDOperand> CallInfo =
-    TLI.LowerCallTo(InChain, RetTy, false, Callee, Args, DAG);
+    TLI.LowerCallTo(InChain, RetTy, false, 0, Callee, Args, DAG);
   SpliceCallInto(CallInfo.second, OutChain);
 
   NeedsAnotherIteration = true;
@@ -2115,10 +2132,6 @@ ExpandIntToFP(bool isSigned, MVT::ValueType DestTy, SDOperand Source) {
   assert(Source.getValueType() == MVT::i64 && "Only handle expand from i64!");
 
   if (!isSigned) {
-    // If this is unsigned, and not supported, first perform the conversion to
-    // signed, then adjust the result if the sign bit is set.
-    SDOperand SignedConv = ExpandIntToFP(true, DestTy, Source);
-
     assert(Source.getValueType() == MVT::i64 &&
            "This only works for 64-bit -> FP");
     // The 64-bit value loaded will be incorrectly if the 'sign bit' of the
@@ -2127,15 +2140,19 @@ ExpandIntToFP(bool isSigned, MVT::ValueType DestTy, SDOperand Source) {
     SDOperand Lo, Hi;
     ExpandOp(Source, Lo, Hi);
 
+    // If this is unsigned, and not supported, first perform the conversion to
+    // signed, then adjust the result if the sign bit is set.
+    SDOperand SignedConv = ExpandIntToFP(true, DestTy,
+                   DAG.getNode(ISD::BUILD_PAIR, Source.getValueType(), Lo, Hi));
+
     SDOperand SignSet = DAG.getSetCC(ISD::SETLT, TLI.getSetCCResultTy(), Hi,
                                      DAG.getConstant(0, Hi.getValueType()));
     SDOperand Zero = getIntPtrConstant(0), Four = getIntPtrConstant(4);
     SDOperand CstOffset = DAG.getNode(ISD::SELECT, Zero.getValueType(),
                                       SignSet, Four, Zero);
-    // FIXME: This is almost certainly broken for big-endian systems.  Should
-    // this just put the fudge factor in the low bits of the uint64 constant or?
-    static Constant *FudgeFactor =
-      ConstantUInt::get(Type::ULongTy, 0x5f800000ULL << 32);
+    uint64_t FF = 0x5f800000ULL;
+    if (TLI.isLittleEndian()) FF <<= 32;
+    static Constant *FudgeFactor = ConstantUInt::get(Type::ULongTy, FF);
 
     MachineConstantPool *CP = DAG.getMachineFunction().getConstantPool();
     SDOperand CPIdx = DAG.getConstantPool(CP->getConstantPoolIndex(FudgeFactor),
@@ -2153,6 +2170,12 @@ ExpandIntToFP(bool isSigned, MVT::ValueType DestTy, SDOperand Source) {
     return DAG.getNode(ISD::ADD, DestTy, SignedConv, FudgeInReg);
   }
 
+  // Expand the source, then glue it back together for the call.  We must expand
+  // the source in case it is shared (this pass of legalize must traverse it).
+  SDOperand SrcLo, SrcHi;
+  ExpandOp(Source, SrcLo, SrcHi);
+  Source = DAG.getNode(ISD::BUILD_PAIR, Source.getValueType(), SrcLo, SrcHi);
+
   SDNode *OutChain = 0;
   SDOperand InChain = FindInputOutputChains(Source.Val, OutChain,
                                             DAG.getEntryNode());
@@ -2169,12 +2192,6 @@ ExpandIntToFP(bool isSigned, MVT::ValueType DestTy, SDOperand Source) {
   TargetLowering::ArgListTy Args;
   const Type *ArgTy = MVT::getTypeForValueType(Source.getValueType());
 
-  // Expand the source, then glue it back together for the call.  We must expand
-  // the source in case it is shared (this pass of legalize must traverse it).
-  SDOperand SrcLo, SrcHi;
-  ExpandOp(Source, SrcLo, SrcHi);
-  Source = DAG.getNode(ISD::BUILD_PAIR, Source.getValueType(), SrcLo, SrcHi);
-
   Args.push_back(std::make_pair(Source, ArgTy));
 
   // We don't care about token chains for libcalls.  We just use the entry
@@ -2183,7 +2200,7 @@ ExpandIntToFP(bool isSigned, MVT::ValueType DestTy, SDOperand Source) {
   const Type *RetTy = MVT::getTypeForValueType(DestTy);
 
   std::pair<SDOperand,SDOperand> CallResult =
-    TLI.LowerCallTo(InChain, RetTy, false, Callee, Args, DAG);
+    TLI.LowerCallTo(InChain, RetTy, false, 0, Callee, Args, DAG);
 
   SpliceCallInto(CallResult.second, OutChain);
   return CallResult.first;
@@ -2216,6 +2233,8 @@ void SelectionDAGLegalize::ExpandOp(SDOperand Op, SDOperand &Lo, SDOperand &Hi){
       Hi = I->second.second;
       return;
     }
+  } else {
+    assert(!ExpandedNodes.count(Op) && "Re-expanding a node!");
   }
 
   // Expanding to multiple registers needs to perform an optimization step, and
@@ -2268,9 +2287,35 @@ void SelectionDAGLegalize::ExpandOp(SDOperand Op, SDOperand &Lo, SDOperand &Hi){
     Hi = DAG.getConstant(0, NVT);
     break;
 
-  case ISD::CTTZ:
-  case ISD::CTLZ:
-    assert(0 && "ct intrinsics cannot be expanded!");
+  case ISD::CTLZ: {
+    // ctlz (HL) -> ctlz(H) != 32 ? ctlz(H) : (ctlz(L)+32)
+    ExpandOp(Node->getOperand(0), Lo, Hi);
+    SDOperand BitsC = DAG.getConstant(MVT::getSizeInBits(NVT), NVT);
+    SDOperand HLZ = DAG.getNode(ISD::CTLZ, NVT, Hi);
+    SDOperand TopNotZero = DAG.getSetCC(ISD::SETNE, TLI.getSetCCResultTy(),
+                                        HLZ, BitsC);
+    SDOperand LowPart = DAG.getNode(ISD::CTLZ, NVT, Lo);
+    LowPart = DAG.getNode(ISD::ADD, NVT, LowPart, BitsC);
+
+    Lo = DAG.getNode(ISD::SELECT, NVT, TopNotZero, HLZ, LowPart);
+    Hi = DAG.getConstant(0, NVT);
+    break;
+  }
+
+  case ISD::CTTZ: {
+    // cttz (HL) -> cttz(L) != 32 ? cttz(L) : (cttz(H)+32)
+    ExpandOp(Node->getOperand(0), Lo, Hi);
+    SDOperand BitsC = DAG.getConstant(MVT::getSizeInBits(NVT), NVT);
+    SDOperand LTZ = DAG.getNode(ISD::CTTZ, NVT, Lo);
+    SDOperand BotNotZero = DAG.getSetCC(ISD::SETNE, TLI.getSetCCResultTy(),
+                                        LTZ, BitsC);
+    SDOperand HiPart = DAG.getNode(ISD::CTTZ, NVT, Hi);
+    HiPart = DAG.getNode(ISD::ADD, NVT, HiPart, BitsC);
+
+    Lo = DAG.getNode(ISD::SELECT, NVT, BotNotZero, LTZ, HiPart);
+    Hi = DAG.getConstant(0, NVT);
+    break;
+  }
 
   case ISD::LOAD: {
     SDOperand Ch = LegalizeOp(Node->getOperand(0));   // Legalize the chain.