X-Git-Url: http://demsky.eecs.uci.edu/git/?a=blobdiff_plain;f=lib%2FVMCore%2FConstantFold.cpp;h=a366970fae821d321383d6c8f7deee438e632ca3;hb=9e77f77687bdeece2a66ed9103379f6da3bbc46e;hp=54a79804818262cc53f0264a315fe9d3a5de1229;hpb=7e02b7e600ce8b719b34e2df7a7e44310229564d;p=oota-llvm.git diff --git a/lib/VMCore/ConstantFold.cpp b/lib/VMCore/ConstantFold.cpp index 54a79804818..a366970fae8 100644 --- a/lib/VMCore/ConstantFold.cpp +++ b/lib/VMCore/ConstantFold.cpp @@ -4,9 +4,108 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Optimizations/ConstantHandling.h" +#include "llvm/ConstantHandling.h" +#include "llvm/iPHINode.h" +#include + +AnnotationID ConstRules::AID(AnnotationManager::getID("opt::ConstRules", + &ConstRules::find)); + +// ConstantFoldInstruction - Attempt to constant fold the specified instruction. +// If successful, the constant result is returned, if not, null is returned. +// +Constant *ConstantFoldInstruction(Instruction *I) { + if (PHINode *PN = dyn_cast(I)) { + if (PN->getNumIncomingValues() == 0) + return Constant::getNullValue(PN->getType()); + + Constant *Result = dyn_cast(PN->getIncomingValue(0)); + if (Result == 0) return 0; + + // Handle PHI nodes specially here... + for (unsigned i = 1, e = PN->getNumIncomingValues(); i != e; ++i) + if (PN->getIncomingValue(i) != Result) + return 0; // Not all the same incoming constants... + + // If we reach here, all incoming values are the same constant. + return Result; + } + + Constant *Op0 = 0; + Constant *Op1 = 0; + + if (I->getNumOperands() != 0) { // Get first operand if it's a constant... + Op0 = dyn_cast(I->getOperand(0)); + if (Op0 == 0) return 0; // Not a constant?, can't fold + + if (I->getNumOperands() != 1) { // Get second operand if it's a constant... + Op1 = dyn_cast(I->getOperand(1)); + if (Op1 == 0) return 0; // Not a constant?, can't fold + } + } + + switch (I->getOpcode()) { + case Instruction::Cast: + return ConstRules::get(*Op0)->castTo(Op0, I->getType()); + case Instruction::Not: return ~*Op0; + case Instruction::Add: return *Op0 + *Op1; + case Instruction::Sub: return *Op0 - *Op1; + case Instruction::Mul: return *Op0 * *Op1; + case Instruction::Div: return *Op0 / *Op1; + case Instruction::Rem: return *Op0 % *Op1; + + case Instruction::SetEQ: return *Op0 == *Op1; + case Instruction::SetNE: return *Op0 != *Op1; + case Instruction::SetLE: return *Op0 <= *Op1; + case Instruction::SetGE: return *Op0 >= *Op1; + case Instruction::SetLT: return *Op0 < *Op1; + case Instruction::SetGT: return *Op0 > *Op1; + case Instruction::Shl: return *Op0 << *Op1; + case Instruction::Shr: return *Op0 >> *Op1; + default: + return 0; + } +} + +Constant *ConstantFoldCastInstruction(const Constant *V, const Type *DestTy) { + return ConstRules::get(*V)->castTo(V, DestTy); +} + +Constant *ConstantFoldUnaryInstruction(unsigned Opcode, const Constant *V) { + switch (Opcode) { + case Instruction::Not: return ~*V; + } + return 0; +} + +Constant *ConstantFoldBinaryInstruction(unsigned Opcode, const Constant *V1, + const Constant *V2) { + switch (Opcode) { + case Instruction::Add: return *V1 + *V2; + case Instruction::Sub: return *V1 - *V2; + case Instruction::Mul: return *V1 * *V2; + case Instruction::Div: return *V1 / *V2; + case Instruction::Rem: return *V1 % *V2; + + case Instruction::SetEQ: return *V1 == *V2; + case Instruction::SetNE: return *V1 != *V2; + case Instruction::SetLE: return *V1 <= *V2; + case Instruction::SetGE: return *V1 >= *V2; + case Instruction::SetLT: return *V1 < *V2; + case Instruction::SetGT: return *V1 > *V2; + } + return 0; +} + +Constant *ConstantFoldShiftInstruction(unsigned Opcode, const Constant *V1, + const Constant *V2) { + switch (Opcode) { + case Instruction::Shl: return *V1 << *V2; + case Instruction::Shr: return *V1 >> *V2; + default: return 0; + } +} -namespace opt { //===----------------------------------------------------------------------===// // TemplateRules Class @@ -27,48 +126,131 @@ class TemplateRules : public ConstRules { // Redirecting functions that cast to the appropriate types //===--------------------------------------------------------------------===// - virtual ConstPoolVal *neg(const ConstPoolVal *V) const { - return SubClassName::Neg((const ArgType *)V); - } - - virtual ConstPoolVal *not(const ConstPoolVal *V) const { + virtual Constant *op_not(const Constant *V) const { return SubClassName::Not((const ArgType *)V); } - virtual ConstPoolVal *add(const ConstPoolVal *V1, - const ConstPoolVal *V2) const { + virtual Constant *add(const Constant *V1, + const Constant *V2) const { return SubClassName::Add((const ArgType *)V1, (const ArgType *)V2); } - virtual ConstPoolVal *sub(const ConstPoolVal *V1, - const ConstPoolVal *V2) const { + virtual Constant *sub(const Constant *V1, + const Constant *V2) const { return SubClassName::Sub((const ArgType *)V1, (const ArgType *)V2); } - virtual ConstPoolBool *lessthan(const ConstPoolVal *V1, - const ConstPoolVal *V2) const { + virtual Constant *mul(const Constant *V1, + const Constant *V2) const { + return SubClassName::Mul((const ArgType *)V1, (const ArgType *)V2); + } + virtual Constant *div(const Constant *V1, + const Constant *V2) const { + return SubClassName::Div((const ArgType *)V1, (const ArgType *)V2); + } + virtual Constant *rem(const Constant *V1, + const Constant *V2) const { + return SubClassName::Rem((const ArgType *)V1, (const ArgType *)V2); + } + virtual Constant *shl(const Constant *V1, + const Constant *V2) const { + return SubClassName::Shl((const ArgType *)V1, (const ArgType *)V2); + } + virtual Constant *shr(const Constant *V1, + const Constant *V2) const { + return SubClassName::Shr((const ArgType *)V1, (const ArgType *)V2); + } + + virtual ConstantBool *lessthan(const Constant *V1, + const Constant *V2) const { return SubClassName::LessThan((const ArgType *)V1, (const ArgType *)V2); } + // Casting operators. ick + virtual ConstantBool *castToBool(const Constant *V) const { + return SubClassName::CastToBool((const ArgType*)V); + } + virtual ConstantSInt *castToSByte(const Constant *V) const { + return SubClassName::CastToSByte((const ArgType*)V); + } + virtual ConstantUInt *castToUByte(const Constant *V) const { + return SubClassName::CastToUByte((const ArgType*)V); + } + virtual ConstantSInt *castToShort(const Constant *V) const { + return SubClassName::CastToShort((const ArgType*)V); + } + virtual ConstantUInt *castToUShort(const Constant *V) const { + return SubClassName::CastToUShort((const ArgType*)V); + } + virtual ConstantSInt *castToInt(const Constant *V) const { + return SubClassName::CastToInt((const ArgType*)V); + } + virtual ConstantUInt *castToUInt(const Constant *V) const { + return SubClassName::CastToUInt((const ArgType*)V); + } + virtual ConstantSInt *castToLong(const Constant *V) const { + return SubClassName::CastToLong((const ArgType*)V); + } + virtual ConstantUInt *castToULong(const Constant *V) const { + return SubClassName::CastToULong((const ArgType*)V); + } + virtual ConstantFP *castToFloat(const Constant *V) const { + return SubClassName::CastToFloat((const ArgType*)V); + } + virtual ConstantFP *castToDouble(const Constant *V) const { + return SubClassName::CastToDouble((const ArgType*)V); + } + virtual ConstantPointer *castToPointer(const Constant *V, + const PointerType *Ty) const { + return SubClassName::CastToPointer((const ArgType*)V, Ty); + } + //===--------------------------------------------------------------------===// // Default "noop" implementations //===--------------------------------------------------------------------===// - inline static ConstPoolVal *Neg(const ArgType *V) { return 0; } - inline static ConstPoolVal *Not(const ArgType *V) { return 0; } + inline static Constant *Not(const ArgType *V) { return 0; } - inline static ConstPoolVal *Add(const ArgType *V1, const ArgType *V2) { + inline static Constant *Add(const ArgType *V1, const ArgType *V2) { return 0; } - - inline static ConstPoolVal *Sub(const ArgType *V1, const ArgType *V2) { + inline static Constant *Sub(const ArgType *V1, const ArgType *V2) { return 0; } - - inline static ConstPoolBool *LessThan(const ArgType *V1, const ArgType *V2) { + inline static Constant *Mul(const ArgType *V1, const ArgType *V2) { return 0; } + inline static Constant *Div(const ArgType *V1, const ArgType *V2) { + return 0; + } + inline static Constant *Rem(const ArgType *V1, const ArgType *V2) { + return 0; + } + inline static Constant *Shl(const ArgType *V1, const ArgType *V2) { + return 0; + } + inline static Constant *Shr(const ArgType *V1, const ArgType *V2) { + return 0; + } + inline static ConstantBool *LessThan(const ArgType *V1, const ArgType *V2) { + return 0; + } + + // Casting operators. ick + inline static ConstantBool *CastToBool (const Constant *V) { return 0; } + inline static ConstantSInt *CastToSByte (const Constant *V) { return 0; } + inline static ConstantUInt *CastToUByte (const Constant *V) { return 0; } + inline static ConstantSInt *CastToShort (const Constant *V) { return 0; } + inline static ConstantUInt *CastToUShort(const Constant *V) { return 0; } + inline static ConstantSInt *CastToInt (const Constant *V) { return 0; } + inline static ConstantUInt *CastToUInt (const Constant *V) { return 0; } + inline static ConstantSInt *CastToLong (const Constant *V) { return 0; } + inline static ConstantUInt *CastToULong (const Constant *V) { return 0; } + inline static ConstantFP *CastToFloat (const Constant *V) { return 0; } + inline static ConstantFP *CastToDouble(const Constant *V) { return 0; } + inline static ConstantPointer *CastToPointer(const Constant *, + const PointerType *) {return 0;} }; @@ -79,9 +261,8 @@ class TemplateRules : public ConstRules { // // EmptyRules provides a concrete base class of ConstRules that does nothing // -static // EmptyInst is static -struct EmptyRules : public TemplateRules { -} EmptyInst; +struct EmptyRules : public TemplateRules { +}; @@ -91,25 +272,85 @@ struct EmptyRules : public TemplateRules { // // BoolRules provides a concrete base class of ConstRules for the 'bool' type. // -static // BoolTyInst is static... -struct BoolRules : public TemplateRules { +struct BoolRules : public TemplateRules { - inline static ConstPoolVal *Not(const ConstPoolBool *V) { - return new ConstPoolBool(!V->getValue()); + inline static Constant *Not(const ConstantBool *V) { + return ConstantBool::get(!V->getValue()); } - inline static ConstPoolVal *Or(const ConstPoolBool *V1, - const ConstPoolBool *V2) { - bool Result = V1->getValue() | V2->getValue(); - return new ConstPoolBool(Result); + inline static Constant *Or(const ConstantBool *V1, + const ConstantBool *V2) { + return ConstantBool::get(V1->getValue() | V2->getValue()); } - inline static ConstPoolVal *And(const ConstPoolBool *V1, - const ConstPoolBool *V2) { - bool Result = V1->getValue() & V2->getValue(); - return new ConstPoolBool(Result); + inline static Constant *And(const ConstantBool *V1, + const ConstantBool *V2) { + return ConstantBool::get(V1->getValue() & V2->getValue()); } -} BoolTyInst; +}; + + +//===----------------------------------------------------------------------===// +// PointerRules Class +//===----------------------------------------------------------------------===// +// +// PointerRules provides a concrete base class of ConstRules for pointer types +// +struct PointerRules : public TemplateRules { + inline static ConstantBool *CastToBool (const Constant *V) { + if (V->isNullValue()) return ConstantBool::False; + return 0; // Can't const prop other types of pointers + } + inline static ConstantSInt *CastToSByte (const Constant *V) { + if (V->isNullValue()) return ConstantSInt::get(Type::SByteTy, 0); + return 0; // Can't const prop other types of pointers + } + inline static ConstantUInt *CastToUByte (const Constant *V) { + if (V->isNullValue()) return ConstantUInt::get(Type::UByteTy, 0); + return 0; // Can't const prop other types of pointers + } + inline static ConstantSInt *CastToShort (const Constant *V) { + if (V->isNullValue()) return ConstantSInt::get(Type::ShortTy, 0); + return 0; // Can't const prop other types of pointers + } + inline static ConstantUInt *CastToUShort(const Constant *V) { + if (V->isNullValue()) return ConstantUInt::get(Type::UShortTy, 0); + return 0; // Can't const prop other types of pointers + } + inline static ConstantSInt *CastToInt (const Constant *V) { + if (V->isNullValue()) return ConstantSInt::get(Type::IntTy, 0); + return 0; // Can't const prop other types of pointers + } + inline static ConstantUInt *CastToUInt (const Constant *V) { + if (V->isNullValue()) return ConstantUInt::get(Type::UIntTy, 0); + return 0; // Can't const prop other types of pointers + } + inline static ConstantSInt *CastToLong (const Constant *V) { + if (V->isNullValue()) return ConstantSInt::get(Type::LongTy, 0); + return 0; // Can't const prop other types of pointers + } + inline static ConstantUInt *CastToULong (const Constant *V) { + if (V->isNullValue()) return ConstantUInt::get(Type::ULongTy, 0); + return 0; // Can't const prop other types of pointers + } + inline static ConstantFP *CastToFloat (const Constant *V) { + if (V->isNullValue()) return ConstantFP::get(Type::FloatTy, 0); + return 0; // Can't const prop other types of pointers + } + inline static ConstantFP *CastToDouble(const Constant *V) { + if (V->isNullValue()) return ConstantFP::get(Type::DoubleTy, 0); + return 0; // Can't const prop other types of pointers + } + + inline static ConstantPointer *CastToPointer(const ConstantPointer *V, + const PointerType *PTy) { + if (V->getType() == PTy) + return const_cast(V); // Allow cast %PTy %ptr to %PTy + if (V->isNullValue()) + return ConstantPointerNull::get(PTy); + return 0; // Can't const prop other types of pointers + } +}; //===----------------------------------------------------------------------===// @@ -120,39 +361,131 @@ struct BoolRules : public TemplateRules { // different types. This allows the C++ compiler to automatically generate our // constant handling operations in a typesafe and accurate manner. // -template -struct DirectRules - : public TemplateRules > { - - inline static ConstPoolVal *Neg(const ConstPoolClass *V) { - return new ConstPoolClass(*Ty, -(BuiltinType)V->getValue());; +template +struct DirectRules : public TemplateRules { + inline static Constant *Add(const ConstantClass *V1, + const ConstantClass *V2) { + BuiltinType Result = (BuiltinType)V1->getValue() + + (BuiltinType)V2->getValue(); + return ConstantClass::get(*Ty, Result); } - inline static ConstPoolVal *Not(const ConstPoolClass *V) { - return new ConstPoolClass(*Ty, !(BuiltinType)V->getValue());; + + inline static Constant *Sub(const ConstantClass *V1, + const ConstantClass *V2) { + BuiltinType Result = (BuiltinType)V1->getValue() - + (BuiltinType)V2->getValue(); + return ConstantClass::get(*Ty, Result); } - inline static ConstPoolVal *Add(const ConstPoolClass *V1, - const ConstPoolClass *V2) { - BuiltinType Result = (BuiltinType)V1->getValue() + + inline static Constant *Mul(const ConstantClass *V1, + const ConstantClass *V2) { + BuiltinType Result = (BuiltinType)V1->getValue() * (BuiltinType)V2->getValue(); - return new ConstPoolClass(*Ty, Result); + return ConstantClass::get(*Ty, Result); } - inline static ConstPoolVal *Sub(const ConstPoolClass *V1, - const ConstPoolClass *V2) { - BuiltinType Result = (BuiltinType)V1->getValue() - + inline static Constant *Div(const ConstantClass *V1, + const ConstantClass *V2) { + if (V2->isNullValue()) return 0; + BuiltinType Result = (BuiltinType)V1->getValue() / (BuiltinType)V2->getValue(); - return new ConstPoolClass(*Ty, Result); + return ConstantClass::get(*Ty, Result); } - inline static ConstPoolBool *LessThan(const ConstPoolClass *V1, - const ConstPoolClass *V2) { + inline static ConstantBool *LessThan(const ConstantClass *V1, + const ConstantClass *V2) { bool Result = (BuiltinType)V1->getValue() < (BuiltinType)V2->getValue(); - return new ConstPoolBool(Result); + return ConstantBool::get(Result); } + + inline static ConstantPointer *CastToPointer(const ConstantClass *V, + const PointerType *PTy) { + if (V->isNullValue()) // Is it a FP or Integral null value? + return ConstantPointerNull::get(PTy); + return 0; // Can't const prop other types of pointers + } + + // Casting operators. ick +#define DEF_CAST(TYPE, CLASS, CTYPE) \ + inline static CLASS *CastTo##TYPE (const ConstantClass *V) { \ + return CLASS::get(Type::TYPE##Ty, (CTYPE)(BuiltinType)V->getValue()); \ + } + + DEF_CAST(Bool , ConstantBool, bool) + DEF_CAST(SByte , ConstantSInt, signed char) + DEF_CAST(UByte , ConstantUInt, unsigned char) + DEF_CAST(Short , ConstantSInt, signed short) + DEF_CAST(UShort, ConstantUInt, unsigned short) + DEF_CAST(Int , ConstantSInt, signed int) + DEF_CAST(UInt , ConstantUInt, unsigned int) + DEF_CAST(Long , ConstantSInt, int64_t) + DEF_CAST(ULong , ConstantUInt, uint64_t) + DEF_CAST(Float , ConstantFP , float) + DEF_CAST(Double, ConstantFP , double) +#undef DEF_CAST }; + +//===----------------------------------------------------------------------===// +// DirectIntRules Class +//===----------------------------------------------------------------------===// +// +// DirectIntRules provides implementations of functions that are valid on +// integer types, but not all types in general. +// +template +struct DirectIntRules + : public DirectRules > { + inline static Constant *Not(const ConstantClass *V) { + return ConstantClass::get(*Ty, ~(BuiltinType)V->getValue());; + } + + inline static Constant *Rem(const ConstantClass *V1, + const ConstantClass *V2) { + if (V2->isNullValue()) return 0; + BuiltinType Result = (BuiltinType)V1->getValue() % + (BuiltinType)V2->getValue(); + return ConstantClass::get(*Ty, Result); + } + + inline static Constant *Shl(const ConstantClass *V1, + const ConstantClass *V2) { + BuiltinType Result = (BuiltinType)V1->getValue() << + (BuiltinType)V2->getValue(); + return ConstantClass::get(*Ty, Result); + } + + inline static Constant *Shr(const ConstantClass *V1, + const ConstantClass *V2) { + BuiltinType Result = (BuiltinType)V1->getValue() >> + (BuiltinType)V2->getValue(); + return ConstantClass::get(*Ty, Result); + } +}; + + +//===----------------------------------------------------------------------===// +// DirectFPRules Class +//===----------------------------------------------------------------------===// +// +// DirectFPRules provides implementations of functions that are valid on +// floating point types, but not all types in general. +// +template +struct DirectFPRules + : public DirectRules > { + inline static Constant *Rem(const ConstantClass *V1, + const ConstantClass *V2) { + if (V2->isNullValue()) return 0; + BuiltinType Result = std::fmod((BuiltinType)V1->getValue(), + (BuiltinType)V2->getValue()); + return ConstantClass::get(*Ty, Result); + } +}; + + //===----------------------------------------------------------------------===// // DirectRules Subclasses //===----------------------------------------------------------------------===// @@ -161,42 +494,38 @@ struct DirectRules // code. Thank goodness C++ compilers are great at stomping out layers of // templates... can you imagine having to do this all by hand? (/me is lazy :) // -static DirectRules SByteTyInst; -static DirectRules UByteTyInst; -static DirectRules ShortTyInst; -static DirectRules UShortTyInst; -static DirectRules IntTyInst; -static DirectRules UIntTyInst; -static DirectRules LongTyInst; -static DirectRules ULongTyInst; -static DirectRules FloatTyInst; -static DirectRules DoubleTyInst; - // ConstRules::find - Return the constant rules that take care of the specified -// type. Note that this is cached in the Type value itself, so switch statement -// is only hit at most once per type. +// type. // -const ConstRules *ConstRules::find(const Type *Ty) { - const ConstRules *Result; +Annotation *ConstRules::find(AnnotationID AID, const Annotable *TyA, void *) { + assert(AID == ConstRules::AID && "Bad annotation for factory!"); + const Type *Ty = cast((const Value*)TyA); + switch (Ty->getPrimitiveID()) { - case Type::BoolTyID: Result = &BoolTyInst; break; - case Type::SByteTyID: Result = &SByteTyInst; break; - case Type::UByteTyID: Result = &UByteTyInst; break; - case Type::ShortTyID: Result = &ShortTyInst; break; - case Type::UShortTyID: Result = &UShortTyInst; break; - case Type::IntTyID: Result = &IntTyInst; break; - case Type::UIntTyID: Result = &UIntTyInst; break; - case Type::LongTyID: Result = &LongTyInst; break; - case Type::ULongTyID: Result = &ULongTyInst; break; - case Type::FloatTyID: Result = &FloatTyInst; break; - case Type::DoubleTyID: Result = &DoubleTyInst; break; - default: Result = &EmptyInst; break; - } - - Ty->setConstRules(Result); // Cache the value for future short circuiting! - return Result; + case Type::BoolTyID: return new BoolRules(); + case Type::PointerTyID: return new PointerRules(); + case Type::SByteTyID: + return new DirectIntRules(); + case Type::UByteTyID: + return new DirectIntRules(); + case Type::ShortTyID: + return new DirectIntRules(); + case Type::UShortTyID: + return new DirectIntRules(); + case Type::IntTyID: + return new DirectIntRules(); + case Type::UIntTyID: + return new DirectIntRules(); + case Type::LongTyID: + return new DirectIntRules(); + case Type::ULongTyID: + return new DirectIntRules(); + case Type::FloatTyID: + return new DirectFPRules(); + case Type::DoubleTyID: + return new DirectFPRules(); + default: + return new EmptyRules(); + } } - - -} // End namespace opt