Refactor a bunch of includes so that TargetMachine.h doesn't have to include
[oota-llvm.git] / lib / CodeGen / SelectionDAG / TargetLowering.cpp
index 6019c8c41e44a5a8700e7598f89492a43380aa76..50d2dfb16e4ddbc55580d9df78ba839e43a1d6fe 100644 (file)
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Target/TargetLowering.h"
+#include "llvm/Target/TargetData.h"
 #include "llvm/Target/TargetMachine.h"
 #include "llvm/Target/MRegisterInfo.h"
 #include "llvm/DerivedTypes.h"
@@ -467,8 +468,14 @@ bool TargetLowering::SimplifyDemandedBits(SDOperand Op, uint64_t DemandedMask,
       HighBits <<= MVT::getSizeInBits(VT) - ShAmt;
       uint64_t TypeMask = MVT::getIntVTBitMask(VT);
       
-      if (SimplifyDemandedBits(Op.getOperand(0),
-                               (DemandedMask << ShAmt) & TypeMask,
+      uint64_t InDemandedMask = (DemandedMask << ShAmt) & TypeMask;
+
+      // If any of the demanded bits are produced by the sign extension, we also
+      // demand the input sign bit.
+      if (HighBits & DemandedMask)
+        InDemandedMask |= MVT::getIntVTSignBit(VT);
+      
+      if (SimplifyDemandedBits(Op.getOperand(0), InDemandedMask,
                                KnownZero, KnownOne, TLO, Depth+1))
         return true;
       assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?"); 
@@ -1010,19 +1017,7 @@ unsigned TargetLowering::ComputeNumSignBits(SDOperand Op, unsigned Depth) const{
     return 1;  // Limit search depth.
 
   switch (Op.getOpcode()) {
-  default: 
-    // Allow the target to implement this method for its nodes.
-    if (Op.getOpcode() >= ISD::BUILTIN_OP_END) {
-  case ISD::INTRINSIC_WO_CHAIN:
-  case ISD::INTRINSIC_W_CHAIN:
-  case ISD::INTRINSIC_VOID:
-      unsigned NumBits = ComputeNumSignBitsForTargetNode(Op, Depth);
-      if (NumBits > 1) return NumBits;
-    }
-    
-    // FIXME: Should use computemaskedbits to look at the top bits.
-    return 1;
-    
+  default: break;
   case ISD::AssertSext:
     Tmp = MVT::getSizeInBits(cast<VTSDNode>(Op.getOperand(1))->getVT());
     return VTBits-Tmp+1;
@@ -1030,6 +1025,31 @@ unsigned TargetLowering::ComputeNumSignBits(SDOperand Op, unsigned Depth) const{
     Tmp = MVT::getSizeInBits(cast<VTSDNode>(Op.getOperand(1))->getVT());
     return VTBits-Tmp;
 
+  case ISD::SEXTLOAD:    // '17' bits known
+    Tmp = MVT::getSizeInBits(cast<VTSDNode>(Op.getOperand(3))->getVT());
+    return VTBits-Tmp+1;
+  case ISD::ZEXTLOAD:    // '16' bits known
+    Tmp = MVT::getSizeInBits(cast<VTSDNode>(Op.getOperand(3))->getVT());
+    return VTBits-Tmp;
+    
+  case ISD::Constant: {
+    uint64_t Val = cast<ConstantSDNode>(Op)->getValue();
+    // If negative, invert the bits, then look at it.
+    if (Val & MVT::getIntVTSignBit(VT))
+      Val = ~Val;
+    
+    // Shift the bits so they are the leading bits in the int64_t.
+    Val <<= 64-VTBits;
+    
+    // Return # leading zeros.  We use 'min' here in case Val was zero before
+    // shifting.  We don't want to return '64' as for an i32 "0".
+    return std::min(VTBits, CountLeadingZeros_64(Val));
+  }
+    
+  case ISD::SIGN_EXTEND:
+    Tmp = VTBits-MVT::getSizeInBits(Op.getOperand(0).getValueType());
+    return ComputeNumSignBits(Op.getOperand(0), Depth+1) + Tmp;
+    
   case ISD::SIGN_EXTEND_INREG:
     // Max of the input and what this extends.
     Tmp = MVT::getSizeInBits(cast<VTSDNode>(Op.getOperand(1))->getVT());
@@ -1046,31 +1066,146 @@ unsigned TargetLowering::ComputeNumSignBits(SDOperand Op, unsigned Depth) const{
       if (Tmp > VTBits) Tmp = VTBits;
     }
     return Tmp;
+  case ISD::SHL:
+    if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
+      // shl destroys sign bits.
+      Tmp = ComputeNumSignBits(Op.getOperand(0), Depth+1);
+      if (C->getValue() >= VTBits ||      // Bad shift.
+          C->getValue() >= Tmp) break;    // Shifted all sign bits out.
+      return Tmp - C->getValue();
+    }
+    break;
+  case ISD::AND:
+  case ISD::OR:
+  case ISD::XOR:    // NOT is handled here.
+    // Logical binary ops preserve the number of sign bits.
+    Tmp = ComputeNumSignBits(Op.getOperand(0), Depth+1);
+    if (Tmp == 1) return 1;  // Early out.
+    Tmp2 = ComputeNumSignBits(Op.getOperand(1), Depth+1);
+    return std::min(Tmp, Tmp2);
+
+  case ISD::SELECT:
+    Tmp = ComputeNumSignBits(Op.getOperand(0), Depth+1);
+    if (Tmp == 1) return 1;  // Early out.
+    Tmp2 = ComputeNumSignBits(Op.getOperand(1), Depth+1);
+    return std::min(Tmp, Tmp2);
     
+  case ISD::SETCC:
+    // If setcc returns 0/-1, all bits are sign bits.
+    if (getSetCCResultContents() == ZeroOrNegativeOneSetCCResult)
+      return VTBits;
+    break;
+  case ISD::ROTL:
+  case ISD::ROTR:
+    if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
+      unsigned RotAmt = C->getValue() & (VTBits-1);
+      
+      // Handle rotate right by N like a rotate left by 32-N.
+      if (Op.getOpcode() == ISD::ROTR)
+        RotAmt = (VTBits-RotAmt) & (VTBits-1);
+
+      // If we aren't rotating out all of the known-in sign bits, return the
+      // number that are left.  This handles rotl(sext(x), 1) for example.
+      Tmp = ComputeNumSignBits(Op.getOperand(0), Depth+1);
+      if (Tmp > RotAmt+1) return Tmp-RotAmt;
+    }
+    break;
   case ISD::ADD:
-  case ISD::SUB:
-    // Add and sub can have at most one carry bit.  Thus we know that the output
+    // Add can have at most one carry bit.  Thus we know that the output
     // is, at worst, one more bit than the inputs.
     Tmp = ComputeNumSignBits(Op.getOperand(0), Depth+1);
-    if (Tmp == 1) return 1;
+    if (Tmp == 1) return 1;  // Early out.
+      
+    // Special case decrementing a value (ADD X, -1):
+    if (ConstantSDNode *CRHS = dyn_cast<ConstantSDNode>(Op.getOperand(0)))
+      if (CRHS->isAllOnesValue()) {
+        uint64_t KnownZero, KnownOne;
+        uint64_t Mask = MVT::getIntVTBitMask(VT);
+        ComputeMaskedBits(Op.getOperand(0), Mask, KnownZero, KnownOne, Depth+1);
+        
+        // If the input is known to be 0 or 1, the output is 0/-1, which is all
+        // sign bits set.
+        if ((KnownZero|1) == Mask)
+          return VTBits;
+        
+        // If we are subtracting one from a positive number, there is no carry
+        // out of the result.
+        if (KnownZero & MVT::getIntVTSignBit(VT))
+          return Tmp;
+      }
+      
+    Tmp2 = ComputeNumSignBits(Op.getOperand(1), Depth+1);
+    if (Tmp2 == 1) return 1;
+      return std::min(Tmp, Tmp2)-1;
+    break;
+    
+  case ISD::SUB:
     Tmp2 = ComputeNumSignBits(Op.getOperand(1), Depth+1);
     if (Tmp2 == 1) return 1;
-    return std::min(Tmp, Tmp2)-1;
+      
+    // Handle NEG.
+    if (ConstantSDNode *CLHS = dyn_cast<ConstantSDNode>(Op.getOperand(0)))
+      if (CLHS->getValue() == 0) {
+        uint64_t KnownZero, KnownOne;
+        uint64_t Mask = MVT::getIntVTBitMask(VT);
+        ComputeMaskedBits(Op.getOperand(1), Mask, KnownZero, KnownOne, Depth+1);
+        // If the input is known to be 0 or 1, the output is 0/-1, which is all
+        // sign bits set.
+        if ((KnownZero|1) == Mask)
+          return VTBits;
+        
+        // If the input is known to be positive (the sign bit is known clear),
+        // the output of the NEG has the same number of sign bits as the input.
+        if (KnownZero & MVT::getIntVTSignBit(VT))
+          return Tmp2;
+        
+        // Otherwise, we treat this like a SUB.
+      }
     
-  //case ISD::ZEXTLOAD:   // 16 bits known
-  //case ISD::SEXTLOAD:   // 17 bits known
-  //case ISD::Constant:
-  //case ISD::SIGN_EXTEND:
-  //
+    // Sub can have at most one carry bit.  Thus we know that the output
+    // is, at worst, one more bit than the inputs.
+    Tmp = ComputeNumSignBits(Op.getOperand(0), Depth+1);
+    if (Tmp == 1) return 1;  // Early out.
+      return std::min(Tmp, Tmp2)-1;
+    break;
+  case ISD::TRUNCATE:
+    // FIXME: it's tricky to do anything useful for this, but it is an important
+    // case for targets like X86.
+    break;
+  }
+  
+  // Allow the target to implement this method for its nodes.
+  if (Op.getOpcode() >= ISD::BUILTIN_OP_END ||
+      Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN || 
+      Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
+      Op.getOpcode() == ISD::INTRINSIC_VOID) {
+    unsigned NumBits = ComputeNumSignBitsForTargetNode(Op, Depth);
+    if (NumBits > 1) return NumBits;
+  }
+  
+  // Finally, if we can prove that the top bits of the result are 0's or 1's,
+  // use this information.
+  uint64_t KnownZero, KnownOne;
+  uint64_t Mask = MVT::getIntVTBitMask(VT);
+  ComputeMaskedBits(Op, Mask, KnownZero, KnownOne, Depth);
+  
+  uint64_t SignBit = MVT::getIntVTSignBit(VT);
+  if (KnownZero & SignBit) {        // SignBit is 0
+    Mask = KnownZero;
+  } else if (KnownOne & SignBit) {  // SignBit is 1;
+    Mask = KnownOne;
+  } else {
+    // Nothing known.
+    return 1;
   }
   
-#if 0
-  // fold (sext_in_reg (setcc x)) -> setcc x iff (setcc x) == 0 or -1
-  if (N0.getOpcode() == ISD::SETCC &&
-      TLI.getSetCCResultContents() == 
-      TargetLowering::ZeroOrNegativeOneSetCCResult)
-    return N0;
-#endif
+  // Okay, we know that the sign bit in Mask is set.  Use CLZ to determine
+  // the number of identical bits in the top of the input value.
+  Mask ^= ~0ULL;
+  Mask <<= 64-VTBits;
+  // Return # leading zeros.  We use 'min' here in case Val was zero before
+  // shifting.  We don't want to return '64' as for an i32 "0".
+  return std::min(VTBits, CountLeadingZeros_64(Mask));
 }