* Merge get_GVInitializer and getCharArrayLength into a single function
[oota-llvm.git] / lib / Transforms / IPO / SimplifyLibCalls.cpp
1 //===- SimplifyLibCalls.cpp - Optimize specific well-known library calls --===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file was developed by Reid Spencer and is distributed under the 
6 // University of Illinois Open Source License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements a variety of small optimizations for calls to specific
11 // well-known (e.g. runtime library) function calls. For example, a call to the
12 // function "exit(3)" that occurs within the main() function can be transformed
13 // into a simple "return 3" instruction. Any optimization that takes this form
14 // (replace call to library function with simpler code that provides same 
15 // result) belongs in this file. 
16 //
17 //===----------------------------------------------------------------------===//
18
19 #include "llvm/Transforms/IPO.h"
20 #include "llvm/Module.h"
21 #include "llvm/Pass.h"
22 #include "llvm/DerivedTypes.h"
23 #include "llvm/Constants.h"
24 #include "llvm/Instructions.h"
25 #include "llvm/ADT/Statistic.h"
26 #include "llvm/ADT/hash_map"
27 #include <iostream>
28 using namespace llvm;
29
30 namespace {
31   Statistic<> SimplifiedLibCalls("simplified-lib-calls", 
32       "Number of well-known library calls simplified");
33
34   /// This class is the base class for a set of small but important 
35   /// optimizations of calls to well-known functions, such as those in the c
36   /// library. This class provides the basic infrastructure for handling 
37   /// runOnModule. Subclasses register themselves and provide two methods:
38   /// RecognizeCall and OptimizeCall. Whenever this class finds a function call,
39   /// it asks the subclasses to recognize the call. If it is recognized, then
40   /// the OptimizeCall method is called on that subclass instance. In this way
41   /// the subclasses implement the calling conditions on which they trigger and
42   /// the action to perform, making it easy to add new optimizations of this
43   /// form.
44   /// @brief A ModulePass for optimizing well-known function calls
45   struct SimplifyLibCalls : public ModulePass {
46
47
48     /// For this pass, process all of the function calls in the module, calling
49     /// RecognizeCall and OptimizeCall as appropriate.
50     virtual bool runOnModule(Module &M);
51
52   };
53
54   RegisterOpt<SimplifyLibCalls> 
55     X("simplify-libcalls","Simplify well-known library calls");
56
57   struct CallOptimizer
58   {
59     /// @brief Constructor that registers the optimization
60     CallOptimizer(const char * fname );
61
62     virtual ~CallOptimizer();
63
64     /// The implementation of this function in subclasses should determine if
65     /// \p F is suitable for the optimization. This method is called by 
66     /// runOnModule to short circuit visiting all the call sites of such a
67     /// function if that function is not suitable in the first place.
68     /// If the called function is suitabe, this method should return true;
69     /// false, otherwise. This function should also perform any lazy 
70     /// initialization that the CallOptimizer needs to do, if its to return 
71     /// true. This avoids doing initialization until the optimizer is actually
72     /// going to be called upon to do some optimization.
73     virtual bool ValidateCalledFunction(
74       const Function* F ///< The function that is the target of call sites
75     ) = 0;
76
77     /// The implementations of this function in subclasses is the heart of the 
78     /// SimplifyLibCalls algorithm. Sublcasses of this class implement 
79     /// OptimizeCall to determine if (a) the conditions are right for optimizing
80     /// the call and (b) to perform the optimization. If an action is taken 
81     /// against ci, the subclass is responsible for returning true and ensuring
82     /// that ci is erased from its parent.
83     /// @param ci the call instruction under consideration
84     /// @param f the function that ci calls.
85     /// @brief Optimize a call, if possible.
86     virtual bool OptimizeCall(
87       CallInst* ci ///< The call instruction that should be optimized.
88     ) = 0;
89
90     const char * getFunctionName() const { return func_name; }
91   private:
92     const char* func_name;
93   };
94
95   /// @brief The list of optimizations deriving from CallOptimizer
96
97   hash_map<std::string,CallOptimizer*> optlist;
98
99   CallOptimizer::CallOptimizer(const char* fname)
100     : func_name(fname)
101   {
102     // Register this call optimizer
103     optlist[func_name] = this;
104   }
105
106   /// Make sure we get our virtual table in this file.
107   CallOptimizer::~CallOptimizer() { }
108
109   /// Provide some functions for accessing standard library prototypes and
110   /// caching them so we don't have to keep recomputing them
111   FunctionType* get_strlen()
112   {
113     static FunctionType* strlen_type = 0;
114     if (!strlen_type)
115     {
116       std::vector<const Type*> args;
117       args.push_back(PointerType::get(Type::SByteTy));
118       strlen_type = FunctionType::get(Type::IntTy, args, false);
119     }
120     return strlen_type;
121   }
122
123   FunctionType* get_memcpy()
124   {
125     static FunctionType* memcpy_type = 0;
126     if (!memcpy_type)
127     {
128       // Note: this is for llvm.memcpy intrinsic
129       std::vector<const Type*> args;
130       args.push_back(PointerType::get(Type::SByteTy));
131       args.push_back(PointerType::get(Type::SByteTy));
132       args.push_back(Type::IntTy);
133       args.push_back(Type::IntTy);
134       memcpy_type = FunctionType::get(
135         PointerType::get(Type::SByteTy), args, false);
136     }
137     return memcpy_type;
138   }
139
140   /// A function to compute the length of a null-terminated string of integers.
141   /// This function can't rely on the size of the constant array because there 
142   /// could be a null terminator in the middle of the array. We also have to 
143   /// bail out if we find a non-integer constant initializer of one of the 
144   /// elements or if there is no null-terminator. The logic below checks
145   bool getConstantStringLength(Value* V, uint64_t& len )
146   {
147     assert(V != 0 && "Invalid args to getCharArrayLength");
148     len = 0; // make sure we initialize this 
149     User* GEP = 0;
150     // If the value is not a GEP instruction nor a constant expression with a 
151     // GEP instruction, then return false because ConstantArray can't occur 
152     // any other way
153     if (GetElementPtrInst* GEPI = dyn_cast<GetElementPtrInst>(V))
154       GEP = GEPI;
155     else if (ConstantExpr* CE = dyn_cast<ConstantExpr>(V))
156       if (CE->getOpcode() == Instruction::GetElementPtr)
157         GEP = CE;
158       else
159         return false;
160     else
161       return false;
162
163     // Check to make sure that the first operand of the GEP is an integer and
164     // has value 0 so that we are sure we're indexing into the initializer. 
165     if (ConstantInt* op1 = dyn_cast<ConstantInt>(GEP->getOperand(1)))
166     {
167       if (!op1->isNullValue())
168         return false;
169     }
170     else
171       return false;
172
173     // Ensure that the second operand is a ConstantInt. If it isn't then this
174     // GEP is wonky and we're not really sure what were referencing into and 
175     // better of not optimizing it. While we're at it, get the second index
176     // value. We'll need this later for indexing the ConstantArray.
177     uint64_t start_idx = 0;
178     if (ConstantInt* CI = dyn_cast<ConstantInt>(GEP->getOperand(2)))
179       start_idx = CI->getRawValue();
180     else
181       return false;
182
183     // The GEP instruction, constant or instruction, must reference a global
184     // variable that is a constant and is initialized. The referenced constant
185     // initializer is the array that we'll use for optimization.
186     GlobalVariable* GV = dyn_cast<GlobalVariable>(GEP->getOperand(0));
187     if (!GV || !GV->isConstant() || !GV->hasInitializer())
188       return false;
189
190     // Get the initializer and make sure its valid.
191     Constant* INTLZR = GV->getInitializer();
192     if (!INTLZR)
193       return false;
194
195     // Handle the ConstantAggregateZero case
196     if (ConstantAggregateZero* CAZ = dyn_cast<ConstantAggregateZero>(INTLZR))
197     {
198       // This is a degenerate case. The initializer is constant zero so the
199       // length of the string must be zero.
200       len = 0;
201       return true;
202     }
203
204     // Must be a Constant Array
205     ConstantArray* A = dyn_cast<ConstantArray>(INTLZR);
206     if (!A)
207       return false;
208
209     // Get the number of elements in the array
210     uint64_t max_elems = A->getType()->getNumElements();
211
212     // Traverse the constant array from start_idx (derived above) which is
213     // the place the GEP refers to in the array. 
214     for ( len = start_idx; len < max_elems; len++)
215     {
216       if (ConstantInt* CI = dyn_cast<ConstantInt>(A->getOperand(len)))
217       {
218         // Check for the null terminator
219         if (CI->isNullValue())
220           break; // we found end of string
221       }
222       else
223         return false; // This array isn't suitable, non-int initializer
224     }
225     if (len >= max_elems)
226       return false; // This array isn't null terminated
227
228     // Subtract out the initial value from the length
229     len -= start_idx;
230     return true; // success!
231   }
232 }
233
234 ModulePass *llvm::createSimplifyLibCallsPass() 
235
236   return new SimplifyLibCalls(); 
237 }
238
239 bool SimplifyLibCalls::runOnModule(Module &M) 
240 {
241   bool result = false;
242
243   // The call optimizations can be recursive. That is, the optimization might
244   // generate a call to another function which can also be optimized. This way
245   // we make the CallOptimizer instances very specific to the case they handle.
246   // It also means we need to keep running over the function calls in the module
247   // until we don't get any more optimizations possible.
248   bool found_optimization = false;
249   do
250   {
251     found_optimization = false;
252     for (Module::iterator FI = M.begin(), FE = M.end(); FI != FE; ++FI)
253     {
254       // All the "well-known" functions are external and have external linkage
255       // because they live in a runtime library somewhere and were (probably) 
256       // not compiled by LLVM.  So, we only act on external functions that have 
257       // external linkage and non-empty uses.
258       if (FI->isExternal() && FI->hasExternalLinkage() && !FI->use_empty())
259       {
260         // Get the optimization class that pertains to this function
261         if (CallOptimizer* CO = optlist[FI->getName().c_str()] )
262         {
263           // Make sure the called function is suitable for the optimization
264           if (CO->ValidateCalledFunction(FI))
265           {
266             // Loop over each of the uses of the function
267             for (Value::use_iterator UI = FI->use_begin(), UE = FI->use_end(); 
268                  UI != UE ; )
269             {
270               // If the use of the function is a call instruction
271               if (CallInst* CI = dyn_cast<CallInst>(*UI++))
272               {
273                 // Do the optimization on the CallOptimizer.
274                 if (CO->OptimizeCall(CI))
275                 {
276                   ++SimplifiedLibCalls;
277                   found_optimization = result = true;
278                 }
279               }
280             }
281           }
282         }
283       }
284     }
285   } while (found_optimization);
286   return result;
287 }
288
289 namespace {
290
291 /// This CallOptimizer will find instances of a call to "exit" that occurs
292 /// within the "main" function and change it to a simple "ret" instruction with
293 /// the same value as passed to the exit function. It assumes that the 
294 /// instructions after the call to exit(3) can be deleted since they are 
295 /// unreachable anyway.
296 /// @brief Replace calls to exit in main with a simple return
297 struct ExitInMainOptimization : public CallOptimizer
298 {
299   ExitInMainOptimization() : CallOptimizer("exit") {}
300   virtual ~ExitInMainOptimization() {}
301
302   // Make sure the called function looks like exit (int argument, int return
303   // type, external linkage, not varargs). 
304   virtual bool ValidateCalledFunction(const Function* f)
305   {
306     if (f->arg_size() >= 1)
307       if (f->arg_begin()->getType()->isInteger())
308         return true;
309     return false;
310   }
311
312   virtual bool OptimizeCall(CallInst* ci)
313   {
314     // To be careful, we check that the call to exit is coming from "main", that
315     // main has external linkage, and the return type of main and the argument
316     // to exit have the same type. 
317     Function *from = ci->getParent()->getParent();
318     if (from->hasExternalLinkage())
319       if (from->getReturnType() == ci->getOperand(1)->getType())
320         if (from->getName() == "main")
321         {
322           // Okay, time to actually do the optimization. First, get the basic 
323           // block of the call instruction
324           BasicBlock* bb = ci->getParent();
325
326           // Create a return instruction that we'll replace the call with. 
327           // Note that the argument of the return is the argument of the call 
328           // instruction.
329           ReturnInst* ri = new ReturnInst(ci->getOperand(1), ci);
330
331           // Split the block at the call instruction which places it in a new
332           // basic block.
333           bb->splitBasicBlock(ci);
334
335           // The block split caused a branch instruction to be inserted into
336           // the end of the original block, right after the return instruction
337           // that we put there. That's not a valid block, so delete the branch
338           // instruction.
339           bb->getInstList().pop_back();
340
341           // Now we can finally get rid of the call instruction which now lives
342           // in the new basic block.
343           ci->eraseFromParent();
344
345           // Optimization succeeded, return true.
346           return true;
347         }
348     // We didn't pass the criteria for this optimization so return false
349     return false;
350   }
351 } ExitInMainOptimizer;
352
353 /// This CallOptimizer will simplify a call to the strcat library function. The
354 /// simplification is possible only if the string being concatenated is a 
355 /// constant array or a constant expression that results in a constant array. In
356 /// this case, if the array is small, we can generate a series of inline store
357 /// instructions to effect the concatenation without calling strcat.
358 /// @brief Simplify the strcat library function.
359 struct StrCatOptimization : public CallOptimizer
360 {
361 private:
362   Function* strlen_func;
363   Function* memcpy_func;
364 public:
365   StrCatOptimization() 
366     : CallOptimizer("strcat") 
367     , strlen_func(0)
368     , memcpy_func(0)
369     {}
370   virtual ~StrCatOptimization() {}
371
372   inline Function* get_strlen_func(Module*M)
373   {
374     if (strlen_func)
375       return strlen_func;
376     return strlen_func = M->getOrInsertFunction("strlen",get_strlen());
377   }
378
379   inline Function* get_memcpy_func(Module* M) 
380   {
381     if (memcpy_func)
382       return memcpy_func;
383     return memcpy_func = M->getOrInsertFunction("llvm.memcpy",get_memcpy());
384   }
385
386   /// @brief Make sure that the "strcat" function has the right prototype
387   virtual bool ValidateCalledFunction(const Function* f) 
388   {
389     if (f->getReturnType() == PointerType::get(Type::SByteTy))
390       if (f->arg_size() == 2) 
391       {
392         Function::const_arg_iterator AI = f->arg_begin();
393         if (AI++->getType() == PointerType::get(Type::SByteTy))
394           if (AI->getType() == PointerType::get(Type::SByteTy))
395           {
396             // Invalidate the pre-computed strlen_func and memcpy_func Functions
397             // because, by definition, this method is only called when a new
398             // Module is being traversed. Invalidation causes re-computation for
399             // the new Module (if necessary).
400             strlen_func = 0;
401             memcpy_func = 0;
402
403             // Indicate this is a suitable call type.
404             return true;
405           }
406       }
407     return false;
408   }
409
410   /// Perform the optimization if the length of the string concatenated
411   /// is reasonably short and it is a constant array.
412   virtual bool OptimizeCall(CallInst* ci)
413   {
414     // Extract the initializer (while making numerous checks) from the 
415     // source operand of the call to strcat. If we get null back, one of
416     // a variety of checks in get_GVInitializer failed
417     uint64_t len = 0;
418     if (!getConstantStringLength(ci->getOperand(2),len))
419       return false;
420
421     // Handle the simple, do-nothing case
422     if (len == 0)
423     {
424       ci->replaceAllUsesWith(ci->getOperand(1));
425       ci->eraseFromParent();
426       return true;
427     }
428
429     // Increment the length because we actually want to memcpy the null
430     // terminator as well.
431     len++;
432
433     // Extract some information from the instruction
434     Module* M = ci->getParent()->getParent()->getParent();
435
436     // We need to find the end of the destination string.  That's where the 
437     // memory is to be moved to. We just generate a call to strlen (further 
438     // optimized in another pass). Note that the get_strlen_func() call 
439     // caches the Function* for us.
440     CallInst* strlen_inst = 
441       new CallInst(get_strlen_func(M),ci->getOperand(1),"",ci);
442
443     // Now that we have the destination's length, we must index into the 
444     // destination's pointer to get the actual memcpy destination (end of
445     // the string .. we're concatenating).
446     std::vector<Value*> idx;
447     idx.push_back(strlen_inst);
448     GetElementPtrInst* gep = 
449       new GetElementPtrInst(ci->getOperand(1),idx,"",ci);
450
451     // We have enough information to now generate the memcpy call to
452     // do the concatenation for us.
453     std::vector<Value*> vals;
454     vals.push_back(gep); // destination
455     vals.push_back(ci->getOperand(2)); // source
456     vals.push_back(ConstantSInt::get(Type::IntTy,len)); // length
457     vals.push_back(ConstantSInt::get(Type::IntTy,1)); // alignment
458     CallInst* memcpy_inst = new CallInst(get_memcpy_func(M), vals, "", ci);
459
460     // Finally, substitute the first operand of the strcat call for the 
461     // strcat call itself since strcat returns its first operand; and, 
462     // kill the strcat CallInst.
463     ci->replaceAllUsesWith(ci->getOperand(1));
464     ci->eraseFromParent();
465     return true;
466   }
467 } StrCatOptimizer;
468
469 /// This CallOptimizer will simplify a call to the strlen library function by
470 /// replacing it with a constant value if the string provided to it is a 
471 /// constant array.
472 /// @brief Simplify the strlen library function.
473 struct StrLenOptimization : public CallOptimizer
474 {
475   StrLenOptimization() : CallOptimizer("strlen") {}
476   virtual ~StrLenOptimization() {}
477
478   /// @brief Make sure that the "strlen" function has the right prototype
479   virtual bool ValidateCalledFunction(const Function* f)
480   {
481     if (f->getReturnType() == Type::IntTy)
482       if (f->arg_size() == 1) 
483         if (Function::const_arg_iterator AI = f->arg_begin())
484           if (AI->getType() == PointerType::get(Type::SByteTy))
485             return true;
486     return false;
487   }
488
489   /// @brief Perform the strlen optimization
490   virtual bool OptimizeCall(CallInst* ci)
491   {
492     // Get the length of the string
493     uint64_t len = 0;
494     if (!getConstantStringLength(ci->getOperand(1),len))
495       return false;
496
497     ci->replaceAllUsesWith(ConstantInt::get(Type::IntTy,len));
498     ci->eraseFromParent();
499     return true;
500   }
501 } StrLenOptimizer;
502
503 /// This CallOptimizer will simplify a call to the memcpy library function by
504 /// expanding it out to a small set of stores if the copy source is a constant
505 /// array. 
506 /// @brief Simplify the memcpy library function.
507 struct MemCpyOptimization : public CallOptimizer
508 {
509   MemCpyOptimization() : CallOptimizer("llvm.memcpy") {}
510   virtual ~MemCpyOptimization() {}
511
512   /// @brief Make sure that the "memcpy" function has the right prototype
513   virtual bool ValidateCalledFunction(const Function* f)
514   {
515     if (f->getReturnType() == PointerType::get(Type::SByteTy))
516       if (f->arg_size() == 4) 
517       {
518         Function::const_arg_iterator AI = f->arg_begin();
519         if (AI++->getType() == PointerType::get(Type::SByteTy))
520           if (AI++->getType() == PointerType::get(Type::SByteTy))
521             if (AI++->getType() == Type::IntTy)
522               if (AI->getType() == Type::IntTy)
523             return true;
524       }
525     return false;
526   }
527
528   /// Because of alignment and instruction information that we don't have, we
529   /// leave the bulk of this to the code generators. The optimization here just
530   /// deals with a few degenerate cases where the length of the string and the
531   /// alignment match the sizes of our intrinsic types so we can do a load and
532   /// store instead of the memcpy call.
533   /// @brief Perform the memcpy optimization.
534   virtual bool OptimizeCall(CallInst* ci)
535   {
536     ConstantInt* CI = dyn_cast<ConstantInt>(ci->getOperand(3));
537     assert(CI && "Operand should be ConstantInt");
538     uint64_t len = CI->getRawValue();
539     CI = dyn_cast<ConstantInt>(ci->getOperand(4));
540     assert(CI && "Operand should be ConstantInt");
541     uint64_t alignment = CI->getRawValue();
542     if (len != alignment)
543       return false;
544
545     Value* dest = ci->getOperand(1);
546     Value* src = ci->getOperand(2);
547     LoadInst* LI = 0;
548     CastInst* SrcCast = 0;
549     CastInst* DestCast = 0;
550     switch (len)
551     {
552       case 1:
553         SrcCast = new CastInst(src,PointerType::get(Type::SByteTy),"",ci);
554         DestCast = new CastInst(dest,PointerType::get(Type::SByteTy),"",ci);
555         LI = new LoadInst(SrcCast,"",ci);
556         break;
557       case 2:
558         SrcCast = new CastInst(src,PointerType::get(Type::ShortTy),"",ci);
559         DestCast = new CastInst(dest,PointerType::get(Type::ShortTy),"",ci);
560         LI = new LoadInst(SrcCast,"",ci);
561         break;
562       case 4:
563         SrcCast = new CastInst(src,PointerType::get(Type::IntTy),"",ci);
564         DestCast = new CastInst(dest,PointerType::get(Type::IntTy),"",ci);
565         LI = new LoadInst(SrcCast,"",ci);
566         break;
567       case 8:
568         SrcCast = new CastInst(src,PointerType::get(Type::LongTy),"",ci);
569         DestCast = new CastInst(dest,PointerType::get(Type::LongTy),"",ci);
570         LI = new LoadInst(SrcCast,"",ci);
571         break;
572       default:
573         return false;
574     }
575     StoreInst* SI = new StoreInst(LI, DestCast, ci);
576     ci->replaceAllUsesWith(dest);
577     ci->eraseFromParent();
578     return true;
579   }
580 } MemCpyOptimizer;
581 }