From 87070fe1073b1e95748d987af0810d02aac43603 Mon Sep 17 00:00:00 2001 From: Elena Demikhovsky Date: Wed, 26 Jun 2013 10:55:03 +0000 Subject: [PATCH] Optimized integer vector multiplication operation by replacing it with shift/xor/sub when it is possible. Fixed a bug in SDIV, where the const operand is not a splat constant vector. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@184931 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 79 ++++++++++++++++++------ lib/Target/X86/X86ISelLowering.cpp | 6 +- test/CodeGen/X86/avx-shift.ll | 3 +- test/CodeGen/X86/avx2-arith.ll | 73 ++++++++++++++++++++++ test/CodeGen/X86/vec_sdiv_to_shift.ll | 8 +++ test/CodeGen/X86/widen_arith-4.ll | 2 +- test/CodeGen/X86/widen_arith-5.ll | 2 +- 7 files changed, 148 insertions(+), 25 deletions(-) diff --git a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index cb9778bbd61..5af4aa01b8f 100644 --- a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -326,7 +326,10 @@ namespace { /// getShiftAmountTy - Returns a type large enough to hold any valid /// shift amount - before type legalization these can be huge. EVT getShiftAmountTy(EVT LHSTy) { - return LegalTypes ? TLI.getShiftAmountTy(LHSTy) : TLI.getPointerTy(); + assert(LHSTy.isInteger() && "Shift amount is not an integer type!"); + if (LHSTy.isVector()) + return LHSTy; + return LegalTypes ? TLI.getScalarShiftAmountTy(LHSTy) : TLI.getPointerTy(); } /// isTypeLegal - This method returns true if we are running before type @@ -1762,43 +1765,73 @@ SDValue DAGCombiner::visitSUBE(SDNode *N) { return SDValue(); } +/// isSplatVector - Returns true if N is a BUILD_VECTOR node whose elements are +/// all the same or undefined. +static bool isConstantSplatVector(SDNode *N, APInt& SplatValue) { + BuildVectorSDNode *C = dyn_cast(N); + if (!C) + return false; + + APInt SplatUndef; + unsigned SplatBitSize; + bool HasAnyUndefs; + EVT EltVT = N->getValueType(0).getVectorElementType(); + return (C->isConstantSplat(SplatValue, SplatUndef, SplatBitSize, + HasAnyUndefs) && + EltVT.getSizeInBits() >= SplatBitSize); +} + SDValue DAGCombiner::visitMUL(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); - ConstantSDNode *N0C = dyn_cast(N0); - ConstantSDNode *N1C = dyn_cast(N1); EVT VT = N0.getValueType(); + // fold (mul x, undef) -> 0 + if (N0.getOpcode() == ISD::UNDEF || N1.getOpcode() == ISD::UNDEF) + return DAG.getConstant(0, VT); + + bool N0IsConst = false; + bool N1IsConst = false; + APInt ConstValue0, ConstValue1; // fold vector ops if (VT.isVector()) { SDValue FoldedVOp = SimplifyVBinOp(N); if (FoldedVOp.getNode()) return FoldedVOp; + + N0IsConst = isConstantSplatVector(N0.getNode(), ConstValue0); + N1IsConst = isConstantSplatVector(N1.getNode(), ConstValue1); + } else { + N0IsConst = dyn_cast(N0) != 0; + ConstValue0 = N0IsConst? (dyn_cast(N0))->getAPIntValue() : APInt(); + N1IsConst = dyn_cast(N1) != 0; + ConstValue1 = N1IsConst? (dyn_cast(N1))->getAPIntValue() : APInt(); } - // fold (mul x, undef) -> 0 - if (N0.getOpcode() == ISD::UNDEF || N1.getOpcode() == ISD::UNDEF) - return DAG.getConstant(0, VT); // fold (mul c1, c2) -> c1*c2 - if (N0C && N1C) - return DAG.FoldConstantArithmetic(ISD::MUL, VT, N0C, N1C); + if (N0IsConst && N1IsConst) + return DAG.FoldConstantArithmetic(ISD::MUL, VT, N0.getNode(), N1.getNode()); + // canonicalize constant to RHS - if (N0C && !N1C) + if (N0IsConst && !N1IsConst) return DAG.getNode(ISD::MUL, SDLoc(N), VT, N1, N0); // fold (mul x, 0) -> 0 - if (N1C && N1C->isNullValue()) + if (N1IsConst && ConstValue1 == 0) return N1; + // fold (mul x, 1) -> x + if (N1IsConst && ConstValue1 == 1) + return N0; // fold (mul x, -1) -> 0-x - if (N1C && N1C->isAllOnesValue()) + if (N1IsConst && ConstValue1.isAllOnesValue()) return DAG.getNode(ISD::SUB, SDLoc(N), VT, DAG.getConstant(0, VT), N0); // fold (mul x, (1 << c)) -> x << c - if (N1C && N1C->getAPIntValue().isPowerOf2()) + if (N1IsConst && ConstValue1.isPowerOf2()) return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, - DAG.getConstant(N1C->getAPIntValue().logBase2(), + DAG.getConstant(ConstValue1.logBase2(), getShiftAmountTy(N0.getValueType()))); // fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c - if (N1C && (-N1C->getAPIntValue()).isPowerOf2()) { - unsigned Log2Val = (-N1C->getAPIntValue()).logBase2(); + if (N1IsConst && (-ConstValue1).isPowerOf2()) { + unsigned Log2Val = (-ConstValue1).logBase2(); // FIXME: If the input is something that is easily negated (e.g. a // single-use add), we should put the negate there. return DAG.getNode(ISD::SUB, SDLoc(N), VT, @@ -1807,9 +1840,12 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { DAG.getConstant(Log2Val, getShiftAmountTy(N0.getValueType())))); } + + APInt Val; // (mul (shl X, c1), c2) -> (mul X, c2 << c1) - if (N1C && N0.getOpcode() == ISD::SHL && - isa(N0.getOperand(1))) { + if (N1IsConst && N0.getOpcode() == ISD::SHL && + (isConstantSplatVector(N0.getOperand(1).getNode(), Val) || + isa(N0.getOperand(1)))) { SDValue C3 = DAG.getNode(ISD::SHL, SDLoc(N), VT, N1, N0.getOperand(1)); AddToWorkList(C3.getNode()); @@ -1822,7 +1858,9 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { { SDValue Sh(0,0), Y(0,0); // Check for both (mul (shl X, C), Y) and (mul Y, (shl X, C)). - if (N0.getOpcode() == ISD::SHL && isa(N0.getOperand(1)) && + if (N0.getOpcode() == ISD::SHL && + (isConstantSplatVector(N0.getOperand(1).getNode(), Val) || + isa(N0.getOperand(1))) && N0.getNode()->hasOneUse()) { Sh = N0; Y = N1; } else if (N1.getOpcode() == ISD::SHL && @@ -1840,8 +1878,9 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { } // fold (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2) - if (N1C && N0.getOpcode() == ISD::ADD && N0.getNode()->hasOneUse() && - isa(N0.getOperand(1))) + if (N1IsConst && N0.getOpcode() == ISD::ADD && N0.getNode()->hasOneUse() && + (isConstantSplatVector(N0.getOperand(1).getNode(), Val) || + isa(N0.getOperand(1)))) return DAG.getNode(ISD::ADD, SDLoc(N), VT, DAG.getNode(ISD::MUL, SDLoc(N0), VT, N0.getOperand(0), N1), diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp index 7db1e474a5c..954790b4bac 100644 --- a/lib/Target/X86/X86ISelLowering.cpp +++ b/lib/Target/X86/X86ISelLowering.cpp @@ -11560,9 +11560,11 @@ SDValue X86TargetLowering::LowerSDIV(SDValue Op, SelectionDAG &DAG) const { return SDValue(); APInt SplatValue, SplatUndef; - unsigned MinSplatBits; + unsigned SplatBitSize; bool HasAnyUndefs; - if (!C->isConstantSplat(SplatValue, SplatUndef, MinSplatBits, HasAnyUndefs)) + if (!C->isConstantSplat(SplatValue, SplatUndef, SplatBitSize, + HasAnyUndefs) || + EltTy.getSizeInBits() < SplatBitSize) return SDValue(); if ((SplatValue != 0) && diff --git a/test/CodeGen/X86/avx-shift.ll b/test/CodeGen/X86/avx-shift.ll index 01eb7361e29..d79dfcc076b 100644 --- a/test/CodeGen/X86/avx-shift.ll +++ b/test/CodeGen/X86/avx-shift.ll @@ -103,9 +103,10 @@ define <32 x i8> @vshift12(<32 x i8> %a) nounwind readnone { ;;; Support variable shifts ; CHECK: _vshift08 -; CHECK: vextractf128 $1 ; CHECK: vpslld $23 +; CHECK: vextractf128 $1 ; CHECK: vpslld $23 +; CHECK: ret define <8 x i32> @vshift08(<8 x i32> %a) nounwind { %bitop = shl <8 x i32> , %a ret <8 x i32> %bitop diff --git a/test/CodeGen/X86/avx2-arith.ll b/test/CodeGen/X86/avx2-arith.ll index 09f95383582..2c0b6685e56 100644 --- a/test/CodeGen/X86/avx2-arith.ll +++ b/test/CodeGen/X86/avx2-arith.ll @@ -74,3 +74,76 @@ define <4 x i64> @mul-v4i64(<4 x i64> %i, <4 x i64> %j) nounwind readnone { ret <4 x i64> %x } +; CHECK: mul_const1 +; CHECK: vpaddd +; CHECK: ret +define <8 x i32> @mul_const1(<8 x i32> %x) { + %y = mul <8 x i32> %x, + ret <8 x i32> %y +} + +; CHECK: mul_const2 +; CHECK: vpsllq $2 +; CHECK: ret +define <4 x i64> @mul_const2(<4 x i64> %x) { + %y = mul <4 x i64> %x, + ret <4 x i64> %y +} + +; CHECK: mul_const3 +; CHECK: vpsllw $3 +; CHECK: ret +define <16 x i16> @mul_const3(<16 x i16> %x) { + %y = mul <16 x i16> %x, + ret <16 x i16> %y +} + +; CHECK: mul_const4 +; CHECK: vpxor +; CHECK: vpsubq +; CHECK: ret +define <4 x i64> @mul_const4(<4 x i64> %x) { + %y = mul <4 x i64> %x, + ret <4 x i64> %y +} + +; CHECK: mul_const5 +; CHECK: vxorps +; CHECK-NEXT: ret +define <8 x i32> @mul_const5(<8 x i32> %x) { + %y = mul <8 x i32> %x, + ret <8 x i32> %y +} + +; CHECK: mul_const6 +; CHECK: vpmulld +; CHECK: ret +define <8 x i32> @mul_const6(<8 x i32> %x) { + %y = mul <8 x i32> %x, + ret <8 x i32> %y +} + +; CHECK: mul_const7 +; CHECK: vpaddq +; CHECK: vpaddq +; CHECK: ret +define <8 x i64> @mul_const7(<8 x i64> %x) { + %y = mul <8 x i64> %x, + ret <8 x i64> %y +} + +; CHECK: mul_const8 +; CHECK: vpsllw $3 +; CHECK: ret +define <8 x i16> @mul_const8(<8 x i16> %x) { + %y = mul <8 x i16> %x, + ret <8 x i16> %y +} + +; CHECK: mul_const9 +; CHECK: vpmulld +; CHECK: ret +define <8 x i32> @mul_const9(<8 x i32> %x) { + %y = mul <8 x i32> %x, + ret <8 x i32> %y +} \ No newline at end of file diff --git a/test/CodeGen/X86/vec_sdiv_to_shift.ll b/test/CodeGen/X86/vec_sdiv_to_shift.ll index 349868a87f5..59ceb2eb36d 100644 --- a/test/CodeGen/X86/vec_sdiv_to_shift.ll +++ b/test/CodeGen/X86/vec_sdiv_to_shift.ll @@ -70,3 +70,11 @@ entry: %a0 = sdiv <16 x i16> %var, ret <16 x i16> %a0 } + +; CHECK: sdiv_non_splat +; CHECK: idivl +; CHECK: ret +define <4 x i32> @sdiv_non_splat(<4 x i32> %x) { + %y = sdiv <4 x i32> %x, + ret <4 x i32> %y +} \ No newline at end of file diff --git a/test/CodeGen/X86/widen_arith-4.ll b/test/CodeGen/X86/widen_arith-4.ll index 5931d639f19..63c8d0e52e5 100644 --- a/test/CodeGen/X86/widen_arith-4.ll +++ b/test/CodeGen/X86/widen_arith-4.ll @@ -33,7 +33,7 @@ forbody: ; preds = %forcond %arrayidx6 = getelementptr <5 x i16>* %tmp5, i32 %tmp4 ; <<5 x i16>*> [#uses=1] %tmp7 = load <5 x i16>* %arrayidx6 ; <<5 x i16>> [#uses=1] %sub = sub <5 x i16> %tmp7, < i16 271, i16 271, i16 271, i16 271, i16 271 > ; <<5 x i16>> [#uses=1] - %mul = mul <5 x i16> %sub, < i16 2, i16 2, i16 2, i16 2, i16 2 > ; <<5 x i16>> [#uses=1] + %mul = mul <5 x i16> %sub, < i16 2, i16 4, i16 2, i16 2, i16 2 > ; <<5 x i16>> [#uses=1] store <5 x i16> %mul, <5 x i16>* %arrayidx br label %forinc diff --git a/test/CodeGen/X86/widen_arith-5.ll b/test/CodeGen/X86/widen_arith-5.ll index 7f2eff09f47..41df0e493b1 100644 --- a/test/CodeGen/X86/widen_arith-5.ll +++ b/test/CodeGen/X86/widen_arith-5.ll @@ -1,6 +1,6 @@ ; RUN: llc < %s -march=x86-64 -mattr=+sse42 | FileCheck %s ; CHECK: movdqa -; CHECK: pmulld +; CHECK: pslld $2 ; CHECK: psubd ; widen a v3i32 to v4i32 to do a vector multiple and a subtraction -- 2.34.1