Refactor a bunch of includes so that TargetMachine.h doesn't have to include
[oota-llvm.git] / lib / CodeGen / SelectionDAG / TargetLowering.cpp
index 7a4d269f799795b4d6bc854da48837cb1f41f562..50d2dfb16e4ddbc55580d9df78ba839e43a1d6fe 100644 (file)
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Target/TargetLowering.h"
+#include "llvm/Target/TargetData.h"
 #include "llvm/Target/TargetMachine.h"
 #include "llvm/Target/MRegisterInfo.h"
+#include "llvm/DerivedTypes.h"
 #include "llvm/CodeGen/SelectionDAG.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/MathExtras.h"
@@ -26,8 +28,8 @@ TargetLowering::TargetLowering(TargetMachine &tm)
   // All operations default to being supported.
   memset(OpActions, 0, sizeof(OpActions));
 
-  IsLittleEndian = TD.isLittleEndian();
-  ShiftAmountTy = SetCCResultTy = PointerTy = getValueType(TD.getIntPtrType());
+  IsLittleEndian = TD->isLittleEndian();
+  ShiftAmountTy = SetCCResultTy = PointerTy = getValueType(TD->getIntPtrType());
   ShiftAmtHandling = Undefined;
   memset(RegClassForVT, 0,MVT::LAST_VALUETYPE*sizeof(TargetRegisterClass*));
   memset(TargetDAGCombineArray, 0, 
@@ -124,6 +126,14 @@ void TargetLowering::computeRegisterProperties() {
   // Set MVT::Vector to always be Expanded
   SetValueTypeAction(MVT::Vector, Expand, *this, TransformToType, 
                      ValueTypeActions);
+  
+  // Loop over all of the legal vector value types, specifying an identity type
+  // transformation.
+  for (unsigned i = MVT::FIRST_VECTOR_VALUETYPE;
+       i <= MVT::LAST_VECTOR_VALUETYPE; ++i) {
+    if (isTypeLegal((MVT::ValueType)i))
+      TransformToType[i] = (MVT::ValueType)i;
+  }
 
   assert(isTypeLegal(MVT::f64) && "Target does not support FP?");
   TransformToType[MVT::f64] = MVT::f64;
@@ -133,6 +143,50 @@ const char *TargetLowering::getTargetNodeName(unsigned Opcode) const {
   return NULL;
 }
 
+/// getPackedTypeBreakdown - Packed types are broken down into some number of
+/// legal scalar types.  For example, <8 x float> maps to 2 MVT::v2f32 values
+/// with Altivec or SSE1, or 8 promoted MVT::f64 values with the X86 FP stack.
+///
+/// This method returns the number and type of the resultant breakdown.
+///
+unsigned TargetLowering::getPackedTypeBreakdown(const PackedType *PTy, 
+                                                MVT::ValueType &PTyElementVT,
+                                      MVT::ValueType &PTyLegalElementVT) const {
+  // Figure out the right, legal destination reg to copy into.
+  unsigned NumElts = PTy->getNumElements();
+  MVT::ValueType EltTy = getValueType(PTy->getElementType());
+  
+  unsigned NumVectorRegs = 1;
+  
+  // Divide the input until we get to a supported size.  This will always
+  // end with a scalar if the target doesn't support vectors.
+  while (NumElts > 1 && !isTypeLegal(getVectorType(EltTy, NumElts))) {
+    NumElts >>= 1;
+    NumVectorRegs <<= 1;
+  }
+  
+  MVT::ValueType VT;
+  if (NumElts == 1) {
+    VT = EltTy;
+  } else {
+    VT = getVectorType(EltTy, NumElts); 
+  }
+  PTyElementVT = VT;
+
+  MVT::ValueType DestVT = getTypeToTransformTo(VT);
+  PTyLegalElementVT = DestVT;
+  if (DestVT < VT) {
+    // Value is expanded, e.g. i64 -> i16.
+    return NumVectorRegs*(MVT::getSizeInBits(VT)/MVT::getSizeInBits(DestVT));
+  } else {
+    // Otherwise, promotion or legal types use the same number of registers as
+    // the vector decimated to the appropriate level.
+    return NumVectorRegs;
+  }
+  
+  return DestVT;
+}
+
 //===----------------------------------------------------------------------===//
 //  Optimization Methods
 //===----------------------------------------------------------------------===//
@@ -414,8 +468,14 @@ bool TargetLowering::SimplifyDemandedBits(SDOperand Op, uint64_t DemandedMask,
       HighBits <<= MVT::getSizeInBits(VT) - ShAmt;
       uint64_t TypeMask = MVT::getIntVTBitMask(VT);
       
-      if (SimplifyDemandedBits(Op.getOperand(0),
-                               (DemandedMask << ShAmt) & TypeMask,
+      uint64_t InDemandedMask = (DemandedMask << ShAmt) & TypeMask;
+
+      // If any of the demanded bits are produced by the sign extension, we also
+      // demand the input sign bit.
+      if (HighBits & DemandedMask)
+        InDemandedMask |= MVT::getIntVTSignBit(VT);
+      
+      if (SimplifyDemandedBits(Op.getOperand(0), InDemandedMask,
                                KnownZero, KnownOne, TLO, Depth+1))
         return true;
       assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?"); 
@@ -554,6 +614,48 @@ bool TargetLowering::SimplifyDemandedBits(SDOperand Op, uint64_t DemandedMask,
     assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?"); 
     break;
   }
+  case ISD::TRUNCATE: {
+    // Simplify the input, using demanded bit information, and compute the known
+    // zero/one bits live out.
+    if (SimplifyDemandedBits(Op.getOperand(0), DemandedMask,
+                             KnownZero, KnownOne, TLO, Depth+1))
+      return true;
+    
+    // If the input is only used by this truncate, see if we can shrink it based
+    // on the known demanded bits.
+    if (Op.getOperand(0).Val->hasOneUse()) {
+      SDOperand In = Op.getOperand(0);
+      switch (In.getOpcode()) {
+      default: break;
+      case ISD::SRL:
+        // Shrink SRL by a constant if none of the high bits shifted in are
+        // demanded.
+        if (ConstantSDNode *ShAmt = dyn_cast<ConstantSDNode>(In.getOperand(1))){
+          uint64_t HighBits = MVT::getIntVTBitMask(In.getValueType());
+          HighBits &= ~MVT::getIntVTBitMask(Op.getValueType());
+          HighBits >>= ShAmt->getValue();
+          
+          if (ShAmt->getValue() < MVT::getSizeInBits(Op.getValueType()) &&
+              (DemandedMask & HighBits) == 0) {
+            // None of the shifted in bits are needed.  Add a truncate of the
+            // shift input, then shift it.
+            SDOperand NewTrunc = TLO.DAG.getNode(ISD::TRUNCATE, 
+                                                 Op.getValueType(), 
+                                                 In.getOperand(0));
+            return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL,Op.getValueType(),
+                                                   NewTrunc, In.getOperand(1)));
+          }
+        }
+        break;
+      }
+    }
+    
+    assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?"); 
+    uint64_t OutMask = MVT::getIntVTBitMask(Op.getValueType());
+    KnownZero &= OutMask;
+    KnownOne &= OutMask;
+    break;
+  }
   case ISD::AssertZext: {
     MVT::ValueType VT = cast<VTSDNode>(Op.getOperand(1))->getVT();
     uint64_t InMask = MVT::getIntVTBitMask(VT);
@@ -565,27 +667,11 @@ bool TargetLowering::SimplifyDemandedBits(SDOperand Op, uint64_t DemandedMask,
     break;
   }
   case ISD::ADD:
-    if (ConstantSDNode *AA = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
-      if (SimplifyDemandedBits(Op.getOperand(0), DemandedMask, KnownZero, 
-                               KnownOne, TLO, Depth+1))
-        return true;
-      // Compute the KnownOne/KnownZero masks for the constant, so we can set
-      // KnownZero appropriately if we're adding a constant that has all low
-      // bits cleared.
-      ComputeMaskedBits(Op.getOperand(1), 
-                        MVT::getIntVTBitMask(Op.getValueType()), 
-                        KnownZero2, KnownOne2, Depth+1);
-      
-      uint64_t KnownZeroOut = std::min(CountTrailingZeros_64(~KnownZero), 
-                                       CountTrailingZeros_64(~KnownZero2));
-      KnownZero = (1ULL << KnownZeroOut) - 1;
-      KnownOne = 0;
-    }
-    break;
   case ISD::SUB:
-    // Just use ComputeMaskedBits to compute output bits, there are no
-    // simplifications that can be done here, and sub always demands all input
-    // bits.
+  case ISD::INTRINSIC_WO_CHAIN:
+  case ISD::INTRINSIC_W_CHAIN:
+  case ISD::INTRINSIC_VOID:
+    // Just use ComputeMaskedBits to compute output bits.
     ComputeMaskedBits(Op, DemandedMask, KnownZero, KnownOne, Depth);
     break;
   }
@@ -827,6 +913,14 @@ void TargetLowering::ComputeMaskedBits(SDOperand Op, uint64_t Mask,
                       KnownZero, KnownOne, Depth+1);
     return;
   }
+  case ISD::TRUNCATE: {
+    ComputeMaskedBits(Op.getOperand(0), Mask, KnownZero, KnownOne, Depth+1);
+    assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?"); 
+    uint64_t OutMask = MVT::getIntVTBitMask(Op.getValueType());
+    KnownZero &= OutMask;
+    KnownOne &= OutMask;
+    break;
+  }
   case ISD::AssertZext: {
     MVT::ValueType VT = cast<VTSDNode>(Op.getOperand(1))->getVT();
     uint64_t InMask = MVT::getIntVTBitMask(VT);
@@ -843,7 +937,8 @@ void TargetLowering::ComputeMaskedBits(SDOperand Op, uint64_t Mask,
     assert((KnownZero2 & KnownOne2) == 0 && "Bits known to be one AND zero?"); 
     
     // Output known-0 bits are known if clear or set in both the low clear bits
-    // common to both LHS & RHS;
+    // common to both LHS & RHS.  For example, 8+(X<<3) is known to have the
+    // low 3 bits clear.
     uint64_t KnownZeroOut = std::min(CountTrailingZeros_64(~KnownZero), 
                                      CountTrailingZeros_64(~KnownZero2));
     
@@ -879,8 +974,12 @@ void TargetLowering::ComputeMaskedBits(SDOperand Op, uint64_t Mask,
   }
   default:
     // Allow the target to implement this method for its nodes.
-    if (Op.getOpcode() >= ISD::BUILTIN_OP_END)
+    if (Op.getOpcode() >= ISD::BUILTIN_OP_END) {
+  case ISD::INTRINSIC_WO_CHAIN:
+  case ISD::INTRINSIC_W_CHAIN:
+  case ISD::INTRINSIC_VOID:
       computeMaskedBitsForTargetNode(Op, Mask, KnownZero, KnownOne);
+    }
     return;
   }
 }
@@ -893,13 +992,239 @@ void TargetLowering::computeMaskedBitsForTargetNode(const SDOperand Op,
                                                     uint64_t &KnownZero, 
                                                     uint64_t &KnownOne,
                                                     unsigned Depth) const {
-  assert(Op.getOpcode() >= ISD::BUILTIN_OP_END &&
+  assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
+          Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
+          Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
+          Op.getOpcode() == ISD::INTRINSIC_VOID) &&
          "Should use MaskedValueIsZero if you don't know whether Op"
          " is a target node!");
   KnownZero = 0;
   KnownOne = 0;
 }
 
+/// ComputeNumSignBits - Return the number of times the sign bit of the
+/// register is replicated into the other bits.  We know that at least 1 bit
+/// is always equal to the sign bit (itself), but other cases can give us
+/// information.  For example, immediately after an "SRA X, 2", we know that
+/// the top 3 bits are all equal to each other, so we return 3.
+unsigned TargetLowering::ComputeNumSignBits(SDOperand Op, unsigned Depth) const{
+  MVT::ValueType VT = Op.getValueType();
+  assert(MVT::isInteger(VT) && "Invalid VT!");
+  unsigned VTBits = MVT::getSizeInBits(VT);
+  unsigned Tmp, Tmp2;
+  
+  if (Depth == 6)
+    return 1;  // Limit search depth.
+
+  switch (Op.getOpcode()) {
+  default: break;
+  case ISD::AssertSext:
+    Tmp = MVT::getSizeInBits(cast<VTSDNode>(Op.getOperand(1))->getVT());
+    return VTBits-Tmp+1;
+  case ISD::AssertZext:
+    Tmp = MVT::getSizeInBits(cast<VTSDNode>(Op.getOperand(1))->getVT());
+    return VTBits-Tmp;
+
+  case ISD::SEXTLOAD:    // '17' bits known
+    Tmp = MVT::getSizeInBits(cast<VTSDNode>(Op.getOperand(3))->getVT());
+    return VTBits-Tmp+1;
+  case ISD::ZEXTLOAD:    // '16' bits known
+    Tmp = MVT::getSizeInBits(cast<VTSDNode>(Op.getOperand(3))->getVT());
+    return VTBits-Tmp;
+    
+  case ISD::Constant: {
+    uint64_t Val = cast<ConstantSDNode>(Op)->getValue();
+    // If negative, invert the bits, then look at it.
+    if (Val & MVT::getIntVTSignBit(VT))
+      Val = ~Val;
+    
+    // Shift the bits so they are the leading bits in the int64_t.
+    Val <<= 64-VTBits;
+    
+    // Return # leading zeros.  We use 'min' here in case Val was zero before
+    // shifting.  We don't want to return '64' as for an i32 "0".
+    return std::min(VTBits, CountLeadingZeros_64(Val));
+  }
+    
+  case ISD::SIGN_EXTEND:
+    Tmp = VTBits-MVT::getSizeInBits(Op.getOperand(0).getValueType());
+    return ComputeNumSignBits(Op.getOperand(0), Depth+1) + Tmp;
+    
+  case ISD::SIGN_EXTEND_INREG:
+    // Max of the input and what this extends.
+    Tmp = MVT::getSizeInBits(cast<VTSDNode>(Op.getOperand(1))->getVT());
+    Tmp = VTBits-Tmp+1;
+    
+    Tmp2 = ComputeNumSignBits(Op.getOperand(0), Depth+1);
+    return std::max(Tmp, Tmp2);
+
+  case ISD::SRA:
+    Tmp = ComputeNumSignBits(Op.getOperand(0), Depth+1);
+    // SRA X, C   -> adds C sign bits.
+    if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
+      Tmp += C->getValue();
+      if (Tmp > VTBits) Tmp = VTBits;
+    }
+    return Tmp;
+  case ISD::SHL:
+    if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
+      // shl destroys sign bits.
+      Tmp = ComputeNumSignBits(Op.getOperand(0), Depth+1);
+      if (C->getValue() >= VTBits ||      // Bad shift.
+          C->getValue() >= Tmp) break;    // Shifted all sign bits out.
+      return Tmp - C->getValue();
+    }
+    break;
+  case ISD::AND:
+  case ISD::OR:
+  case ISD::XOR:    // NOT is handled here.
+    // Logical binary ops preserve the number of sign bits.
+    Tmp = ComputeNumSignBits(Op.getOperand(0), Depth+1);
+    if (Tmp == 1) return 1;  // Early out.
+    Tmp2 = ComputeNumSignBits(Op.getOperand(1), Depth+1);
+    return std::min(Tmp, Tmp2);
+
+  case ISD::SELECT:
+    Tmp = ComputeNumSignBits(Op.getOperand(0), Depth+1);
+    if (Tmp == 1) return 1;  // Early out.
+    Tmp2 = ComputeNumSignBits(Op.getOperand(1), Depth+1);
+    return std::min(Tmp, Tmp2);
+    
+  case ISD::SETCC:
+    // If setcc returns 0/-1, all bits are sign bits.
+    if (getSetCCResultContents() == ZeroOrNegativeOneSetCCResult)
+      return VTBits;
+    break;
+  case ISD::ROTL:
+  case ISD::ROTR:
+    if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
+      unsigned RotAmt = C->getValue() & (VTBits-1);
+      
+      // Handle rotate right by N like a rotate left by 32-N.
+      if (Op.getOpcode() == ISD::ROTR)
+        RotAmt = (VTBits-RotAmt) & (VTBits-1);
+
+      // If we aren't rotating out all of the known-in sign bits, return the
+      // number that are left.  This handles rotl(sext(x), 1) for example.
+      Tmp = ComputeNumSignBits(Op.getOperand(0), Depth+1);
+      if (Tmp > RotAmt+1) return Tmp-RotAmt;
+    }
+    break;
+  case ISD::ADD:
+    // Add can have at most one carry bit.  Thus we know that the output
+    // is, at worst, one more bit than the inputs.
+    Tmp = ComputeNumSignBits(Op.getOperand(0), Depth+1);
+    if (Tmp == 1) return 1;  // Early out.
+      
+    // Special case decrementing a value (ADD X, -1):
+    if (ConstantSDNode *CRHS = dyn_cast<ConstantSDNode>(Op.getOperand(0)))
+      if (CRHS->isAllOnesValue()) {
+        uint64_t KnownZero, KnownOne;
+        uint64_t Mask = MVT::getIntVTBitMask(VT);
+        ComputeMaskedBits(Op.getOperand(0), Mask, KnownZero, KnownOne, Depth+1);
+        
+        // If the input is known to be 0 or 1, the output is 0/-1, which is all
+        // sign bits set.
+        if ((KnownZero|1) == Mask)
+          return VTBits;
+        
+        // If we are subtracting one from a positive number, there is no carry
+        // out of the result.
+        if (KnownZero & MVT::getIntVTSignBit(VT))
+          return Tmp;
+      }
+      
+    Tmp2 = ComputeNumSignBits(Op.getOperand(1), Depth+1);
+    if (Tmp2 == 1) return 1;
+      return std::min(Tmp, Tmp2)-1;
+    break;
+    
+  case ISD::SUB:
+    Tmp2 = ComputeNumSignBits(Op.getOperand(1), Depth+1);
+    if (Tmp2 == 1) return 1;
+      
+    // Handle NEG.
+    if (ConstantSDNode *CLHS = dyn_cast<ConstantSDNode>(Op.getOperand(0)))
+      if (CLHS->getValue() == 0) {
+        uint64_t KnownZero, KnownOne;
+        uint64_t Mask = MVT::getIntVTBitMask(VT);
+        ComputeMaskedBits(Op.getOperand(1), Mask, KnownZero, KnownOne, Depth+1);
+        // If the input is known to be 0 or 1, the output is 0/-1, which is all
+        // sign bits set.
+        if ((KnownZero|1) == Mask)
+          return VTBits;
+        
+        // If the input is known to be positive (the sign bit is known clear),
+        // the output of the NEG has the same number of sign bits as the input.
+        if (KnownZero & MVT::getIntVTSignBit(VT))
+          return Tmp2;
+        
+        // Otherwise, we treat this like a SUB.
+      }
+    
+    // Sub can have at most one carry bit.  Thus we know that the output
+    // is, at worst, one more bit than the inputs.
+    Tmp = ComputeNumSignBits(Op.getOperand(0), Depth+1);
+    if (Tmp == 1) return 1;  // Early out.
+      return std::min(Tmp, Tmp2)-1;
+    break;
+  case ISD::TRUNCATE:
+    // FIXME: it's tricky to do anything useful for this, but it is an important
+    // case for targets like X86.
+    break;
+  }
+  
+  // Allow the target to implement this method for its nodes.
+  if (Op.getOpcode() >= ISD::BUILTIN_OP_END ||
+      Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN || 
+      Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
+      Op.getOpcode() == ISD::INTRINSIC_VOID) {
+    unsigned NumBits = ComputeNumSignBitsForTargetNode(Op, Depth);
+    if (NumBits > 1) return NumBits;
+  }
+  
+  // Finally, if we can prove that the top bits of the result are 0's or 1's,
+  // use this information.
+  uint64_t KnownZero, KnownOne;
+  uint64_t Mask = MVT::getIntVTBitMask(VT);
+  ComputeMaskedBits(Op, Mask, KnownZero, KnownOne, Depth);
+  
+  uint64_t SignBit = MVT::getIntVTSignBit(VT);
+  if (KnownZero & SignBit) {        // SignBit is 0
+    Mask = KnownZero;
+  } else if (KnownOne & SignBit) {  // SignBit is 1;
+    Mask = KnownOne;
+  } else {
+    // Nothing known.
+    return 1;
+  }
+  
+  // Okay, we know that the sign bit in Mask is set.  Use CLZ to determine
+  // the number of identical bits in the top of the input value.
+  Mask ^= ~0ULL;
+  Mask <<= 64-VTBits;
+  // Return # leading zeros.  We use 'min' here in case Val was zero before
+  // shifting.  We don't want to return '64' as for an i32 "0".
+  return std::min(VTBits, CountLeadingZeros_64(Mask));
+}
+
+
+
+/// ComputeNumSignBitsForTargetNode - This method can be implemented by
+/// targets that want to expose additional information about sign bits to the
+/// DAG Combiner.
+unsigned TargetLowering::ComputeNumSignBitsForTargetNode(SDOperand Op,
+                                                         unsigned Depth) const {
+  assert((Op.getOpcode() >= ISD::BUILTIN_OP_END ||
+          Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
+          Op.getOpcode() == ISD::INTRINSIC_W_CHAIN ||
+          Op.getOpcode() == ISD::INTRINSIC_VOID) &&
+         "Should use ComputeNumSignBits if you don't know whether Op"
+         " is a target node!");
+  return 1;
+}
+
+
 SDOperand TargetLowering::
 PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const {
   // Default implementation: no optimization.
@@ -992,3 +1317,16 @@ getRegForInlineAsmConstraint(const std::string &Constraint,
   
   return std::pair<unsigned, const TargetRegisterClass*>(0, 0);
 }
+
+//===----------------------------------------------------------------------===//
+//  Loop Strength Reduction hooks
+//===----------------------------------------------------------------------===//
+
+/// isLegalAddressImmediate - Return true if the integer value or
+/// GlobalValue can be used as the offset of the target addressing mode.
+bool TargetLowering::isLegalAddressImmediate(int64_t V) const {
+  return false;
+}
+bool TargetLowering::isLegalAddressImmediate(GlobalValue *GV) const {
+  return false;
+}