Add generic fmad DAG node.
authorMatt Arsenault <Matthew.Arsenault@amd.com>
Fri, 20 Feb 2015 22:10:33 +0000 (22:10 +0000)
committerMatt Arsenault <Matthew.Arsenault@amd.com>
Fri, 20 Feb 2015 22:10:33 +0000 (22:10 +0000)
This allows sharing of FMA forming combines to work
with instructions that have the same semantics as a separate
multiply and add.

This is expand by default, and only formed post legalization
so it shouldn't have much impact on targets that do not want it.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@230070 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/CodeGen/ISDOpcodes.h
include/llvm/Target/TargetSelectionDAG.td
lib/CodeGen/SelectionDAG/DAGCombiner.cpp
lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
lib/CodeGen/TargetLoweringBase.cpp

index 2d5d6c3393e4342fe3a9f5664d0a3bbececcb395..2d1c8cd6fdd9204551f8b3b135f2ec1fb8b78ca0 100644 (file)
@@ -229,7 +229,14 @@ namespace ISD {
     SMULO, UMULO,
 
     /// Simple binary floating point operators.
-    FADD, FSUB, FMUL, FMA, FDIV, FREM,
+    FADD, FSUB, FMUL, FDIV, FREM,
+
+    /// FMA - Perform a * b + c with no intermediate rounding step.
+    FMA,
+
+    /// FMAD - Perform a * b + c, while getting the same result as the
+    /// separately rounded operations.
+    FMAD,
 
     /// FCOPYSIGN(X, Y) - Return the value of X with the sign of Y.  NOTE: This
     /// DAG node does not require that X and Y have the same type, just that the
index 744a7b0c97a9788dc7d294fe0131995c5900a39f..2ecd900d34bb7940fb184c1119210fb5b8260c8d 100644 (file)
@@ -381,6 +381,7 @@ def fmul       : SDNode<"ISD::FMUL"       , SDTFPBinOp, [SDNPCommutative]>;
 def fdiv       : SDNode<"ISD::FDIV"       , SDTFPBinOp>;
 def frem       : SDNode<"ISD::FREM"       , SDTFPBinOp>;
 def fma        : SDNode<"ISD::FMA"        , SDTFPTernaryOp>;
+def fmad       : SDNode<"ISD::FMAD"       , SDTFPTernaryOp>;
 def fabs       : SDNode<"ISD::FABS"       , SDTFPUnaryOp>;
 def fminnum    : SDNode<"ISD::FMINNUM"    , SDTFPBinOp>;
 def fmaxnum    : SDNode<"ISD::FMAXNUM"    , SDTFPBinOp>;
index e8d1acf60ab1cbfef1e0ba1f49d475a95dbbdec8..53867b57bc394f624d37a59a0fb57e6764963e7d 100644 (file)
@@ -6938,6 +6938,133 @@ ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) {
   return DAG.getNode(ISD::BUILD_VECTOR, SDLoc(BV), VT, Ops);
 }
 
