From: Benjamin Kramer Date: Sat, 26 Apr 2014 23:09:49 +0000 (+0000) Subject: DAGCombiner: Simplify code a bit, make more transforms work with vectors. X-Git-Url: http://demsky.eecs.uci.edu/git/?a=commitdiff_plain;h=eb3430cfbde7c28586a0d55cff8a88dbb3aa348f;p=oota-llvm.git DAGCombiner: Simplify code a bit, make more transforms work with vectors. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@207338 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 0156fe1c0ec..2ca3f3e452c 100644 --- a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -644,8 +644,13 @@ static ConstantSDNode *isConstOrConstSplat(SDValue N) { if (ConstantSDNode *CN = dyn_cast(N)) return CN; - if (BuildVectorSDNode *BV = dyn_cast(N)) - return BV->getConstantSplatValue(); + if (BuildVectorSDNode *BV = dyn_cast(N)) { + ConstantSDNode *CN = BV->getConstantSplatValue(); + + // BuildVectors can truncate their operands. Ignore that case here. + if (CN && CN->getValueType(0) == N.getValueType().getScalarType()) + return CN; + } return nullptr; } @@ -1957,8 +1962,8 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { SDValue DAGCombiner::visitSDIV(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); - ConstantSDNode *N0C = dyn_cast(N0.getNode()); - ConstantSDNode *N1C = dyn_cast(N1.getNode()); + ConstantSDNode *N0C = isConstOrConstSplat(N0); + ConstantSDNode *N1C = isConstOrConstSplat(N1); EVT VT = N->getValueType(0); // fold vector ops @@ -1985,25 +1990,15 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) { N0, N1); } - const APInt *Divisor = nullptr; - if (N1C) { - Divisor = &N1C->getAPIntValue(); - } else if (N1.getValueType().isVector() && - N1->getOpcode() == ISD::BUILD_VECTOR) { - BuildVectorSDNode *BV = cast(N->getOperand(1)); - if (ConstantSDNode *C = BV->getConstantSplatValue()) - Divisor = &C->getAPIntValue(); - } - // fold (sdiv X, pow2) -> simple ops after legalize - if (Divisor && !!*Divisor && - (Divisor->isPowerOf2() || (-*Divisor).isPowerOf2())) { + if (N1C && !N1C->isNullValue() && (N1C->getAPIntValue().isPowerOf2() || + (-N1C->getAPIntValue()).isPowerOf2())) { // If dividing by powers of two is cheap, then don't perform the following // fold. if (TLI.isPow2DivCheap()) return SDValue(); - unsigned lg2 = Divisor->countTrailingZeros(); + unsigned lg2 = N1C->getAPIntValue().countTrailingZeros(); // Splat the sign bit into the register SDValue SGN = @@ -2025,7 +2020,7 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) { // If we're dividing by a positive value, we're done. Otherwise, we must // negate the result. - if (Divisor->isNonNegative()) + if (N1C->getAPIntValue().isNonNegative()) return SRA; AddToWorkList(SRA.getNode()); @@ -2034,7 +2029,7 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) { // if integer divide is expensive and we satisfy the requirements, emit an // alternate sequence. - if ((N1C || N1->getOpcode() == ISD::BUILD_VECTOR) && !TLI.isIntDivCheap()) { + if (N1C && !TLI.isIntDivCheap()) { SDValue Op = BuildSDIV(N); if (Op.getNode()) return Op; } @@ -2052,8 +2047,8 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) { SDValue DAGCombiner::visitUDIV(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); - ConstantSDNode *N0C = dyn_cast(N0.getNode()); - ConstantSDNode *N1C = dyn_cast(N1.getNode()); + ConstantSDNode *N0C = isConstOrConstSplat(N0); + ConstantSDNode *N1C = isConstOrConstSplat(N1); EVT VT = N->getValueType(0); // fold vector ops @@ -2086,7 +2081,7 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) { } } // fold (udiv x, c) -> alternate - if ((N1C || N1->getOpcode() == ISD::BUILD_VECTOR) && !TLI.isIntDivCheap()) { + if (N1C && !TLI.isIntDivCheap()) { SDValue Op = BuildUDIV(N); if (Op.getNode()) return Op; } @@ -2104,8 +2099,8 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) { SDValue DAGCombiner::visitSREM(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); - ConstantSDNode *N0C = dyn_cast(N0); - ConstantSDNode *N1C = dyn_cast(N1); + ConstantSDNode *N0C = isConstOrConstSplat(N0); + ConstantSDNode *N1C = isConstOrConstSplat(N1); EVT VT = N->getValueType(0); // fold (srem c1, c2) -> c1%c2 @@ -2146,8 +2141,8 @@ SDValue DAGCombiner::visitSREM(SDNode *N) { SDValue DAGCombiner::visitUREM(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); - ConstantSDNode *N0C = dyn_cast(N0); - ConstantSDNode *N1C = dyn_cast(N1); + ConstantSDNode *N0C = isConstOrConstSplat(N0); + ConstantSDNode *N1C = isConstOrConstSplat(N1); EVT VT = N->getValueType(0); // fold (urem c1, c2) -> c1%c2 @@ -11187,28 +11182,20 @@ SDValue DAGCombiner::SimplifySetCC(EVT VT, SDValue N0, /// multiplying by a magic number. See: /// SDValue DAGCombiner::BuildSDIV(SDNode *N) { - const APInt *Divisor; - if (N->getValueType(0).isVector()) { - // Handle splat vectors. - BuildVectorSDNode *BV = cast(N->getOperand(1)); - if (ConstantSDNode *C = BV->getConstantSplatValue()) - Divisor = &C->getAPIntValue(); - else - return SDValue(); - } else { - Divisor = &cast(N->getOperand(1))->getAPIntValue(); - } + ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1)); + if (!C) + return SDValue(); // Avoid division by zero. - if (!*Divisor) + if (!C->getAPIntValue()) return SDValue(); std::vector Built; - SDValue S = TLI.BuildSDIV(N, *Divisor, DAG, LegalOperations, &Built); + SDValue S = + TLI.BuildSDIV(N, C->getAPIntValue(), DAG, LegalOperations, &Built); - for (std::vector::iterator ii = Built.begin(), ee = Built.end(); - ii != ee; ++ii) - AddToWorkList(*ii); + for (SDNode *N : Built) + AddToWorkList(N); return S; } @@ -11217,28 +11204,20 @@ SDValue DAGCombiner::BuildSDIV(SDNode *N) { /// multiplying by a magic number. See: /// SDValue DAGCombiner::BuildUDIV(SDNode *N) { - const APInt *Divisor; - if (N->getValueType(0).isVector()) { - // Handle splat vectors. - BuildVectorSDNode *BV = cast(N->getOperand(1)); - if (ConstantSDNode *C = BV->getConstantSplatValue()) - Divisor = &C->getAPIntValue(); - else - return SDValue(); - } else { - Divisor = &cast(N->getOperand(1))->getAPIntValue(); - } + ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1)); + if (!C) + return SDValue(); // Avoid division by zero. - if (!*Divisor) + if (!C->getAPIntValue()) return SDValue(); std::vector Built; - SDValue S = TLI.BuildUDIV(N, *Divisor, DAG, LegalOperations, &Built); + SDValue S = + TLI.BuildUDIV(N, C->getAPIntValue(), DAG, LegalOperations, &Built); - for (std::vector::iterator ii = Built.begin(), ee = Built.end(); - ii != ee; ++ii) - AddToWorkList(*ii); + for (SDNode *N : Built) + AddToWorkList(N); return S; } diff --git a/test/CodeGen/X86/vector-idiv.ll b/test/CodeGen/X86/vector-idiv.ll index 06af3434b1a..3b300f74061 100644 --- a/test/CodeGen/X86/vector-idiv.ll +++ b/test/CodeGen/X86/vector-idiv.ll @@ -151,3 +151,38 @@ define <8 x i32> @test9(<8 x i32> %a) { ; AVX: vpsrad $2 ; AVX: vpadd } + +define <8 x i32> @test10(<8 x i32> %a) { + %rem = urem <8 x i32> %a, + ret <8 x i32> %rem + +; AVX-LABEL: test10: +; AVX: vpermd +; AVX: vpmuludq +; AVX: vshufps $-35 +; AVX: vpmuludq +; AVX: vshufps $-35 +; AVX: vpsubd +; AVX: vpsrld $1 +; AVX: vpadd +; AVX: vpsrld $2 +; AVX: vpmulld +} + +define <8 x i32> @test11(<8 x i32> %a) { + %rem = srem <8 x i32> %a, + ret <8 x i32> %rem + +; AVX-LABEL: test11: +; AVX: vpermd +; AVX: vpmuldq +; AVX: vshufps $-35 +; AVX: vpmuldq +; AVX: vshufps $-35 +; AVX: vpshufd $-40 +; AVX: vpadd +; AVX: vpsrld $31 +; AVX: vpsrad $2 +; AVX: vpadd +; AVX: vpmulld +}