Support opaque type printing a little bit at least
[oota-llvm.git] / lib / VMCore / ConstantFold.cpp
index 2f1fe5f8025e44d92cf6a6e9998e6640baef8411..a366970fae821d321383d6c8f7deee438e632ca3 100644 (file)
 //===----------------------------------------------------------------------===//
 
 #include "llvm/ConstantHandling.h"
+#include "llvm/iPHINode.h"
+#include <cmath>
 
 AnnotationID ConstRules::AID(AnnotationManager::getID("opt::ConstRules",
                                                      &ConstRules::find));
 
+// ConstantFoldInstruction - Attempt to constant fold the specified instruction.
+// If successful, the constant result is returned, if not, null is returned.
+//
+Constant *ConstantFoldInstruction(Instruction *I) {
+  if (PHINode *PN = dyn_cast<PHINode>(I)) {
+    if (PN->getNumIncomingValues() == 0)
+      return Constant::getNullValue(PN->getType());
+    
+    Constant *Result = dyn_cast<Constant>(PN->getIncomingValue(0));
+    if (Result == 0) return 0;
+
+    // Handle PHI nodes specially here...
+    for (unsigned i = 1, e = PN->getNumIncomingValues(); i != e; ++i)
+      if (PN->getIncomingValue(i) != Result)
+        return 0;   // Not all the same incoming constants...
+
+    // If we reach here, all incoming values are the same constant.
+    return Result;
+  }
+
+  Constant *Op0 = 0;
+  Constant *Op1 = 0;
+
+  if (I->getNumOperands() != 0) {    // Get first operand if it's a constant...
+    Op0 = dyn_cast<Constant>(I->getOperand(0));
+    if (Op0 == 0) return 0;          // Not a constant?, can't fold
+
+    if (I->getNumOperands() != 1) {  // Get second operand if it's a constant...
+      Op1 = dyn_cast<Constant>(I->getOperand(1));
+      if (Op1 == 0) return 0;        // Not a constant?, can't fold
+    }
+  }
+
+  switch (I->getOpcode()) {
+  case Instruction::Cast:
+    return ConstRules::get(*Op0)->castTo(Op0, I->getType());
+  case Instruction::Not:     return ~*Op0;
+  case Instruction::Add:     return *Op0 + *Op1;
+  case Instruction::Sub:     return *Op0 - *Op1;
+  case Instruction::Mul:     return *Op0 * *Op1;
+  case Instruction::Div:     return *Op0 / *Op1;
+  case Instruction::Rem:     return *Op0 % *Op1;
+
+  case Instruction::SetEQ:   return *Op0 == *Op1;
+  case Instruction::SetNE:   return *Op0 != *Op1;
+  case Instruction::SetLE:   return *Op0 <= *Op1;
+  case Instruction::SetGE:   return *Op0 >= *Op1;
+  case Instruction::SetLT:   return *Op0 <  *Op1;
+  case Instruction::SetGT:   return *Op0 >  *Op1;
+  case Instruction::Shl:     return *Op0 << *Op1;
+  case Instruction::Shr:     return *Op0 >> *Op1;
+  default:
+    return 0;
+  }
+}
+
+Constant *ConstantFoldCastInstruction(const Constant *V, const Type *DestTy) {
+  return ConstRules::get(*V)->castTo(V, DestTy);
+}
+
+Constant *ConstantFoldUnaryInstruction(unsigned Opcode, const Constant *V) {
+  switch (Opcode) {
+  case Instruction::Not:  return ~*V;
+  }
+  return 0;
+}
+
+Constant *ConstantFoldBinaryInstruction(unsigned Opcode, const Constant *V1,
+                                        const Constant *V2) {
+  switch (Opcode) {
+  case Instruction::Add:     return *V1 + *V2;
+  case Instruction::Sub:     return *V1 - *V2;
+  case Instruction::Mul:     return *V1 * *V2;
+  case Instruction::Div:     return *V1 / *V2;
+  case Instruction::Rem:     return *V1 % *V2;
+
+  case Instruction::SetEQ:   return *V1 == *V2;
+  case Instruction::SetNE:   return *V1 != *V2;
+  case Instruction::SetLE:   return *V1 <= *V2;
+  case Instruction::SetGE:   return *V1 >= *V2;
+  case Instruction::SetLT:   return *V1 <  *V2;
+  case Instruction::SetGT:   return *V1 >  *V2;
+  }
+  return 0;
+}
+
+Constant *ConstantFoldShiftInstruction(unsigned Opcode, const Constant *V1, 
+                                       const Constant *V2) {
+  switch (Opcode) {
+  case Instruction::Shl:     return *V1 << *V2;
+  case Instruction::Shr:     return *V1 >> *V2;
+  default:                   return 0;
+  }
+}
+
+
 //===----------------------------------------------------------------------===//
 //                             TemplateRules Class
 //===----------------------------------------------------------------------===//