+// Attempt different variants of (fadd (fmul a, b), c) -> fma or fmad
+static SDValue performFaddFmulCombines(unsigned FusedOpcode,
+                                       bool Aggressive,
+                                       SDNode *N,
+                                       const TargetLowering &TLI,
+                                       SelectionDAG &DAG) {
+  SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
+  EVT VT = N->getValueType(0);
+
+  // fold (fadd (fmul x, y), z) -> (fma x, y, z)
+  if (N0.getOpcode() == ISD::FMUL &&
+      (Aggressive || N0->hasOneUse())) {
+    return DAG.getNode(FusedOpcode, SDLoc(N), VT,
+                       N0.getOperand(0), N0.getOperand(1), N1);
+  }
+
+  // fold (fadd x, (fmul y, z)) -> (fma y, z, x)
+  // Note: Commutes FADD operands.
+  if (N1.getOpcode() == ISD::FMUL &&
+      (Aggressive || N1->hasOneUse())) {
+    return DAG.getNode(FusedOpcode, SDLoc(N), VT,
+                       N1.getOperand(0), N1.getOperand(1), N0);
+  }
+
+  // More folding opportunities when target permits.
+  if (Aggressive) {
+    // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y (fma u, v, z))
+    if (N0.getOpcode() == ISD::FMA &&
+        N0.getOperand(2).getOpcode() == ISD::FMUL) {
+      return DAG.getNode(FusedOpcode, SDLoc(N), VT,
+                         N0.getOperand(0), N0.getOperand(1),
+                         DAG.getNode(FusedOpcode, SDLoc(N), VT,
+                                     N0.getOperand(2).getOperand(0),
+                                     N0.getOperand(2).getOperand(1),
+                                     N1));
+    }
+
+    // fold (fadd x, (fma y, z, (fmul u, v)) -> (fma y, z (fma u, v, x))
+    if (N1->getOpcode() == ISD::FMA &&
+        N1.getOperand(2).getOpcode() == ISD::FMUL) {
+      return DAG.getNode(FusedOpcode, SDLoc(N), VT,
+                         N1.getOperand(0), N1.getOperand(1),
+                         DAG.getNode(FusedOpcode, SDLoc(N), VT,
+                                     N1.getOperand(2).getOperand(0),
+                                     N1.getOperand(2).getOperand(1),
+                                     N0));
+    }
+  }
+
+  return SDValue();
+}
+
+static SDValue performFsubFmulCombines(unsigned FusedOpcode,
+                                       bool Aggressive,
+                                       SDNode *N,
+                                       const TargetLowering &TLI,
+                                       SelectionDAG &DAG) {
+  SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
+  EVT VT = N->getValueType(0);
+
+  SDLoc SL(N);
+
+  // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
+  if (N0.getOpcode() == ISD::FMUL &&
+      (Aggressive || N0->hasOneUse())) {
+    return DAG.getNode(FusedOpcode, SL, VT,
+                       N0.getOperand(0), N0.getOperand(1),
+                       DAG.getNode(ISD::FNEG, SL, VT, N1));
+  }
+
+  // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
+  // Note: Commutes FSUB operands.
+  if (N1.getOpcode() == ISD::FMUL &&
+      (Aggressive || N1->hasOneUse()))
+    return DAG.getNode(FusedOpcode, SL, VT,
+                       DAG.getNode(ISD::FNEG, SL, VT,
+                                   N1.getOperand(0)),
+                       N1.getOperand(1), N0);
+
+  // fold (fsub (fneg (fmul, x, y)), z) -> (fma (fneg x), y, (fneg z))
+  if (N0.getOpcode() == ISD::FNEG &&
+      N0.getOperand(0).getOpcode() == ISD::FMUL &&
+      (Aggressive || (N0->hasOneUse() && N0.getOperand(0).hasOneUse()))) {
+    SDValue N00 = N0.getOperand(0).getOperand(0);
+    SDValue N01 = N0.getOperand(0).getOperand(1);
+    return DAG.getNode(FusedOpcode, SL, VT,
+                       DAG.getNode(ISD::FNEG, SL, VT, N00), N01,
+                       DAG.getNode(ISD::FNEG, SL, VT, N1));
+  }
+
+  // More folding opportunities when target permits.
+  if (Aggressive) {
+    // fold (fsub (fma x, y, (fmul u, v)), z)
+    //   -> (fma x, y (fma u, v, (fneg z)))
+    if (N0.getOpcode() == FusedOpcode &&
+        N0.getOperand(2).getOpcode() == ISD::FMUL) {
+      return DAG.getNode(FusedOpcode, SDLoc(N), VT,
+                         N0.getOperand(0), N0.getOperand(1),
+                         DAG.getNode(FusedOpcode, SDLoc(N), VT,
+                                     N0.getOperand(2).getOperand(0),
+                                     N0.getOperand(2).getOperand(1),
+                                     DAG.getNode(ISD::FNEG, SDLoc(N), VT,
+                                                 N1)));
+    }
+
+    // fold (fsub x, (fma y, z, (fmul u, v)))
+    //   -> (fma (fneg y), z, (fma (fneg u), v, x))
+    if (N1.getOpcode() == FusedOpcode &&
+        N1.getOperand(2).getOpcode() == ISD::FMUL) {
+      SDValue N20 = N1.getOperand(2).getOperand(0);
+      SDValue N21 = N1.getOperand(2).getOperand(1);
+      return DAG.getNode(FusedOpcode, SDLoc(N), VT,
+                         DAG.getNode(ISD::FNEG, SDLoc(N), VT,
+                                     N1.getOperand(0)),
+                         N1.getOperand(1),
+                         DAG.getNode(FusedOpcode, SDLoc(N), VT,
+                                     DAG.getNode(ISD::FNEG, SDLoc(N),  VT,
+                                                 N20),
+                                     N21, N0));
+    }
+  }
+
+  return SDValue();
+}
+
 SDValue DAGCombiner::visitFADD(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
@@ -7077,23 +7204,27 @@ SDValue DAGCombiner::visitFADD(SDNode *N) {
     }
   } // enable-unsafe-fp-math
 
