* Make the ctor take a TargetData even though it's not using it yet
[oota-llvm.git] / lib / VMCore / ConstantFold.cpp
1 //===- ConstantHandling.cpp - Implement ConstantHandling.h ----------------===//
2 //
3 // This file implements the various intrinsic operations, on constant values.
4 //
5 //===----------------------------------------------------------------------===//
6
7 #include "llvm/ConstantHandling.h"
8 #include "llvm/iPHINode.h"
9 #include <cmath>
10
11 AnnotationID ConstRules::AID(AnnotationManager::getID("opt::ConstRules",
12                                                       &ConstRules::find));
13
14 // ConstantFoldInstruction - Attempt to constant fold the specified instruction.
15 // If successful, the constant result is returned, if not, null is returned.
16 //
17 Constant *ConstantFoldInstruction(Instruction *I) {
18   if (PHINode *PN = dyn_cast<PHINode>(I)) {
19     if (PN->getNumIncomingValues() == 0)
20       return Constant::getNullValue(PN->getType());
21     
22     Constant *Result = dyn_cast<Constant>(PN->getIncomingValue(0));
23     if (Result == 0) return 0;
24
25     // Handle PHI nodes specially here...
26     for (unsigned i = 1, e = PN->getNumIncomingValues(); i != e; ++i)
27       if (PN->getIncomingValue(i) != Result)
28         return 0;   // Not all the same incoming constants...
29
30     // If we reach here, all incoming values are the same constant.
31     return Result;
32   }
33
34   Constant *Op0 = 0;
35   Constant *Op1 = 0;
36
37   if (I->getNumOperands() != 0) {    // Get first operand if it's a constant...
38     Op0 = dyn_cast<Constant>(I->getOperand(0));
39     if (Op0 == 0) return 0;          // Not a constant?, can't fold
40
41     if (I->getNumOperands() != 1) {  // Get second operand if it's a constant...
42       Op1 = dyn_cast<Constant>(I->getOperand(1));
43       if (Op1 == 0) return 0;        // Not a constant?, can't fold
44     }
45   }
46
47   switch (I->getOpcode()) {
48   case Instruction::Cast:
49     return ConstRules::get(*Op0)->castTo(Op0, I->getType());
50   case Instruction::Not:     return ~*Op0;
51   case Instruction::Add:     return *Op0 + *Op1;
52   case Instruction::Sub:     return *Op0 - *Op1;
53   case Instruction::Mul:     return *Op0 * *Op1;
54   case Instruction::Div:     return *Op0 / *Op1;
55   case Instruction::Rem:     return *Op0 % *Op1;
56
57   case Instruction::SetEQ:   return *Op0 == *Op1;
58   case Instruction::SetNE:   return *Op0 != *Op1;
59   case Instruction::SetLE:   return *Op0 <= *Op1;
60   case Instruction::SetGE:   return *Op0 >= *Op1;
61   case Instruction::SetLT:   return *Op0 <  *Op1;
62   case Instruction::SetGT:   return *Op0 >  *Op1;
63   case Instruction::Shl:     return *Op0 << *Op1;
64   case Instruction::Shr:     return *Op0 >> *Op1;
65   default:
66     return 0;
67   }
68 }
69
70 Constant *ConstantFoldCastInstruction(const Constant *V, const Type *DestTy) {
71   return ConstRules::get(*V)->castTo(V, DestTy);
72 }
73
74 Constant *ConstantFoldUnaryInstruction(unsigned Opcode, const Constant *V) {
75   switch (Opcode) {
76   case Instruction::Not:  return ~*V;
77   }
78   return 0;
79 }
80
81 Constant *ConstantFoldBinaryInstruction(unsigned Opcode, const Constant *V1,
82                                         const Constant *V2) {
83   switch (Opcode) {
84   case Instruction::Add:     return *V1 + *V2;
85   case Instruction::Sub:     return *V1 - *V2;
86   case Instruction::Mul:     return *V1 * *V2;
87   case Instruction::Div:     return *V1 / *V2;
88   case Instruction::Rem:     return *V1 % *V2;
89
90   case Instruction::SetEQ:   return *V1 == *V2;
91   case Instruction::SetNE:   return *V1 != *V2;
92   case Instruction::SetLE:   return *V1 <= *V2;
93   case Instruction::SetGE:   return *V1 >= *V2;
94   case Instruction::SetLT:   return *V1 <  *V2;
95   case Instruction::SetGT:   return *V1 >  *V2;
96   }
97   return 0;
98 }
99
100 Constant *ConstantFoldShiftInstruction(unsigned Opcode, const Constant *V1, 
101                                        const Constant *V2) {
102   switch (Opcode) {
103   case Instruction::Shl:     return *V1 << *V2;
104   case Instruction::Shr:     return *V1 >> *V2;
105   default:                   return 0;
106   }
107 }
108
109
110 //===----------------------------------------------------------------------===//
111 //                             TemplateRules Class
112 //===----------------------------------------------------------------------===//
113 //
114 // TemplateRules - Implement a subclass of ConstRules that provides all 
115 // operations as noops.  All other rules classes inherit from this class so 
116 // that if functionality is needed in the future, it can simply be added here 
117 // and to ConstRules without changing anything else...
118 // 
119 // This class also provides subclasses with typesafe implementations of methods
120 // so that don't have to do type casting.
121 //
122 template<class ArgType, class SubClassName>
123 class TemplateRules : public ConstRules {
124
125   //===--------------------------------------------------------------------===//
126   // Redirecting functions that cast to the appropriate types
127   //===--------------------------------------------------------------------===//
128
129   virtual Constant *op_not(const Constant *V) const {
130     return SubClassName::Not((const ArgType *)V);
131   }
132
133   
134   virtual Constant *add(const Constant *V1, 
135                         const Constant *V2) const { 
136     return SubClassName::Add((const ArgType *)V1, (const ArgType *)V2);  
137   }
138
139   virtual Constant *sub(const Constant *V1, 
140                         const Constant *V2) const { 
141     return SubClassName::Sub((const ArgType *)V1, (const ArgType *)V2);  
142   }
143
144   virtual Constant *mul(const Constant *V1, 
145                         const Constant *V2) const { 
146     return SubClassName::Mul((const ArgType *)V1, (const ArgType *)V2);  
147   }
148   virtual Constant *div(const Constant *V1, 
149                         const Constant *V2) const { 
150     return SubClassName::Div((const ArgType *)V1, (const ArgType *)V2);  
151   }
152   virtual Constant *rem(const Constant *V1, 
153                         const Constant *V2) const { 
154     return SubClassName::Rem((const ArgType *)V1, (const ArgType *)V2);  
155   }
156   virtual Constant *shl(const Constant *V1, 
157                         const Constant *V2) const { 
158     return SubClassName::Shl((const ArgType *)V1, (const ArgType *)V2);  
159   }
160   virtual Constant *shr(const Constant *V1, 
161                         const Constant *V2) const { 
162     return SubClassName::Shr((const ArgType *)V1, (const ArgType *)V2);  
163   }
164
165   virtual ConstantBool *lessthan(const Constant *V1, 
166                                  const Constant *V2) const { 
167     return SubClassName::LessThan((const ArgType *)V1, (const ArgType *)V2);
168   }
169
170   // Casting operators.  ick
171   virtual ConstantBool *castToBool(const Constant *V) const {
172     return SubClassName::CastToBool((const ArgType*)V);
173   }
174   virtual ConstantSInt *castToSByte(const Constant *V) const {
175     return SubClassName::CastToSByte((const ArgType*)V);
176   }
177   virtual ConstantUInt *castToUByte(const Constant *V) const {
178     return SubClassName::CastToUByte((const ArgType*)V);
179   }
180   virtual ConstantSInt *castToShort(const Constant *V) const {
181     return SubClassName::CastToShort((const ArgType*)V);
182   }
183   virtual ConstantUInt *castToUShort(const Constant *V) const {
184     return SubClassName::CastToUShort((const ArgType*)V);
185   }
186   virtual ConstantSInt *castToInt(const Constant *V) const {
187     return SubClassName::CastToInt((const ArgType*)V);
188   }
189   virtual ConstantUInt *castToUInt(const Constant *V) const {
190     return SubClassName::CastToUInt((const ArgType*)V);
191   }
192   virtual ConstantSInt *castToLong(const Constant *V) const {
193     return SubClassName::CastToLong((const ArgType*)V);
194   }
195   virtual ConstantUInt *castToULong(const Constant *V) const {
196     return SubClassName::CastToULong((const ArgType*)V);
197   }
198   virtual ConstantFP   *castToFloat(const Constant *V) const {
199     return SubClassName::CastToFloat((const ArgType*)V);
200   }
201   virtual ConstantFP   *castToDouble(const Constant *V) const {
202     return SubClassName::CastToDouble((const ArgType*)V);
203   }
204   virtual ConstantPointer *castToPointer(const Constant *V, 
205                                          const PointerType *Ty) const {
206     return SubClassName::CastToPointer((const ArgType*)V, Ty);
207   }
208
209   //===--------------------------------------------------------------------===//
210   // Default "noop" implementations
211   //===--------------------------------------------------------------------===//
212
213   inline static Constant *Not(const ArgType *V) { return 0; }
214
215   inline static Constant *Add(const ArgType *V1, const ArgType *V2) {
216     return 0;
217   }
218   inline static Constant *Sub(const ArgType *V1, const ArgType *V2) {
219     return 0;
220   }
221   inline static Constant *Mul(const ArgType *V1, const ArgType *V2) {
222     return 0;
223   }
224   inline static Constant *Div(const ArgType *V1, const ArgType *V2) {
225     return 0;
226   }
227   inline static Constant *Rem(const ArgType *V1, const ArgType *V2) {
228     return 0;
229   }
230   inline static Constant *Shl(const ArgType *V1, const ArgType *V2) {
231     return 0;
232   }
233   inline static Constant *Shr(const ArgType *V1, const ArgType *V2) {
234     return 0;
235   }
236   inline static ConstantBool *LessThan(const ArgType *V1, const ArgType *V2) {
237     return 0;
238   }
239
240   // Casting operators.  ick
241   inline static ConstantBool *CastToBool  (const Constant *V) { return 0; }
242   inline static ConstantSInt *CastToSByte (const Constant *V) { return 0; }
243   inline static ConstantUInt *CastToUByte (const Constant *V) { return 0; }
244   inline static ConstantSInt *CastToShort (const Constant *V) { return 0; }
245   inline static ConstantUInt *CastToUShort(const Constant *V) { return 0; }
246   inline static ConstantSInt *CastToInt   (const Constant *V) { return 0; }
247   inline static ConstantUInt *CastToUInt  (const Constant *V) { return 0; }
248   inline static ConstantSInt *CastToLong  (const Constant *V) { return 0; }
249   inline static ConstantUInt *CastToULong (const Constant *V) { return 0; }
250   inline static ConstantFP   *CastToFloat (const Constant *V) { return 0; }
251   inline static ConstantFP   *CastToDouble(const Constant *V) { return 0; }
252   inline static ConstantPointer *CastToPointer(const Constant *,
253                                                const PointerType *) {return 0;}
254 };
255
256
257
258 //===----------------------------------------------------------------------===//
259 //                             EmptyRules Class
260 //===----------------------------------------------------------------------===//
261 //
262 // EmptyRules provides a concrete base class of ConstRules that does nothing
263 //
264 struct EmptyRules : public TemplateRules<Constant, EmptyRules> {
265 };
266
267
268
269 //===----------------------------------------------------------------------===//
270 //                              BoolRules Class
271 //===----------------------------------------------------------------------===//
272 //
273 // BoolRules provides a concrete base class of ConstRules for the 'bool' type.
274 //
275 struct BoolRules : public TemplateRules<ConstantBool, BoolRules> {
276
277   inline static Constant *Not(const ConstantBool *V) { 
278     return ConstantBool::get(!V->getValue());
279   }
280
281   inline static Constant *Or(const ConstantBool *V1,
282                              const ConstantBool *V2) {
283     return ConstantBool::get(V1->getValue() | V2->getValue());
284   }
285
286   inline static Constant *And(const ConstantBool *V1, 
287                               const ConstantBool *V2) {
288     return ConstantBool::get(V1->getValue() & V2->getValue());
289   }
290 };
291
292
293 //===----------------------------------------------------------------------===//
294 //                            PointerRules Class
295 //===----------------------------------------------------------------------===//
296 //
297 // PointerRules provides a concrete base class of ConstRules for pointer types
298 //
299 struct PointerRules : public TemplateRules<ConstantPointer, PointerRules> {
300   inline static ConstantBool *CastToBool  (const Constant *V) {
301     if (V->isNullValue()) return ConstantBool::False;
302     return 0;  // Can't const prop other types of pointers
303   }
304   inline static ConstantSInt *CastToSByte (const Constant *V) {
305     if (V->isNullValue()) return ConstantSInt::get(Type::SByteTy, 0);
306     return 0;  // Can't const prop other types of pointers
307   }
308   inline static ConstantUInt *CastToUByte (const Constant *V) {
309     if (V->isNullValue()) return ConstantUInt::get(Type::UByteTy, 0);
310     return 0;  // Can't const prop other types of pointers
311   }
312   inline static ConstantSInt *CastToShort (const Constant *V) {
313     if (V->isNullValue()) return ConstantSInt::get(Type::ShortTy, 0);
314     return 0;  // Can't const prop other types of pointers
315   }
316   inline static ConstantUInt *CastToUShort(const Constant *V) {
317     if (V->isNullValue()) return ConstantUInt::get(Type::UShortTy, 0);
318     return 0;  // Can't const prop other types of pointers
319   }
320   inline static ConstantSInt *CastToInt   (const Constant *V) {
321     if (V->isNullValue()) return ConstantSInt::get(Type::IntTy, 0);
322     return 0;  // Can't const prop other types of pointers
323   }
324   inline static ConstantUInt *CastToUInt  (const Constant *V) {
325     if (V->isNullValue()) return ConstantUInt::get(Type::UIntTy, 0);
326     return 0;  // Can't const prop other types of pointers
327   }
328   inline static ConstantSInt *CastToLong  (const Constant *V) {
329     if (V->isNullValue()) return ConstantSInt::get(Type::LongTy, 0);
330     return 0;  // Can't const prop other types of pointers
331   }
332   inline static ConstantUInt *CastToULong (const Constant *V) {
333     if (V->isNullValue()) return ConstantUInt::get(Type::ULongTy, 0);
334     return 0;  // Can't const prop other types of pointers
335   }
336   inline static ConstantFP   *CastToFloat (const Constant *V) {
337     if (V->isNullValue()) return ConstantFP::get(Type::FloatTy, 0);
338     return 0;  // Can't const prop other types of pointers
339   }
340   inline static ConstantFP   *CastToDouble(const Constant *V) {
341     if (V->isNullValue()) return ConstantFP::get(Type::DoubleTy, 0);
342     return 0;  // Can't const prop other types of pointers
343   }
344
345   inline static ConstantPointer *CastToPointer(const ConstantPointer *V,
346                                                const PointerType *PTy) {
347     if (V->getType() == PTy)
348       return const_cast<ConstantPointer*>(V);  // Allow cast %PTy %ptr to %PTy
349     if (V->isNullValue())
350       return ConstantPointerNull::get(PTy);
351     return 0;  // Can't const prop other types of pointers
352   }
353 };
354
355
356 //===----------------------------------------------------------------------===//
357 //                             DirectRules Class
358 //===----------------------------------------------------------------------===//
359 //
360 // DirectRules provides a concrete base classes of ConstRules for a variety of
361 // different types.  This allows the C++ compiler to automatically generate our
362 // constant handling operations in a typesafe and accurate manner.
363 //
364 template<class ConstantClass, class BuiltinType, Type **Ty, class SuperClass>
365 struct DirectRules : public TemplateRules<ConstantClass, SuperClass> {
366   inline static Constant *Add(const ConstantClass *V1, 
367                               const ConstantClass *V2) {
368     BuiltinType Result = (BuiltinType)V1->getValue() + 
369                          (BuiltinType)V2->getValue();
370     return ConstantClass::get(*Ty, Result);
371   }
372
373   inline static Constant *Sub(const ConstantClass *V1, 
374                               const ConstantClass *V2) {
375     BuiltinType Result = (BuiltinType)V1->getValue() -
376                          (BuiltinType)V2->getValue();
377     return ConstantClass::get(*Ty, Result);
378   }
379
380   inline static Constant *Mul(const ConstantClass *V1, 
381                               const ConstantClass *V2) {
382     BuiltinType Result = (BuiltinType)V1->getValue() *
383                          (BuiltinType)V2->getValue();
384     return ConstantClass::get(*Ty, Result);
385   }
386
387   inline static Constant *Div(const ConstantClass *V1,
388                               const ConstantClass *V2) {
389     if (V2->isNullValue()) return 0;
390     BuiltinType Result = (BuiltinType)V1->getValue() /
391                          (BuiltinType)V2->getValue();
392     return ConstantClass::get(*Ty, Result);
393   }
394
395   inline static ConstantBool *LessThan(const ConstantClass *V1, 
396                                        const ConstantClass *V2) {
397     bool Result = (BuiltinType)V1->getValue() < (BuiltinType)V2->getValue();
398     return ConstantBool::get(Result);
399   } 
400
401   inline static ConstantPointer *CastToPointer(const ConstantClass *V,
402                                                const PointerType *PTy) {
403     if (V->isNullValue())    // Is it a FP or Integral null value?
404       return ConstantPointerNull::get(PTy);
405     return 0;  // Can't const prop other types of pointers
406   }
407
408   // Casting operators.  ick
409 #define DEF_CAST(TYPE, CLASS, CTYPE) \
410   inline static CLASS *CastTo##TYPE  (const ConstantClass *V) {    \
411     return CLASS::get(Type::TYPE##Ty, (CTYPE)(BuiltinType)V->getValue()); \
412   }
413
414   DEF_CAST(Bool  , ConstantBool, bool)
415   DEF_CAST(SByte , ConstantSInt, signed char)
416   DEF_CAST(UByte , ConstantUInt, unsigned char)
417   DEF_CAST(Short , ConstantSInt, signed short)
418   DEF_CAST(UShort, ConstantUInt, unsigned short)
419   DEF_CAST(Int   , ConstantSInt, signed int)
420   DEF_CAST(UInt  , ConstantUInt, unsigned int)
421   DEF_CAST(Long  , ConstantSInt, int64_t)
422   DEF_CAST(ULong , ConstantUInt, uint64_t)
423   DEF_CAST(Float , ConstantFP  , float)
424   DEF_CAST(Double, ConstantFP  , double)
425 #undef DEF_CAST
426 };
427
428
429 //===----------------------------------------------------------------------===//
430 //                           DirectIntRules Class
431 //===----------------------------------------------------------------------===//
432 //
433 // DirectIntRules provides implementations of functions that are valid on
434 // integer types, but not all types in general.
435 //
436 template <class ConstantClass, class BuiltinType, Type **Ty>
437 struct DirectIntRules
438   : public DirectRules<ConstantClass, BuiltinType, Ty,
439                        DirectIntRules<ConstantClass, BuiltinType, Ty> > {
440   inline static Constant *Not(const ConstantClass *V) { 
441     return ConstantClass::get(*Ty, ~(BuiltinType)V->getValue());;
442   }
443
444   inline static Constant *Rem(const ConstantClass *V1,
445                               const ConstantClass *V2) {
446     if (V2->isNullValue()) return 0;
447     BuiltinType Result = (BuiltinType)V1->getValue() %
448                          (BuiltinType)V2->getValue();
449     return ConstantClass::get(*Ty, Result);
450   }
451
452   inline static Constant *Shl(const ConstantClass *V1,
453                               const ConstantClass *V2) {
454     BuiltinType Result = (BuiltinType)V1->getValue() <<
455                          (BuiltinType)V2->getValue();
456     return ConstantClass::get(*Ty, Result);
457   }
458
459   inline static Constant *Shr(const ConstantClass *V1,
460                               const ConstantClass *V2) {
461     BuiltinType Result = (BuiltinType)V1->getValue() >>
462                          (BuiltinType)V2->getValue();
463     return ConstantClass::get(*Ty, Result);
464   }
465 };
466
467
468 //===----------------------------------------------------------------------===//
469 //                           DirectFPRules Class
470 //===----------------------------------------------------------------------===//
471 //
472 // DirectFPRules provides implementations of functions that are valid on
473 // floating point types, but not all types in general.
474 //
475 template <class ConstantClass, class BuiltinType, Type **Ty>
476 struct DirectFPRules
477   : public DirectRules<ConstantClass, BuiltinType, Ty,
478                        DirectFPRules<ConstantClass, BuiltinType, Ty> > {
479   inline static Constant *Rem(const ConstantClass *V1,
480                               const ConstantClass *V2) {
481     if (V2->isNullValue()) return 0;
482     BuiltinType Result = std::fmod((BuiltinType)V1->getValue(),
483                                    (BuiltinType)V2->getValue());
484     return ConstantClass::get(*Ty, Result);
485   }
486 };
487
488
489 //===----------------------------------------------------------------------===//
490 //                            DirectRules Subclasses
491 //===----------------------------------------------------------------------===//
492 //
493 // Given the DirectRules class we can now implement lots of types with little
494 // code.  Thank goodness C++ compilers are great at stomping out layers of 
495 // templates... can you imagine having to do this all by hand? (/me is lazy :)
496 //
497
498 // ConstRules::find - Return the constant rules that take care of the specified
499 // type.
500 //
501 Annotation *ConstRules::find(AnnotationID AID, const Annotable *TyA, void *) {
502   assert(AID == ConstRules::AID && "Bad annotation for factory!");
503   const Type *Ty = cast<Type>((const Value*)TyA);
504   
505   switch (Ty->getPrimitiveID()) {
506   case Type::BoolTyID:    return new BoolRules();
507   case Type::PointerTyID: return new PointerRules();
508   case Type::SByteTyID:
509     return new DirectIntRules<ConstantSInt,   signed char , &Type::SByteTy>();
510   case Type::UByteTyID:
511     return new DirectIntRules<ConstantUInt, unsigned char , &Type::UByteTy>();
512   case Type::ShortTyID:
513     return new DirectIntRules<ConstantSInt,   signed short, &Type::ShortTy>();
514   case Type::UShortTyID:
515     return new DirectIntRules<ConstantUInt, unsigned short, &Type::UShortTy>();
516   case Type::IntTyID:
517     return new DirectIntRules<ConstantSInt,   signed int  , &Type::IntTy>();
518   case Type::UIntTyID:
519     return new DirectIntRules<ConstantUInt, unsigned int  , &Type::UIntTy>();
520   case Type::LongTyID:
521     return new DirectIntRules<ConstantSInt,  int64_t      , &Type::LongTy>();
522   case Type::ULongTyID:
523     return new DirectIntRules<ConstantUInt, uint64_t      , &Type::ULongTy>();
524   case Type::FloatTyID:
525     return new DirectFPRules<ConstantFP  , float         , &Type::FloatTy>();
526   case Type::DoubleTyID:
527     return new DirectFPRules<ConstantFP  , double        , &Type::DoubleTy>();
528   default:
529     return new EmptyRules();
530   }
531 }