@@ -51,6 +149,18 @@ class TemplateRules : public ConstRules {
                         const Constant *V2) const { 
     return SubClassName::Div((const ArgType *)V1, (const ArgType *)V2);  
   }
+  virtual Constant *rem(const Constant *V1, 
+                        const Constant *V2) const { 
+    return SubClassName::Rem((const ArgType *)V1, (const ArgType *)V2);  
+  }
+  virtual Constant *shl(const Constant *V1, 
+                        const Constant *V2) const { 
+    return SubClassName::Shl((const ArgType *)V1, (const ArgType *)V2);  
+  }
+  virtual Constant *shr(const Constant *V1, 
+                        const Constant *V2) const { 
+    return SubClassName::Shr((const ArgType *)V1, (const ArgType *)V2);  
+  }
 
   virtual ConstantBool *lessthan(const Constant *V1, 
                                  const Constant *V2) const { 
@@ -114,6 +224,15 @@ class TemplateRules : public ConstRules {
   inline static Constant *Div(const ArgType *V1, const ArgType *V2) {
     return 0;
   }
+  inline static Constant *Rem(const ArgType *V1, const ArgType *V2) {
+    return 0;
+  }
+  inline static Constant *Shl(const ArgType *V1, const ArgType *V2) {
+    return 0;
+  }
+  inline static Constant *Shr(const ArgType *V1, const ArgType *V2) {
+    return 0;
+  }
   inline static ConstantBool *LessThan(const ArgType *V1, const ArgType *V2) {
     return 0;
   }
@@ -225,6 +344,8 @@ struct PointerRules : public TemplateRules<ConstantPointer, PointerRules> {
 
   inline static ConstantPointer *CastToPointer(const ConstantPointer *V,
                                                const PointerType *PTy) {
+    if (V->getType() == PTy)
+      return const_cast<ConstantPointer*>(V);  // Allow cast %PTy %ptr to %PTy
     if (V->isNullValue())
       return ConstantPointerNull::get(PTy);
     return 0;  // Can't const prop other types of pointers
@@ -240,15 +361,8 @@ struct PointerRules : public TemplateRules<ConstantPointer, PointerRules> {
 // different types.  This allows the C++ compiler to automatically generate our
 // constant handling operations in a typesafe and accurate manner.
 //
-template<class ConstantClass, class BuiltinType, Type **Ty>
-struct DirectRules 
-  : public TemplateRules<ConstantClass, 
-                         DirectRules<ConstantClass, BuiltinType, Ty> > {
-
-  inline static Constant *Not(const ConstantClass *V) { 
-    return ConstantClass::get(*Ty, !(BuiltinType)V->getValue());;
-  }
-
+template<class ConstantClass, class BuiltinType, Type **Ty, class SuperClass>
+struct DirectRules : public TemplateRules<ConstantClass, SuperClass> {
   inline static Constant *Add(const ConstantClass *V1, 
                               const ConstantClass *V2) {
     BuiltinType Result = (BuiltinType)V1->getValue() + 
@@ -270,8 +384,9 @@ struct DirectRules
     return ConstantClass::get(*Ty, Result);
   }
 
-  inline static Constant *Div(const ConstantClass *V1, 
+  inline static Constant *Div(const ConstantClass *V1,
                               const ConstantClass *V2) {
+    if (V2->isNullValue()) return 0;
     BuiltinType Result = (BuiltinType)V1->getValue() /
                          (BuiltinType)V2->getValue();
     return ConstantClass::get(*Ty, Result);
@@ -310,6 +425,67 @@ struct DirectRules
 #undef DEF_CAST
 };
 
+
+//===----------------------------------------------------------------------===//
+//                           DirectIntRules Class
+//===----------------------------------------------------------------------===//
+//
+// DirectIntRules provides implementations of functions that are valid on
+// integer types, but not all types in general.
+//
+template <class ConstantClass, class BuiltinType, Type **Ty>
+struct DirectIntRules
+  : public DirectRules<ConstantClass, BuiltinType, Ty,
+                       DirectIntRules<ConstantClass, BuiltinType, Ty> > {
+  inline static Constant *Not(const ConstantClass *V) { 
+    return ConstantClass::get(*Ty, ~(BuiltinType)V->getValue());;
+  }
+
+  inline static Constant *Rem(const ConstantClass *V1,
+                              const ConstantClass *V2) {
+    if (V2->isNullValue()) return 0;
+    BuiltinType Result = (BuiltinType)V1->getValue() %
+                         (BuiltinType)V2->getValue();
+    return ConstantClass::get(*Ty, Result);
+  }
+
+  inline static Constant *Shl(const ConstantClass *V1,
+                              const ConstantClass *V2) {
+    BuiltinType Result = (BuiltinType)V1->getValue() <<
+                         (BuiltinType)V2->getValue();
+    return ConstantClass::get(*Ty, Result);
+  }
+
+  inline static Constant *Shr(const ConstantClass *V1,
+                              const ConstantClass *V2) {
+    BuiltinType Result = (BuiltinType)V1->getValue() >>
+                         (BuiltinType)V2->getValue();
+    return ConstantClass::get(*Ty, Result);
+  }
+};
+
+
+//===----------------------------------------------------------------------===//
+//                           DirectFPRules Class
+//===----------------------------------------------------------------------===//
+//
+// DirectFPRules provides implementations of functions that are valid on
+// floating point types, but not all types in general.
+//
+template <class ConstantClass, class BuiltinType, Type **Ty>
+struct DirectFPRules
+  : public DirectRules<ConstantClass, BuiltinType, Ty,
+                       DirectFPRules<ConstantClass, BuiltinType, Ty> > {
+  inline static Constant *Rem(const ConstantClass *V1,
+                              const ConstantClass *V2) {
+    if (V2->isNullValue()) return 0;
+    BuiltinType Result = std::fmod((BuiltinType)V1->getValue(),
+                                   (BuiltinType)V2->getValue());
+    return ConstantClass::get(*Ty, Result);
+  }
+};
+
+
 //===----------------------------------------------------------------------===//
 //                            DirectRules Subclasses
 //===----------------------------------------------------------------------===//
@@ -330,25 +506,25 @@ Annotation *ConstRules::find(AnnotationID AID, const Annotable *TyA, void *) {
   case Type::BoolTyID:    return new BoolRules();
   case Type::PointerTyID: return new PointerRules();
   case Type::SByteTyID:
-    return new DirectRules<ConstantSInt,   signed char , &Type::SByteTy>();
+    return new DirectIntRules<ConstantSInt,   signed char , &Type::SByteTy>();
   case Type::UByteTyID:
-    return new DirectRules<ConstantUInt, unsigned char , &Type::UByteTy>();
+    return new DirectIntRules<ConstantUInt, unsigned char , &Type::UByteTy>();
   case Type::ShortTyID:
-    return new DirectRules<ConstantSInt,   signed short, &Type::ShortTy>();
+    return new DirectIntRules<ConstantSInt,   signed short, &Type::ShortTy>();
   case Type::UShortTyID:
-    return new DirectRules<ConstantUInt, unsigned short, &Type::UShortTy>();
+    return new DirectIntRules<ConstantUInt, unsigned short, &Type::UShortTy>();
   case Type::IntTyID:
-    return new DirectRules<ConstantSInt,   signed int  , &Type::IntTy>();
+    return new DirectIntRules<ConstantSInt,   signed int  , &Type::IntTy>();
   case Type::UIntTyID:
-    return new DirectRules<ConstantUInt, unsigned int  , &Type::UIntTy>();
+    return new DirectIntRules<ConstantUInt, unsigned int  , &Type::UIntTy>();
   case Type::LongTyID:
-    return new DirectRules<ConstantSInt,  int64_t      , &Type::LongTy>();
+    return new DirectIntRules<ConstantSInt,  int64_t      , &Type::LongTy>();
   case Type::ULongTyID:
-    return new DirectRules<ConstantUInt, uint64_t      , &Type::ULongTy>();
+    return new DirectIntRules<ConstantUInt, uint64_t      , &Type::ULongTy>();
   case Type::FloatTyID:
-    return new DirectRules<ConstantFP  , float         , &Type::FloatTy>();
+    return new DirectFPRules<ConstantFP  , float         , &Type::FloatTy>();
   case Type::DoubleTyID:
-    return new DirectRules<ConstantFP  , double        , &Type::DoubleTy>();
+    return new DirectFPRules<ConstantFP  , double        , &Type::DoubleTy>();
   default:
     return new EmptyRules();
   }