MC: Modernize MCOperand API naming. NFC.
[oota-llvm.git] / lib / Target / R600 / SITypeRewriter.cpp
index 9da11e88eb8b8be43b4f8c3b05eb89469f81834d..591ce857cc7d645be4e58c9e3c3020536f29936c 100644 (file)
 ///      legal for some compute APIs, and we don't want to declare it as legal
 ///      in the backend, because we want the legalizer to expand all v16i8
 ///      operations.
+/// v1* => *
+///   - Having v1* types complicates the legalizer and we can easily replace
+///   - them with the element type.
 //===----------------------------------------------------------------------===//
 
 #include "AMDGPU.h"
-
 #include "llvm/IR/IRBuilder.h"
-#include "llvm/InstVisitor.h"
+#include "llvm/IR/InstVisitor.h"
 
 using namespace llvm;
 
@@ -33,13 +35,13 @@ class SITypeRewriter : public FunctionPass,
   static char ID;
   Module *Mod;
   Type *v16i8;
-  Type *i128;
+  Type *v4i32;
 
 public:
   SITypeRewriter() : FunctionPass(ID) { }
-  virtual bool doInitialization(Module &M);
-  virtual bool runOnFunction(Function &F);
-  virtual const char *getPassName() const {
+  bool doInitialization(Module &M) override;
+  bool runOnFunction(Function &F) override;
+  const char *getPassName() const override {
     return "SI Type Rewriter";
   }
   void visitLoadInst(LoadInst &I);
@@ -54,24 +56,23 @@ char SITypeRewriter::ID = 0;
 bool SITypeRewriter::doInitialization(Module &M) {
   Mod = &M;
   v16i8 = VectorType::get(Type::getInt8Ty(M.getContext()), 16);
-  i128 = Type::getIntNTy(M.getContext(), 128);
+  v4i32 = VectorType::get(Type::getInt32Ty(M.getContext()), 4);
   return false;
 }
 
 bool SITypeRewriter::runOnFunction(Function &F) {
-  AttributeSet Set = F.getAttributes();
-  Attribute A = Set.getAttribute(AttributeSet::FunctionIndex, "ShaderType");
+  Attribute A = F.getFnAttribute("ShaderType");
 
   unsigned ShaderType = ShaderType::COMPUTE;
   if (A.isStringAttribute()) {
     StringRef Str = A.getValueAsString();
     Str.getAsInteger(0, ShaderType);
   }
-  if (ShaderType != ShaderType::COMPUTE) {
-    visit(F);
-  }
+  if (ShaderType == ShaderType::COMPUTE)
+    return false;
 
   visit(F);
+  visit(F);
 
   return false;
 }
@@ -82,9 +83,10 @@ void SITypeRewriter::visitLoadInst(LoadInst &I) {
   Type *ElemTy = PtrTy->getPointerElementType();
   IRBuilder<> Builder(&I);
   if (ElemTy == v16i8)  {
-    Value *BitCast = Builder.CreateBitCast(Ptr, Type::getIntNPtrTy(I.getContext(), 128, 2));
+    Value *BitCast = Builder.CreateBitCast(Ptr,
+        PointerType::get(v4i32,PtrTy->getPointerAddressSpace()));
     LoadInst *Load = Builder.CreateLoad(BitCast);
-    SmallVector <std::pair<unsigned, MDNode*>, 8> MD;
+    SmallVector<std::pair<unsigned, MDNode *>, 8> MD;
     I.getAllMetadataOtherThanDebugLoc(MD);
     for (unsigned i = 0, e = MD.size(); i != e; ++i) {
       Load->setMetadata(MD[i].first, MD[i].second);
@@ -97,18 +99,31 @@ void SITypeRewriter::visitLoadInst(LoadInst &I) {
 
 void SITypeRewriter::visitCallInst(CallInst &I) {
   IRBuilder<> Builder(&I);
+
   SmallVector <Value*, 8> Args;
   SmallVector <Type*, 8> Types;
   bool NeedToReplace = false;
   Function *F = I.getCalledFunction();
-  std::string Name = F->getName().str();
+  std::string Name = F->getName();
   for (unsigned i = 0, e = I.getNumArgOperands(); i != e; ++i) {
     Value *Arg = I.getArgOperand(i);
     if (Arg->getType() == v16i8) {
-      Args.push_back(Builder.CreateBitCast(Arg, i128));
-      Types.push_back(i128);
+      Args.push_back(Builder.CreateBitCast(Arg, v4i32));
+      Types.push_back(v4i32);
+      NeedToReplace = true;
+      Name = Name + ".v4i32";
+    } else if (Arg->getType()->isVectorTy() &&
+               Arg->getType()->getVectorNumElements() == 1 &&
+               Arg->getType()->getVectorElementType() ==
+                                              Type::getInt32Ty(I.getContext())){
+      Type *ElementTy = Arg->getType()->getVectorElementType();
+      std::string TypeName = "i32";
+      InsertElementInst *Def = cast<InsertElementInst>(Arg);
+      Args.push_back(Def->getOperand(1));
+      Types.push_back(ElementTy);
+      std::string VecTypeName = "v1" + TypeName;
+      Name = Name.replace(Name.find(VecTypeName), VecTypeName.length(), TypeName);
       NeedToReplace = true;
-      Name = Name + ".i128";
     } else {
       Args.push_back(Arg);
       Types.push_back(Arg->getType());
@@ -129,12 +144,12 @@ void SITypeRewriter::visitCallInst(CallInst &I) {
 
 void SITypeRewriter::visitBitCast(BitCastInst &I) {
   IRBuilder<> Builder(&I);
-  if (I.getDestTy() != i128) {
+  if (I.getDestTy() != v4i32) {
     return;
   }
 
   if (BitCastInst *Op = dyn_cast<BitCastInst>(I.getOperand(0))) {
-    if (Op->getSrcTy() == i128) {
+    if (Op->getSrcTy() == v4i32) {
       I.replaceAllUsesWith(Op->getOperand(0));
       I.eraseFromParent();
     }