Use std::unique_ptr for StreamFilter instead of std::shared_ptr

pull/103/head
Tatsuhiro Tsujikawa 2013-07-04 00:39:11 +09:00
parent cb205a207c
commit 57f1902ee1
16 changed files with 91 additions and 92 deletions

View File

@ -65,12 +65,13 @@ enum {
} // namespace } // namespace
ChunkedDecodingStreamFilter::ChunkedDecodingStreamFilter ChunkedDecodingStreamFilter::ChunkedDecodingStreamFilter
(const std::shared_ptr<StreamFilter>& delegate): (std::unique_ptr<StreamFilter> delegate)
StreamFilter(delegate), : StreamFilter{std::move(delegate)},
state_(PREV_CHUNK_SIZE), state_{PREV_CHUNK_SIZE},
chunkSize_(0), chunkSize_{0},
chunkRemaining_(0), chunkRemaining_{0},
bytesProcessed_(0) {} bytesProcessed_{0}
{}
ChunkedDecodingStreamFilter::~ChunkedDecodingStreamFilter() {} ChunkedDecodingStreamFilter::~ChunkedDecodingStreamFilter() {}

View File

@ -47,7 +47,7 @@ private:
size_t bytesProcessed_; size_t bytesProcessed_;
public: public:
ChunkedDecodingStreamFilter ChunkedDecodingStreamFilter
(const std::shared_ptr<StreamFilter>& delegate = std::shared_ptr<StreamFilter>()); (std::unique_ptr<StreamFilter> delegate = std::unique_ptr<StreamFilter>{});
virtual ~ChunkedDecodingStreamFilter(); virtual ~ChunkedDecodingStreamFilter();

View File

@ -108,9 +108,8 @@ DownloadCommand::DownloadCommand
peerStat_->downloadStart(); peerStat_->downloadStart();
getSegmentMan()->registerPeerStat(peerStat_); getSegmentMan()->registerPeerStat(peerStat_);
WrDiskCache* wrDiskCache = getPieceStorage()->getWrDiskCache(); streamFilter_ = make_unique<SinkStreamFilter>
streamFilter_.reset(new SinkStreamFilter(wrDiskCache, (getPieceStorage()->getWrDiskCache(), pieceHashValidationEnabled_);
pieceHashValidationEnabled_));
streamFilter_->init(); streamFilter_->init();
sinkFilterOnly_ = true; sinkFilterOnly_ = true;
checkSocketRecvBuffer(); checkSocketRecvBuffer();
@ -410,13 +409,13 @@ void DownloadCommand::completeSegment(cuid_t cuid,
} }
void DownloadCommand::installStreamFilter void DownloadCommand::installStreamFilter
(const std::shared_ptr<StreamFilter>& streamFilter) (std::unique_ptr<StreamFilter> streamFilter)
{ {
if(!streamFilter) { if(!streamFilter) {
return; return;
} }
streamFilter->installDelegate(streamFilter_); streamFilter->installDelegate(std::move(streamFilter_));
streamFilter_ = streamFilter; streamFilter_ = std::move(streamFilter);
const std::string& name = streamFilter_->getName(); const std::string& name = streamFilter_->getName();
sinkFilterOnly_ = util::endsWith(name, SinkStreamFilter::NAME); sinkFilterOnly_ = util::endsWith(name, SinkStreamFilter::NAME);
} }

View File

