Add support for memmove
[oota-llvm.git] / lib / Analysis / DataStructure / Local.cpp
index 641c2e0a1eb1705c0af68387cad5165aaf6482e9..0aca21a437ba90fcbb2aa3b5e1b2fd176b4da06a 100644 (file)
@@ -1,4 +1,11 @@
 //===- Local.cpp - Compute a local data structure graph for a function ----===//
+// 
+//                     The LLVM Compiler Infrastructure
+//
+// This file was developed by the LLVM research group and is distributed under
+// the University of Illinois Open Source License. See LICENSE.TXT for details.
+// 
+//===----------------------------------------------------------------------===//
 //
 // Compute the local version of the data structure graph for a function.  The
 // external interface to this file is the DSGraph constructor.
 
 #include "llvm/Analysis/DataStructure.h"
 #include "llvm/Analysis/DSGraph.h"
-#include "llvm/iMemory.h"
-#include "llvm/iTerminators.h"
-#include "llvm/iPHINode.h"
-#include "llvm/iOther.h"
 #include "llvm/Constants.h"
 #include "llvm/DerivedTypes.h"
-#include "llvm/Function.h"
-#include "llvm/GlobalVariable.h"
+#include "llvm/Instructions.h"
 #include "llvm/Support/InstVisitor.h"
 #include "llvm/Target/TargetData.h"
-#include "Support/Statistic.h"
-#include "Support/Timer.h"
 #include "Support/CommandLine.h"
+#include "Support/Debug.h"
+#include "Support/Timer.h"
 
 // FIXME: This should eventually be a FunctionPass that is automatically
 // aggregated into a Pass.
