[InstCombine] Optimize icmp slt signum(x), 1 --> icmp slt x, 1
authorSanjoy Das <sanjoy@playingwithpointers.com>
Wed, 16 Sep 2015 20:41:29 +0000 (20:41 +0000)
committerSanjoy Das <sanjoy@playingwithpointers.com>
Wed, 16 Sep 2015 20:41:29 +0000 (20:41 +0000)
Summary:
`signum(x)` is sometimes implemented as `(x >> 63) | (-x >>> 63)` (for
an `i64` `x`).  This change adds a matcher for that pattern, and an
instcombine rule to optimize `signum(x) s< 1`.

Later, we can also consider optimizing:

  icmp slt signum(x), 0 --> icmp slt x, 0
  icmp sle signum(x), 1 --> true

etc.

Reviewers: majnemer

Subscribers: llvm-commits

Differential Revision: http://reviews.llvm.org/D12703

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@247846 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/IR/PatternMatch.h
lib/Transforms/InstCombine/InstCombineCompares.cpp
test/Transforms/InstCombine/compare-signs.ll

index 41154e6441a93b9c3c1b42e9959eb2f84bb4a538..f4d7d8c4441628c35d137039eda8cae860a72337 100644 (file)
@@ -1272,6 +1272,46 @@ inline typename m_Intrinsic_Ty<Opnd0, Opnd1>::Ty m_FMax(const Opnd0 &Op0,
   return m_Intrinsic<Intrinsic::maxnum>(Op0, Op1);
 }
 
+template <typename Opnd_t> struct Signum_match {
+  Opnd_t Val;
+  Signum_match(const Opnd_t &V) : Val(V) {}
+
+  template <typename OpTy> bool match(OpTy *V) {
+    unsigned TypeSize = V->getType()->getScalarSizeInBits();
+    if (TypeSize == 0)
+      return false;
+
+    unsigned ShiftWidth = TypeSize - 1;
+    Value *OpL = nullptr, *OpR = nullptr;
+
+    // This is the representation of signum we match:
+    //
+    //  signum(x) == (x >> 63) | (-x >>u 63)
+    //
+    // An i1 value is its own signum, so it's correct to match
+    //
+    //  signum(x) == (x >> 0)  | (-x >>u 0)
+    //
+    // for i1 values.
+
+    auto LHS = m_AShr(m_Value(OpL), m_SpecificInt(ShiftWidth));
+    auto RHS = m_LShr(m_Neg(m_Value(OpR)), m_SpecificInt(ShiftWidth));
+    auto Signum = m_Or(LHS, RHS);
+
+    return Signum.match(V) && OpL == OpR && Val.match(OpL);
+  }
+};
+
+/// \brief Matches a signum pattern.
+///
+/// signum(x) =
+///      x >  0  ->  1
+///      x == 0  ->  0
+///      x <  0  -> -1
+template <typename Val_t> inline Signum_match<Val_t> m_Signum(const Val_t &V) {
+  return Signum_match<Val_t>(V);
+}
+
 } // end namespace PatternMatch
 } // end namespace llvm
 
index f1a66816c4cb626d8b87f965fc236e3cf104a0cb..0d56ce5ebe81f3f6337d09ca488a0c977263e0e2 100644 (file)
@@ -1145,6 +1145,14 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
 
   switch (LHSI->getOpcode()) {
   case Instruction::Trunc:
+    if (RHS->isOne() && RHSV.getBitWidth() > 1) {
+      // icmp slt trunc(signum(V)) 1 --> icmp slt V, 1
+      Value *V = nullptr;
+      if (ICI.getPredicate() == ICmpInst::ICMP_SLT &&
+          match(LHSI->getOperand(0), m_Signum(m_Value(V))))
+        return new ICmpInst(ICmpInst::ICMP_SLT, V,
+                            ConstantInt::get(V->getType(), 1));
+    }
     if (ICI.isEquality() && LHSI->hasOneUse()) {
       // Simplify icmp eq (trunc x to i8), 42 -> icmp eq x, 42|highbits if all
       // of the high bits truncated out of x are known.
@@ -1467,6 +1475,15 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
     break;
 
   case Instruction::Or: {
+    if (RHS->isOne()) {
+      // icmp slt signum(V) 1 --> icmp slt V, 1
+      Value *V = nullptr;
+      if (ICI.getPredicate() == ICmpInst::ICMP_SLT &&
+          match(LHSI, m_Signum(m_Value(V))))
+        return new ICmpInst(ICmpInst::ICMP_SLT, V,
+                            ConstantInt::get(V->getType(), 1));
+    }
+
     if (!ICI.isEquality() || !RHS->isNullValue() || !LHSI->hasOneUse())
       break;
     Value *P, *Q;
index 62cd5b3f94d56ec924b42fb910dc6ce438389a13..0ed0ac7d8d9c9c8a7ca2971b41f505df793b7940 100644 (file)
@@ -56,3 +56,43 @@ entry:
 ; CHECK-NOT: zext
 ; CHECK: ret i32 %2
 }
+
+define i1 @test4a(i32 %a) {
+; CHECK-LABEL: @test4a(
+ entry:
+; CHECK: %c = icmp slt i32 %a, 1
+; CHECK-NEXT: ret i1 %c
+  %l = ashr i32 %a, 31
+  %na = sub i32 0, %a
+  %r = lshr i32 %na, 31
+  %signum = or i32 %l, %r
+  %c = icmp slt i32 %signum, 1
+  ret i1 %c
+}
+
+define i1 @test4b(i64 %a) {
+; CHECK-LABEL: @test4b(
+ entry:
+; CHECK: %c = icmp slt i64 %a, 1
+; CHECK-NEXT: ret i1 %c
+  %l = ashr i64 %a, 63
+  %na = sub i64 0, %a
+  %r = lshr i64 %na, 63
+  %signum = or i64 %l, %r
+  %c = icmp slt i64 %signum, 1
+  ret i1 %c
+}
+
+define i1 @test4c(i64 %a) {
+; CHECK-LABEL: @test4c(
+ entry:
+; CHECK: %c = icmp slt i64 %a, 1
+; CHECK-NEXT: ret i1 %c
+  %l = ashr i64 %a, 63
+  %na = sub i64 0, %a
+  %r = lshr i64 %na, 63
+  %signum = or i64 %l, %r
+  %signum.trunc = trunc i64 %signum to i32
+  %c = icmp slt i32 %signum.trunc, 1
+  ret i1 %c
+}