Remove a whole bunch more ugliness. This is actually getting to the point of
authorChris Lattner <sabre@nondot.org>
Mon, 12 Jan 2004 21:02:29 +0000 (21:02 +0000)
committerChris Lattner <sabre@nondot.org>
Mon, 12 Jan 2004 21:02:29 +0000 (21:02 +0000)
this whole refactoring: allow constant folding methods to return something
other than predefined classes, allow them to return generic Constant*'s.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@10806 91177308-0d34-0410-b5e6-96231b3b80d8

lib/VMCore/ConstantFold.cpp
lib/VMCore/ConstantFold.h
lib/VMCore/ConstantFolding.h

index 45b021bb7dc80bfa9296064efc5b2a6ba125d3f0..ddec284942a238ae7665bb56b6a184b4c384ad81 100644 (file)
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "ConstantHandling.h"
+#include "llvm/Constants.h"
 #include "llvm/iPHINode.h"
 #include "llvm/InstrTypes.h"
 #include "llvm/DerivedTypes.h"
@@ -57,7 +58,24 @@ Constant *llvm::ConstantFoldCastInstruction(const Constant *V,
         return ConstantExpr::getCast(CE->getOperand(0), DestTy);
     }
 
-  return ConstRules::get(V, V).castTo(V, DestTy);
+  ConstRules &Rules = ConstRules::get(V, V);
+
+  switch (DestTy->getPrimitiveID()) {
+  case Type::BoolTyID:    return Rules.castToBool(V);
+  case Type::UByteTyID:   return Rules.castToUByte(V);
+  case Type::SByteTyID:   return Rules.castToSByte(V);
+  case Type::UShortTyID:  return Rules.castToUShort(V);
+  case Type::ShortTyID:   return Rules.castToShort(V);
+  case Type::UIntTyID:    return Rules.castToUInt(V);
+  case Type::IntTyID:     return Rules.castToInt(V);
+  case Type::ULongTyID:   return Rules.castToULong(V);
+  case Type::LongTyID:    return Rules.castToLong(V);
+  case Type::FloatTyID:   return Rules.castToFloat(V);
+  case Type::DoubleTyID:  return Rules.castToDouble(V);
+  case Type::PointerTyID:
+    return Rules.castToPointer(V, cast<PointerType>(DestTy));
+  default: return 0;
+  }
 }
 
 Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode,
@@ -209,47 +227,45 @@ class TemplateRules : public ConstRules {
     return SubClassName::Shr((const ArgType *)V1, (const ArgType *)V2);  
   }
 
