Add a ConstantDataVector::getSplatValue() method, for parity with
authorChris Lattner <sabre@nondot.org>
Thu, 26 Jan 2012 02:31:22 +0000 (02:31 +0000)
committerChris Lattner <sabre@nondot.org>
Thu, 26 Jan 2012 02:31:22 +0000 (02:31 +0000)
ConstantVector.  Fix some outright bugs in the implementation of
ConstantArray and Constant struct, which would cause us to not make
one big UndefValue when asking for an array/struct with all undef
elements.  Enhance Constant::isAllOnesValue to work with
ConstantDataVector.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@149021 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Constants.h
lib/VMCore/Constants.cpp

index a685ccc31c6790d26ce977f502fa166ffb3e50b9..1180b955f2719c5b565979e96a6f1a6a9921e38c 100644 (file)
@@ -766,6 +766,10 @@ public:
   /// i32/i64/float/double) and must be a ConstantFP or ConstantInt.
   static Constant *getSplat(unsigned NumElts, Constant *Elt);
 
+  /// getSplatValue - If this is a splat constant, meaning that all of the
+  /// elements have the same value, return that value. Otherwise return NULL.
+  Constant *getSplatValue() const;
+  
   /// getType - Specialize the getType() method to always return a VectorType,
   /// which reduces the amount of casting needed in parts of the compiler.
   ///
index c1d6da5465c31efa3ff0c74ae4e07e370f2a78e1..f5239e15697944f571c0e7ccf4c511ba38f24ed7 100644 (file)
@@ -78,6 +78,11 @@ bool Constant::isAllOnesValue() const {
     if (Constant *Splat = CV->getSplatValue())
       return Splat->isAllOnesValue();
 
+  // Check for constant vectors which are splats of -1 values.
+  if (const ConstantDataVector *CV = dyn_cast<ConstantDataVector>(this))
+    if (Constant *Splat = CV->getSplatValue())
+      return Splat->isAllOnesValue();
+
   return false;
 }
 
@@ -711,17 +716,27 @@ Constant *ConstantArray::get(ArrayType *Ty, ArrayRef<Constant*> V) {
   }
   LLVMContextImpl *pImpl = Ty->getContext().pImpl;
   // If this is an all-zero array, return a ConstantAggregateZero object
+  bool isAllZero = true;
+  bool isUndef = false;
   if (!V.empty()) {
     Constant *C = V[0];
-    if (!C->isNullValue())
-      return pImpl->ArrayConstants.getOrCreate(Ty, V);
-    
-    for (unsigned i = 1, e = V.size(); i != e; ++i)
-      if (V[i] != C)
-        return pImpl->ArrayConstants.getOrCreate(Ty, V);
+    isAllZero = C->isNullValue();
+    isUndef = isa<UndefValue>(C);
+
+    if (isAllZero || isUndef)
+      for (unsigned i = 1, e = V.size(); i != e; ++i)
+        if (V[i] != C) {
+          isAllZero = false;
+          isUndef = false;
+          break;
+        }
   }
-  
-  return ConstantAggregateZero::get(Ty);
+
+  if (isAllZero)
+    return ConstantAggregateZero::get(Ty);
+  if (isUndef)
+    return UndefValue::get(Ty);
+  return pImpl->ArrayConstants.getOrCreate(Ty, V);
 }
 
 /// ConstantArray::get(const string&) - Return an array that is initialized to
