IR: Allow vectors of halfs to be ConstantDataVectors
[oota-llvm.git] / unittests / IR / ConstantsTest.cpp
index 2df008a78ccd7a6da5e00c8771dc5dbcd8f281e1..8c33453d293dfb6432bf7fabaf7caf219a2c4bed 100644 (file)
@@ -348,7 +348,7 @@ TEST(ConstantsTest, GEPReplaceWithConstant) {
   std::unique_ptr<Module> M(new Module("MyModule", Context));
 
   Type *IntTy = Type::getInt32Ty(Context);
-  auto *PtrTy = PointerType::get(IntTy, 0);
+  Type *PtrTy = PointerType::get(IntTy, 0);
   auto *C1 = ConstantInt::get(IntTy, 1);
   auto *Placeholder = new GlobalVariable(
       *M, IntTy, false, GlobalValue::ExternalWeakLinkage, nullptr);
@@ -361,7 +361,7 @@ TEST(ConstantsTest, GEPReplaceWithConstant) {
 
   auto *Global = new GlobalVariable(*M, PtrTy, false,
                                     GlobalValue::ExternalLinkage, nullptr);
-  auto *Alias = GlobalAlias::create(PtrTy, GlobalValue::ExternalLinkage,
+  auto *Alias = GlobalAlias::create(IntTy, 0, GlobalValue::ExternalLinkage,
                                     "alias", Global, M.get());
   Placeholder->replaceAllUsesWith(Alias);
   ASSERT_EQ(GEP, Ref->getInitializer());
@@ -382,5 +382,33 @@ TEST(ConstantsTest, AliasCAPI) {
   ASSERT_EQ(unwrap<GlobalAlias>(AliasRef)->getAliasee(), Aliasee);
 }
 
+static std::string getNameOfType(Type *T) {
+  std::string S;
+  raw_string_ostream RSOS(S);
+  T->print(RSOS);
+  return S;
+}
+
+TEST(ConstantsTest, BuildConstantDataVectors) {
+  LLVMContext Context;
+  std::unique_ptr<Module> M(new Module("MyModule", Context));
+
+  for (Type *T : {Type::getInt8Ty(Context), Type::getInt16Ty(Context),
+                  Type::getInt32Ty(Context), Type::getInt64Ty(Context)}) {
+    Constant *Vals[] = {ConstantInt::get(T, 0), ConstantInt::get(T, 1)};
+    Constant *CDV = ConstantVector::get(Vals);
+    ASSERT_TRUE(dyn_cast<ConstantDataVector>(CDV) != nullptr)
+        << " T = " << getNameOfType(T);
+  }
+
+  for (Type *T : {Type::getHalfTy(Context), Type::getFloatTy(Context),
+                  Type::getDoubleTy(Context)}) {
+    Constant *Vals[] = {ConstantFP::get(T, 0), ConstantFP::get(T, 1)};
+    Constant *CDV = ConstantVector::get(Vals);
+    ASSERT_TRUE(dyn_cast<ConstantDataVector>(CDV) != nullptr)
+        << " T = " << getNameOfType(T);
+  }
+}
+
 }  // end anonymous namespace
 }  // end namespace llvm