-  virtual ConstantBool *lessthan(const Constant *V1, 
-                                 const Constant *V2) const { 
+  virtual Constant *lessthan(const Constant *V1, const Constant *V2) const { 
     return SubClassName::LessThan((const ArgType *)V1, (const ArgType *)V2);
   }
-  virtual ConstantBool *equalto(const Constant *V1, 
-                                const Constant *V2) const { 
+  virtual Constant *equalto(const Constant *V1, const Constant *V2) const { 
     return SubClassName::EqualTo((const ArgType *)V1, (const ArgType *)V2);
   }
 
   // Casting operators.  ick
-  virtual ConstantBool *castToBool(const Constant *V) const {
+  virtual Constant *castToBool(const Constant *V) const {
     return SubClassName::CastToBool((const ArgType*)V);
   }
-  virtual ConstantSInt *castToSByte(const Constant *V) const {
+  virtual Constant *castToSByte(const Constant *V) const {
     return SubClassName::CastToSByte((const ArgType*)V);
   }
-  virtual ConstantUInt *castToUByte(const Constant *V) const {
+  virtual Constant *castToUByte(const Constant *V) const {
     return SubClassName::CastToUByte((const ArgType*)V);
   }
-  virtual ConstantSInt *castToShort(const Constant *V) const {
+  virtual Constant *castToShort(const Constant *V) const {
     return SubClassName::CastToShort((const ArgType*)V);
   }
-  virtual ConstantUInt *castToUShort(const Constant *V) const {
+  virtual Constant *castToUShort(const Constant *V) const {
     return SubClassName::CastToUShort((const ArgType*)V);
   }
-  virtual ConstantSInt *castToInt(const Constant *V) const {
+  virtual Constant *castToInt(const Constant *V) const {
     return SubClassName::CastToInt((const ArgType*)V);
   }
-  virtual ConstantUInt *castToUInt(const Constant *V) const {
+  virtual Constant *castToUInt(const Constant *V) const {
     return SubClassName::CastToUInt((const ArgType*)V);
   }
-  virtual ConstantSInt *castToLong(const Constant *V) const {
+  virtual Constant *castToLong(const Constant *V) const {
     return SubClassName::CastToLong((const ArgType*)V);
   }
-  virtual ConstantUInt *castToULong(const Constant *V) const {
+  virtual Constant *castToULong(const Constant *V) const {
     return SubClassName::CastToULong((const ArgType*)V);
   }
-  virtual ConstantFP   *castToFloat(const Constant *V) const {
+  virtual Constant *castToFloat(const Constant *V) const {
     return SubClassName::CastToFloat((const ArgType*)V);
   }
-  virtual ConstantFP   *castToDouble(const Constant *V) const {
+  virtual Constant *castToDouble(const Constant *V) const {
     return SubClassName::CastToDouble((const ArgType*)V);
   }
   virtual Constant *castToPointer(const Constant *V, 
@@ -271,27 +287,27 @@ class TemplateRules : public ConstRules {
   static Constant *Xor(const ArgType *V1, const ArgType *V2) { return 0; }
   static Constant *Shl(const ArgType *V1, const ArgType *V2) { return 0; }
   static Constant *Shr(const ArgType *V1, const ArgType *V2) { return 0; }
-  static ConstantBool *LessThan(const ArgType *V1, const ArgType *V2) {
+  static Constant *LessThan(const ArgType *V1, const ArgType *V2) {
     return 0;
   }
-  static ConstantBool *EqualTo(const ArgType *V1, const ArgType *V2) {
+  static Constant *EqualTo(const ArgType *V1, const ArgType *V2) {
     return 0;
   }
 
   // Casting operators.  ick
-  static ConstantBool *CastToBool  (const Constant *V) { return 0; }
-  static ConstantSInt *CastToSByte (const Constant *V) { return 0; }
-  static ConstantUInt *CastToUByte (const Constant *V) { return 0; }
-  static ConstantSInt *CastToShort (const Constant *V) { return 0; }
-  static ConstantUInt *CastToUShort(const Constant *V) { return 0; }
-  static ConstantSInt *CastToInt   (const Constant *V) { return 0; }
-  static ConstantUInt *CastToUInt  (const Constant *V) { return 0; }
-  static ConstantSInt *CastToLong  (const Constant *V) { return 0; }
-  static ConstantUInt *CastToULong (const Constant *V) { return 0; }
-  static ConstantFP   *CastToFloat (const Constant *V) { return 0; }
-  static ConstantFP   *CastToDouble(const Constant *V) { return 0; }
-  static Constant     *CastToPointer(const Constant *,
-                                     const PointerType *) {return 0;}
+  static Constant *CastToBool  (const Constant *V) { return 0; }
+  static Constant *CastToSByte (const Constant *V) { return 0; }
+  static Constant *CastToUByte (const Constant *V) { return 0; }
+  static Constant *CastToShort (const Constant *V) { return 0; }
+  static Constant *CastToUShort(const Constant *V) { return 0; }
+  static Constant *CastToInt   (const Constant *V) { return 0; }
+  static Constant *CastToUInt  (const Constant *V) { return 0; }
+  static Constant *CastToLong  (const Constant *V) { return 0; }
+  static Constant *CastToULong (const Constant *V) { return 0; }
+  static Constant *CastToFloat (const Constant *V) { return 0; }
+  static Constant *CastToDouble(const Constant *V) { return 0; }
+  static Constant *CastToPointer(const Constant *,
+                                 const PointerType *) {return 0;}
 };
 
 
@@ -303,7 +319,7 @@ class TemplateRules : public ConstRules {
 // EmptyRules provides a concrete base class of ConstRules that does nothing
 //
 struct EmptyRules : public TemplateRules<Constant, EmptyRules> {
-  static ConstantBool *EqualTo(const Constant *V1, const Constant *V2) {
+  static Constant *EqualTo(const Constant *V1, const Constant *V2) {
     if (V1 == V2) return ConstantBool::True;
     return 0;
   }
@@ -319,11 +335,11 @@ struct EmptyRules : public TemplateRules<Constant, EmptyRules> {
 //
 struct BoolRules : public TemplateRules<ConstantBool, BoolRules> {
 
-  static ConstantBool *LessThan(const ConstantBool *V1, const ConstantBool *V2){
+  static Constant *LessThan(const ConstantBool *V1, const ConstantBool *V2){
     return ConstantBool::get(V1->getValue() < V2->getValue());
   }
 
-  static ConstantBool *EqualTo(const Constant *V1, const Constant *V2) {
+  static Constant *EqualTo(const Constant *V1, const Constant *V2) {
     return ConstantBool::get(V1 == V2);
   }
 
@@ -341,7 +357,7 @@ struct BoolRules : public TemplateRules<ConstantBool, BoolRules> {
 
   // Casting operators.  ick
 #define DEF_CAST(TYPE, CLASS, CTYPE) \
-  static CLASS *CastTo##TYPE  (const ConstantBool *V) {    \
+  static Constant *CastTo##TYPE  (const ConstantBool *V) {    \
     return CLASS::get(Type::TYPE##Ty, (CTYPE)(bool)V->getValue()); \
   }
 
@@ -369,40 +385,40 @@ struct BoolRules : public TemplateRules<ConstantBool, BoolRules> {
 //
 struct NullPointerRules : public TemplateRules<ConstantPointerNull,
                                                NullPointerRules> {
-  static ConstantBool *EqualTo(const Constant *V1, const Constant *V2) {
+  static Constant *EqualTo(const Constant *V1, const Constant *V2) {
     return ConstantBool::True;  // Null pointers are always equal
   }
-  static ConstantBool *CastToBool  (const Constant *V) {
+  static Constant *CastToBool(const Constant *V) {
     return ConstantBool::False;
   }
-  static ConstantSInt *CastToSByte (const Constant *V) {
+  static Constant *CastToSByte (const Constant *V) {
     return ConstantSInt::get(Type::SByteTy, 0);
   }
-  static ConstantUInt *CastToUByte (const Constant *V) {
+  static Constant *CastToUByte (const Constant *V) {
     return ConstantUInt::get(Type::UByteTy, 0);
   }
-  static ConstantSInt *CastToShort (const Constant *V) {
+  static Constant *CastToShort (const Constant *V) {
     return ConstantSInt::get(Type::ShortTy, 0);
   }
-  static ConstantUInt *CastToUShort(const Constant *V) {
+  static Constant *CastToUShort(const Constant *V) {
     return ConstantUInt::get(Type::UShortTy, 0);
   }
-  static ConstantSInt *CastToInt   (const Constant *V) {
+  static Constant *CastToInt   (const Constant *V) {
     return ConstantSInt::get(Type::IntTy, 0);
   }
-  static ConstantUInt *CastToUInt  (const Constant *V) {
+  static Constant *CastToUInt  (const Constant *V) {
     return ConstantUInt::get(Type::UIntTy, 0);
   }
-  static ConstantSInt *CastToLong  (const Constant *V) {
+  static Constant *CastToLong  (const Constant *V) {
     return ConstantSInt::get(Type::LongTy, 0);
   }
-  static ConstantUInt *CastToULong (const Constant *V) {
+  static Constant *CastToULong (const Constant *V) {
     return ConstantUInt::get(Type::ULongTy, 0);
   }
-  static ConstantFP   *CastToFloat (const Constant *V) {
+  static Constant *CastToFloat (const Constant *V) {
     return ConstantFP::get(Type::FloatTy, 0);
   }
-  static ConstantFP   *CastToDouble(const Constant *V) {
+  static Constant *CastToDouble(const Constant *V) {
     return ConstantFP::get(Type::DoubleTy, 0);
   }
 
@@ -444,14 +460,12 @@ struct DirectRules : public TemplateRules<ConstantClass, SuperClass> {
     return ConstantClass::get(*Ty, R);
   }
 
-  static ConstantBool *LessThan(const ConstantClass *V1,
-                                const ConstantClass *V2) {
+  static Constant *LessThan(const ConstantClass *V1, const ConstantClass *V2) {
     bool R = (BuiltinType)V1->getValue() < (BuiltinType)V2->getValue();
     return ConstantBool::get(R);
   } 
 
-  static ConstantBool *EqualTo(const ConstantClass *V1,
-                               const ConstantClass *V2) {
+  static Constant *EqualTo(const ConstantClass *V1, const ConstantClass *V2) {
     bool R = (BuiltinType)V1->getValue() == (BuiltinType)V2->getValue();
     return ConstantBool::get(R);
   }
@@ -465,7 +479,7 @@ struct DirectRules : public TemplateRules<ConstantClass, SuperClass> {
 
   // Casting operators.  ick
 #define DEF_CAST(TYPE, CLASS, CTYPE) \
-  static CLASS *CastTo##TYPE  (const ConstantClass *V) {    \
+  static Constant *CastTo##TYPE  (const ConstantClass *V) {    \
     return CLASS::get(Type::TYPE##Ty, (CTYPE)(BuiltinType)V->getValue()); \
   }
 
index 8475e44f908a413055e33e4569c7dd1210b3fb24..dc5d0cfbfaea3bdad5c06ded2eb480acb8678e23 100644 (file)
 #ifndef CONSTANTHANDLING_H
 #define CONSTANTHANDLING_H
 
-#include "llvm/Constants.h"
-#include "llvm/Type.h"
+#include <vector>
 
 namespace llvm {
-
-class PointerType;
+  class Constant;
+  class Type;
+  class PointerType;
 
 struct ConstRules {
   ConstRules() {}
@@ -37,44 +37,24 @@ struct ConstRules {
   virtual Constant *shl(const Constant *V1, const Constant *V2) const = 0;
   virtual Constant *shr(const Constant *V1, const Constant *V2) const = 0;
 
-  virtual ConstantBool *lessthan(const Constant *V1, 
-                                 const Constant *V2) const = 0;
-  virtual ConstantBool *equalto(const Constant *V1, 
-                                const Constant *V2) const = 0;
+  virtual Constant *lessthan(const Constant *V1, const Constant *V2) const = 0;
+                             
+  virtual Constant *equalto(const Constant *V1, const Constant *V2) const = 0;
 
   // Casting operators.  ick
-  virtual ConstantBool *castToBool  (const Constant *V) const = 0;
-  virtual ConstantSInt *castToSByte (const Constant *V) const = 0;
-  virtual ConstantUInt *castToUByte (const Constant *V) const = 0;
-  virtual ConstantSInt *castToShort (const Constant *V) const = 0;
-  virtual ConstantUInt *castToUShort(const Constant *V) const = 0;
-  virtual ConstantSInt *castToInt   (const Constant *V) const = 0;
-  virtual ConstantUInt *castToUInt  (const Constant *V) const = 0;
-  virtual ConstantSInt *castToLong  (const Constant *V) const = 0;
-  virtual ConstantUInt *castToULong (const Constant *V) const = 0;
-  virtual ConstantFP   *castToFloat (const Constant *V) const = 0;
-  virtual ConstantFP   *castToDouble(const Constant *V) const = 0;
-  virtual Constant     *castToPointer(const Constant *V,
-                                      const PointerType *Ty) const = 0;
-
-  inline Constant *castTo(const Constant *V, const Type *Ty) const {
-    switch (Ty->getPrimitiveID()) {
-    case Type::BoolTyID:   return castToBool(V);
-    case Type::UByteTyID:  return castToUByte(V);
-    case Type::SByteTyID:  return castToSByte(V);
-    case Type::UShortTyID: return castToUShort(V);
-    case Type::ShortTyID:  return castToShort(V);
-    case Type::UIntTyID:   return castToUInt(V);
-    case Type::IntTyID:    return castToInt(V);
-    case Type::ULongTyID:  return castToULong(V);
-    case Type::LongTyID:   return castToLong(V);
-    case Type::FloatTyID:  return castToFloat(V);
-    case Type::DoubleTyID: return castToDouble(V);
-    case Type::PointerTyID:
-      return castToPointer(V, reinterpret_cast<const PointerType*>(Ty));
-    default: return 0;
-    }
-  }
+  virtual Constant *castToBool  (const Constant *V) const = 0;
+  virtual Constant *castToSByte (const Constant *V) const = 0;
+  virtual Constant *castToUByte (const Constant *V) const = 0;
+  virtual Constant *castToShort (const Constant *V) const = 0;
+  virtual Constant *castToUShort(const Constant *V) const = 0;
+  virtual Constant *castToInt   (const Constant *V) const = 0;
+  virtual Constant *castToUInt  (const Constant *V) const = 0;
+  virtual Constant *castToLong  (const Constant *V) const = 0;
+  virtual Constant *castToULong (const Constant *V) const = 0;
+  virtual Constant *castToFloat (const Constant *V) const = 0;
+  virtual Constant *castToDouble(const Constant *V) const = 0;
+  virtual Constant *castToPointer(const Constant *V,
+                                  const PointerType *Ty) const = 0;
 
   // ConstRules::get - Return an instance of ConstRules for the specified
   // constant operands.
index 8475e44f908a413055e33e4569c7dd1210b3fb24..dc5d0cfbfaea3bdad5c06ded2eb480acb8678e23 100644 (file)
 #ifndef CONSTANTHANDLING_H
 #define CONSTANTHANDLING_H
 
-#include "llvm/Constants.h"
-#include "llvm/Type.h"
+#include <vector>
 
 namespace llvm {
-
-class PointerType;
+  class Constant;
+  class Type;
+  class PointerType;
 
 struct ConstRules {
   ConstRules() {}
@@ -37,44 +37,24 @@ struct ConstRules {
   virtual Constant *shl(const Constant *V1, const Constant *V2) const = 0;
   virtual Constant *shr(const Constant *V1, const Constant *V2) const = 0;
 
-  virtual ConstantBool *lessthan(const Constant *V1, 
-                                 const Constant *V2) const = 0;
-  virtual ConstantBool *equalto(const Constant *V1, 
-                                const Constant *V2) const = 0;
+  virtual Constant *lessthan(const Constant *V1, const Constant *V2) const = 0;
+                             
+  virtual Constant *equalto(const Constant *V1, const Constant *V2) const = 0;
 
   // Casting operators.  ick
-  virtual ConstantBool *castToBool  (const Constant *V) const = 0;
-  virtual ConstantSInt *castToSByte (const Constant *V) const = 0;
-  virtual ConstantUInt *castToUByte (const Constant *V) const = 0;
-  virtual ConstantSInt *castToShort (const Constant *V) const = 0;
-  virtual ConstantUInt *castToUShort(const Constant *V) const = 0;
-  virtual ConstantSInt *castToInt   (const Constant *V) const = 0;
-  virtual ConstantUInt *castToUInt  (const Constant *V) const = 0;
-  virtual ConstantSInt *castToLong  (const Constant *V) const = 0;
-  virtual ConstantUInt *castToULong (const Constant *V) const = 0;
-  virtual ConstantFP   *castToFloat (const Constant *V) const = 0;
-  virtual ConstantFP   *castToDouble(const Constant *V) const = 0;
-  virtual Constant     *castToPointer(const Constant *V,
-                                      const PointerType *Ty) const = 0;
-
-  inline Constant *castTo(const Constant *V, const Type *Ty) const {
-    switch (Ty->getPrimitiveID()) {
-    case Type::BoolTyID:   return castToBool(V);
-    case Type::UByteTyID:  return castToUByte(V);
-    case Type::SByteTyID:  return castToSByte(V);
-    case Type::UShortTyID: return castToUShort(V);
-    case Type::ShortTyID:  return castToShort(V);
-    case Type::UIntTyID:   return castToUInt(V);
-    case Type::IntTyID:    return castToInt(V);
-    case Type::ULongTyID:  return castToULong(V);
-    case Type::LongTyID:   return castToLong(V);
-    case Type::FloatTyID:  return castToFloat(V);
-    case Type::DoubleTyID: return castToDouble(V);
-    case Type::PointerTyID:
-      return castToPointer(V, reinterpret_cast<const PointerType*>(Ty));
-    default: return 0;
-    }
-  }
+  virtual Constant *castToBool  (const Constant *V) const = 0;
+  virtual Constant *castToSByte (const Constant *V) const = 0;
+  virtual Constant *castToUByte (const Constant *V) const = 0;
+  virtual Constant *castToShort (const Constant *V) const = 0;
+  virtual Constant *castToUShort(const Constant *V) const = 0;
+  virtual Constant *castToInt   (const Constant *V) const = 0;
+  virtual Constant *castToUInt  (const Constant *V) const = 0;
+  virtual Constant *castToLong  (const Constant *V) const = 0;
+  virtual Constant *castToULong (const Constant *V) const = 0;
+  virtual Constant *castToFloat (const Constant *V) const = 0;
+  virtual Constant *castToDouble(const Constant *V) const = 0;
+  virtual Constant *castToPointer(const Constant *V,
+                                  const PointerType *Ty) const = 0;
 
   // ConstRules::get - Return an instance of ConstRules for the specified
   // constant operands.