@ -69,7 +69,7 @@ private:
void completeSegment(cuid_t cuid, const std::shared_ptr<Segment>& segment); void completeSegment(cuid_t cuid, const std::shared_ptr<Segment>& segment);
std::shared_ptr<StreamFilter> streamFilter_; std::unique_ptr<StreamFilter> streamFilter_;
bool sinkFilterOnly_; bool sinkFilterOnly_;
protected: protected:
@ -89,12 +89,12 @@ public:
const std::shared_ptr<SocketRecvBuffer>& socketRecvBuffer); const std::shared_ptr<SocketRecvBuffer>& socketRecvBuffer);
virtual ~DownloadCommand(); virtual ~DownloadCommand();
const std::shared_ptr<StreamFilter>& getStreamFilter() const const std::unique_ptr<StreamFilter>& getStreamFilter() const
{ {
return streamFilter_; return streamFilter_;
} }
void installStreamFilter(const std::shared_ptr<StreamFilter>& streamFilter); void installStreamFilter(std::unique_ptr<StreamFilter> streamFilter);
void setStartupIdleTime(time_t startupIdleTime) void setStartupIdleTime(time_t startupIdleTime)
{ {

View File

@ -44,8 +44,12 @@ namespace aria2 {
const std::string GZipDecodingStreamFilter::NAME("GZipDecodingStreamFilter"); const std::string GZipDecodingStreamFilter::NAME("GZipDecodingStreamFilter");
GZipDecodingStreamFilter::GZipDecodingStreamFilter GZipDecodingStreamFilter::GZipDecodingStreamFilter
(const std::shared_ptr<StreamFilter>& delegate): (std::unique_ptr<StreamFilter> delegate)
StreamFilter(delegate), strm_(0), finished_(false), bytesProcessed_(0) {} : StreamFilter{std::move(delegate)},
strm_{nullptr},
finished_{false},
bytesProcessed_{0}
{}
GZipDecodingStreamFilter::~GZipDecodingStreamFilter() GZipDecodingStreamFilter::~GZipDecodingStreamFilter()
{ {

View File

@ -52,7 +52,7 @@ private:
static const size_t OUTBUF_LENGTH = 16*1024; static const size_t OUTBUF_LENGTH = 16*1024;
public: public:
GZipDecodingStreamFilter GZipDecodingStreamFilter
(const std::shared_ptr<StreamFilter>& delegate = std::shared_ptr<StreamFilter>()); (std::unique_ptr<StreamFilter> delegate = std::unique_ptr<StreamFilter>{});
virtual ~GZipDecodingStreamFilter(); virtual ~GZipDecodingStreamFilter();

View File

@ -189,17 +189,17 @@ const std::string& HttpResponse::getTransferEncoding() const
return httpHeader_->find(HttpHeader::TRANSFER_ENCODING); return httpHeader_->find(HttpHeader::TRANSFER_ENCODING);
} }
std::shared_ptr<StreamFilter> HttpResponse::getTransferEncodingStreamFilter() const std::unique_ptr<StreamFilter>
HttpResponse::getTransferEncodingStreamFilter() const
{ {
std::shared_ptr<StreamFilter> filter;
// TODO Transfer-Encoding header field can contains multiple tokens. We should // TODO Transfer-Encoding header field can contains multiple tokens. We should
// parse the field and retrieve each token. // parse the field and retrieve each token.
if(isTransferEncodingSpecified()) { if(isTransferEncodingSpecified()) {
if(util::strieq(getTransferEncoding(), "chunked")) { if(util::strieq(getTransferEncoding(), "chunked")) {
filter.reset(new ChunkedDecodingStreamFilter()); return make_unique<ChunkedDecodingStreamFilter>();
} }
} }
return filter; return std::unique_ptr<StreamFilter>{};
} }
bool HttpResponse::isContentEncodingSpecified() const bool HttpResponse::isContentEncodingSpecified() const
@ -212,16 +212,16 @@ const std::string& HttpResponse::getContentEncoding() const
return httpHeader_->find(HttpHeader::CONTENT_ENCODING); return httpHeader_->find(HttpHeader::CONTENT_ENCODING);
} }
std::shared_ptr<StreamFilter> HttpResponse::getContentEncodingStreamFilter() const std::unique_ptr<StreamFilter>
HttpResponse::getContentEncodingStreamFilter() const
{ {
std::shared_ptr<StreamFilter> filter;
#ifdef HAVE_ZLIB #ifdef HAVE_ZLIB
if(util::strieq(getContentEncoding(), "gzip") || if(util::strieq(getContentEncoding(), "gzip") ||
util::strieq(getContentEncoding(), "deflate")) { util::strieq(getContentEncoding(), "deflate")) {
filter.reset(new GZipDecodingStreamFilter()); return make_unique<GZipDecodingStreamFilter>();
} }
#endif // HAVE_ZLIB #endif // HAVE_ZLIB
return filter; return std::unique_ptr<StreamFilter>{};
} }
int64_t HttpResponse::getContentLength() const int64_t HttpResponse::getContentLength() const

View File

@ -86,13 +86,13 @@ public:
const std::string& getTransferEncoding() const; const std::string& getTransferEncoding() const;
std::shared_ptr<StreamFilter> getTransferEncodingStreamFilter() const; std::unique_ptr<StreamFilter> getTransferEncodingStreamFilter() const;
bool isContentEncodingSpecified() const; bool isContentEncodingSpecified() const;
const std::string& getContentEncoding() const; const std::string& getContentEncoding() const;
std::shared_ptr<StreamFilter> getContentEncodingStreamFilter() const; std::unique_ptr<StreamFilter> getContentEncodingStreamFilter() const;
int64_t getContentLength() const; int64_t getContentLength() const;

View File

@ -84,51 +84,44 @@
namespace aria2 { namespace aria2 {
namespace { namespace {
std::shared_ptr<StreamFilter> getTransferEncodingStreamFilter std::unique_ptr<StreamFilter> getTransferEncodingStreamFilter
(HttpResponse* httpResponse, (HttpResponse* httpResponse,
const std::shared_ptr<StreamFilter>& delegate = std::shared_ptr<StreamFilter>()) std::unique_ptr<StreamFilter> delegate = std::unique_ptr<StreamFilter>{})
{ {
std::shared_ptr<StreamFilter> filter;
if(httpResponse->isTransferEncodingSpecified()) { if(httpResponse->isTransferEncodingSpecified()) {
filter = httpResponse->getTransferEncodingStreamFilter(); auto filter = httpResponse->getTransferEncodingStreamFilter();
if(!filter) { if(!filter) {
throw DL_ABORT_EX throw DL_ABORT_EX
(fmt(EX_TRANSFER_ENCODING_NOT_SUPPORTED, (fmt(EX_TRANSFER_ENCODING_NOT_SUPPORTED,
httpResponse->getTransferEncoding().c_str())); httpResponse->getTransferEncoding().c_str()));
} }
filter->init(); filter->init();
filter->installDelegate(delegate); filter->installDelegate(std::move(delegate));
return filter;
} }
if(!filter) { return delegate;
filter = delegate;
}
return filter;
} }
} // namespace } // namespace
namespace { namespace {
std::shared_ptr<StreamFilter> getContentEncodingStreamFilter std::unique_ptr<StreamFilter> getContentEncodingStreamFilter
(HttpResponse* httpResponse, (HttpResponse* httpResponse,
const std::shared_ptr<StreamFilter>& delegate = std::shared_ptr<StreamFilter>()) std::unique_ptr<StreamFilter> delegate = std::unique_ptr<StreamFilter>{})
{ {
std::shared_ptr<StreamFilter> filter;
if(httpResponse->isContentEncodingSpecified()) { if(httpResponse->isContentEncodingSpecified()) {
filter = httpResponse->getContentEncodingStreamFilter(); auto filter = httpResponse->getContentEncodingStreamFilter();
if(!filter) { if(!filter) {
A2_LOG_INFO A2_LOG_INFO
(fmt("Content-Encoding %s is specified, but the current implementation" (fmt("Content-Encoding %s is specified, but the current implementation"
"doesn't support it. The decoding process is skipped and the" "doesn't support it. The decoding process is skipped and the"
"downloaded content will be still encoded.", "downloaded content will be still encoded.",
httpResponse->getContentEncoding().c_str())); httpResponse->getContentEncoding().c_str()));
} else {
filter->init();
filter->installDelegate(delegate);
} }
filter->init();
filter->installDelegate(std::move(delegate));
return filter;
} }
if(!filter) { return delegate;
filter = delegate;
}
return filter;
} }
} // namespace } // namespace
@ -311,11 +304,13 @@ bool HttpResponseCommand::executeInternal()
(httpResponse.get(), (httpResponse.get(),
getContentEncodingStreamFilter(httpResponse.get())); getContentEncodingStreamFilter(httpResponse.get()));
getDownloadEngine()->addCommand getDownloadEngine()->addCommand
(createHttpDownloadCommand(std::move(httpResponse), teFilter)); (createHttpDownloadCommand(std::move(httpResponse),
std::move(teFilter)));
} else { } else {
auto teFilter = getTransferEncodingStreamFilter(httpResponse.get()); auto teFilter = getTransferEncodingStreamFilter(httpResponse.get());
getDownloadEngine()->addCommand getDownloadEngine()->addCommand
(createHttpDownloadCommand(std::move(httpResponse), teFilter)); (createHttpDownloadCommand(std::move(httpResponse),
std::move(teFilter)));
} }
return true; return true;
} }
@ -375,7 +370,8 @@ bool HttpResponseCommand::handleDefaultEncoding
!getRequest()->isPipeliningEnabled()) { !getRequest()->isPipeliningEnabled()) {
auto teFilter = getTransferEncodingStreamFilter(httpResponse.get()); auto teFilter = getTransferEncodingStreamFilter(httpResponse.get());
checkEntry->pushNextCommand checkEntry->pushNextCommand
(createHttpDownloadCommand(std::move(httpResponse), teFilter)); (createHttpDownloadCommand(std::move(httpResponse),
std::move(teFilter)));
} else { } else {
getSegmentMan()->cancelSegment(getCuid()); getSegmentMan()->cancelSegment(getCuid());
getFileEntry()->poolRequest(getRequest()); getFileEntry()->poolRequest(getRequest());
@ -477,7 +473,8 @@ bool HttpResponseCommand::handleOtherEncoding
getSegmentMan()->getSegmentWithIndex(getCuid(), 0); getSegmentMan()->getSegmentWithIndex(getCuid(), 0);
getDownloadEngine()->addCommand getDownloadEngine()->addCommand
(createHttpDownloadCommand(std::move(httpResponse), streamFilter)); (createHttpDownloadCommand(std::move(httpResponse),
std::move(streamFilter)));
return true; return true;
} }
@ -492,7 +489,7 @@ bool HttpResponseCommand::skipResponseBody
(getCuid(), getRequest(), getFileEntry(), getRequestGroup(), (getCuid(), getRequest(), getFileEntry(), getRequestGroup(),
httpConnection_, std::move(httpResponse), httpConnection_, std::move(httpResponse),
getDownloadEngine(), getSocket()); getDownloadEngine(), getSocket());
command->installStreamFilter(filter); command->installStreamFilter(std::move(filter));
// If request method is HEAD or the response body is zero-length, // If request method is HEAD or the response body is zero-length,
// set command's status to real time so that avoid read check blocking // set command's status to real time so that avoid read check blocking
@ -510,11 +507,10 @@ bool HttpResponseCommand::skipResponseBody
} }
namespace { namespace {
bool decideFileAllocation bool decideFileAllocation(StreamFilter* filter)
(const std::shared_ptr<StreamFilter>& filter)
{ {
#ifdef HAVE_ZLIB #ifdef HAVE_ZLIB
for(std::shared_ptr<StreamFilter> f = filter; f; f = f->getDelegate()){ for(StreamFilter* f = filter; f; f = f->getDelegate().get()){
// Since the compressed file's length are returned in the response header // Since the compressed file's length are returned in the response header
// and the decompressed file size is unknown at this point, disable file // and the decompressed file size is unknown at this point, disable file
// allocation here. // allocation here.
@ -530,7 +526,7 @@ bool decideFileAllocation
std::unique_ptr<HttpDownloadCommand> std::unique_ptr<HttpDownloadCommand>
HttpResponseCommand::createHttpDownloadCommand HttpResponseCommand::createHttpDownloadCommand
(std::unique_ptr<HttpResponse> httpResponse, (std::unique_ptr<HttpResponse> httpResponse,
const std::shared_ptr<StreamFilter>& filter) std::unique_ptr<StreamFilter> filter)
{ {
auto command = make_unique<HttpDownloadCommand> auto command = make_unique<HttpDownloadCommand>
@ -541,11 +537,11 @@ HttpResponseCommand::createHttpDownloadCommand
command->setStartupIdleTime(getOption()->getAsInt(PREF_STARTUP_IDLE_TIME)); command->setStartupIdleTime(getOption()->getAsInt(PREF_STARTUP_IDLE_TIME));
command->setLowestDownloadSpeedLimit command->setLowestDownloadSpeedLimit
(getOption()->getAsInt(PREF_LOWEST_SPEED_LIMIT)); (getOption()->getAsInt(PREF_LOWEST_SPEED_LIMIT));
command->installStreamFilter(filter);
if(getRequestGroup()->isFileAllocationEnabled() && if(getRequestGroup()->isFileAllocationEnabled() &&
!decideFileAllocation(filter)) { !decideFileAllocation(filter.get())) {
getRequestGroup()->setFileAllocationEnabled(false); getRequestGroup()->setFileAllocationEnabled(false);
} }
command->installStreamFilter(std::move(filter));
getRequestGroup()->getURISelector()->tuneDownloadCommand getRequestGroup()->getURISelector()->tuneDownloadCommand
(getFileEntry()->getRemainingUris(), command.get()); (getFileEntry()->getRemainingUris(), command.get());

View File

@ -70,7 +70,7 @@ private:
std::unique_ptr<HttpDownloadCommand> std::unique_ptr<HttpDownloadCommand>
createHttpDownloadCommand createHttpDownloadCommand
(std::unique_ptr<HttpResponse> httpResponse, (std::unique_ptr<HttpResponse> httpResponse,
const std::shared_ptr<StreamFilter>& streamFilter); std::unique_ptr<StreamFilter> streamFilter);
void updateLastModifiedTime(const Time& lastModified); void updateLastModifiedTime(const Time& lastModified);

View File

@ -87,13 +87,13 @@ HttpSkipResponseCommand::HttpSkipResponseCommand
HttpSkipResponseCommand::~HttpSkipResponseCommand() {} HttpSkipResponseCommand::~HttpSkipResponseCommand() {}
void HttpSkipResponseCommand::installStreamFilter void HttpSkipResponseCommand::installStreamFilter
(const std::shared_ptr<StreamFilter>& streamFilter) (std::unique_ptr<StreamFilter> streamFilter)
{ {
if(!streamFilter) { if(!streamFilter) {
return; return;
} }
streamFilter->installDelegate(streamFilter_); streamFilter->installDelegate(std::move(streamFilter_));
streamFilter_ = streamFilter; streamFilter_ = std::move(streamFilter);
const std::string& name = streamFilter_->getName(); const std::string& name = streamFilter_->getName();
sinkFilterOnly_ = util::endsWith(name, SinkStreamFilter::NAME); sinkFilterOnly_ = util::endsWith(name, SinkStreamFilter::NAME);
} }

View File

@ -49,7 +49,7 @@ private:
std::unique_ptr<HttpResponse> httpResponse_; std::unique_ptr<HttpResponse> httpResponse_;
std::shared_ptr<StreamFilter> streamFilter_; std::unique_ptr<StreamFilter> streamFilter_;
bool sinkFilterOnly_; bool sinkFilterOnly_;
@ -75,7 +75,7 @@ public:
virtual ~HttpSkipResponseCommand(); virtual ~HttpSkipResponseCommand();
void installStreamFilter(const std::shared_ptr<StreamFilter>& streamFilter); void installStreamFilter(std::unique_ptr<StreamFilter> streamFilter);
void disableSocketCheck(); void disableSocketCheck();
}; };

View File

@ -36,19 +36,19 @@
namespace aria2 { namespace aria2 {
StreamFilter::StreamFilter StreamFilter::StreamFilter(std::unique_ptr<StreamFilter> delegate)
(const std::shared_ptr<StreamFilter>& delegate): : delegate_(std::move(delegate))
delegate_(delegate) {} {}
StreamFilter::~StreamFilter() {} StreamFilter::~StreamFilter() {}
bool StreamFilter::installDelegate(const std::shared_ptr<StreamFilter>& filter) bool StreamFilter::installDelegate(std::unique_ptr<StreamFilter> filter)
{ {
if(!delegate_) { if(!delegate_) {
delegate_ = filter; delegate_ = std::move(filter);
return true; return true;
} else { } else {
return delegate_->installDelegate(filter); return delegate_->installDelegate(std::move(filter));
} }
} }

View File

@ -48,10 +48,10 @@ class Segment;
// Interface for basic decoding functionality. // Interface for basic decoding functionality.
class StreamFilter { class StreamFilter {
private: private:
std::shared_ptr<StreamFilter> delegate_; std::unique_ptr<StreamFilter> delegate_;
public: public:
StreamFilter StreamFilter
(const std::shared_ptr<StreamFilter>& delegate = std::shared_ptr<StreamFilter>()); (std::unique_ptr<StreamFilter> delegate = std::unique_ptr<StreamFilter>{});
virtual ~StreamFilter(); virtual ~StreamFilter();
@ -75,9 +75,9 @@ public:
// tranfrom() invocation. // tranfrom() invocation.
virtual size_t getBytesProcessed() const = 0; virtual size_t getBytesProcessed() const = 0;
virtual bool installDelegate(const std::shared_ptr<StreamFilter>& filter); virtual bool installDelegate(std::unique_ptr<StreamFilter> filter);
std::shared_ptr<StreamFilter> getDelegate() const const std::unique_ptr<StreamFilter>& getDelegate() const
{ {
return delegate_; return delegate_;
} }

View File

@ -9,6 +9,7 @@
#include "ByteArrayDiskWriter.h" #include "ByteArrayDiskWriter.h"
#include "SinkStreamFilter.h" #include "SinkStreamFilter.h"
#include "MockSegment.h" #include "MockSegment.h"
#include "a2functional.h"
namespace aria2 { namespace aria2 {
@ -24,8 +25,7 @@ class ChunkedDecodingStreamFilterTest:public CppUnit::TestFixture {
CPPUNIT_TEST(testGetName); CPPUNIT_TEST(testGetName);
CPPUNIT_TEST_SUITE_END(); CPPUNIT_TEST_SUITE_END();
std::shared_ptr<ChunkedDecodingStreamFilter> filter_; std::unique_ptr<ChunkedDecodingStreamFilter> filter_;
std::shared_ptr<SinkStreamFilter> sinkFilter_;
std::shared_ptr<ByteArrayDiskWriter> writer_; std::shared_ptr<ByteArrayDiskWriter> writer_;
std::shared_ptr<Segment> segment_; std::shared_ptr<Segment> segment_;
@ -36,12 +36,12 @@ class ChunkedDecodingStreamFilterTest:public CppUnit::TestFixture {
public: public:
void setUp() void setUp()
{ {
writer_.reset(new ByteArrayDiskWriter()); writer_ = std::make_shared<ByteArrayDiskWriter>();
sinkFilter_.reset(new SinkStreamFilter()); auto sinkFilter = make_unique<SinkStreamFilter>();
filter_.reset(new ChunkedDecodingStreamFilter(sinkFilter_)); sinkFilter->init();
sinkFilter_->init(); filter_ = make_unique<ChunkedDecodingStreamFilter>(std::move(sinkFilter));
filter_->init(); filter_->init();
segment_.reset(new MockSegment()); segment_ = std::make_shared<MockSegment>();
} }
void testTransform(); void testTransform();

View File

@ -30,30 +30,29 @@ class GZipDecodingStreamFilterTest:public CppUnit::TestFixture {
public: public:
MockSegment2():positionToWrite_(0) {} MockSegment2():positionToWrite_(0) {}
virtual void updateWrittenLength(int32_t bytes) virtual void updateWrittenLength(int32_t bytes) override
{ {
positionToWrite_ += bytes; positionToWrite_ += bytes;
} }
virtual int64_t getPositionToWrite() const virtual int64_t getPositionToWrite() const override
{ {
return positionToWrite_; return positionToWrite_;
} }
}; };
std::shared_ptr<GZipDecodingStreamFilter> filter_; std::unique_ptr<GZipDecodingStreamFilter> filter_;
std::shared_ptr<SinkStreamFilter> sinkFilter_;
std::shared_ptr<ByteArrayDiskWriter> writer_; std::shared_ptr<ByteArrayDiskWriter> writer_;
std::shared_ptr<MockSegment2> segment_; std::shared_ptr<MockSegment2> segment_;
public: public:
void setUp() void setUp()
{ {
writer_.reset(new ByteArrayDiskWriter()); writer_ = std::make_shared<ByteArrayDiskWriter>();
sinkFilter_.reset(new SinkStreamFilter()); auto sinkFilter = make_unique<SinkStreamFilter>();
filter_.reset(new GZipDecodingStreamFilter(sinkFilter_)); sinkFilter->init();
sinkFilter_->init(); filter_ = make_unique<GZipDecodingStreamFilter>(std::move(sinkFilter));
filter_->init(); filter_->init();
segment_.reset(new MockSegment2()); segment_ = std::make_shared<MockSegment2>();
} }
void testTransform(); void testTransform();