[IR] Make {extract,insert}element accept an index of any integer type.
authorMichael J. Spencer <bigcheesegs@gmail.com>
Thu, 1 May 2014 22:12:39 +0000 (22:12 +0000)
committerMichael J. Spencer <bigcheesegs@gmail.com>
Thu, 1 May 2014 22:12:39 +0000 (22:12 +0000)
Given the following C code llvm currently generates suboptimal code for
x86-64:

__m128 bss4( const __m128 *ptr, size_t i, size_t j )
{
    float f = ptr[i][j];
    return (__m128) { f, f, f, f };
}

=================================================

define <4 x float> @_Z4bss4PKDv4_fmm(<4 x float>* nocapture readonly %ptr, i64 %i, i64 %j) #0 {
  %a1 = getelementptr inbounds <4 x float>* %ptr, i64 %i
  %a2 = load <4 x float>* %a1, align 16, !tbaa !1
  %a3 = trunc i64 %j to i32
  %a4 = extractelement <4 x float> %a2, i32 %a3
  %a5 = insertelement <4 x float> undef, float %a4, i32 0
  %a6 = insertelement <4 x float> %a5, float %a4, i32 1
  %a7 = insertelement <4 x float> %a6, float %a4, i32 2
  %a8 = insertelement <4 x float> %a7, float %a4, i32 3
  ret <4 x float> %a8
}

=================================================

        shlq    $4, %rsi
        addq    %rdi, %rsi
        movslq  %edx, %rax
        vbroadcastss    (%rsi,%rax,4), %xmm0
        retq

=================================================

The movslq is uneeded, but is present because of the trunc to i32 and then
sext back to i64 that the backend adds for vbroadcastss.

We can't remove it because it changes the meaning. The IR that clang
generates is already suboptimal. What clang really should emit is:

  %a4 = extractelement <4 x float> %a2, i64 %j

This patch makes that legal. A separate patch will teach clang to do it.

Differential Revision: http://reviews.llvm.org/D3519

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@207801 91177308-0d34-0410-b5e6-96231b3b80d8

docs/LangRef.rst
lib/Bitcode/Reader/BitcodeReader.cpp
lib/Bitcode/Writer/BitcodeWriter.cpp
lib/IR/Constants.cpp
lib/IR/Instructions.cpp
test/CodeGen/X86/vec_splat.ll
test/Feature/instructions.ll

index a6595f7fd068ea679848e9ca0256ddadc2d1541f..3d99a0e79b1223931ce298b3c0346c25917572dc 100644 (file)
@@ -4470,7 +4470,7 @@ Syntax:
 
 ::
 
-      <result> = extractelement <n x <ty>> <val>, i32 <idx>    ; yields <ty>
+      <result> = extractelement <n x <ty>> <val>, <ty2> <idx>  ; yields <ty>
 
 Overview:
 """""""""
@@ -4484,7 +4484,7 @@ Arguments:
 The first operand of an '``extractelement``' instruction is a value of
 :ref:`vector <t_vector>` type. The second operand is an index indicating
 the position from which to extract the element. The index may be a
-variable.
+variable of any integer type.
 
 Semantics:
 """"""""""
@@ -4510,7 +4510,7 @@ Syntax:
 
 ::
 
-      <result> = insertelement <n x <ty>> <val>, <ty> <elt>, i32 <idx>    ; yields <n x <ty>>
+      <result> = insertelement <n x <ty>> <val>, <ty> <elt>, <ty2> <idx>    ; yields <n x <ty>>
 
 Overview:
 """""""""
@@ -4525,7 +4525,7 @@ The first operand of an '``insertelement``' instruction is a value of
 :ref:`vector <t_vector>` type. The second operand is a scalar value whose
 type must equal the element type of the first operand. The third operand
 is an index indicating the position at which to insert the value. The
-index may be a variable.
+index may be a variable of any integer type.
 
 Semantics:
 """"""""""
index 74e8439141eb575e07c83f709b5f80be4a493bbb..1b2cf76f3e610834bbab17918e140f2651a70f0c 100644 (file)
@@ -1418,7 +1418,8 @@ error_code BitcodeReader::ParseConstants() {
                                   ValueList.getConstantFwdRef(Record[2],CurTy));
       break;
     }