@@ -780,14 +795,31 @@ ConstantStruct::ConstantStruct(StructType *T, ArrayRef<Constant *> V)
 
 // ConstantStruct accessors.
 Constant *ConstantStruct::get(StructType *ST, ArrayRef<Constant*> V) {
-  // Create a ConstantAggregateZero value if all elements are zeros.
-  for (unsigned i = 0, e = V.size(); i != e; ++i)
-    if (!V[i]->isNullValue())
-      return ST->getContext().pImpl->StructConstants.getOrCreate(ST, V);
-
   assert((ST->isOpaque() || ST->getNumElements() == V.size()) &&
          "Incorrect # elements specified to ConstantStruct::get");
-  return ConstantAggregateZero::get(ST);
+
+  // Create a ConstantAggregateZero value if all elements are zeros.
+  bool isZero = true;
+  bool isUndef = false;
+  
+  if (!V.empty()) {
+    isUndef = isa<UndefValue>(V[0]);
+    isZero = V[0]->isNullValue();
+    if (isUndef || isZero) {
+      for (unsigned i = 0, e = V.size(); i != e; ++i) {
+        if (!V[i]->isNullValue())
+          isZero = false;
+        if (!isa<UndefValue>(V[i]))
+          isUndef = false;
+      }
+    }
+  }  
+  if (isZero)
+    return ConstantAggregateZero::get(ST);
+  if (isUndef)
+    return UndefValue::get(ST);
+    
+  return ST->getContext().pImpl->StructConstants.getOrCreate(ST, V);
 }
 
 Constant *ConstantStruct::get(StructType *T, ...) {
@@ -2329,6 +2361,20 @@ bool ConstantDataSequential::isCString() const {
   return Str.drop_back().find(0) == StringRef::npos;
 }
 
+/// getSplatValue - If this is a splat constant, meaning that all of the
+/// elements have the same value, return that value. Otherwise return NULL.
+Constant *ConstantDataVector::getSplatValue() const {
+  const char *Base = getRawDataValues().data();
+  
+  // Compare elements 1+ to the 0'th element.
+  unsigned EltSize = getElementByteSize();
+  for (unsigned i = 1, e = getNumElements(); i != e; ++i)
+    if (memcmp(Base, Base+i*EltSize, EltSize))
+      return 0;
+  
+  // If they're all the same, return the 0th one as a representative.
+  return getElementAsConstant(0);
+}
 
 //===----------------------------------------------------------------------===//
 //                replaceUsesOfWithOnConstant implementations
@@ -2360,33 +2406,25 @@ void ConstantArray::replaceUsesOfWithOnConstant(Value *From, Value *To,
 
   // Fill values with the modified operands of the constant array.  Also, 
   // compute whether this turns into an all-zeros array.
-  bool isAllZeros = false;
   unsigned NumUpdated = 0;
-  if (!ToC->isNullValue()) {
-    for (Use *O = OperandList, *E = OperandList+getNumOperands(); O != E; ++O) {
-      Constant *Val = cast<Constant>(O->get());
-      if (Val == From) {
-        Val = ToC;
-        ++NumUpdated;
-      }
-      Values.push_back(Val);
-    }
-  } else {
-    isAllZeros = true;
-    for (Use *O = OperandList, *E = OperandList+getNumOperands();O != E; ++O) {
-      Constant *Val = cast<Constant>(O->get());
-      if (Val == From) {
-        Val = ToC;
-        ++NumUpdated;
-      }
-      Values.push_back(Val);
-      if (isAllZeros) isAllZeros = Val->isNullValue();
+  
+  // Keep track of whether all the values in the array are "ToC".
+  bool AllSame = true;
+  for (Use *O = OperandList, *E = OperandList+getNumOperands(); O != E; ++O) {
+    Constant *Val = cast<Constant>(O->get());
+    if (Val == From) {
+      Val = ToC;
+      ++NumUpdated;
     }
+    Values.push_back(Val);
+    AllSame = Val == ToC;
   }
   
   Constant *Replacement = 0;
-  if (isAllZeros) {
+  if (AllSame && ToC->isNullValue()) {
     Replacement = ConstantAggregateZero::get(getType());
+  } else if (AllSame && isa<UndefValue>(ToC)) {
+    Replacement = UndefValue::get(getType());
   } else {
     // Check to see if we have this array type already.
     bool Exists;
@@ -2446,16 +2484,24 @@ void ConstantStruct::replaceUsesOfWithOnConstant(Value *From, Value *To,
   // Fill values with the modified operands of the constant struct.  Also, 
   // compute whether this turns into an all-zeros struct.
   bool isAllZeros = false;
-  if (!ToC->isNullValue()) {
-    for (Use *O = OperandList, *E = OperandList + getNumOperands(); O != E; ++O)
-      Values.push_back(cast<Constant>(O->get()));
-  } else {
+  bool isAllUndef = false;
+  if (ToC->isNullValue()) {
     isAllZeros = true;
     for (Use *O = OperandList, *E = OperandList+getNumOperands(); O != E; ++O) {
       Constant *Val = cast<Constant>(O->get());
       Values.push_back(Val);
       if (isAllZeros) isAllZeros = Val->isNullValue();
     }
+  } else if (isa<UndefValue>(ToC)) {
+    isAllUndef = true;
+    for (Use *O = OperandList, *E = OperandList+getNumOperands(); O != E; ++O) {
+      Constant *Val = cast<Constant>(O->get());
+      Values.push_back(Val);
+      if (isAllUndef) isAllUndef = isa<UndefValue>(Val);
+    }
+  } else {
+    for (Use *O = OperandList, *E = OperandList + getNumOperands(); O != E; ++O)
+      Values.push_back(cast<Constant>(O->get()));
   }
   Values[OperandToUpdate] = ToC;
   
@@ -2464,6 +2510,8 @@ void ConstantStruct::replaceUsesOfWithOnConstant(Value *From, Value *To,
   Constant *Replacement = 0;
   if (isAllZeros) {
     Replacement = ConstantAggregateZero::get(getType());
+  } else if (isAllUndef) {
+    Replacement = UndefValue::get(getType());
   } else {
     // Check to see if we have this struct type already.
     bool Exists;