@@ -30,9 +32,6 @@ static RegisterAnalysis<LocalDataStructures>
 X("datastructure", "Local Data Structure Analysis");
 
 namespace DS {
-  // FIXME: Do something smarter with target data!
-  TargetData TD("temp-td");
-
   // isPointerType - Return true if this type is big enough to hold a pointer.
   bool isPointerType(const Type *Ty) {
     if (isa<PointerType>(Ty))
@@ -62,26 +61,32 @@ namespace {
   /// graph by performing a single pass over the function in question.
   ///
   class GraphBuilder : InstVisitor<GraphBuilder> {
-    Function &F;
     DSGraph &G;
-    DSNodeHandle &RetNode;               // Node that gets returned...
+    DSNodeHandle *RetNode;               // Node that gets returned...
     DSGraph::ScalarMapTy &ScalarMap;
-    std::vector<DSCallSite> &FunctionCalls;
+    std::vector<DSCallSite> *FunctionCalls;
 
   public:
     GraphBuilder(Function &f, DSGraph &g, DSNodeHandle &retNode, 
-                 DSGraph::ScalarMapTy &SM, std::vector<DSCallSite> &fc)
-      : F(f), G(g), RetNode(retNode), ScalarMap(SM),
-        FunctionCalls(fc) {
+                 std::vector<DSCallSite> &fc)
+      : G(g), RetNode(&retNode), ScalarMap(G.getScalarMap()),
+        FunctionCalls(&fc) {
 
       // Create scalar nodes for all pointer arguments...
-      for (Function::aiterator I = F.abegin(), E = F.aend(); I != E; ++I)
+      for (Function::aiterator I = f.abegin(), E = f.aend(); I != E; ++I)
         if (isPointerType(I->getType()))
           getValueDest(*I);
 
-      visit(F);  // Single pass over the function
+      visit(f);  // Single pass over the function
+    }
+
+    // GraphBuilder ctor for working on the globals graph
+    GraphBuilder(DSGraph &g)
+      : G(g), RetNode(0), ScalarMap(G.getScalarMap()), FunctionCalls(0) {
     }
 
+    void mergeInGlobalInitializer(GlobalVariable *GV);
+
   private:
     // Visitor functions, used to handle each instruction type we encounter...
     friend class InstVisitor<GraphBuilder>;
@@ -96,11 +101,15 @@ namespace {
     void visitLoadInst(LoadInst &LI);
     void visitStoreInst(StoreInst &SI);
     void visitCallInst(CallInst &CI);
+    void visitInvokeInst(InvokeInst &II);
     void visitSetCondInst(SetCondInst &SCI) {}  // SetEQ & friends are ignored
     void visitFreeInst(FreeInst &FI);
     void visitCastInst(CastInst &CI);
     void visitInstruction(Instruction &I);
 
+    void visitCallSite(CallSite CS);
+
+    void MergeConstantInitIntoNode(DSNodeHandle &NH, Constant *C);
   private:
     // Helper functions used to implement the visitation functions...
 
@@ -138,13 +147,14 @@ namespace {
 //===----------------------------------------------------------------------===//
 // DSGraph constructor - Simply use the GraphBuilder to construct the local
 // graph.
-DSGraph::DSGraph(Function &F, DSGraph *GG) : GlobalsGraph(GG) {
+DSGraph::DSGraph(const TargetData &td, Function &F, DSGraph *GG)
+  : GlobalsGraph(GG), TD(td) {
   PrintAuxCalls = false;
 
   DEBUG(std::cerr << "  [Loc] Calculating graph for: " << F.getName() << "\n");
 
   // Use the graph builder to construct the local version of the graph
-  GraphBuilder B(F, *this, ReturnNodes[&F], ScalarMap, FunctionCalls);
+  GraphBuilder B(F, *this, ReturnNodes[&F], FunctionCalls);
 #ifndef NDEBUG
   Timer::addPeakMemoryMeasurement();
 #endif
@@ -288,7 +298,6 @@ void GraphBuilder::visitGetElementPtrInst(User &GEP) {
   DSNodeHandle Value = getValueDest(*GEP.getOperand(0));
   if (Value.getNode() == 0) return;
 
-  unsigned Offset = 0;
   const PointerType *PTy = cast<PointerType>(GEP.getOperand(0)->getType());
   const Type *CurTy = PTy->getElementType();
 
@@ -298,6 +307,8 @@ void GraphBuilder::visitGetElementPtrInst(User &GEP) {
     return;
   }
 
+  const TargetData &TD = Value.getNode()->getTargetData();
+
 #if 0
   // Handle the pointer index specially...
   if (GEP.getNumOperands() > 1 &&
@@ -326,6 +337,7 @@ void GraphBuilder::visitGetElementPtrInst(User &GEP) {
 #endif
 
   // All of these subscripts are indexing INTO the elements we have...
+  unsigned Offset = 0;
   for (unsigned i = 2, e = GEP.getNumOperands(); i < e; ++i)
     if (GEP.getOperand(i)->getType() == Type::LongTy) {
       // Get the type indexing into...
@@ -347,7 +359,7 @@ void GraphBuilder::visitGetElementPtrInst(User &GEP) {
           unsigned RawOffset = Offset+Value.getOffset();
 
           // Loop over all of the elements of the array, merging them into the
-          // zero'th element.
+          // zeroth element.
           for (unsigned i = 1, e = ATy->getNumElements(); i != e; ++i)
             // Merge all of the byte components of this array element
             for (unsigned j = 0; j != ElSize; ++j)
@@ -391,7 +403,7 @@ void GraphBuilder::visitStoreInst(StoreInst &SI) {
   // Mark that the node is written to...
   Dest.getNode()->setModifiedMarker();
 
-  // Ensure a typerecord exists...
+  // Ensure a type-record exists...
   Dest.getNode()->mergeTypeInfo(StoredTy, Dest.getOffset());
 
   // Avoid adding edges from null, or processing non-"pointer" stores
@@ -401,33 +413,87 @@ void GraphBuilder::visitStoreInst(StoreInst &SI) {
 
 void GraphBuilder::visitReturnInst(ReturnInst &RI) {
   if (RI.getNumOperands() && isPointerType(RI.getOperand(0)->getType()))
-    RetNode.mergeWith(getValueDest(*RI.getOperand(0)));
+    RetNode->mergeWith(getValueDest(*RI.getOperand(0)));
 }
 
 void GraphBuilder::visitCallInst(CallInst &CI) {
+  visitCallSite(&CI);
+}
+
+void GraphBuilder::visitInvokeInst(InvokeInst &II) {
+  visitCallSite(&II);
+}
+
+void GraphBuilder::visitCallSite(CallSite CS) {
+  // Special case handling of certain libc allocation functions here.
+  if (Function *F = CS.getCalledFunction())
+    if (F->isExternal())
+      if (F->getName() == "calloc") {
+        setDestTo(*CS.getInstruction(),
+                  createNode()->setHeapNodeMarker()->setModifiedMarker());
+        return;
+      } else if (F->getName() == "realloc") {
+        DSNodeHandle RetNH = getValueDest(*CS.getInstruction());
+        RetNH.mergeWith(getValueDest(**CS.arg_begin()));
+        if (DSNode *N = RetNH.getNode())
+          N->setHeapNodeMarker()->setModifiedMarker()->setReadMarker();
+        return;
+      } else if (F->getName() == "memset") {
+        // Merge the first argument with the return value, and mark the memory
+        // modified.
+        DSNodeHandle RetNH = getValueDest(*CS.getInstruction());
+        RetNH.mergeWith(getValueDest(**CS.arg_begin()));
+        if (DSNode *N = RetNH.getNode())
+          N->setModifiedMarker();
+        return;
+      } else if (F->getName() == "memmove") {
+        // Merge the first & second arguments with the result, and mark the
+        // memory read and modified.
+        DSNodeHandle RetNH = getValueDest(*CS.getInstruction());
+        RetNH.mergeWith(getValueDest(**CS.arg_begin()));
+        RetNH.mergeWith(getValueDest(**(CS.arg_begin()+1)));
+        if (DSNode *N = RetNH.getNode())
+          N->setModifiedMarker()->setReadMarker();
+        return;
+      } else if (F->getName() == "bzero") {
+        // Mark the memory modified.
+        DSNodeHandle H = getValueDest(**CS.arg_begin());
+        if (DSNode *N = H.getNode())
+          N->setModifiedMarker();
+        return;
+      }
+
+
   // Set up the return value...
   DSNodeHandle RetVal;
-  if (isPointerType(CI.getType()))
-    RetVal = getValueDest(CI);
+  Instruction *I = CS.getInstruction();
+  if (isPointerType(I->getType()))
+    RetVal = getValueDest(*I);
 
   DSNode *Callee = 0;
-  if (DisableDirectCallOpt || !isa<Function>(CI.getOperand(0)))
-    Callee = getValueDest(*CI.getOperand(0)).getNode();
+  if (DisableDirectCallOpt || !isa<Function>(CS.getCalledValue())) {
+    Callee = getValueDest(*CS.getCalledValue()).getNode();
+    if (Callee == 0) {
+      std::cerr << "WARNING: Program is calling through a null pointer?\n"
+                << *I;
+      return;  // Calling a null pointer?
+    }
+  }
 
   std::vector<DSNodeHandle> Args;
-  Args.reserve(CI.getNumOperands()-1);
+  Args.reserve(CS.arg_end()-CS.arg_begin());
 
   // Calculate the arguments vector...
-  for (unsigned i = 1, e = CI.getNumOperands(); i != e; ++i)
-    if (isPointerType(CI.getOperand(i)->getType()))
-      Args.push_back(getValueDest(*CI.getOperand(i)));
+  for (CallSite::arg_iterator I = CS.arg_begin(), E = CS.arg_end(); I != E; ++I)
+    if (isPointerType((*I)->getType()))
+      Args.push_back(getValueDest(**I));
 
   // Add a new function call entry...
   if (Callee)
-    FunctionCalls.push_back(DSCallSite(CI, RetVal, Callee, Args));
+    FunctionCalls->push_back(DSCallSite(CS, RetVal, Callee, Args));
   else
-    FunctionCalls.push_back(DSCallSite(CI, RetVal,
-                                       cast<Function>(CI.getOperand(0)), Args));
+    FunctionCalls->push_back(DSCallSite(CS, RetVal, CS.getCalledFunction(),
+                                        Args));
 }
 
 void GraphBuilder::visitFreeInst(FreeInst &FI) {
@@ -474,13 +540,63 @@ void GraphBuilder::visitInstruction(Instruction &Inst) {
 // LocalDataStructures Implementation
 //===----------------------------------------------------------------------===//
 
+// MergeConstantInitIntoNode - Merge the specified constant into the node
+// pointed to by NH.
+void GraphBuilder::MergeConstantInitIntoNode(DSNodeHandle &NH, Constant *C) {
+  // Ensure a type-record exists...
+  NH.getNode()->mergeTypeInfo(C->getType(), NH.getOffset());
+
+  if (C->getType()->isFirstClassType()) {
+    if (isPointerType(C->getType()))
+      // Avoid adding edges from null, or processing non-"pointer" stores
+      NH.addEdgeTo(getValueDest(*C));
+    return;
+  }
+
+  const TargetData &TD = NH.getNode()->getTargetData();
+
+  if (ConstantArray *CA = dyn_cast<ConstantArray>(C)) {
+    for (unsigned i = 0, e = CA->getNumOperands(); i != e; ++i)
+      // We don't currently do any indexing for arrays...
+      MergeConstantInitIntoNode(NH, cast<Constant>(CA->getOperand(i)));
+  } else if (ConstantStruct *CS = dyn_cast<ConstantStruct>(C)) {
+    const StructLayout *SL = TD.getStructLayout(CS->getType());
+    for (unsigned i = 0, e = CS->getNumOperands(); i != e; ++i) {
+      DSNodeHandle NewNH(NH.getNode(), NH.getOffset()+SL->MemberOffsets[i]);
+      MergeConstantInitIntoNode(NewNH, cast<Constant>(CS->getOperand(i)));
+    }
+  } else {
+    assert(0 && "Unknown constant type!");
+  }
+}
+
+void GraphBuilder::mergeInGlobalInitializer(GlobalVariable *GV) {
+  assert(!GV->isExternal() && "Cannot merge in external global!");
+  // Get a node handle to the global node and merge the initializer into it.
+  DSNodeHandle NH = getValueDest(*GV);
+  MergeConstantInitIntoNode(NH, GV->getInitializer());
+}
+
+
 bool LocalDataStructures::run(Module &M) {
-  GlobalsGraph = new DSGraph();
+  GlobalsGraph = new DSGraph(getAnalysis<TargetData>());
+
+  const TargetData &TD = getAnalysis<TargetData>();
 
   // Calculate all of the graphs...
   for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I)
     if (!I->isExternal())
-      DSInfo.insert(std::make_pair(I, new DSGraph(*I, GlobalsGraph)));
+      DSInfo.insert(std::make_pair(I, new DSGraph(TD, *I, GlobalsGraph)));
+
+  GraphBuilder GGB(*GlobalsGraph);
+
+  // Add initializers for all of the globals to the globals graph...
+  for (Module::giterator I = M.gbegin(), E = M.gend(); I != E; ++I)
+    if (!I->isExternal())
+      GGB.mergeInGlobalInitializer(I);
+
+  GlobalsGraph->markIncompleteNodes(DSGraph::MarkFormalArgs);
+  GlobalsGraph->removeTriviallyDeadNodes();
   return false;
 }