From 493bc7e30090915290a8f08fbeba02f3d912498e Mon Sep 17 00:00:00 2001 From: Derek Schuff Date: Thu, 21 May 2015 19:40:19 +0000 Subject: [PATCH] Fix StreamingMemoryObject to respect known object size. The existing code for method StreamingMemoryObject.fetchToPos does not respect the corresonding call to setKnownObjectSize(). As a result, it allows the StreamingMemoryObject to read bytes past the object size. This patch provides a test case, and code to fix the problem. Patch by Karl Schimpf Differential Revision: http://reviews.llvm.org/D8931 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@237939 91177308-0d34-0410-b5e6-96231b3b80d8 --- include/llvm/Support/StreamingMemoryObject.h | 23 ++++++++++---------- lib/Support/StreamingMemoryObject.cpp | 17 ++++++++++----- unittests/Support/StreamingMemoryObject.cpp | 9 ++++++++ 3 files changed, 33 insertions(+), 16 deletions(-) diff --git a/include/llvm/Support/StreamingMemoryObject.h b/include/llvm/Support/StreamingMemoryObject.h index 9d1d607005f..932e635cd07 100644 --- a/include/llvm/Support/StreamingMemoryObject.h +++ b/include/llvm/Support/StreamingMemoryObject.h @@ -59,26 +59,27 @@ private: mutable size_t ObjectSize; // 0 if unknown, set if wrapper seen or EOF reached mutable bool EOFReached; - // Fetch enough bytes such that Pos can be read or EOF is reached - // (i.e. BytesRead > Pos). Return true if Pos can be read. - // Unlike most of the functions in BitcodeReader, returns true on success. - // Most of the requests will be small, but we fetch at kChunkSize bytes - // at a time to avoid making too many potentially expensive GetBytes calls + // Fetch enough bytes such that Pos can be read (i.e. BytesRead > + // Pos). Returns true if Pos can be read. Unlike most of the + // functions in BitcodeReader, returns true on success. Most of the + // requests will be small, but we fetch at kChunkSize bytes at a + // time to avoid making too many potentially expensive GetBytes + // calls. bool fetchToPos(size_t Pos) const { - if (EOFReached) - return Pos < ObjectSize; while (Pos >= BytesRead) { + if (EOFReached) + return false; Bytes.resize(BytesRead + BytesSkipped + kChunkSize); size_t bytes = Streamer->GetBytes(&Bytes[BytesRead + BytesSkipped], kChunkSize); BytesRead += bytes; - if (bytes != kChunkSize) { // reached EOF/ran out of bytes - ObjectSize = BytesRead; + if (bytes == 0) { // reached EOF/ran out of bytes + if (ObjectSize == 0) + ObjectSize = BytesRead; EOFReached = true; - break; } } - return Pos < BytesRead; + return !ObjectSize || Pos < ObjectSize; } StreamingMemoryObject(const StreamingMemoryObject&) = delete; diff --git a/lib/Support/StreamingMemoryObject.cpp b/lib/Support/StreamingMemoryObject.cpp index 90f3ed80d12..6c5652af04c 100644 --- a/lib/Support/StreamingMemoryObject.cpp +++ b/lib/Support/StreamingMemoryObject.cpp @@ -73,7 +73,7 @@ namespace llvm { // block until we actually want to read it. bool StreamingMemoryObject::isValidAddress(uint64_t address) const { if (ObjectSize && address < ObjectSize) return true; - return fetchToPos(address); + return fetchToPos(address); } uint64_t StreamingMemoryObject::getExtent() const { @@ -87,13 +87,18 @@ uint64_t StreamingMemoryObject::getExtent() const { uint64_t StreamingMemoryObject::readBytes(uint8_t *Buf, uint64_t Size, uint64_t Address) const { fetchToPos(Address + Size - 1); - if (Address >= BytesRead) + // Note: For wrapped bitcode files will set ObjectSize after the + // first call to fetchToPos. In such cases, ObjectSize can be + // smaller than BytesRead. + size_t MaxAddress = + (ObjectSize && ObjectSize < BytesRead) ? ObjectSize : BytesRead; + if (Address >= MaxAddress) return 0; uint64_t End = Address + Size; - if (End > BytesRead) - End = BytesRead; - assert(static_cast(End - Address) >= 0); + if (End > MaxAddress) + End = MaxAddress; + assert(End >= Address); Size = End - Address; memcpy(Buf, &Bytes[Address + BytesSkipped], Size); return Size; @@ -109,6 +114,8 @@ bool StreamingMemoryObject::dropLeadingBytes(size_t s) { void StreamingMemoryObject::setKnownObjectSize(size_t size) { ObjectSize = size; Bytes.reserve(size); + if (ObjectSize <= BytesRead) + EOFReached = true; } MemoryObject *getNonStreamedMemoryObject(const unsigned char *Start, diff --git a/unittests/Support/StreamingMemoryObject.cpp b/unittests/Support/StreamingMemoryObject.cpp index 20136491ab8..c043efbb5e4 100644 --- a/unittests/Support/StreamingMemoryObject.cpp +++ b/unittests/Support/StreamingMemoryObject.cpp @@ -27,3 +27,12 @@ TEST(StreamingMemoryObject, Test) { StreamingMemoryObject O(DS); EXPECT_TRUE(O.isValidAddress(32 * 1024)); } + +TEST(StreamingMemoryObject, TestSetKnownObjectSize) { + auto *DS = new NullDataStreamer(); + StreamingMemoryObject O(DS); + uint8_t Buf[32]; + EXPECT_EQ((uint64_t) 16, O.readBytes(Buf, 16, 0)); + O.setKnownObjectSize(24); + EXPECT_EQ((uint64_t) 8, O.readBytes(Buf, 16, 16)); +} -- 2.34.1