+  if (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT)) {
+    // Assume if there is an fmad instruction that it should be aggressively
+    // used.
+    if (SDValue Fused = performFaddFmulCombines(ISD::FMAD, true, N, TLI, DAG))
+      return Fused;
+  }
+
   // FADD -> FMA combines:
   if ((Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath) &&
       TLI.isFMAFasterThanFMulAndFAdd(VT) &&
       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT))) {
 
-    // fold (fadd (fmul x, y), z) -> (fma x, y, z)
-    if (N0.getOpcode() == ISD::FMUL &&
-        (N0->hasOneUse() || TLI.enableAggressiveFMAFusion(VT)))
-      return DAG.getNode(ISD::FMA, SDLoc(N), VT,
-                         N0.getOperand(0), N0.getOperand(1), N1);
-
-    // fold (fadd x, (fmul y, z)) -> (fma y, z, x)
-    // Note: Commutes FADD operands.
-    if (N1.getOpcode() == ISD::FMUL &&
-        (N1->hasOneUse() || TLI.enableAggressiveFMAFusion(VT)))
-      return DAG.getNode(ISD::FMA, SDLoc(N), VT,
-                         N1.getOperand(0), N1.getOperand(1), N0);
+    if (!TLI.isOperationLegal(ISD::FMAD, VT)) {
+      // Don't form FMA if we are preferring FMAD.
+      if (SDValue Fused
+          = performFaddFmulCombines(ISD::FMA,
+                                    TLI.enableAggressiveFMAFusion(VT),
+                                    N, TLI, DAG)) {
+        return Fused;
+      }
+    }
 
     // When FP_EXTEND nodes are free on the target, and there is an opportunity
     // to combine into FMA, arrange such nodes accordingly.
@@ -7122,30 +7253,6 @@ SDValue DAGCombiner::visitFADD(SDNode *N) {
                                          N10.getOperand(1)), N0);
       }
     }
-
-    // More folding opportunities when target permits.
-    if (TLI.enableAggressiveFMAFusion(VT)) {
-
-      // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y (fma u, v, z))
-      if (N0.getOpcode() == ISD::FMA &&
-          N0.getOperand(2).getOpcode() == ISD::FMUL)
-        return DAG.getNode(ISD::FMA, SDLoc(N), VT,
-                           N0.getOperand(0), N0.getOperand(1),
-                           DAG.getNode(ISD::FMA, SDLoc(N), VT,
-                                       N0.getOperand(2).getOperand(0),
-                                       N0.getOperand(2).getOperand(1),
-                                       N1));
-
-      // fold (fadd x, (fma y, z, (fmul u, v)) -> (fma y, z (fma u, v, x))
-      if (N1->getOpcode() == ISD::FMA &&
-          N1.getOperand(2).getOpcode() == ISD::FMUL)
-        return DAG.getNode(ISD::FMA, SDLoc(N), VT,
-                           N1.getOperand(0), N1.getOperand(1),
-                           DAG.getNode(ISD::FMA, SDLoc(N), VT,
-                                       N1.getOperand(2).getOperand(0),
-                                       N1.getOperand(2).getOperand(1),
-                                       N0));
-    }
   }
 
   return SDValue();
@@ -7207,43 +7314,32 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) {
     }
   }
 
