Extract the load/store type verification to a separate function.
[oota-llvm.git] / lib / Bitcode / Reader / BitcodeReader.cpp
index e0800916c8cd6d0c8caa5b57b2aaa8fd30959a96..86c61bdf66b33035527114c225588106b25de8d4 100644 (file)
@@ -401,6 +401,12 @@ static std::error_code Error(DiagnosticHandlerFunction DiagnosticHandler,
   return Error(DiagnosticHandler, EC, EC.message());
 }
 
+static std::error_code Error(DiagnosticHandlerFunction DiagnosticHandler,
+                             const Twine &Message) {
+  return Error(DiagnosticHandler,
+               make_error_code(BitcodeError::CorruptedBitcode), Message);
+}
+
 std::error_code BitcodeReader::Error(BitcodeError E, const Twine &Message) {
   return ::Error(DiagnosticHandler, make_error_code(E), Message);
 }
@@ -3290,6 +3296,20 @@ std::error_code BitcodeReader::ParseMetadataAttachment(Function &F) {
   }
 }
 
+static std::error_code TypeCheckLoadStoreInst(DiagnosticHandlerFunction DH,
+                                              Type *ValType, Type *PtrType) {
+  if (!isa<PointerType>(PtrType))
+    return Error(DH, "Load/Store operand is not a pointer type");
+  Type *ElemType = cast<PointerType>(PtrType)->getElementType();
+
+  if (ValType && ValType != ElemType)
+    return Error(DH, "Explicit load/store type does not match pointee type of "
+                     "pointer operand");
+  if (!PointerType::isLoadableOrStorableType(ElemType))
+    return Error(DH, "Cannot load/store from pointer");
+  return std::error_code();
+}
+
 /// ParseFunctionBody - Lazily parse the specified function body block.
 std::error_code BitcodeReader::ParseFunctionBody(Function *F) {
   if (Stream.EnterSubBlock(bitc::FUNCTION_BLOCK_ID))
@@ -4071,13 +4091,11 @@ std::error_code BitcodeReader::ParseFunctionBody(Function *F) {
       Type *Ty = nullptr;
       if (OpNum + 3 == Record.size())
         Ty = getTypeByID(Record[OpNum++]);
-      if (!isa<PointerType>(Op->getType()))
-        return Error("Load operand is not a pointer type");
+      if (std::error_code EC =
+              TypeCheckLoadStoreInst(DiagnosticHandler, Ty, Op->getType()))
+        return EC;
       if (!Ty)
         Ty = cast<PointerType>(Op->getType())->getElementType();
-      else if (Ty != cast<PointerType>(Op->getType())->getElementType())
-        return Error("Explicit load type does not match pointee type of "
-                     "pointer operand");
 
       unsigned Align;
       if (std::error_code EC = parseAlignmentValue(Record[OpNum], Align))
@@ -4098,6 +4116,11 @@ std::error_code BitcodeReader::ParseFunctionBody(Function *F) {
       Type *Ty = nullptr;
       if (OpNum + 5 == Record.size())
         Ty = getTypeByID(Record[OpNum++]);
+      if (std::error_code EC =
+              TypeCheckLoadStoreInst(DiagnosticHandler, Ty, Op->getType()))
+        return EC;
+      if (!Ty)
+        Ty = cast<PointerType>(Op->getType())->getElementType();
 
       AtomicOrdering Ordering = GetDecodedOrdering(Record[OpNum+2]);
       if (Ordering == NotAtomic || Ordering == Release ||
@@ -4112,10 +4135,6 @@ std::error_code BitcodeReader::ParseFunctionBody(Function *F) {
         return EC;
       I = new LoadInst(Op, "", Record[OpNum+1], Align, Ordering, SynchScope);
 
-      (void)Ty;
-      assert((!Ty || Ty == I->getType()) &&
-             "Explicit type doesn't match pointee type of the first operand");
-
       InstructionList.push_back(I);
       break;
     }
@@ -4131,6 +4150,10 @@ std::error_code BitcodeReader::ParseFunctionBody(Function *F) {
                           Val)) ||
           OpNum + 2 != Record.size())
         return Error("Invalid record");
+
+      if (std::error_code EC = TypeCheckLoadStoreInst(
+              DiagnosticHandler, Val->getType(), Ptr->getType()))
+        return EC;
       unsigned Align;
       if (std::error_code EC = parseAlignmentValue(Record[OpNum], Align))
         return EC;
@@ -4152,6 +4175,9 @@ std::error_code BitcodeReader::ParseFunctionBody(Function *F) {
           OpNum + 4 != Record.size())
         return Error("Invalid record");
 
+      if (std::error_code EC = TypeCheckLoadStoreInst(
+              DiagnosticHandler, Val->getType(), Ptr->getType()))
+        return EC;
       AtomicOrdering Ordering = GetDecodedOrdering(Record[OpNum+2]);
       if (Ordering == NotAtomic || Ordering == Acquire ||
           Ordering == AcquireRelease)
@@ -4187,6 +4213,9 @@ std::error_code BitcodeReader::ParseFunctionBody(Function *F) {
         return Error("Invalid record");
       SynchronizationScope SynchScope = GetDecodedSynchScope(Record[OpNum+2]);
 
+      if (std::error_code EC = TypeCheckLoadStoreInst(
+              DiagnosticHandler, Cmp->getType(), Ptr->getType()))
+        return EC;
       AtomicOrdering FailureOrdering;
       if (Record.size() < 7)
         FailureOrdering =