PTX: Attempt to cleanup/unify the handling of FP rounding modes. This requires
[oota-llvm.git] / lib / Target / PTX / PTXFPRoundingModePass.cpp
1 //===-- PTXFPRoundingModePass.cpp - Assign rounding modes pass ------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file defines a machine function pass that sets appropriate FP rounding
11 // modes for all relevant instructions.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #define DEBUG_TYPE "ptx-fp-rounding-mode"
16
17 #include "PTX.h"
18 #include "PTXTargetMachine.h"
19 #include "llvm/CodeGen/MachineFunctionPass.h"
20 #include "llvm/CodeGen/MachineRegisterInfo.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/ErrorHandling.h"
23 #include "llvm/Support/raw_ostream.h"
24
25 // NOTE: PTXFPRoundingModePass should be executed just before emission.
26
27 namespace llvm {
28   /// PTXFPRoundingModePass - Pass to assign appropriate FP rounding modes to
29   /// all FP instructions. Essentially, this pass just looks for all FP
30   /// instructions that have a rounding mode set to RndDefault, and sets an
31   /// appropriate rounding mode based on the target device.
32   ///
33   class PTXFPRoundingModePass : public MachineFunctionPass {
34     private:
35       static char ID;
36       PTXTargetMachine& TargetMachine;
37
38     public:
39       PTXFPRoundingModePass(PTXTargetMachine &TM, CodeGenOpt::Level OptLevel)
40         : MachineFunctionPass(ID),
41           TargetMachine(TM) {}
42
43       virtual bool runOnMachineFunction(MachineFunction &MF);
44
45       virtual const char *getPassName() const {
46         return "PTX FP Rounding Mode Pass";
47       }
48
49     private:
50
51       void processInstruction(MachineInstr &MI);
52   }; // class PTXFPRoundingModePass
53 } // namespace llvm
54
55 using namespace llvm;
56
57 char PTXFPRoundingModePass::ID = 0;
58
59 bool PTXFPRoundingModePass::runOnMachineFunction(MachineFunction &MF) {
60
61   // Look at each basic block
62   for (MachineFunction::iterator bbi = MF.begin(), bbe = MF.end(); bbi != bbe;
63        ++bbi) {
64     MachineBasicBlock &MBB = *bbi;
65     // Look at each instruction
66     for (MachineBasicBlock::iterator ii = MBB.begin(), ie = MBB.end();
67          ii != ie; ++ii) {
68       MachineInstr &MI = *ii;
69       processInstruction(MI);
70     }
71   }
72   return false;
73 }
74
75 void PTXFPRoundingModePass::processInstruction(MachineInstr &MI) {
76   // If the instruction has a rounding mode set to RndDefault, then assign an
77   // appropriate rounding mode based on the target device.
78   const PTXSubtarget& ST = TargetMachine.getSubtarget<PTXSubtarget>();
79   switch (MI.getOpcode()) {
80   case PTX::FADDrr32:
81   case PTX::FADDri32:
82   case PTX::FADDrr64:
83   case PTX::FADDri64:
84   case PTX::FSUBrr32:
85   case PTX::FSUBri32:
86   case PTX::FSUBrr64:
87   case PTX::FSUBri64:
88   case PTX::FMULrr32:
89   case PTX::FMULri32:
90   case PTX::FMULrr64:
91   case PTX::FMULri64:
92     if (MI.getOperand(1).getImm() == PTXRoundingMode::RndDefault) {
93       MI.getOperand(1).setImm(PTXRoundingMode::RndNearestEven);
94     }
95     break;
96   case PTX::FNEGrr32:
97   case PTX::FNEGri32:
98   case PTX::FNEGrr64:
99   case PTX::FNEGri64:
100     if (MI.getOperand(1).getImm() == PTXRoundingMode::RndDefault) {
101       MI.getOperand(1).setImm(PTXRoundingMode::RndNone);
102     }
103     break;
104   case PTX::FDIVrr32:
105   case PTX::FDIVri32:
106   case PTX::FDIVrr64:
107   case PTX::FDIVri64:
108     if (MI.getOperand(1).getImm() == PTXRoundingMode::RndDefault) {
109       if (ST.fdivNeedsRoundingMode())
110         MI.getOperand(1).setImm(PTXRoundingMode::RndNearestEven);
111       else
112         MI.getOperand(1).setImm(PTXRoundingMode::RndNone);
113     }
114     break;
115   case PTX::FMADrrr32:
116   case PTX::FMADrri32:
117   case PTX::FMADrii32:
118   case PTX::FMADrrr64:
119   case PTX::FMADrri64:
120   case PTX::FMADrii64:
121     if (MI.getOperand(1).getImm() == PTXRoundingMode::RndDefault) {
122       if (ST.fmadNeedsRoundingMode())
123         MI.getOperand(1).setImm(PTXRoundingMode::RndNearestEven);
124       else
125         MI.getOperand(1).setImm(PTXRoundingMode::RndNone);
126     }
127     break;
128   case PTX::FSQRTrr32:
129   case PTX::FSQRTri32:
130   case PTX::FSQRTrr64:
131   case PTX::FSQRTri64:
132     if (MI.getOperand(1).getImm() == PTXRoundingMode::RndDefault) {
133       MI.getOperand(1).setImm(PTXRoundingMode::RndNearestEven);
134     }
135     break;
136   case PTX::FSINrr32:
137   case PTX::FSINri32:
138   case PTX::FSINrr64:
139   case PTX::FSINri64:
140   case PTX::FCOSrr32:
141   case PTX::FCOSri32:
142   case PTX::FCOSrr64:
143   case PTX::FCOSri64:
144     if (MI.getOperand(1).getImm() == PTXRoundingMode::RndDefault) {
145       MI.getOperand(1).setImm(PTXRoundingMode::RndApprox);
146     }
147     break;
148   }
149 }
150
151 FunctionPass *llvm::createPTXFPRoundingModePass(PTXTargetMachine &TM,
152                                                 CodeGenOpt::Level OptLevel) {
153   return new PTXFPRoundingModePass(TM, OptLevel);
154 }
155