+  if (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT)) {
+    // Assume if there is an fmad instruction that it should be aggressively
+    // used.
+    if (SDValue Fused = performFsubFmulCombines(ISD::FMAD, true, N, TLI, DAG))
+      return Fused;
+  }
+
   // FSUB -> FMA combines:
   if ((Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath) &&
       TLI.isFMAFasterThanFMulAndFAdd(VT) &&
       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT))) {
 
-    // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
-    if (N0.getOpcode() == ISD::FMUL &&
-        (N0->hasOneUse() || TLI.enableAggressiveFMAFusion(VT)))
-      return DAG.getNode(ISD::FMA, dl, VT,
-                         N0.getOperand(0), N0.getOperand(1),
-                         DAG.getNode(ISD::FNEG, dl, VT, N1));
-
-    // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
-    // Note: Commutes FSUB operands.
-    if (N1.getOpcode() == ISD::FMUL &&
-        (N1->hasOneUse() || TLI.enableAggressiveFMAFusion(VT)))
-      return DAG.getNode(ISD::FMA, dl, VT,
-                         DAG.getNode(ISD::FNEG, dl, VT,
-                         N1.getOperand(0)),
-                         N1.getOperand(1), N0);
-
-    // fold (fsub (fneg (fmul, x, y)), z) -> (fma (fneg x), y, (fneg z))
-    if (N0.getOpcode() == ISD::FNEG &&
-        N0.getOperand(0).getOpcode() == ISD::FMUL &&
-        ((N0->hasOneUse() && N0.getOperand(0).hasOneUse()) ||
-            TLI.enableAggressiveFMAFusion(VT))) {
-      SDValue N00 = N0.getOperand(0).getOperand(0);
-      SDValue N01 = N0.getOperand(0).getOperand(1);
-      return DAG.getNode(ISD::FMA, dl, VT,
-                         DAG.getNode(ISD::FNEG, dl, VT, N00), N01,
-                         DAG.getNode(ISD::FNEG, dl, VT, N1));
+    if (!TLI.isOperationLegal(ISD::FMAD, VT)) {
+      // Don't form FMA if we are preferring FMAD.
+
+      if (SDValue Fused
+          = performFsubFmulCombines(ISD::FMA,
+                                    TLI.enableAggressiveFMAFusion(VT),
+                                    N, TLI, DAG)) {
+        return Fused;
+      }
     }
 
     // When FP_EXTEND nodes are free on the target, and there is an opportunity
     // to combine into FMA, arrange such nodes accordingly.
     if (TLI.isFPExtFree(VT)) {
-
       // fold (fsub (fpext (fmul x, y)), z)
       //   -> (fma (fpext x), (fpext y), (fneg z))
       if (N0.getOpcode() == ISD::FP_EXTEND) {
@@ -7308,38 +7404,6 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) {
         }
       }
     }
-
-    // More folding opportunities when target permits.
-    if (TLI.enableAggressiveFMAFusion(VT)) {
-
-      // fold (fsub (fma x, y, (fmul u, v)), z)
-      //   -> (fma x, y (fma u, v, (fneg z)))
-      if (N0.getOpcode() == ISD::FMA &&
-          N0.getOperand(2).getOpcode() == ISD::FMUL)
-        return DAG.getNode(ISD::FMA, SDLoc(N), VT,
-                           N0.getOperand(0), N0.getOperand(1),
-                           DAG.getNode(ISD::FMA, SDLoc(N), VT,
-                                       N0.getOperand(2).getOperand(0),
-                                       N0.getOperand(2).getOperand(1),
-                                       DAG.getNode(ISD::FNEG, SDLoc(N), VT,
-                                                   N1)));
-
-      // fold (fsub x, (fma y, z, (fmul u, v)))
-      //   -> (fma (fneg y), z, (fma (fneg u), v, x))
-      if (N1.getOpcode() == ISD::FMA &&
-          N1.getOperand(2).getOpcode() == ISD::FMUL) {
-        SDValue N20 = N1.getOperand(2).getOperand(0);
-        SDValue N21 = N1.getOperand(2).getOperand(1);
-        return DAG.getNode(ISD::FMA, SDLoc(N), VT,
-                           DAG.getNode(ISD::FNEG, SDLoc(N), VT,
-                                       N1.getOperand(0)),
-                           N1.getOperand(1),
-                           DAG.getNode(ISD::FMA, SDLoc(N), VT,
-                                       DAG.getNode(ISD::FNEG, SDLoc(N),  VT,
-                                                   N20),
-                                       N21, N0));
-      }
-    }
   }
 
   return SDValue();
index e5473e35caed44563236c207d83c1dc892b7d0fd..ed337eb96486c15832cdec3eecd227fee59945cd 100644 (file)
@@ -3519,6 +3519,9 @@ void SelectionDAGLegalize::ExpandNode(SDNode *Node) {
                                       RTLIB::FMA_F80, RTLIB::FMA_F128,
                                       RTLIB::FMA_PPCF128));
     break;
+  case ISD::FMAD:
+    llvm_unreachable("Illegal fmad should never be formed");
+
   case ISD::FADD:
     Results.push_back(ExpandFPLibCall(Node, RTLIB::ADD_F32, RTLIB::ADD_F64,
                                       RTLIB::ADD_F80, RTLIB::ADD_F128,
index e8577d898c2d08b7351a0651d0aeca2514da9f28..17eff944c6228ed3c60ea53d833fe8771b30efae 100644 (file)
@@ -187,6 +187,7 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
   case ISD::FMUL:                       return "fmul";
   case ISD::FDIV:                       return "fdiv";
   case ISD::FMA:                        return "fma";
+  case ISD::FMAD:                       return "fmad";
   case ISD::FREM:                       return "frem";
   case ISD::FCOPYSIGN:                  return "fcopysign";
   case ISD::FGETSIGN:                   return "fgetsign";
index 630c3313772ee0d8da0eec1d678a4e8ecc1b4ffd..459969b58b95fd975f3c789d765eacbab769e965 100644 (file)
@@ -765,6 +765,7 @@ void TargetLoweringBase::initActions() {
     setOperationAction(ISD::CONCAT_VECTORS, VT, Expand);
     setOperationAction(ISD::FMINNUM, VT, Expand);
     setOperationAction(ISD::FMAXNUM, VT, Expand);
+    setOperationAction(ISD::FMAD, VT, Expand);
 
     // These library functions default to expand.
     setOperationAction(ISD::FROUND, VT, Expand);