diff options
Diffstat (limited to 'mfbt')
-rw-r--r-- | mfbt/BufferList.h | 129 | ||||
-rw-r--r-- | mfbt/tests/TestBufferList.cpp | 33 |
2 files changed, 115 insertions, 47 deletions
diff --git a/mfbt/BufferList.h b/mfbt/BufferList.h index 26b335957f..aaff31454d 100644 --- a/mfbt/BufferList.h +++ b/mfbt/BufferList.h @@ -9,6 +9,7 @@ #include <algorithm> #include "mozilla/AllocPolicy.h" +#include "mozilla/Maybe.h" #include "mozilla/Move.h" #include "mozilla/ScopeExit.h" #include "mozilla/Types.h" @@ -456,61 +457,101 @@ BufferList<AllocPolicy>::Extract(IterImpl& aIter, size_t aSize, bool* aSuccess) MOZ_ASSERT(aSize % kSegmentAlignment == 0); MOZ_ASSERT(intptr_t(aIter.mData) % kSegmentAlignment == 0); - IterImpl iter = aIter; - size_t size = aSize; - size_t toCopy = std::min(size, aIter.RemainingInSegment()); - MOZ_ASSERT(toCopy % kSegmentAlignment == 0); + auto failure = [this, aSuccess]() { + *aSuccess = false; + return BufferList(0, 0, mStandardCapacity); + }; - BufferList result(0, toCopy, mStandardCapacity); - BufferList error(0, 0, mStandardCapacity); + // Number of segments we'll need to copy data from to satisfy the request. + size_t segmentsNeeded = 0; + // If this is None then the last segment is a full segment, otherwise we need + // to copy this many bytes. + Maybe<size_t> lastSegmentSize; + { + // Copy of the iterator to walk the BufferList and see how many segments we + // need to copy. + IterImpl iter = aIter; + size_t remaining = aSize; + while (!iter.Done() && remaining && + remaining >= iter.RemainingInSegment()) { + remaining -= iter.RemainingInSegment(); + iter.Advance(*this, iter.RemainingInSegment()); + segmentsNeeded++; + } - // Copy the head - if (!result.WriteBytes(aIter.mData, toCopy)) { - *aSuccess = false; - return error; + if (remaining) { + if (iter.Done()) { + // We reached the end of the BufferList and there wasn't enough data to + // satisfy the request. + return failure(); + } + lastSegmentSize.emplace(remaining); + // The last block also counts as a segment. This makes the conditionals + // on segmentsNeeded work in the rest of the function. + segmentsNeeded++; + } } - iter.Advance(*this, toCopy); - size -= toCopy; - // Move segments to result - auto resultGuard = MakeScopeExit([&] { - *aSuccess = false; - result.mSegments.erase(result.mSegments.begin()+1, result.mSegments.end()); - }); - - size_t movedSize = 0; - uintptr_t toRemoveStart = iter.mSegment; - uintptr_t toRemoveEnd = iter.mSegment; - while (!iter.Done() && - !iter.HasRoomFor(size)) { - if (!result.mSegments.append(Segment(mSegments[iter.mSegment].mData, - mSegments[iter.mSegment].mSize, - mSegments[iter.mSegment].mCapacity))) { - return error; - } - movedSize += iter.RemainingInSegment(); - size -= iter.RemainingInSegment(); - toRemoveEnd++; - iter.Advance(*this, iter.RemainingInSegment()); + BufferList result(0, 0, mStandardCapacity); + if (!result.mSegments.reserve(segmentsNeeded + lastSegmentSize.isSome())) { + return failure(); } - if (size) { - if (!iter.HasRoomFor(size) || - !result.WriteBytes(iter.Data(), size)) { - return error; + // Copy the first segment, it's special because we can't just steal the + // entire Segment struct from this->mSegments. + size_t firstSegmentSize = std::min(aSize, aIter.RemainingInSegment()); + if (!result.WriteBytes(aIter.Data(), firstSegmentSize)) { + return failure(); + } + aIter.Advance(*this, firstSegmentSize); + segmentsNeeded--; + + // The entirety of the request wasn't in the first segment, now copy the + // rest. + if (segmentsNeeded) { + char* finalSegment = nullptr; + // Pre-allocate the final segment so that if this fails, we return before + // we delete the elements from |this->mSegments|. + if (lastSegmentSize.isSome()) { + MOZ_RELEASE_ASSERT(mStandardCapacity >= *lastSegmentSize); + finalSegment = this->template pod_malloc<char>(mStandardCapacity); + if (!finalSegment) { + return failure(); + } + } + + size_t copyStart = aIter.mSegment; + // Copy segments from this over to the result and remove them from our + // storage. Not needed if the only segment we need to copy is the last + // partial one. + size_t segmentsToCopy = segmentsNeeded - lastSegmentSize.isSome(); + for (size_t i = 0; i < segmentsToCopy; ++i) { + result.mSegments.infallibleAppend( + Segment(mSegments[aIter.mSegment].mData, + mSegments[aIter.mSegment].mSize, + mSegments[aIter.mSegment].mCapacity)); + aIter.Advance(*this, aIter.RemainingInSegment()); + } + MOZ_RELEASE_ASSERT(aIter.mSegment == copyStart + segmentsToCopy); + mSegments.erase(mSegments.begin() + copyStart, + mSegments.begin() + copyStart + segmentsToCopy); + + // Reset the iter's position for what we just deleted. + aIter.mSegment -= segmentsToCopy; + + if (lastSegmentSize.isSome()) { + // We called reserve() on result.mSegments so infallibleAppend is safe. + result.mSegments.infallibleAppend( + Segment(finalSegment, 0, mStandardCapacity)); + bool r = result.WriteBytes(aIter.Data(), *lastSegmentSize); + MOZ_RELEASE_ASSERT(r); + aIter.Advance(*this, *lastSegmentSize); } - iter.Advance(*this, size); } - mSegments.erase(mSegments.begin() + toRemoveStart, mSegments.begin() + toRemoveEnd); - mSize -= movedSize; - aIter.mSegment = iter.mSegment - (toRemoveEnd - toRemoveStart); - aIter.mData = iter.mData; - aIter.mDataEnd = iter.mDataEnd; - MOZ_ASSERT(aIter.mDataEnd == mSegments[aIter.mSegment].End()); + mSize -= aSize; result.mSize = aSize; - resultGuard.release(); *aSuccess = true; return result; } diff --git a/mfbt/tests/TestBufferList.cpp b/mfbt/tests/TestBufferList.cpp index cccaac021b..207fa106f5 100644 --- a/mfbt/tests/TestBufferList.cpp +++ b/mfbt/tests/TestBufferList.cpp @@ -245,12 +245,39 @@ int main(void) BufferList bl3 = bl.Extract(iter, kExtractOverSize, &success); MOZ_RELEASE_ASSERT(!success); - MOZ_RELEASE_ASSERT(iter.AdvanceAcrossSegments(bl, kSmallWrite * 3 - kExtractSize - kExtractStart)); - MOZ_RELEASE_ASSERT(iter.Done()); - iter = bl2.Iter(); MOZ_RELEASE_ASSERT(iter.AdvanceAcrossSegments(bl2, kExtractSize)); MOZ_RELEASE_ASSERT(iter.Done()); + BufferList bl4(8, 8, 8); + bl4.WriteBytes("abcd1234", 8); + iter = bl4.Iter(); + iter.Advance(bl4, 8); + + BufferList bl5 = bl4.Extract(iter, kExtractSize, &success); + MOZ_RELEASE_ASSERT(!success); + + BufferList bl6(0, 0, 16); + bl6.WriteBytes("abcdefgh12345678", 16); + bl6.WriteBytes("ijklmnop87654321", 16); + iter = bl6.Iter(); + iter.Advance(bl6, 8); + BufferList bl7 = bl6.Extract(iter, 16, &success); + MOZ_RELEASE_ASSERT(success); + char data[16]; + MOZ_RELEASE_ASSERT(bl6.ReadBytes(iter, data, 8)); + MOZ_RELEASE_ASSERT(memcmp(data, "87654321", 8) == 0); + iter = bl7.Iter(); + MOZ_RELEASE_ASSERT(bl7.ReadBytes(iter, data, 16)); + MOZ_RELEASE_ASSERT(memcmp(data, "12345678ijklmnop", 16) == 0); + + BufferList bl8(0, 0, 16); + bl8.WriteBytes("abcdefgh12345678", 16); + iter = bl8.Iter(); + BufferList bl9 = bl8.Extract(iter, 8, &success); + MOZ_RELEASE_ASSERT(success); + MOZ_RELEASE_ASSERT(bl9.Size() == 8); + MOZ_RELEASE_ASSERT(!iter.Done()); + return 0; } |