Convert to SymbolTable's new iteration interface.
[oota-llvm.git] / lib / VMCore / Constants.cpp
index dac23aea55c07ea8b021fbb2bf13438c81535fea..a393be98ae3899434f3f245bf92aa13da27e999b 100644 (file)
@@ -267,14 +267,26 @@ ConstantStruct::ConstantStruct(const StructType *T,
 
 ConstantPointerRef::ConstantPointerRef(GlobalValue *GV)
   : Constant(GV->getType()) {
+  Operands.reserve(1);
   Operands.push_back(Use(GV, this));
 }
 
 ConstantExpr::ConstantExpr(unsigned Opcode, Constant *C, const Type *Ty)
   : Constant(Ty), iType(Opcode) {
+  Operands.reserve(1);
   Operands.push_back(Use(C, this));
 }
 
+// Select instruction creation ctor
+ConstantExpr::ConstantExpr(Constant *C, Constant *V1, Constant *V2)
+  : Constant(V1->getType()), iType(Instruction::Select) {
+  Operands.reserve(3);
+  Operands.push_back(Use(C, this));
+  Operands.push_back(Use(V1, this));
+  Operands.push_back(Use(V2, this));
+}
+
+
 static bool isSetCC(unsigned Opcode) {
   return Opcode == Instruction::SetEQ || Opcode == Instruction::SetNE ||
          Opcode == Instruction::SetLT || Opcode == Instruction::SetGT ||
@@ -283,6 +295,7 @@ static bool isSetCC(unsigned Opcode) {
 
 ConstantExpr::ConstantExpr(unsigned Opcode, Constant *C1, Constant *C2)
   : Constant(isSetCC(Opcode) ? Type::BoolTy : C1->getType()), iType(Opcode) {
+  Operands.reserve(2);
   Operands.push_back(Use(C1, this));
   Operands.push_back(Use(C2, this));
 }
@@ -296,6 +309,80 @@ ConstantExpr::ConstantExpr(Constant *C, const std::vector<Constant*> &IdxList,
     Operands.push_back(Use(IdxList[i], this));
 }
 
+/// ConstantExpr::get* - Return some common constants without having to
+/// specify the full Instruction::OPCODE identifier.
+///
+Constant *ConstantExpr::getNeg(Constant *C) {
+  if (!C->getType()->isFloatingPoint())
+    return get(Instruction::Sub, getNullValue(C->getType()), C);
+  else
+    return get(Instruction::Sub, ConstantFP::get(C->getType(), -0.0), C);
+}
+Constant *ConstantExpr::getNot(Constant *C) {
+  assert(isa<ConstantIntegral>(C) && "Cannot NOT a nonintegral type!");
+  return get(Instruction::Xor, C,
+             ConstantIntegral::getAllOnesValue(C->getType()));
+}
+Constant *ConstantExpr::getAdd(Constant *C1, Constant *C2) {
+  return get(Instruction::Add, C1, C2);
+}
+Constant *ConstantExpr::getSub(Constant *C1, Constant *C2) {
+  return get(Instruction::Sub, C1, C2);
+}
+Constant *ConstantExpr::getMul(Constant *C1, Constant *C2) {
+  return get(Instruction::Mul, C1, C2);
+}
+Constant *ConstantExpr::getDiv(Constant *C1, Constant *C2) {
+  return get(Instruction::Div, C1, C2);
+}
+Constant *ConstantExpr::getRem(Constant *C1, Constant *C2) {
+  return get(Instruction::Rem, C1, C2);
+}
+Constant *ConstantExpr::getAnd(Constant *C1, Constant *C2) {
+  return get(Instruction::And, C1, C2);
+}
+Constant *ConstantExpr::getOr(Constant *C1, Constant *C2) {
+  return get(Instruction::Or, C1, C2);
+}
+Constant *ConstantExpr::getXor(Constant *C1, Constant *C2) {
+  return get(Instruction::Xor, C1, C2);
+}
+Constant *ConstantExpr::getSetEQ(Constant *C1, Constant *C2) {
+  return get(Instruction::SetEQ, C1, C2);
+}
+Constant *ConstantExpr::getSetNE(Constant *C1, Constant *C2) {
+  return get(Instruction::SetNE, C1, C2);
+}
+Constant *ConstantExpr::getSetLT(Constant *C1, Constant *C2) {
+  return get(Instruction::SetLT, C1, C2);
+}
+Constant *ConstantExpr::getSetGT(Constant *C1, Constant *C2) {
+  return get(Instruction::SetGT, C1, C2);
+}
+Constant *ConstantExpr::getSetLE(Constant *C1, Constant *C2) {
+  return get(Instruction::SetLE, C1, C2);
+}
+Constant *ConstantExpr::getSetGE(Constant *C1, Constant *C2) {
+  return get(Instruction::SetGE, C1, C2);
+}
+Constant *ConstantExpr::getShl(Constant *C1, Constant *C2) {
+  return get(Instruction::Shl, C1, C2);
+}
+Constant *ConstantExpr::getShr(Constant *C1, Constant *C2) {
+  return get(Instruction::Shr, C1, C2);
+}
+
+Constant *ConstantExpr::getUShr(Constant *C1, Constant *C2) {
+  if (C1->getType()->isUnsigned()) return getShr(C1, C2);
+  return getCast(getShr(getCast(C1,
+                    C1->getType()->getUnsignedVersion()), C2), C1->getType());
+}
+
+Constant *ConstantExpr::getSShr(Constant *C1, Constant *C2) {
+  if (C1->getType()->isSigned()) return getShr(C1, C2);
+  return getCast(getShr(getCast(C1,
+                        C1->getType()->getSignedVersion()), C2), C1->getType());
+}
 
 
 //===----------------------------------------------------------------------===//
@@ -494,6 +581,14 @@ void ConstantExpr::replaceUsesOfWithOnConstant(Value *From, Value *ToV,
   } else if (getOpcode() == Instruction::Cast) {
     assert(getOperand(0) == From && "Cast only has one use!");
     Replacement = ConstantExpr::getCast(To, getType());
+  } else if (getOpcode() == Instruction::Select) {
+    Constant *C1 = getOperand(0);
+    Constant *C2 = getOperand(1);
+    Constant *C3 = getOperand(2);
+    if (C1 == From) C1 = To;
+    if (C2 == From) C2 = To;
+    if (C3 == From) C3 = To;
+    Replacement = ConstantExpr::getSelect(C1, C2, C3);
   } else if (getNumOperands() == 2) {
     Constant *C1 = getOperand(0);
     Constant *C2 = getOperand(1);
@@ -981,6 +1076,8 @@ namespace llvm {
            V.first < Instruction::BinaryOpsEnd) ||
           V.first == Instruction::Shl || V.first == Instruction::Shr)
         return new ConstantExpr(V.first, V.second[0], V.second[1]);
+      if (V.first == Instruction::Select)
+        return new ConstantExpr(V.second[0], V.second[1], V.second[2]);
       
       assert(V.first == Instruction::GetElementPtr && "Invalid ConstantExpr!");
       
@@ -997,6 +1094,11 @@ namespace llvm {
       case Instruction::Cast:
         New = ConstantExpr::getCast(OldC->getOperand(0), NewTy);
         break;
+      case Instruction::Select:
+        New = ConstantExpr::getSelectTy(NewTy, OldC->getOperand(0),
+                                        OldC->getOperand(1),
+                                        OldC->getOperand(2));
+        break;
       case Instruction::Shl:
       case Instruction::Shr:
         New = ConstantExpr::getShiftTy(NewTy, OldC->getOpcode(),
@@ -1039,6 +1141,22 @@ Constant *ConstantExpr::getCast(Constant *C, const Type *Ty) {
   return ExprConstants.getOrCreate(Ty, Key);
 }
 
+Constant *ConstantExpr::getSignExtend(Constant *C, const Type *Ty) {
+  assert(C->getType()->isInteger() && Ty->isInteger() &&
+         C->getType()->getPrimitiveSize() <= Ty->getPrimitiveSize() &&
+         "This is an illegal sign extension!");
+  C = ConstantExpr::getCast(C, C->getType()->getSignedVersion());
+  return ConstantExpr::getCast(C, Ty);
+}
+
+Constant *ConstantExpr::getZeroExtend(Constant *C, const Type *Ty) {
+  assert(C->getType()->isInteger() && Ty->isInteger() &&
+         C->getType()->getPrimitiveSize() <= Ty->getPrimitiveSize() &&
+         "This is an illegal zero extension!");
+  C = ConstantExpr::getCast(C, C->getType()->getUnsignedVersion());
+  return ConstantExpr::getCast(C, Ty);
+}
+
 Constant *ConstantExpr::getTy(const Type *ReqTy, unsigned Opcode,
                               Constant *C1, Constant *C2) {
   if (Opcode == Instruction::Shl || Opcode == Instruction::Shr)
@@ -1059,6 +1177,23 @@ Constant *ConstantExpr::getTy(const Type *ReqTy, unsigned Opcode,
   return ExprConstants.getOrCreate(ReqTy, Key);
 }
 
+Constant *ConstantExpr::getSelectTy(const Type *ReqTy, Constant *C,
+                                    Constant *V1, Constant *V2) {
+  assert(C->getType() == Type::BoolTy && "Select condition must be bool!");
+  assert(V1->getType() == V2->getType() && "Select value types must match!");
+  assert(V1->getType()->isFirstClassType() && "Cannot select aggregate type!");
+
+  if (ReqTy == V1->getType())
+    if (Constant *SC = ConstantFoldSelectInstruction(C, V1, V2))
+      return SC;        // Fold common cases
+
+  std::vector<Constant*> argVec(3, C);
+  argVec[1] = V1;
+  argVec[2] = V2;
+  ExprMapKeyType Key = std::make_pair(Instruction::Select, argVec);
+  return ExprConstants.getOrCreate(ReqTy, Key);
+}
+
 /// getShiftTy - Return a shift left or shift right constant expr
 Constant *ConstantExpr::getShiftTy(const Type *ReqTy, unsigned Opcode,
                                    Constant *C1, Constant *C2) {
@@ -1081,11 +1216,15 @@ Constant *ConstantExpr::getShiftTy(const Type *ReqTy, unsigned Opcode,
 
 Constant *ConstantExpr::getGetElementPtrTy(const Type *ReqTy, Constant *C,
                                         const std::vector<Constant*> &IdxList) {
+  assert(GetElementPtrInst::getIndexedType(C->getType(),
+                   std::vector<Value*>(IdxList.begin(), IdxList.end()), true) &&
+         "GEP indices invalid!");
+
   if (Constant *FC = ConstantFoldGetElementPtr(C, IdxList))
     return FC;          // Fold a few common cases...
+
   assert(isa<PointerType>(C->getType()) &&
          "Non-pointer type for constant GetElementPtr expression");
-
   // Look up the constant in the table first to ensure uniqueness
   std::vector<Constant*> argVec(1, C);
   argVec.insert(argVec.end(), IdxList.begin(), IdxList.end());
@@ -1101,17 +1240,6 @@ Constant *ConstantExpr::getGetElementPtr(Constant *C,
   const Type *Ty = GetElementPtrInst::getIndexedType(C->getType(), VIdxList,
                                                      true);
   assert(Ty && "GEP indices invalid!");
-
-  if (C->isNullValue()) {
-    bool isNull = true;
-    for (unsigned i = 0, e = IdxList.size(); i != e; ++i)
-      if (!IdxList[i]->isNullValue()) {
-        isNull = false;
-        break;
-      }
-    if (isNull) return ConstantPointerNull::get(PointerType::get(Ty));
-  }
-
   return getGetElementPtrTy(PointerType::get(Ty), C, IdxList);
 }
 
@@ -1126,26 +1254,3 @@ void ConstantExpr::destroyConstant() {
 const char *ConstantExpr::getOpcodeName() const {
   return Instruction::getOpcodeName(getOpcode());
 }
-
-unsigned Constant::mutateReferences(Value *OldV, Value *NewV) {
-  // Uses of constant pointer refs are global values, not constants!
-  if (ConstantPointerRef *CPR = dyn_cast<ConstantPointerRef>(this)) {
-    GlobalValue *NewGV = cast<GlobalValue>(NewV);
-    GlobalValue *OldGV = CPR->getValue();
-
-    assert(OldGV == OldV && "Cannot mutate old value if I'm not using it!");
-    Operands[0] = NewGV;
-    OldGV->getParent()->mutateConstantPointerRef(OldGV, NewGV);
-    return 1;
-  } else {
-    Constant *NewC = cast<Constant>(NewV);
-    unsigned NumReplaced = 0;
-    for (unsigned i = 0, N = getNumOperands(); i != N; ++i)
-      if (Operands[i] == OldV) {
-        ++NumReplaced;
-        Operands[i] = NewC;
-      }
-    return NumReplaced;
-  }
-}
-