Implement folding of expressions like 'uint cast (int* getelementptr (int*
[oota-llvm.git] / lib / VMCore / ConstantFold.cpp
index 70984cd0a634daef58ae411267a1ed4026efe976..e42de2a7e3951aaca833322b83cb7ed7c6169b0e 100644 (file)
@@ -489,7 +489,7 @@ ConstRules &ConstRules::get(const Constant *V1, const Constant *V2) {
       isa<ConstantPointerRef>(V1) || isa<ConstantPointerRef>(V2))
     return EmptyR;
 
-  switch (V1->getType()->getPrimitiveID()) {
+  switch (V1->getType()->getTypeID()) {
   default: assert(0 && "Unknown value type for constant folding!");
   case Type::BoolTyID:    return BoolR;
   case Type::PointerTyID: return NullPointerR;
@@ -564,7 +564,7 @@ Constant *llvm::ConstantFoldCastInstruction(const Constant *V,
 
   ConstRules &Rules = ConstRules::get(V, V);
 
-  switch (DestTy->getPrimitiveID()) {
+  switch (DestTy->getTypeID()) {
   case Type::BoolTyID:    return Rules.castToBool(V);
   case Type::UByteTyID:   return Rules.castToUByte(V);
   case Type::SByteTyID:   return Rules.castToSByte(V);
@@ -971,6 +971,18 @@ Constant *llvm::ConstantFoldGetElementPtr(const Constant *C,
       assert(Ty != 0 && "Invalid indices for GEP!");
       return ConstantPointerNull::get(PointerType::get(Ty));
     }
+
+    if (IdxList.size() == 1) {
+      const Type *ElTy = cast<PointerType>(C->getType())->getElementType();
+      if (unsigned ElSize = ElTy->getPrimitiveSize()) {
+        // gep null, C is equal to C*sizeof(nullty).  If nullty is a known llvm
+        // type, we can statically fold this.
+        Constant *R = ConstantUInt::get(Type::UIntTy, ElSize);
+        R = ConstantExpr::getCast(R, IdxList[0]->getType());
+        R = ConstantExpr::getMul(R, IdxList[0]);
+        return ConstantExpr::getCast(R, C->getType());
+      }
+    }
   }
 
   if (ConstantExpr *CE = dyn_cast<ConstantExpr>(const_cast<Constant*>(C))) {
@@ -993,11 +1005,14 @@ Constant *llvm::ConstantFoldGetElementPtr(const Constant *C,
         // Add the last index of the source with the first index of the new GEP.
         // Make sure to handle the case when they are actually different types.
         Constant *Combined = CE->getOperand(CE->getNumOperands()-1);
-        if (!IdxList[0]->isNullValue())   // Otherwise it must be an array
+        if (!IdxList[0]->isNullValue()) {  // Otherwise it must be an array
+          const Type *IdxTy = Combined->getType();
+          if (IdxTy != IdxList[0]->getType()) IdxTy = Type::LongTy;
           Combined = 
             ConstantExpr::get(Instruction::Add,
-                              ConstantExpr::getCast(IdxList[0], Type::LongTy),
-                              ConstantExpr::getCast(Combined, Type::LongTy));
+                              ConstantExpr::getCast(IdxList[0], IdxTy),
+                              ConstantExpr::getCast(Combined, IdxTy));
+        }
         
         NewIndices.push_back(Combined);
         NewIndices.insert(NewIndices.end(), IdxList.begin()+1, IdxList.end());