Reapply r141870, SCEV expansion of post-inc.
[oota-llvm.git] / lib / Target / PTX / PTXISelDAGToDAG.cpp
index b3c85da7b4461662a967dff06d872f1a1b68bcbb..5c7ee298f31abd1e420f4370d2127f0d890b8e70 100644 (file)
 //===----------------------------------------------------------------------===//
 
 #include "PTX.h"
+#include "PTXMachineFunctionInfo.h"
 #include "PTXTargetMachine.h"
+#include "llvm/ADT/StringExtras.h"
 #include "llvm/CodeGen/SelectionDAGISel.h"
 #include "llvm/DerivedTypes.h"
+#include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
 
 using namespace llvm;
@@ -36,17 +39,20 @@ class PTXDAGToDAGISel : public SelectionDAGISel {
     bool SelectADDRrr(SDValue &Addr, SDValue &R1, SDValue &R2);
     bool SelectADDRri(SDValue &Addr, SDValue &Base, SDValue &Offset);
     bool SelectADDRii(SDValue &Addr, SDValue &Base, SDValue &Offset);
+    bool SelectADDRlocal(SDValue &Addr, SDValue &Base, SDValue &Offset);
 
     // Include the pieces auto'gened from the target description
 #include "PTXGenDAGISel.inc"
 
   private:
-    SDNode *SelectREAD_PARAM(SDNode *Node);
-
     // We need this only because we can't match intruction BRAdp
     // pattern (PTXbrcond bb:$d, ...) in PTXInstrInfo.td
     SDNode *SelectBRCOND(SDNode *Node);
 
+    SDNode *SelectREADPARAM(SDNode *Node);
+    SDNode *SelectWRITEPARAM(SDNode *Node);
+    SDNode *SelectFrameIndex(SDNode *Node);
+
     bool isImm(const SDValue &operand);
     bool SelectImm(const SDValue &operand, SDValue &imm);
 
@@ -67,53 +73,26 @@ PTXDAGToDAGISel::PTXDAGToDAGISel(PTXTargetMachine &TM,
 
 SDNode *PTXDAGToDAGISel::Select(SDNode *Node) {
   switch (Node->getOpcode()) {
-    case PTXISD::READ_PARAM:
-      return SelectREAD_PARAM(Node);
     case ISD::BRCOND:
       return SelectBRCOND(Node);
+    case PTXISD::READ_PARAM:
+      return SelectREADPARAM(Node);
+    case PTXISD::WRITE_PARAM:
+      return SelectWRITEPARAM(Node);
+    case ISD::FrameIndex:
+      return SelectFrameIndex(Node);
     default:
       return SelectCode(Node);
   }
 }
 
-SDNode *PTXDAGToDAGISel::SelectREAD_PARAM(SDNode *Node) {
-  SDValue  index = Node->getOperand(1);
-  DebugLoc dl    = Node->getDebugLoc();
-  unsigned opcode;
-
-  if (index.getOpcode() != ISD::TargetConstant)
-    llvm_unreachable("READ_PARAM: index is not ISD::TargetConstant");
-
-  if (Node->getValueType(0) == MVT::i16) {
-    opcode = PTX::LDpiU16;
-  }
-  else if (Node->getValueType(0) == MVT::i32) {
-    opcode = PTX::LDpiU32;
-  }
-  else if (Node->getValueType(0) == MVT::i64) {
-    opcode = PTX::LDpiU64;
-  }
-  else if (Node->getValueType(0) == MVT::f32) {
-    opcode = PTX::LDpiF32;
-  }
-  else if (Node->getValueType(0) == MVT::f64) {
-    opcode = PTX::LDpiF64;
-  }
-  else {
-    llvm_unreachable("Unknown parameter type for ld.param");
-  }
-
-  return PTXInstrInfo::
-    GetPTXMachineNode(CurDAG, opcode, dl, Node->getValueType(0), index);
-}
-
 SDNode *PTXDAGToDAGISel::SelectBRCOND(SDNode *Node) {
   assert(Node->getNumOperands() >= 3);
 
   SDValue Chain  = Node->getOperand(0);
   SDValue Pred   = Node->getOperand(1);
   SDValue Target = Node->getOperand(2); // branch target
-  SDValue PredOp = CurDAG->getTargetConstant(PTX::PRED_NORMAL, MVT::i32);
+  SDValue PredOp = CurDAG->getTargetConstant(PTXPredicate::Normal, MVT::i32);
   DebugLoc dl = Node->getDebugLoc();
 
   assert(Target.getOpcode()  == ISD::BasicBlock);
@@ -124,6 +103,97 @@ SDNode *PTXDAGToDAGISel::SelectBRCOND(SDNode *Node) {
   return CurDAG->getMachineNode(PTX::BRAdp, dl, MVT::Other, Ops, 4);
 }
 
+SDNode *PTXDAGToDAGISel::SelectREADPARAM(SDNode *Node) {
+  SDValue Chain = Node->getOperand(0);
+  SDValue Index = Node->getOperand(1);
+
+  int OpCode;
+
+  // Get the type of parameter we are reading
+  EVT VT = Node->getValueType(0);
+  assert(VT.isSimple() && "READ_PARAM only implemented for MVT types");
+
+  MVT Type = VT.getSimpleVT();
+
+  if (Type == MVT::i1)
+    OpCode = PTX::READPARAMPRED;
+  else if (Type == MVT::i16)
+    OpCode = PTX::READPARAMI16;
+  else if (Type == MVT::i32)
+    OpCode = PTX::READPARAMI32;
+  else if (Type == MVT::i64)
+    OpCode = PTX::READPARAMI64;
+  else if (Type == MVT::f32)
+    OpCode = PTX::READPARAMF32;
+  else {
+    assert(Type == MVT::f64 && "Unexpected type!");
+    OpCode = PTX::READPARAMF64;
+  }
+
+  SDValue Pred = CurDAG->getRegister(PTX::NoRegister, MVT::i1);
+  SDValue PredOp = CurDAG->getTargetConstant(PTXPredicate::None, MVT::i32);
+  DebugLoc dl = Node->getDebugLoc();
+
+  SDValue Ops[] = { Index, Pred, PredOp, Chain };
+  return CurDAG->getMachineNode(OpCode, dl, VT, Ops, 4);
+}
+
+SDNode *PTXDAGToDAGISel::SelectWRITEPARAM(SDNode *Node) {
+
+  SDValue Chain = Node->getOperand(0);
+  SDValue Value = Node->getOperand(1);
+
+  int OpCode;
+
+  //Node->dumpr(CurDAG);
+
+  // Get the type of parameter we are writing
+  EVT VT = Value->getValueType(0);
+  assert(VT.isSimple() && "WRITE_PARAM only implemented for MVT types");
+
+  MVT Type = VT.getSimpleVT();
+
+  if (Type == MVT::i1)
+    OpCode = PTX::WRITEPARAMPRED;
+  else if (Type == MVT::i16)
+    OpCode = PTX::WRITEPARAMI16;
+  else if (Type == MVT::i32)
+    OpCode = PTX::WRITEPARAMI32;
+  else if (Type == MVT::i64)
+    OpCode = PTX::WRITEPARAMI64;
+  else if (Type == MVT::f32)
+    OpCode = PTX::WRITEPARAMF32;
+  else if (Type == MVT::f64)
+    OpCode = PTX::WRITEPARAMF64;
+  else
+    llvm_unreachable("Invalid type in SelectWRITEPARAM");
+
+  SDValue Pred = CurDAG->getRegister(PTX::NoRegister, MVT::i1);
+  SDValue PredOp = CurDAG->getTargetConstant(PTXPredicate::None, MVT::i32);
+  DebugLoc dl = Node->getDebugLoc();
+
+  SDValue Ops[] = { Value, Pred, PredOp, Chain };
+  SDNode* Ret = CurDAG->getMachineNode(OpCode, dl, MVT::Other, Ops, 4);
+
+  //dbgs() << "SelectWRITEPARAM produced:\n\t";
+  //Ret->dumpr(CurDAG);
+
+  return Ret;
+}
+
+SDNode *PTXDAGToDAGISel::SelectFrameIndex(SDNode *Node) {
+  int FI = cast<FrameIndexSDNode>(Node)->getIndex();
+  //dbgs() << "Selecting FrameIndex at index " << FI << "\n";
+  //SDValue TFI = CurDAG->getTargetFrameIndex(FI, Node->getValueType(0));
+
+  PTXMachineFunctionInfo *MFI = MF->getInfo<PTXMachineFunctionInfo>();
+
+  SDValue FrameSymbol = CurDAG->getTargetExternalSymbol(MFI->getFrameSymbol(FI),
+                                                        Node->getValueType(0));
+
+  return FrameSymbol.getNode();
+}
+
 // Match memory operand of the form [reg+reg]
 bool PTXDAGToDAGISel::SelectADDRrr(SDValue &Addr, SDValue &R1, SDValue &R2) {
   if (Addr.getOpcode() != ISD::ADD || Addr.getNumOperands() < 2 ||
@@ -141,14 +211,54 @@ bool PTXDAGToDAGISel::SelectADDRrr(SDValue &Addr, SDValue &R1, SDValue &R2) {
 // Match memory operand of the form [reg], [imm+reg], and [reg+imm]
 bool PTXDAGToDAGISel::SelectADDRri(SDValue &Addr, SDValue &Base,
                                    SDValue &Offset) {
-  if (Addr.getOpcode() != ISD::ADD) {
+  // FrameIndex addresses are handled separately
+  //errs() << "SelectADDRri: ";
+  //Addr.getNode()->dumpr();
+  if (isa<FrameIndexSDNode>(Addr)) {
+    //errs() << "Failure\n";
+    return false;
+  }
+
+  if (CurDAG->isBaseWithConstantOffset(Addr)) {
+    Base = Addr.getOperand(0);
+    if (isa<FrameIndexSDNode>(Base)) {
+      //errs() << "Failure\n";
+      return false;
+    }
+    ConstantSDNode *CN = dyn_cast<ConstantSDNode>(Addr.getOperand(1));
+    Offset = CurDAG->getTargetConstant(CN->getZExtValue(), MVT::i32);
+    //errs() << "Success\n";
+    return true;
+  }
+
+  /*if (Addr.getNumOperands() == 1) {
+    Base = Addr;
+    Offset = CurDAG->getTargetConstant(0, Addr.getValueType().getSimpleVT());
+    errs() << "Success\n";
+    return true;
+  }*/
+
+  //errs() << "SelectADDRri fails on: ";
+  //Addr.getNode()->dumpr();
+
+  if (isImm(Addr)) {
+    //errs() << "Failure\n";
+    return false;
+  }
+
+  Base = Addr;
+  Offset = CurDAG->getTargetConstant(0, Addr.getValueType().getSimpleVT());
+
+  //errs() << "Success\n";
+  return true;
+
+  /*if (Addr.getOpcode() != ISD::ADD) {
     // let SelectADDRii handle the [imm] case
     if (isImm(Addr))
       return false;
     // it is [reg]
 
     assert(Addr.getValueType().isSimple() && "Type must be simple");
-
     Base = Addr;
     Offset = CurDAG->getTargetConstant(0, Addr.getValueType().getSimpleVT());
 
@@ -170,7 +280,7 @@ bool PTXDAGToDAGISel::SelectADDRri(SDValue &Addr, SDValue &Base,
     }
 
   // neither [reg+imm] nor [imm+reg]
-  return false;
+  return false;*/
 }
 
 // Match memory operand of the form [imm+imm] and [imm]
@@ -194,6 +304,36 @@ bool PTXDAGToDAGISel::SelectADDRii(SDValue &Addr, SDValue &Base,
   return false;
 }
 
+// Match memory operand of the form [reg], [imm+reg], and [reg+imm]
+bool PTXDAGToDAGISel::SelectADDRlocal(SDValue &Addr, SDValue &Base,
+                                      SDValue &Offset) {
+  //errs() << "SelectADDRlocal: ";
+  //Addr.getNode()->dumpr();
+  if (isa<FrameIndexSDNode>(Addr)) {
+    Base = Addr;
+    Offset = CurDAG->getTargetConstant(0, Addr.getValueType().getSimpleVT());
+    //errs() << "Success\n";
+    return true;
+  }
+
+  if (CurDAG->isBaseWithConstantOffset(Addr)) {
+    Base = Addr.getOperand(0);
+    if (!isa<FrameIndexSDNode>(Base)) {
+      //errs() << "Failure\n";
+      return false;
+    }
+    ConstantSDNode *CN = dyn_cast<ConstantSDNode>(Addr.getOperand(1));
+    Offset = CurDAG->getTargetConstant(CN->getZExtValue(), MVT::i32);
+    //errs() << "Offset: ";
+    //Offset.getNode()->dumpr();
+    //errs() << "Success\n";
+    return true;
+  }
+
+  //errs() << "Failure\n";
+  return false;
+}
+
 bool PTXDAGToDAGISel::isImm(const SDValue &operand) {
   return ConstantSDNode::classof(operand.getNode());
 }