Merging r259381:
[oota-llvm.git] / lib / IR / Constants.cpp
index b4a07a1b6b4a0a54163a479f713d3179f576fc3d..0898bf645385131761b5b4d40823d99a91c4be2b 100644 (file)
@@ -899,7 +899,9 @@ static Constant *getSequenceIfElementsMatch(Constant *C,
     else if (CI->getType()->isIntegerTy(64))
       return getIntSequenceIfElementsMatch<SequenceTy, uint64_t>(V);
   } else if (ConstantFP *CFP = dyn_cast<ConstantFP>(C)) {
-    if (CFP->getType()->isFloatTy())
+    if (CFP->getType()->isHalfTy())
+      return getFPSequenceIfElementsMatch<SequenceTy, uint16_t>(V);
+    else if (CFP->getType()->isFloatTy())
       return getFPSequenceIfElementsMatch<SequenceTy, uint32_t>(V);
     else if (CFP->getType()->isDoubleTy())
       return getFPSequenceIfElementsMatch<SequenceTy, uint64_t>(V);
@@ -1152,8 +1154,7 @@ ArrayRef<unsigned> ConstantExpr::getIndices() const {
 }
 
 unsigned ConstantExpr::getPredicate() const {
-  assert(isCompare());
-  return ((const CompareConstantExpr*)this)->predicate;
+  return cast<CompareConstantExpr>(this)->predicate;
 }
 
 /// getWithOperandReplaced - Return a constant expression identical to this
@@ -2365,7 +2366,7 @@ StringRef ConstantDataSequential::getRawDataValues() const {
 /// ConstantDataArray only works with normal float and int types that are
 /// stored densely in memory, not with things like i42 or x86_f80.
 bool ConstantDataSequential::isElementTypeCompatible(Type *Ty) {
-  if (Ty->isFloatTy() || Ty->isDoubleTy()) return true;
+  if (Ty->isHalfTy() || Ty->isFloatTy() || Ty->isDoubleTy()) return true;
   if (auto *IT = dyn_cast<IntegerType>(Ty)) {
     switch (IT->getBitWidth()) {
     case 8:
@@ -2521,7 +2522,7 @@ Constant *ConstantDataArray::get(LLVMContext &Context, ArrayRef<double> Elts) {
 /// object.
 Constant *ConstantDataArray::getFP(LLVMContext &Context,
                                    ArrayRef<uint16_t> Elts) {
-  Type *Ty = VectorType::get(Type::getHalfTy(Context), Elts.size());
+  Type *Ty = ArrayType::get(Type::getHalfTy(Context), Elts.size());
   const char *Data = reinterpret_cast<const char *>(Elts.data());
   return getImpl(StringRef(const_cast<char *>(Data), Elts.size() * 2), Ty);
 }
@@ -2637,6 +2638,11 @@ Constant *ConstantDataVector::getSplat(unsigned NumElts, Constant *V) {
   }
 
   if (ConstantFP *CFP = dyn_cast<ConstantFP>(V)) {
+    if (CFP->getType()->isHalfTy()) {
+      SmallVector<uint16_t, 16> Elts(
+          NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
+      return getFP(V->getContext(), Elts);
+    }
     if (CFP->getType()->isFloatTy()) {
       SmallVector<uint32_t, 16> Elts(
           NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
@@ -2682,6 +2688,10 @@ APFloat ConstantDataSequential::getElementAsAPFloat(unsigned Elt) const {
   switch (getElementType()->getTypeID()) {
   default:
     llvm_unreachable("Accessor can only be used when element is float/double!");
+  case Type::HalfTyID: {
+    auto EltVal = *reinterpret_cast<const uint16_t *>(EltPtr);
+    return APFloat(APFloat::IEEEhalf, APInt(16, EltVal));
+  }
   case Type::FloatTyID: {
     auto EltVal = *reinterpret_cast<const uint32_t *>(EltPtr);
     return APFloat(APFloat::IEEEsingle, APInt(32, EltVal));
@@ -2716,7 +2726,8 @@ double ConstantDataSequential::getElementAsDouble(unsigned Elt) const {
 /// Note that this has to compute a new constant to return, so it isn't as
 /// efficient as getElementAsInteger/Float/Double.
 Constant *ConstantDataSequential::getElementAsConstant(unsigned Elt) const {
-  if (getElementType()->isFloatTy() || getElementType()->isDoubleTy())
+  if (getElementType()->isHalfTy() || getElementType()->isFloatTy() ||
+      getElementType()->isDoubleTy())
     return ConstantFP::get(getContext(), getElementAsAPFloat(Elt));
 
   return ConstantInt::get(getElementType(), getElementAsInteger(Elt));
@@ -3009,7 +3020,7 @@ Instruction *ConstantExpr::getAsInstruction() {
   case Instruction::ICmp:
   case Instruction::FCmp:
     return CmpInst::Create((Instruction::OtherOps)getOpcode(),
-                           getPredicate(), Ops[0], Ops[1]);
+                           (CmpInst::Predicate)getPredicate(), Ops[0], Ops[1]);
 
   default:
     assert(getNumOperands() == 2 && "Must be binary operator?");