From: Sanjay Patel Date: Sun, 21 Sep 2014 15:19:15 +0000 (+0000) Subject: Refactor reciprocal square root estimate into target-independent function; NFC. X-Git-Url: http://demsky.eecs.uci.edu/git/?a=commitdiff_plain;h=3e05b40fd083c4c5b1428ebd975616c0f0221388;p=oota-llvm.git Refactor reciprocal square root estimate into target-independent function; NFC. This is purely a plumbing patch. No functional changes intended. The ultimate goal is to allow targets other than PowerPC (certainly X86 and Aarch64) to turn this: z = y / sqrt(x) into: z = y * rsqrte(x) using whatever HW magic they can use. See http://llvm.org/bugs/show_bug.cgi?id=20900 . The first step is to add a target hook for RSQRTE, take the already target-independent code selfishly hoarded by PPC, and put it into DAGCombiner. Next steps: The code in DAGCombiner::BuildRSQRTE() should be refactored further; tests that exercise that logic need to be added. Logic in PPCTargetLowering::BuildRSQRTE() should be hoisted into DAGCombiner. X86 and AArch64 overrides for TargetLowering.BuildRSQRTE() should be added. Differential Revision: http://reviews.llvm.org/D5425 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@218219 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/include/llvm/Target/TargetLowering.h b/include/llvm/Target/TargetLowering.h index e6c4634079c..0a7222599aa 100644 --- a/include/llvm/Target/TargetLowering.h +++ b/include/llvm/Target/TargetLowering.h @@ -2602,6 +2602,10 @@ public: return SDValue(); } + virtual SDValue BuildRSQRTE(SDValue Op, DAGCombinerInfo &DCI) const { + return SDValue(); + } + //===--------------------------------------------------------------------===// // Legalization utility functions // diff --git a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index aa2f2d1f2b1..30ac63570ff 100644 --- a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -326,6 +326,7 @@ namespace { SDValue BuildSDIV(SDNode *N); SDValue BuildSDIVPow2(SDNode *N); SDValue BuildUDIV(SDNode *N); + SDValue BuildRSQRTE(SDNode *N); SDValue MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1, bool DemandHighBits = true); SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1); @@ -6987,23 +6988,29 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) { if (N0CFP && N1CFP) return DAG.getNode(ISD::FDIV, SDLoc(N), VT, N0, N1); - // fold (fdiv X, c2) -> fmul X, 1/c2 if losing precision is acceptable. - if (N1CFP && Options.UnsafeFPMath) { - // Compute the reciprocal 1.0 / c2. - APFloat N1APF = N1CFP->getValueAPF(); - APFloat Recip(N1APF.getSemantics(), 1); // 1.0 - APFloat::opStatus st = Recip.divide(N1APF, APFloat::rmNearestTiesToEven); - // Only do the transform if the reciprocal is a legal fp immediate that - // isn't too nasty (eg NaN, denormal, ...). - if ((st == APFloat::opOK || st == APFloat::opInexact) && // Not too nasty - (!LegalOperations || - // FIXME: custom lowering of ConstantFP might fail (see e.g. ARM - // backend)... we should handle this gracefully after Legalize. - // TLI.isOperationLegalOrCustom(llvm::ISD::ConstantFP, VT) || - TLI.isOperationLegal(llvm::ISD::ConstantFP, VT) || - TLI.isFPImmLegal(Recip, VT))) - return DAG.getNode(ISD::FMUL, SDLoc(N), VT, N0, - DAG.getConstantFP(Recip, VT)); + if (Options.UnsafeFPMath) { + // fold (fdiv X, c2) -> fmul X, 1/c2 if losing precision is acceptable. + if (N1CFP) { + // Compute the reciprocal 1.0 / c2. + APFloat N1APF = N1CFP->getValueAPF(); + APFloat Recip(N1APF.getSemantics(), 1); // 1.0 + APFloat::opStatus st = Recip.divide(N1APF, APFloat::rmNearestTiesToEven); + // Only do the transform if the reciprocal is a legal fp immediate that + // isn't too nasty (eg NaN, denormal, ...). + if ((st == APFloat::opOK || st == APFloat::opInexact) && // Not too nasty + (!LegalOperations || + // FIXME: custom lowering of ConstantFP might fail (see e.g. ARM + // backend)... we should handle this gracefully after Legalize. + // TLI.isOperationLegalOrCustom(llvm::ISD::ConstantFP, VT) || + TLI.isOperationLegal(llvm::ISD::ConstantFP, VT) || + TLI.isFPImmLegal(Recip, VT))) + return DAG.getNode(ISD::FMUL, SDLoc(N), VT, N0, + DAG.getConstantFP(Recip, VT)); + } + // If this FDIV is part of a reciprocal square root, it may be folded + // into a target-specific square root estimate instruction. + if (SDValue SqrtOp = BuildRSQRTE(N)) + return SqrtOp; } // (fdiv (fneg X), (fneg Y)) -> (fdiv X, Y) @@ -11695,6 +11702,44 @@ SDValue DAGCombiner::BuildUDIV(SDNode *N) { return S; } +/// Given an ISD::FDIV node with either a direct or indirect ISD::FSQRT operand, +/// generate a DAG expression using a reciprocal square root estimate op. +SDValue DAGCombiner::BuildRSQRTE(SDNode *N) { + // Expose the DAG combiner to the target combiner implementations. + TargetLowering::DAGCombinerInfo DCI(DAG, Level, false, this); + SDLoc DL(N); + EVT VT = N->getValueType(0); + SDValue N1 = N->getOperand(1); + + if (N1.getOpcode() == ISD::FSQRT) { + SDValue RV = TLI.BuildRSQRTE(N1.getOperand(0), DCI); + if (RV.getNode()) { + DCI.AddToWorklist(RV.getNode()); + return DAG.getNode(ISD::FMUL, DL, VT, N->getOperand(0), RV); + } + } else if (N1.getOpcode() == ISD::FP_EXTEND && + N1.getOperand(0).getOpcode() == ISD::FSQRT) { + SDValue RV = TLI.BuildRSQRTE(N1.getOperand(0).getOperand(0), DCI); + if (RV.getNode()) { + DCI.AddToWorklist(RV.getNode()); + RV = DAG.getNode(ISD::FP_EXTEND, SDLoc(N1), VT, RV); + DCI.AddToWorklist(RV.getNode()); + return DAG.getNode(ISD::FMUL, DL, VT, N->getOperand(0), RV); + } + } else if (N1.getOpcode() == ISD::FP_ROUND && + N1.getOperand(0).getOpcode() == ISD::FSQRT) { + SDValue RV = TLI.BuildRSQRTE(N1.getOperand(0).getOperand(0), DCI); + if (RV.getNode()) { + DCI.AddToWorklist(RV.getNode()); + RV = DAG.getNode(ISD::FP_ROUND, SDLoc(N1), VT, RV, N1.getOperand(1)); + DCI.AddToWorklist(RV.getNode()); + return DAG.getNode(ISD::FMUL, DL, VT, N->getOperand(0), RV); + } + } + + return SDValue(); +} + /// Return true if base is a frame index, which is known not to alias with /// anything but itself. Provides base object and offset as results. static bool FindBaseOffset(SDValue Ptr, SDValue &Base, int64_t &Offset, diff --git a/lib/Target/PowerPC/PPCISelLowering.cpp b/lib/Target/PowerPC/PPCISelLowering.cpp index fd188fe37e2..d96cdab604f 100644 --- a/lib/Target/PowerPC/PPCISelLowering.cpp +++ b/lib/Target/PowerPC/PPCISelLowering.cpp @@ -7489,8 +7489,7 @@ SDValue PPCTargetLowering::DAGCombineFastRecip(SDValue Op, return SDValue(); } -SDValue PPCTargetLowering::DAGCombineFastRecipFSQRT(SDValue Op, - DAGCombinerInfo &DCI) const { +SDValue PPCTargetLowering::BuildRSQRTE(SDValue Op, DAGCombinerInfo &DCI) const { if (DCI.isAfterLegalizeVectorOps()) return SDValue(); @@ -8289,43 +8288,6 @@ SDValue PPCTargetLowering::PerformDAGCombine(SDNode *N, assert(TM.Options.UnsafeFPMath && "Reciprocal estimates require UnsafeFPMath"); - if (N->getOperand(1).getOpcode() == ISD::FSQRT) { - SDValue RV = - DAGCombineFastRecipFSQRT(N->getOperand(1).getOperand(0), DCI); - if (RV.getNode()) { - DCI.AddToWorklist(RV.getNode()); - return DAG.getNode(ISD::FMUL, dl, N->getValueType(0), - N->getOperand(0), RV); - } - } else if (N->getOperand(1).getOpcode() == ISD::FP_EXTEND && - N->getOperand(1).getOperand(0).getOpcode() == ISD::FSQRT) { - SDValue RV = - DAGCombineFastRecipFSQRT(N->getOperand(1).getOperand(0).getOperand(0), - DCI); - if (RV.getNode()) { - DCI.AddToWorklist(RV.getNode()); - RV = DAG.getNode(ISD::FP_EXTEND, SDLoc(N->getOperand(1)), - N->getValueType(0), RV); - DCI.AddToWorklist(RV.getNode()); - return DAG.getNode(ISD::FMUL, dl, N->getValueType(0), - N->getOperand(0), RV); - } - } else if (N->getOperand(1).getOpcode() == ISD::FP_ROUND && - N->getOperand(1).getOperand(0).getOpcode() == ISD::FSQRT) { - SDValue RV = - DAGCombineFastRecipFSQRT(N->getOperand(1).getOperand(0).getOperand(0), - DCI); - if (RV.getNode()) { - DCI.AddToWorklist(RV.getNode()); - RV = DAG.getNode(ISD::FP_ROUND, SDLoc(N->getOperand(1)), - N->getValueType(0), RV, - N->getOperand(1).getOperand(1)); - DCI.AddToWorklist(RV.getNode()); - return DAG.getNode(ISD::FMUL, dl, N->getValueType(0), - N->getOperand(0), RV); - } - } - SDValue RV = DAGCombineFastRecip(N->getOperand(1), DCI); if (RV.getNode()) { DCI.AddToWorklist(RV.getNode()); @@ -8341,7 +8303,7 @@ SDValue PPCTargetLowering::PerformDAGCombine(SDNode *N, // Compute this as 1/(1/sqrt(X)), which is the reciprocal of the // reciprocal sqrt. - SDValue RV = DAGCombineFastRecipFSQRT(N->getOperand(0), DCI); + SDValue RV = BuildRSQRTE(N->getOperand(0), DCI); if (RV.getNode()) { DCI.AddToWorklist(RV.getNode()); RV = DAGCombineFastRecip(RV, DCI); diff --git a/lib/Target/PowerPC/PPCISelLowering.h b/lib/Target/PowerPC/PPCISelLowering.h index c53dc83fa8a..5628bc79342 100644 --- a/lib/Target/PowerPC/PPCISelLowering.h +++ b/lib/Target/PowerPC/PPCISelLowering.h @@ -696,7 +696,7 @@ namespace llvm { SDValue DAGCombineExtBoolTrunc(SDNode *N, DAGCombinerInfo &DCI) const; SDValue DAGCombineTruncBoolExt(SDNode *N, DAGCombinerInfo &DCI) const; SDValue DAGCombineFastRecip(SDValue Op, DAGCombinerInfo &DCI) const; - SDValue DAGCombineFastRecipFSQRT(SDValue Op, DAGCombinerInfo &DCI) const; + SDValue BuildRSQRTE(SDValue Op, DAGCombinerInfo &DCI) const; CCAssignFn *useFastISelCCs(unsigned Flag) const; };