From 35b9b49fd17fe8600d5a6ac25d812694af7b1051 Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Sat, 14 Apr 2007 01:17:48 +0000 Subject: [PATCH] Implement a few missing xforms: printf("foo\n") -> puts. printf("x") -> putchar printf("") -> noop. Still need to do the xforms for fprintf. This implements Transforms/SimplifyLibCalls/Printf.ll git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@35984 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Transforms/IPO/SimplifyLibCalls.cpp | 57 +++++++++++++++++++++---- 1 file changed, 48 insertions(+), 9 deletions(-) diff --git a/lib/Transforms/IPO/SimplifyLibCalls.cpp b/lib/Transforms/IPO/SimplifyLibCalls.cpp index 87c1480fba7..b7b5bee1eb2 100644 --- a/lib/Transforms/IPO/SimplifyLibCalls.cpp +++ b/lib/Transforms/IPO/SimplifyLibCalls.cpp @@ -1154,22 +1154,61 @@ public: /// @brief Make sure that the "printf" function has the right prototype virtual bool ValidateCalledFunction(const Function *F, SimplifyLibCalls &SLC){ - // Just make sure this has at least 1 arguments - return F->arg_size() >= 1; + // Just make sure this has at least 1 argument and returns an integer or + // void type. + const FunctionType *FT = F->getFunctionType(); + return FT->getNumParams() >= 1 && + (isa(FT->getReturnType()) || + FT->getReturnType() == Type::VoidTy); } /// @brief Perform the printf optimization. virtual bool OptimizeCall(CallInst *CI, SimplifyLibCalls &SLC) { - // If the call has more than 2 operands, we can't optimize it - if (CI->getNumOperands() != 3) - return false; - // All the optimizations depend on the length of the first argument and the // fact that it is a constant string array. Check that now std::string FormatStr; if (!GetConstantStringInfo(CI->getOperand(1), FormatStr)) return false; + + // If this is a simple constant string with no format specifiers that ends + // with a \n, turn it into a puts call. + if (FormatStr.empty()) { + // Tolerate printf's declared void. + if (CI->use_empty()) return ReplaceCallWith(CI, 0); + return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), 0)); + } + + if (FormatStr.size() == 1) { + // Turn this into a putchar call, even if it is a %. + Value *V = ConstantInt::get(Type::Int32Ty, FormatStr[0]); + new CallInst(SLC.get_putchar(), V, "", CI); + if (CI->use_empty()) return ReplaceCallWith(CI, 0); + return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), 1)); + } + // Check to see if the format str is something like "foo\n", in which case + // we convert it to a puts call. We don't allow it to contain any format + // characters. + if (FormatStr[FormatStr.size()-1] == '\n' && + FormatStr.find('%') == std::string::npos) { + // Create a string literal with no \n on it. We expect the constant merge + // pass to be run after this pass, to merge duplicate strings. + FormatStr.erase(FormatStr.end()-1); + Constant *Init = ConstantArray::get(FormatStr, true); + Constant *GV = new GlobalVariable(Init->getType(), true, + GlobalVariable::InternalLinkage, + Init, "str", + CI->getParent()->getParent()->getParent()); + // Cast GV to be a pointer to char. + GV = ConstantExpr::getBitCast(GV, PointerType::get(Type::Int8Ty)); + new CallInst(SLC.get_puts(), GV, "", CI); + + if (CI->use_empty()) return ReplaceCallWith(CI, 0); + return ReplaceCallWith(CI, + ConstantInt::get(CI->getType(), FormatStr.size())); + } + + // Only support %c or "%s\n" for now. if (FormatStr.size() < 2 || FormatStr[0] != '%') return false; @@ -1178,7 +1217,7 @@ public: switch (FormatStr[1]) { default: return false; case 's': - if (FormatStr != "%s\n" || + if (FormatStr != "%s\n" || CI->getNumOperands() < 3 || // TODO: could insert strlen call to compute string length. !CI->use_empty()) return false; @@ -1189,7 +1228,7 @@ public: return ReplaceCallWith(CI, 0); case 'c': { // printf("%c",c) -> putchar(c) - if (FormatStr.size() != 2) + if (FormatStr.size() != 2 || CI->getNumOperands() < 3) return false; Value *V = CI->getOperand(2); @@ -1917,7 +1956,7 @@ static Value *CastToCStr(Value *V, Instruction *IP) { // * pow(pow(x,y),z)-> pow(x,y*z) // // puts: -// * puts("") -> fputc("\n",stdout) (how do we get "stdout"?) +// * puts("") -> putchar("\n") // // round, roundf, roundl: // * round(cnst) -> cnst' -- 2.34.1