This analysis doesn't take 'throwing' into consideration, it looks at
[oota-llvm.git] / lib / Transforms / IPO / SimplifyLibCalls.cpp
1 //===- SimplifyLibCalls.cpp - Optimize specific well-known library calls --===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file was developed by Reid Spencer and is distributed under the 
6 // University of Illinois Open Source License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements a variety of small optimizations for calls to specific
11 // well-known (e.g. runtime library) function calls. For example, a call to the
12 // function "exit(3)" that occurs within the main() function can be transformed
13 // into a simple "return 3" instruction. Any optimization that takes this form
14 // (replace call to library function with simpler code that provides same 
15 // result) belongs in this file. 
16 //
17 //===----------------------------------------------------------------------===//
18
19 #define DEBUG_TYPE "simplify-libcalls"
20 #include "llvm/Constants.h"
21 #include "llvm/DerivedTypes.h"
22 #include "llvm/Instructions.h"
23 #include "llvm/Module.h"
24 #include "llvm/Pass.h"
25 #include "llvm/ADT/hash_map"
26 #include "llvm/ADT/Statistic.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Target/TargetData.h"
29 #include "llvm/Transforms/IPO.h"
30 #include <iostream>
31 using namespace llvm;
32
33 namespace {
34   Statistic<> SimplifiedLibCalls("simplified-lib-calls", 
35       "Number of well-known library calls simplified");
36
37   /// This class is the base class for a set of small but important 
38   /// optimizations of calls to well-known functions, such as those in the c
39   /// library. This class provides the basic infrastructure for handling 
40   /// runOnModule. Subclasses register themselves and provide two methods:
41   /// RecognizeCall and OptimizeCall. Whenever this class finds a function call,
42   /// it asks the subclasses to recognize the call. If it is recognized, then
43   /// the OptimizeCall method is called on that subclass instance. In this way
44   /// the subclasses implement the calling conditions on which they trigger and
45   /// the action to perform, making it easy to add new optimizations of this
46   /// form.
47   /// @brief A ModulePass for optimizing well-known function calls
48   struct SimplifyLibCalls : public ModulePass {
49
50     /// We need some target data for accurate signature details that are
51     /// target dependent. So we require target data in our AnalysisUsage.
52     virtual void getAnalysisUsage(AnalysisUsage& Info) const;
53
54     /// For this pass, process all of the function calls in the module, calling
55     /// RecognizeCall and OptimizeCall as appropriate.
56     virtual bool runOnModule(Module &M);
57
58   };
59
60   RegisterOpt<SimplifyLibCalls> 
61     X("simplify-libcalls","Simplify well-known library calls");
62
63   struct CallOptimizer
64   {
65     /// @brief Constructor that registers the optimization
66     CallOptimizer(const char * fname );
67
68     virtual ~CallOptimizer();
69
70     /// The implementation of this function in subclasses should determine if
71     /// \p F is suitable for the optimization. This method is called by 
72     /// runOnModule to short circuit visiting all the call sites of such a
73     /// function if that function is not suitable in the first place.
74     /// If the called function is suitabe, this method should return true;
75     /// false, otherwise. This function should also perform any lazy 
76     /// initialization that the CallOptimizer needs to do, if its to return 
77     /// true. This avoids doing initialization until the optimizer is actually
78     /// going to be called upon to do some optimization.
79     virtual bool ValidateCalledFunction(
80       const Function* F,   ///< The function that is the target of call sites
81       const TargetData& TD ///< Information about the target
82     ) = 0;
83
84     /// The implementations of this function in subclasses is the heart of the 
85     /// SimplifyLibCalls algorithm. Sublcasses of this class implement 
86     /// OptimizeCall to determine if (a) the conditions are right for optimizing
87     /// the call and (b) to perform the optimization. If an action is taken 
88     /// against ci, the subclass is responsible for returning true and ensuring
89     /// that ci is erased from its parent.
90     /// @param ci the call instruction under consideration
91     /// @param f the function that ci calls.
92     /// @brief Optimize a call, if possible.
93     virtual bool OptimizeCall(
94       CallInst* ci,         ///< The call instruction that should be optimized.
95       const TargetData& TD  ///< Information about the target
96     ) = 0;
97
98     const char * getFunctionName() const { return func_name; }
99   private:
100     const char* func_name;
101   };
102
103   /// @brief The list of optimizations deriving from CallOptimizer
104
105   hash_map<std::string,CallOptimizer*> optlist;
106
107   CallOptimizer::CallOptimizer(const char* fname)
108     : func_name(fname)
109   {
110     // Register this call optimizer
111     optlist[func_name] = this;
112   }
113
114   /// Make sure we get our virtual table in this file.
115   CallOptimizer::~CallOptimizer() { }
116
117 }
118
119 ModulePass *llvm::createSimplifyLibCallsPass() 
120
121   return new SimplifyLibCalls(); 
122 }
123
124 void SimplifyLibCalls::getAnalysisUsage(AnalysisUsage& Info) const
125 {
126   // Ask that the TargetData analysis be performed before us so we can use
127   // the target data.
128   Info.addRequired<TargetData>();
129 }
130
131 bool SimplifyLibCalls::runOnModule(Module &M) 
132 {
133   TargetData& TD = getAnalysis<TargetData>();
134
135   bool result = false;
136
137   // The call optimizations can be recursive. That is, the optimization might
138   // generate a call to another function which can also be optimized. This way
139   // we make the CallOptimizer instances very specific to the case they handle.
140   // It also means we need to keep running over the function calls in the module
141   // until we don't get any more optimizations possible.
142   bool found_optimization = false;
143   do
144   {
145     found_optimization = false;
146     for (Module::iterator FI = M.begin(), FE = M.end(); FI != FE; ++FI)
147     {
148       // All the "well-known" functions are external and have external linkage
149       // because they live in a runtime library somewhere and were (probably) 
150       // not compiled by LLVM.  So, we only act on external functions that have 
151       // external linkage and non-empty uses.
152       if (FI->isExternal() && FI->hasExternalLinkage() && !FI->use_empty())
153       {
154         // Get the optimization class that pertains to this function
155         if (CallOptimizer* CO = optlist[FI->getName().c_str()] )
156         {
157           // Make sure the called function is suitable for the optimization
158           if (CO->ValidateCalledFunction(FI,TD))
159           {
160             // Loop over each of the uses of the function
161             for (Value::use_iterator UI = FI->use_begin(), UE = FI->use_end(); 
162                  UI != UE ; )
163             {
164               // If the use of the function is a call instruction
165               if (CallInst* CI = dyn_cast<CallInst>(*UI++))
166               {
167                 // Do the optimization on the CallOptimizer.
168                 if (CO->OptimizeCall(CI,TD))
169                 {
170                   ++SimplifiedLibCalls;
171                   found_optimization = result = true;
172                   DEBUG(std::cerr << "simplify-libcall: " << CO->getFunctionName() << "\n");
173                 }
174               }
175             }
176           }
177         }
178       }
179     }
180   } while (found_optimization);
181   return result;
182 }
183
184 namespace {
185
186   /// Provide some functions for accessing standard library prototypes and
187   /// caching them so we don't have to keep recomputing them
188   FunctionType* get_strlen(const Type* IntPtrTy)
189   {
190     static FunctionType* strlen_type = 0;
191     if (!strlen_type)
192     {
193       std::vector<const Type*> args;
194       args.push_back(PointerType::get(Type::SByteTy));
195       strlen_type = FunctionType::get(IntPtrTy, args, false);
196     }
197     return strlen_type;
198   }
199
200   FunctionType* get_memcpy()
201   {
202     static FunctionType* memcpy_type = 0;
203     if (!memcpy_type)
204     {
205       // Note: this is for llvm.memcpy intrinsic
206       std::vector<const Type*> args;
207       args.push_back(PointerType::get(Type::SByteTy));
208       args.push_back(PointerType::get(Type::SByteTy));
209       args.push_back(Type::IntTy);
210       args.push_back(Type::IntTy);
211       memcpy_type = FunctionType::get(Type::VoidTy, args, false);
212     }
213     return memcpy_type;
214   }
215
216   /// A function to compute the length of a null-terminated string of integers.
217   /// This function can't rely on the size of the constant array because there 
218   /// could be a null terminator in the middle of the array. We also have to 
219   /// bail out if we find a non-integer constant initializer of one of the 
220   /// elements or if there is no null-terminator. The logic below checks
221   bool getConstantStringLength(Value* V, uint64_t& len )
222   {
223     assert(V != 0 && "Invalid args to getConstantStringLength");
224     len = 0; // make sure we initialize this 
225     User* GEP = 0;
226     // If the value is not a GEP instruction nor a constant expression with a 
227     // GEP instruction, then return false because ConstantArray can't occur 
228     // any other way
229     if (GetElementPtrInst* GEPI = dyn_cast<GetElementPtrInst>(V))
230       GEP = GEPI;
231     else if (ConstantExpr* CE = dyn_cast<ConstantExpr>(V))
232       if (CE->getOpcode() == Instruction::GetElementPtr)
233         GEP = CE;
234       else
235         return false;
236     else
237       return false;
238
239     // Make sure the GEP has exactly three arguments.
240     if (GEP->getNumOperands() != 3)
241       return false;
242
243     // Check to make sure that the first operand of the GEP is an integer and
244     // has value 0 so that we are sure we're indexing into the initializer. 
245     if (ConstantInt* op1 = dyn_cast<ConstantInt>(GEP->getOperand(1)))
246     {
247       if (!op1->isNullValue())
248         return false;
249     }
250     else
251       return false;
252
253     // Ensure that the second operand is a ConstantInt. If it isn't then this
254     // GEP is wonky and we're not really sure what were referencing into and 
255     // better of not optimizing it. While we're at it, get the second index
256     // value. We'll need this later for indexing the ConstantArray.
257     uint64_t start_idx = 0;
258     if (ConstantInt* CI = dyn_cast<ConstantInt>(GEP->getOperand(2)))
259       start_idx = CI->getRawValue();
260     else
261       return false;
262
263     // The GEP instruction, constant or instruction, must reference a global
264     // variable that is a constant and is initialized. The referenced constant
265     // initializer is the array that we'll use for optimization.
266     GlobalVariable* GV = dyn_cast<GlobalVariable>(GEP->getOperand(0));
267     if (!GV || !GV->isConstant() || !GV->hasInitializer())
268       return false;
269
270     // Get the initializer.
271     Constant* INTLZR = GV->getInitializer();
272
273     // Handle the ConstantAggregateZero case
274     if (ConstantAggregateZero* CAZ = dyn_cast<ConstantAggregateZero>(INTLZR))
275     {
276       // This is a degenerate case. The initializer is constant zero so the
277       // length of the string must be zero.
278       len = 0;
279       return true;
280     }
281
282     // Must be a Constant Array
283     ConstantArray* A = dyn_cast<ConstantArray>(INTLZR);
284     if (!A)
285       return false;
286
287     // Get the number of elements in the array
288     uint64_t max_elems = A->getType()->getNumElements();
289
290     // Traverse the constant array from start_idx (derived above) which is
291     // the place the GEP refers to in the array. 
292     for ( len = start_idx; len < max_elems; len++)
293     {
294       if (ConstantInt* CI = dyn_cast<ConstantInt>(A->getOperand(len)))
295       {
296         // Check for the null terminator
297         if (CI->isNullValue())
298           break; // we found end of string
299       }
300       else
301         return false; // This array isn't suitable, non-int initializer
302     }
303     if (len >= max_elems)
304       return false; // This array isn't null terminated
305
306     // Subtract out the initial value from the length
307     len -= start_idx;
308     return true; // success!
309   }
310
311 /// This CallOptimizer will find instances of a call to "exit" that occurs
312 /// within the "main" function and change it to a simple "ret" instruction with
313 /// the same value as passed to the exit function. It assumes that the 
314 /// instructions after the call to exit(3) can be deleted since they are 
315 /// unreachable anyway.
316 /// @brief Replace calls to exit in main with a simple return
317 struct ExitInMainOptimization : public CallOptimizer
318 {
319   ExitInMainOptimization() : CallOptimizer("exit") {}
320   virtual ~ExitInMainOptimization() {}
321
322   // Make sure the called function looks like exit (int argument, int return
323   // type, external linkage, not varargs). 
324   virtual bool ValidateCalledFunction(const Function* f, const TargetData& TD)
325   {
326     if (f->arg_size() >= 1)
327       if (f->arg_begin()->getType()->isInteger())
328         return true;
329     return false;
330   }
331
332   virtual bool OptimizeCall(CallInst* ci, const TargetData& TD)
333   {
334     // To be careful, we check that the call to exit is coming from "main", that
335     // main has external linkage, and the return type of main and the argument
336     // to exit have the same type. 
337     Function *from = ci->getParent()->getParent();
338     if (from->hasExternalLinkage())
339       if (from->getReturnType() == ci->getOperand(1)->getType())
340         if (from->getName() == "main")
341         {
342           // Okay, time to actually do the optimization. First, get the basic 
343           // block of the call instruction
344           BasicBlock* bb = ci->getParent();
345
346           // Create a return instruction that we'll replace the call with. 
347           // Note that the argument of the return is the argument of the call 
348           // instruction.
349           ReturnInst* ri = new ReturnInst(ci->getOperand(1), ci);
350
351           // Split the block at the call instruction which places it in a new
352           // basic block.
353           bb->splitBasicBlock(ci);
354
355           // The block split caused a branch instruction to be inserted into
356           // the end of the original block, right after the return instruction
357           // that we put there. That's not a valid block, so delete the branch
358           // instruction.
359           bb->getInstList().pop_back();
360
361           // Now we can finally get rid of the call instruction which now lives
362           // in the new basic block.
363           ci->eraseFromParent();
364
365           // Optimization succeeded, return true.
366           return true;
367         }
368     // We didn't pass the criteria for this optimization so return false
369     return false;
370   }
371 } ExitInMainOptimizer;
372
373 /// This CallOptimizer will simplify a call to the strcat library function. The
374 /// simplification is possible only if the string being concatenated is a 
375 /// constant array or a constant expression that results in a constant array. In
376 /// this case, if the array is small, we can generate a series of inline store
377 /// instructions to effect the concatenation without calling strcat.
378 /// @brief Simplify the strcat library function.
379 struct StrCatOptimization : public CallOptimizer
380 {
381 private:
382   Function* strlen_func;
383   Function* memcpy_func;
384 public:
385   StrCatOptimization() 
386     : CallOptimizer("strcat") 
387     , strlen_func(0)
388     , memcpy_func(0)
389     {}
390   virtual ~StrCatOptimization() {}
391
392   inline Function* get_strlen_func(Module*M,const Type* IntPtrTy)
393   {
394     if (strlen_func)
395       return strlen_func;
396     return strlen_func = M->getOrInsertFunction("strlen",get_strlen(IntPtrTy));
397   }
398
399   inline Function* get_memcpy_func(Module* M) 
400   {
401     if (memcpy_func)
402       return memcpy_func;
403     return memcpy_func = M->getOrInsertFunction("llvm.memcpy",get_memcpy());
404   }
405
406   /// @brief Make sure that the "strcat" function has the right prototype
407   virtual bool ValidateCalledFunction(const Function* f, const TargetData& TD) 
408   {
409     if (f->getReturnType() == PointerType::get(Type::SByteTy))
410       if (f->arg_size() == 2) 
411       {
412         Function::const_arg_iterator AI = f->arg_begin();
413         if (AI++->getType() == PointerType::get(Type::SByteTy))
414           if (AI->getType() == PointerType::get(Type::SByteTy))
415           {
416             // Invalidate the pre-computed strlen_func and memcpy_func Functions
417             // because, by definition, this method is only called when a new
418             // Module is being traversed. Invalidation causes re-computation for
419             // the new Module (if necessary).
420             strlen_func = 0;
421             memcpy_func = 0;
422
423             // Indicate this is a suitable call type.
424             return true;
425           }
426       }
427     return false;
428   }
429
430   /// Perform the optimization if the length of the string concatenated
431   /// is reasonably short and it is a constant array.
432   virtual bool OptimizeCall(CallInst* ci, const TargetData& TD)
433   {
434     // Extract the initializer (while making numerous checks) from the 
435     // source operand of the call to strcat. If we get null back, one of
436     // a variety of checks in get_GVInitializer failed
437     uint64_t len = 0;
438     if (!getConstantStringLength(ci->getOperand(2),len))
439       return false;
440
441     // Handle the simple, do-nothing case
442     if (len == 0)
443     {
444       ci->replaceAllUsesWith(ci->getOperand(1));
445       ci->eraseFromParent();
446       return true;
447     }
448
449     // Increment the length because we actually want to memcpy the null
450     // terminator as well.
451     len++;
452
453     // Extract some information from the instruction
454     Module* M = ci->getParent()->getParent()->getParent();
455
456     // We need to find the end of the destination string.  That's where the 
457     // memory is to be moved to. We just generate a call to strlen (further 
458     // optimized in another pass). Note that the get_strlen_func() call 
459     // caches the Function* for us.
460     CallInst* strlen_inst = 
461       new CallInst(get_strlen_func(M,TD.getIntPtrType()),
462                    ci->getOperand(1),"",ci);
463
464     // Now that we have the destination's length, we must index into the 
465     // destination's pointer to get the actual memcpy destination (end of
466     // the string .. we're concatenating).
467     std::vector<Value*> idx;
468     idx.push_back(strlen_inst);
469     GetElementPtrInst* gep = 
470       new GetElementPtrInst(ci->getOperand(1),idx,"",ci);
471
472     // We have enough information to now generate the memcpy call to
473     // do the concatenation for us.
474     std::vector<Value*> vals;
475     vals.push_back(gep); // destination
476     vals.push_back(ci->getOperand(2)); // source
477     vals.push_back(ConstantSInt::get(Type::IntTy,len)); // length
478     vals.push_back(ConstantSInt::get(Type::IntTy,1)); // alignment
479     CallInst* memcpy_inst = new CallInst(get_memcpy_func(M), vals, "", ci);
480
481     // Finally, substitute the first operand of the strcat call for the 
482     // strcat call itself since strcat returns its first operand; and, 
483     // kill the strcat CallInst.
484     ci->replaceAllUsesWith(ci->getOperand(1));
485     ci->eraseFromParent();
486     return true;
487   }
488 } StrCatOptimizer;
489
490 /// This CallOptimizer will simplify a call to the strlen library function by
491 /// replacing it with a constant value if the string provided to it is a 
492 /// constant array.
493 /// @brief Simplify the strlen library function.
494 struct StrLenOptimization : public CallOptimizer
495 {
496   StrLenOptimization() : CallOptimizer("strlen") {}
497   virtual ~StrLenOptimization() {}
498
499   /// @brief Make sure that the "strlen" function has the right prototype
500   virtual bool ValidateCalledFunction(const Function* f, const TargetData& TD)
501   {
502     if (f->getReturnType() == TD.getIntPtrType())
503       if (f->arg_size() == 1) 
504         if (Function::const_arg_iterator AI = f->arg_begin())
505           if (AI->getType() == PointerType::get(Type::SByteTy))
506             return true;
507     return false;
508   }
509
510   /// @brief Perform the strlen optimization
511   virtual bool OptimizeCall(CallInst* ci, const TargetData& TD)
512   {
513     // Get the length of the string
514     uint64_t len = 0;
515     if (!getConstantStringLength(ci->getOperand(1),len))
516       return false;
517
518     ci->replaceAllUsesWith(ConstantInt::get(TD.getIntPtrType(),len));
519     ci->eraseFromParent();
520     return true;
521   }
522 } StrLenOptimizer;
523
524 /// This CallOptimizer will simplify a call to the memcpy library function by
525 /// expanding it out to a small set of stores if the copy source is a constant
526 /// array. 
527 /// @brief Simplify the memcpy library function.
528 struct MemCpyOptimization : public CallOptimizer
529 {
530   MemCpyOptimization() : CallOptimizer("llvm.memcpy") {}
531 protected:
532   MemCpyOptimization(const char* fname) : CallOptimizer(fname) {}
533 public:
534   virtual ~MemCpyOptimization() {}
535
536   /// @brief Make sure that the "memcpy" function has the right prototype
537   virtual bool ValidateCalledFunction(const Function* f, const TargetData& TD)
538   {
539     // Just make sure this has 4 arguments per LLVM spec.
540     return (f->arg_size() == 4);
541   }
542
543   /// Because of alignment and instruction information that we don't have, we
544   /// leave the bulk of this to the code generators. The optimization here just
545   /// deals with a few degenerate cases where the length of the string and the
546   /// alignment match the sizes of our intrinsic types so we can do a load and
547   /// store instead of the memcpy call.
548   /// @brief Perform the memcpy optimization.
549   virtual bool OptimizeCall(CallInst* ci, const TargetData& TD)
550   {
551     // Make sure we have constant int values to work with
552     ConstantInt* LEN = dyn_cast<ConstantInt>(ci->getOperand(3));
553     if (!LEN)
554       return false;
555     ConstantInt* ALIGN = dyn_cast<ConstantInt>(ci->getOperand(4));
556     if (!ALIGN)
557       return false;
558
559     // If the length is larger than the alignment, we can't optimize
560     uint64_t len = LEN->getRawValue();
561     uint64_t alignment = ALIGN->getRawValue();
562     if (len > alignment)
563       return false;
564
565     Value* dest = ci->getOperand(1);
566     Value* src = ci->getOperand(2);
567     CastInst* SrcCast = 0;
568     CastInst* DestCast = 0;
569     switch (len)
570     {
571       case 0:
572         // The memcpy is a no-op so just dump its call.
573         ci->eraseFromParent();
574         return true;
575       case 1:
576         SrcCast = new CastInst(src,PointerType::get(Type::SByteTy),"",ci);
577         DestCast = new CastInst(dest,PointerType::get(Type::SByteTy),"",ci);
578         break;
579       case 2:
580         SrcCast = new CastInst(src,PointerType::get(Type::ShortTy),"",ci);
581         DestCast = new CastInst(dest,PointerType::get(Type::ShortTy),"",ci);
582         break;
583       case 4:
584         SrcCast = new CastInst(src,PointerType::get(Type::IntTy),"",ci);
585         DestCast = new CastInst(dest,PointerType::get(Type::IntTy),"",ci);
586         break;
587       case 8:
588         SrcCast = new CastInst(src,PointerType::get(Type::LongTy),"",ci);
589         DestCast = new CastInst(dest,PointerType::get(Type::LongTy),"",ci);
590         break;
591       default:
592         return false;
593     }
594     LoadInst* LI = new LoadInst(SrcCast,"",ci);
595     StoreInst* SI = new StoreInst(LI, DestCast, ci);
596     ci->eraseFromParent();
597     return true;
598   }
599 } MemCpyOptimizer;
600
601 /// This CallOptimizer will simplify a call to the memmove library function. It
602 /// is identical to MemCopyOptimization except for the name of the intrinsic.
603 /// @brief Simplify the memmove library function.
604 struct MemMoveOptimization : public MemCpyOptimization
605 {
606   MemMoveOptimization() : MemCpyOptimization("llvm.memmove") {}
607
608 } MemMoveOptimizer;
609
610 }