-    case bitc::CST_CODE_CE_EXTRACTELT: { // CE_EXTRACTELT: [opty, opval, opval]
+    case bitc::CST_CODE_CE_EXTRACTELT
+        : { // CE_EXTRACTELT: [opty, opval, opty, opval]
       if (Record.size() < 3)
         return Error(InvalidRecord);
       VectorType *OpTy =
@@ -1426,20 +1427,37 @@ error_code BitcodeReader::ParseConstants() {
       if (!OpTy)
         return Error(InvalidRecord);
       Constant *Op0 = ValueList.getConstantFwdRef(Record[1], OpTy);
-      Constant *Op1 = ValueList.getConstantFwdRef(Record[2],
-                                                  Type::getInt32Ty(Context));
+      Constant *Op1 = nullptr;
+      if (Record.size() == 4) {
+        Type *IdxTy = getTypeByID(Record[2]);
+        if (!IdxTy)
+          return Error(InvalidRecord);
+        Op1 = ValueList.getConstantFwdRef(Record[3], IdxTy);
+      } else // TODO: Remove with llvm 4.0
+        Op1 = ValueList.getConstantFwdRef(Record[2], Type::getInt32Ty(Context));
+      if (!Op1)
+        return Error(InvalidRecord);
       V = ConstantExpr::getExtractElement(Op0, Op1);
       break;
     }
-    case bitc::CST_CODE_CE_INSERTELT: { // CE_INSERTELT: [opval, opval, opval]
+    case bitc::CST_CODE_CE_INSERTELT
+        : { // CE_INSERTELT: [opval, opval, opty, opval]
       VectorType *OpTy = dyn_cast<VectorType>(CurTy);
       if (Record.size() < 3 || !OpTy)
         return Error(InvalidRecord);
       Constant *Op0 = ValueList.getConstantFwdRef(Record[0], OpTy);
       Constant *Op1 = ValueList.getConstantFwdRef(Record[1],
                                                   OpTy->getElementType());
-      Constant *Op2 = ValueList.getConstantFwdRef(Record[2],
-                                                  Type::getInt32Ty(Context));
+      Constant *Op2 = nullptr;
+      if (Record.size() == 4) {
+        Type *IdxTy = getTypeByID(Record[2]);
+        if (!IdxTy)
+          return Error(InvalidRecord);
+        Op2 = ValueList.getConstantFwdRef(Record[3], IdxTy);
+      } else // TODO: Remove with llvm 4.0
+        Op2 = ValueList.getConstantFwdRef(Record[2], Type::getInt32Ty(Context));
+      if (!Op2)
+        return Error(InvalidRecord);
       V = ConstantExpr::getInsertElement(Op0, Op1, Op2);
       break;
     }
@@ -2460,7 +2478,7 @@ error_code BitcodeReader::ParseFunctionBody(Function *F) {
       unsigned OpNum = 0;
       Value *Vec, *Idx;
       if (getValueTypePair(Record, OpNum, NextValueNo, Vec) ||
-          popValue(Record, OpNum, NextValueNo, Type::getInt32Ty(Context), Idx))
+          getValueTypePair(Record, OpNum, NextValueNo, Idx))
         return Error(InvalidRecord);
       I = ExtractElementInst::Create(Vec, Idx);
       InstructionList.push_back(I);
@@ -2473,7 +2491,7 @@ error_code BitcodeReader::ParseFunctionBody(Function *F) {
       if (getValueTypePair(Record, OpNum, NextValueNo, Vec) ||
           popValue(Record, OpNum, NextValueNo,
                    cast<VectorType>(Vec->getType())->getElementType(), Elt) ||
-          popValue(Record, OpNum, NextValueNo, Type::getInt32Ty(Context), Idx))
+          getValueTypePair(Record, OpNum, NextValueNo, Idx))
         return Error(InvalidRecord);
       I = InsertElementInst::Create(Vec, Elt, Idx);
       InstructionList.push_back(I);
index 92965fa7e445a94f5f7dcccada8c5095d184e659..23374872df13885cefe560c5059973b3c6f34a59 100644 (file)
@@ -1087,12 +1087,14 @@ static void WriteConstants(unsigned FirstVal, unsigned LastVal,
         Code = bitc::CST_CODE_CE_EXTRACTELT;
         Record.push_back(VE.getTypeID(C->getOperand(0)->getType()));
         Record.push_back(VE.getValueID(C->getOperand(0)));
+        Record.push_back(VE.getTypeID(C->getOperand(1)->getType()));
         Record.push_back(VE.getValueID(C->getOperand(1)));
         break;
       case Instruction::InsertElement:
         Code = bitc::CST_CODE_CE_INSERTELT;
         Record.push_back(VE.getValueID(C->getOperand(0)));
         Record.push_back(VE.getValueID(C->getOperand(1)));
+        Record.push_back(VE.getTypeID(C->getOperand(2)->getType()));
         Record.push_back(VE.getValueID(C->getOperand(2)));
         break;
       case Instruction::ShuffleVector:
@@ -1253,13 +1255,13 @@ static void WriteInstruction(const Instruction &I, unsigned InstID,
   case Instruction::ExtractElement:
     Code = bitc::FUNC_CODE_INST_EXTRACTELT;
     PushValueAndType(I.getOperand(0), InstID, Vals, VE);
-    pushValue(I.getOperand(1), InstID, Vals, VE);
+    PushValueAndType(I.getOperand(1), InstID, Vals, VE);
     break;
   case Instruction::InsertElement:
     Code = bitc::FUNC_CODE_INST_INSERTELT;
     PushValueAndType(I.getOperand(0), InstID, Vals, VE);
     pushValue(I.getOperand(1), InstID, Vals, VE);
-    pushValue(I.getOperand(2), InstID, Vals, VE);
+    PushValueAndType(I.getOperand(2), InstID, Vals, VE);
     break;
   case Instruction::ShuffleVector:
     Code = bitc::FUNC_CODE_INST_SHUFFLEVEC;
index 54be9802920dee35c83e796f42aa3c980762c47e..bb8d60b234f320d8f942a2194f3ceea6ad58cc17 100644 (file)
@@ -1937,8 +1937,8 @@ ConstantExpr::getFCmp(unsigned short pred, Constant *LHS, Constant *RHS) {
 Constant *ConstantExpr::getExtractElement(Constant *Val, Constant *Idx) {
   assert(Val->getType()->isVectorTy() &&
          "Tried to create extractelement operation on non-vector type!");
-  assert(Idx->getType()->isIntegerTy(32) &&
-         "Extractelement index must be i32 type!");
+  assert(Idx->getType()->isIntegerTy() &&
+         "Extractelement index must be an integer type!");
 
   if (Constant *FC = ConstantFoldExtractElementInstruction(Val, Idx))
     return FC;          // Fold a few common cases.
@@ -1958,7 +1958,7 @@ Constant *ConstantExpr::getInsertElement(Constant *Val, Constant *Elt,
          "Tried to create insertelement operation on non-vector type!");
   assert(Elt->getType() == Val->getType()->getVectorElementType() &&
          "Insertelement types must match!");
-  assert(Idx->getType()->isIntegerTy(32) &&
+  assert(Idx->getType()->isIntegerTy() &&
          "Insertelement index must be i32 type!");
 
   if (Constant *FC = ConstantFoldInsertElementInstruction(Val, Elt, Idx))
index e8bcddbd9921a8deed4f11a8275ed67566f5f2b2..8aebb8c7c2c60b3f94dc26012e116b73a11f06f1 100644 (file)
@@ -1479,7 +1479,7 @@ ExtractElementInst::ExtractElementInst(Value *Val, Value *Index,
 
 
 bool ExtractElementInst::isValidOperands(const Value *Val, const Value *Index) {
-  if (!Val->getType()->isVectorTy() || !Index->getType()->isIntegerTy(32))
+  if (!Val->getType()->isVectorTy() || !Index->getType()->isIntegerTy())
     return false;
   return true;
 }
@@ -1526,7 +1526,7 @@ bool InsertElementInst::isValidOperands(const Value *Vec, const Value *Elt,
   if (Elt->getType() != cast<VectorType>(Vec->getType())->getElementType())
     return false;// Second operand of insertelement must be vector element type.
     
-  if (!Index->getType()->isIntegerTy(32))
+  if (!Index->getType()->isIntegerTy())
     return false;  // Third operand of insertelement must be i32.
   return true;
 }
index 543c96ef3d451200c600d414ed93890d5c01e049..a02e3836078c6779ef5be476cdad90df7fb494f0 100644 (file)
@@ -32,3 +32,19 @@ define void @test_v2sd(<2 x double>* %P, <2 x double>* %Q, double %X) nounwind {
 ; SSE3-LABEL: test_v2sd:
 ; SSE3: movddup
 }
+
+; Fold extract of a load into the load's address computation. This avoids spilling to the stack.
+define <4 x float> @load_extract_splat(<4 x float>* nocapture readonly %ptr, i64 %i, i64 %j) nounwind {
+  %1 = getelementptr inbounds <4 x float>* %ptr, i64 %i
+  %2 = load <4 x float>* %1, align 16
+  %3 = extractelement <4 x float> %2, i64 %j
+  %4 = insertelement <4 x float> undef, float %3, i32 0
+  %5 = insertelement <4 x float> %4, float %3, i32 1
+  %6 = insertelement <4 x float> %5, float %3, i32 2
+  %7 = insertelement <4 x float> %6, float %3, i32 3
+  ret <4 x float> %7
+  
+; AVX-LABEL: load_extract_splat
+; AVX-NOT: movs
+; AVX: vbroadcastss
+}
index d0c303d71914c17baa7fa5d6d22db61944e0b7b3..aa962948a16876894f88cf4fa75c5947f7e6e839 100644 (file)
@@ -4,11 +4,13 @@
 
 define i32 @test_extractelement(<4 x i32> %V) {
         %R = extractelement <4 x i32> %V, i32 1         ; <i32> [#uses=1]
+               %S = extractelement <4 x i32> %V, i64 1         ; <i32> [#uses=0]
         ret i32 %R
 }
 
 define <4 x i32> @test_insertelement(<4 x i32> %V) {
         %R = insertelement <4 x i32> %V, i32 0, i32 0           ; <<4 x i32>> [#uses=1]
+               %S = insertelement <4 x i32> %V, i32 0, i64 0           ; <<4 x i32>> [#uses=0]
         ret <4 x i32> %R
 }