Fix a regression in InstCombine/xor.ll
[oota-llvm.git] / lib / CodeGen / VirtRegMap.cpp
index bb29ffd9548a26b1541ff5f687c794c638ce56da..572bace5944619151f0cf72a6261cb6c10d22a3d 100644 (file)
 #include "VirtRegMap.h"
 #include "llvm/Function.h"
 #include "llvm/CodeGen/MachineFrameInfo.h"
+#include "llvm/CodeGen/MachineInstr.h"
 #include "llvm/Target/TargetMachine.h"
 #include "llvm/Target/TargetInstrInfo.h"
-#include "Support/Statistic.h"
+#include "Support/CommandLine.h"
 #include "Support/Debug.h"
+#include "Support/DenseMap.h"
+#include "Support/Statistic.h"
 #include "Support/STLExtras.h"
-#include <iostream>
 
 using namespace llvm;
 
@@ -33,76 +35,182 @@ namespace {
     Statistic<> numStores("spiller", "Number of stores added");
     Statistic<> numLoads ("spiller", "Number of loads added");
 
+    enum SpillerName { simple, local };
+
+    cl::opt<SpillerName>
+    SpillerOpt("spiller",
+               cl::desc("Spiller to use: (default: local)"),
+               cl::Prefix,
+               cl::values(clEnumVal(simple, "  simple spiller"),
+                          clEnumVal(local,  "  local spiller"),
+                          clEnumValEnd),
+               cl::init(local));
 }
 
 int VirtRegMap::assignVirt2StackSlot(unsigned virtReg)
 {
     assert(MRegisterInfo::isVirtualRegister(virtReg));
-    assert(v2ssMap_[toIndex(virtReg)] == NO_STACK_SLOT &&
+    assert(v2ssMap_[virtReg] == NO_STACK_SLOT &&
            "attempt to assign stack slot to already spilled register");
     const TargetRegisterClass* rc =
         mf_->getSSARegMap()->getRegClass(virtReg);
     int frameIndex = mf_->getFrameInfo()->CreateStackObject(rc);
-    v2ssMap_[toIndex(virtReg)] = frameIndex;
+    v2ssMap_[virtReg] = frameIndex;
     ++numSpills;
     return frameIndex;
 }
 
+void VirtRegMap::assignVirt2StackSlot(unsigned virtReg, int frameIndex)
+{
+    assert(MRegisterInfo::isVirtualRegister(virtReg));
+    assert(v2ssMap_[virtReg] == NO_STACK_SLOT &&
+           "attempt to assign stack slot to already spilled register");
+     v2ssMap_[virtReg] = frameIndex;
+}
+
+void VirtRegMap::virtFolded(unsigned virtReg,
+                            MachineInstr* oldMI,
+                            MachineInstr* newMI)
+{
+    // move previous memory references folded to new instruction
+    MI2VirtMap::iterator i, e;
+    std::vector<MI2VirtMap::mapped_type> regs;
+    for (tie(i, e) = mi2vMap_.equal_range(oldMI); i != e; ) {
+        regs.push_back(i->second);
+        mi2vMap_.erase(i++);
+    }
+    for (unsigned i = 0, e = regs.size(); i != e; ++i)
+        mi2vMap_.insert(std::make_pair(newMI, i));
+
+    // add new memory reference
+    mi2vMap_.insert(std::make_pair(newMI, virtReg));
+}
+
 std::ostream& llvm::operator<<(std::ostream& os, const VirtRegMap& vrm)
 {
     const MRegisterInfo* mri = vrm.mf_->getTarget().getRegisterInfo();
 
     std::cerr << "********** REGISTER MAP **********\n";
-    for (unsigned i = 0, e = vrm.v2pMap_.size(); i != e; ++i) {
+    for (unsigned i = MRegisterInfo::FirstVirtualRegister,
+             e = vrm.mf_->getSSARegMap()->getLastVirtReg(); i <= e; ++i) {
         if (vrm.v2pMap_[i] != VirtRegMap::NO_PHYS_REG)
-            std::cerr << "[reg" << VirtRegMap::fromIndex(i) << " -> "
+            std::cerr << "[reg" << i << " -> "
                       << mri->getName(vrm.v2pMap_[i]) << "]\n";
     }
-    for (unsigned i = 0, e = vrm.v2ssMap_.size(); i != e; ++i) {
+    for (unsigned i = MRegisterInfo::FirstVirtualRegister,
+             e = vrm.mf_->getSSARegMap()->getLastVirtReg(); i <= e; ++i) {
         if (vrm.v2ssMap_[i] != VirtRegMap::NO_STACK_SLOT)
-            std::cerr << "[reg" << VirtRegMap::fromIndex(i) << " -> fi#"
+            std::cerr << "[reg" << i << " -> fi#"
                       << vrm.v2ssMap_[i] << "]\n";
     }
     return std::cerr << '\n';
 }
 
+Spiller::~Spiller()
+{
+
+}
+
 namespace {
 
-    class Spiller {
+    class SimpleSpiller : public Spiller {
+    public:
+        bool runOnMachineFunction(MachineFunction& mf, const VirtRegMap& vrm) {
+            DEBUG(std::cerr << "********** REWRITE MACHINE CODE **********\n");
+            DEBUG(std::cerr << "********** Function: "
+              << mf.getFunction()->getName() << '\n');
+            const TargetMachine& tm = mf.getTarget();
+            const MRegisterInfo& mri = *tm.getRegisterInfo();
+
+            typedef DenseMap<bool, VirtReg2IndexFunctor> Loaded;
+            Loaded loaded;
+
+            for (MachineFunction::iterator mbbi = mf.begin(),
+                     mbbe = mf.end(); mbbi != mbbe; ++mbbi) {
+                DEBUG(std::cerr << mbbi->getBasicBlock()->getName() << ":\n");
+                for (MachineBasicBlock::iterator mii = mbbi->begin(),
+                         mie = mbbi->end(); mii != mie; ++mii) {
+                    loaded.grow(mf.getSSARegMap()->getLastVirtReg());
+                    for (unsigned i = 0,e = mii->getNumOperands(); i != e; ++i){
+                        MachineOperand& mop = mii->getOperand(i);
+                        if (mop.isRegister() && mop.getReg() &&
+                            MRegisterInfo::isVirtualRegister(mop.getReg())) {
+                            unsigned virtReg = mop.getReg();
+                            unsigned physReg = vrm.getPhys(virtReg);
+                            if (mop.isUse() &&
+                                vrm.hasStackSlot(mop.getReg()) &&
+                                !loaded[virtReg]) {
+                                mri.loadRegFromStackSlot(
+                                    *mbbi,
+                                    mii,
+                                    physReg,
+                                    vrm.getStackSlot(virtReg),
+                                    mf.getSSARegMap()->getRegClass(virtReg));
+                                loaded[virtReg] = true;
+                                DEBUG(std::cerr << '\t';
+                                      prior(mii)->print(std::cerr, &tm));
+                                ++numLoads;
+                            }
+                            if (mop.isDef() &&
+                                vrm.hasStackSlot(mop.getReg())) {
+                                mri.storeRegToStackSlot(
+                                    *mbbi,
+                                    next(mii),
+                                    physReg,
+                                    vrm.getStackSlot(virtReg),
+                                    mf.getSSARegMap()->getRegClass(virtReg));
+                                ++numStores;
+                            }
+                            mii->SetMachineOperandReg(i, physReg);
+                        }
+                    }
+                    DEBUG(std::cerr << '\t'; mii->print(std::cerr, &tm));
+                    loaded.clear();
+                }
+            }
+            return true;
+        }
+    };
+
+    class LocalSpiller : public Spiller {
         typedef std::vector<unsigned> Phys2VirtMap;
         typedef std::vector<bool> PhysFlag;
+        typedef DenseMap<MachineInstr*, VirtReg2IndexFunctor> Virt2MI;
 
-        MachineFunction& mf_;
-        const TargetMachine& tm_;
-        const TargetInstrInfo& tii_;
-        const MRegisterInfo& mri_;
-        const VirtRegMap& vrm_;
+        MachineFunction* mf_;
+        const TargetMachine* tm_;
+        const TargetInstrInfo* tii_;
+        const MRegisterInfo* mri_;
+        const VirtRegMap* vrm_;
         Phys2VirtMap p2vMap_;
         PhysFlag dirty_;
+        Virt2MI lastDef_;
 
     public:
-        Spiller(MachineFunction& mf, const VirtRegMap& vrm)
-            : mf_(mf),
-              tm_(mf_.getTarget()),
-              tii_(tm_.getInstrInfo()),
-              mri_(*tm_.getRegisterInfo()),
-              vrm_(vrm),
-              p2vMap_(mri_.getNumRegs()),
-              dirty_(mri_.getNumRegs()) {
+        bool runOnMachineFunction(MachineFunction& mf, const VirtRegMap& vrm) {
+            mf_ = &mf;
+            tm_ = &mf_->getTarget();
+            tii_ = tm_->getInstrInfo();
+            mri_ = tm_->getRegisterInfo();
+            vrm_ = &vrm;
+            p2vMap_.assign(mri_->getNumRegs(), 0);
+            dirty_.assign(mri_->getNumRegs(), false);
+
             DEBUG(std::cerr << "********** REWRITE MACHINE CODE **********\n");
             DEBUG(std::cerr << "********** Function: "
-                  << mf_.getFunction()->getName() << '\n');
-        }
+                  << mf_->getFunction()->getName() << '\n');
 
-        void eliminateVirtRegs() {
-            for (MachineFunction::iterator mbbi = mf_.begin(),
-                     mbbe = mf_.end(); mbbi != mbbe; ++mbbi) {
-                // clear map and dirty flag
-                p2vMap_.assign(p2vMap_.size(), 0);
-                dirty_.assign(dirty_.size(), false);
+            for (MachineFunction::iterator mbbi = mf_->begin(),
+                     mbbe = mf_->end(); mbbi != mbbe; ++mbbi) {
+                lastDef_.grow(mf_->getSSARegMap()->getLastVirtReg());
                 DEBUG(std::cerr << mbbi->getBasicBlock()->getName() << ":\n");
                 eliminateVirtRegsInMbb(*mbbi);
+                // clear map, dirty flag and last ref
+                p2vMap_.assign(p2vMap_.size(), 0);
+                dirty_.assign(dirty_.size(), false);
+                lastDef_.clear();
             }
+            return true;
         }
 
     private:
@@ -110,12 +218,22 @@ namespace {
                                MachineBasicBlock::iterator mii,
                                unsigned physReg) {
             unsigned virtReg = p2vMap_[physReg];
-            if (dirty_[physReg] && vrm_.hasStackSlot(virtReg)) {
-                mri_.storeRegToStackSlot(mbb, mii, physReg,
-                                         vrm_.getStackSlot(virtReg),
-                                         mri_.getRegClass(physReg));
+            if (dirty_[physReg] && vrm_->hasStackSlot(virtReg)) {
+                assert(lastDef_[virtReg] && "virtual register is mapped "
+                       "to a register and but was not defined!");
+                MachineBasicBlock::iterator lastDef = lastDef_[virtReg];
+                MachineBasicBlock::iterator nextLastRef = next(lastDef);
+                mri_->storeRegToStackSlot(*lastDef->getParent(),
+                                          nextLastRef,
+                                          physReg,
+                                          vrm_->getStackSlot(virtReg),
+                                          mri_->getRegClass(physReg));
                 ++numStores;
-                DEBUG(std::cerr << "*\t"; prior(mii)->print(std::cerr, tm_));
+                DEBUG(std::cerr << "added: ";
+                      prior(nextLastRef)->print(std::cerr, tm_);
+                      std::cerr << "after: ";
+                      lastDef->print(std::cerr, tm_));
+                lastDef_[virtReg] = 0;
             }
             p2vMap_[physReg] = 0;
             dirty_[physReg] = false;
@@ -125,7 +243,7 @@ namespace {
                            MachineBasicBlock::iterator mii,
                            unsigned physReg) {
             vacateJustPhysReg(mbb, mii, physReg);
-            for (const unsigned* as = mri_.getAliasSet(physReg); *as; ++as)
+            for (const unsigned* as = mri_->getAliasSet(physReg); *as; ++as)
                 vacateJustPhysReg(mbb, mii, *as);
         }
 
@@ -138,12 +256,14 @@ namespace {
                 vacatePhysReg(mbb, mii, physReg);
                 p2vMap_[physReg] = virtReg;
                 // load if necessary
-                if (vrm_.hasStackSlot(virtReg)) {
-                    mri_.loadRegFromStackSlot(mbb, mii, physReg,
-                                              vrm_.getStackSlot(virtReg),
-                                              mri_.getRegClass(physReg));
+                if (vrm_->hasStackSlot(virtReg)) {
+                    mri_->loadRegFromStackSlot(mbb, mii, physReg,
+                                               vrm_->getStackSlot(virtReg),
+                                               mri_->getRegClass(physReg));
                     ++numLoads;
-                    DEBUG(std::cerr << "*\t"; prior(mii)->print(std::cerr,tm_));
+                    DEBUG(std::cerr << "added: ";
+                          prior(mii)->print(std::cerr, tm_));
+                    lastDef_[virtReg] = mii;
                 }
             }
         }
@@ -158,38 +278,61 @@ namespace {
 
             p2vMap_[physReg] = virtReg;
             dirty_[physReg] = true;
+            lastDef_[virtReg] = mii;
         }
 
         void eliminateVirtRegsInMbb(MachineBasicBlock& mbb) {
             for (MachineBasicBlock::iterator mii = mbb.begin(),
                      mie = mbb.end(); mii != mie; ++mii) {
+
+                // if we have references to memory operands make sure
+                // we clear all physical registers that may contain
+                // the value of the spilled virtual register
+                VirtRegMap::MI2VirtMap::const_iterator i, e;
+                for (tie(i, e) = vrm_->getFoldedVirts(mii); i != e; ++i) {
+                    if (vrm_->hasPhys(i->second))
+                        vacateJustPhysReg(mbb, mii, vrm_->getPhys(i->second));
+                }
+
                 // rewrite all used operands
                 for (unsigned i = 0, e = mii->getNumOperands(); i != e; ++i) {
                     MachineOperand& op = mii->getOperand(i);
-                    if (op.isRegister() && op.isUse() &&
+                    if (op.isRegister() && op.getReg() && op.isUse() &&
                         MRegisterInfo::isVirtualRegister(op.getReg())) {
-                        unsigned physReg = vrm_.getPhys(op.getReg());
-                        handleUse(mbb, mii, op.getReg(), physReg);
+                        unsigned virtReg = op.getReg();
+                        unsigned physReg = vrm_->getPhys(virtReg);
+                        handleUse(mbb, mii, virtReg, physReg);
                         mii->SetMachineOperandReg(i, physReg);
                         // mark as dirty if this is def&use
-                        if (op.isDef()) dirty_[physReg] = true;
+                        if (op.isDef()) {
+                            dirty_[physReg] = true;
+                            lastDef_[virtReg] = mii;
+                        }
                     }
                 }
 
-                // spill implicit defs
-                const TargetInstrDescriptor& tid =tii_.get(mii->getOpcode());
+                // spill implicit physical register defs
+                const TargetInstrDescriptor& tid = tii_->get(mii->getOpcode());
                 for (const unsigned* id = tid.ImplicitDefs; *id; ++id)
                     vacatePhysReg(mbb, mii, *id);
 
+                // spill explicit physical register defs
+                for (unsigned i = 0, e = mii->getNumOperands(); i != e; ++i) {
+                    MachineOperand& op = mii->getOperand(i);
+                    if (op.isRegister() && op.getReg() && !op.isUse() &&
+                        MRegisterInfo::isPhysicalRegister(op.getReg()))
+                        vacatePhysReg(mbb, mii, op.getReg());
+                }
+
                 // rewrite def operands (def&use was handled with the
                 // uses so don't check for those here)
                 for (unsigned i = 0, e = mii->getNumOperands(); i != e; ++i) {
                     MachineOperand& op = mii->getOperand(i);
-                    if (op.isRegister() && !op.isUse())
+                    if (op.isRegister() && op.getReg() && !op.isUse())
                         if (MRegisterInfo::isPhysicalRegister(op.getReg()))
                             vacatePhysReg(mbb, mii, op.getReg());
                         else {
-                            unsigned physReg = vrm_.getPhys(op.getReg());
+                            unsigned physReg = vrm_->getPhys(op.getReg());
                             handleDef(mbb, mii, op.getReg(), physReg);
                             mii->SetMachineOperandReg(i, physReg);
                         }
@@ -205,8 +348,15 @@ namespace {
     };
 }
 
-
-void llvm::eliminateVirtRegs(MachineFunction& mf, const VirtRegMap& vrm)
+llvm::Spiller* llvm::createSpiller()
 {
-    Spiller(mf, vrm).eliminateVirtRegs();
+    switch (SpillerOpt) {
+    default:
+        std::cerr << "no spiller selected";
+        abort();
+    case local:
+        return new LocalSpiller();
+    case simple:
+        return new SimpleSpiller();
+    }
 }