diff --git a/src/DownloadContext.cc b/src/DownloadContext.cc index b73845fe..d2be45da 100644 --- a/src/DownloadContext.cc +++ b/src/DownloadContext.cc @@ -128,22 +128,31 @@ void DownloadContext::setFilePathWithIndex } } -void DownloadContext::setFileFilter(IntSequence seq) +void DownloadContext::setFileFilter(SegList& sgl) { - std::vector fileIndexes = seq.flush(); - std::sort(fileIndexes.begin(), fileIndexes.end()); - fileIndexes.erase(std::unique(fileIndexes.begin(), fileIndexes.end()), - fileIndexes.end()); - - bool selectAll = fileIndexes.empty() || fileEntries_.size() == 1; - - int32_t index = 1; - for(std::vector >::const_iterator i = - fileEntries_.begin(), eoi = fileEntries_.end(); - i != eoi; ++i, ++index) { - (*i)->setRequested - (selectAll || - std::binary_search(fileIndexes.begin(), fileIndexes.end(), index)); + sgl.normalize(); + if(!sgl.hasNext() || fileEntries_.size() == 1) { + std::for_each(fileEntries_.begin(), fileEntries_.end(), + std::bind2nd(mem_fun_sh(&FileEntry::setRequested), true)); + return; + } + assert(sgl.peek() >= 1); + size_t i = 0; + while(i < fileEntries_.size() && sgl.hasNext()) { + size_t idx = sgl.peek()-1; + if(i == idx) { + fileEntries_[i]->setRequested(true); + ++i; + sgl.next(); + } else if(i < idx) { + fileEntries_[i]->setRequested(false); + ++i; + } else { + sgl.next(); + } + } + for(; i < fileEntries_.size(); ++i) { + fileEntries_[i]->setRequested(false); } } diff --git a/src/DownloadContext.h b/src/DownloadContext.h index db10723f..f59ef72e 100644 --- a/src/DownloadContext.h +++ b/src/DownloadContext.h @@ -45,7 +45,7 @@ #include "TimerA2.h" #include "A2STR.h" #include "ValueBase.h" -#include "IntSequence.h" +#include "SegList.h" namespace aria2 { @@ -179,7 +179,7 @@ public: ownerRequestGroup_ = owner; } - void setFileFilter(IntSequence seq); + void setFileFilter(SegList& sgl); // Sets file path for specified index. index starts from 1. The // index is the same used in setFileFilter(). path is not escaped by diff --git a/src/SegList.h b/src/SegList.h index 8115d038..ef6923fc 100644 --- a/src/SegList.h +++ b/src/SegList.h @@ -49,6 +49,13 @@ public: : index_(0), val_(std::numeric_limits::min()) {} + void clear() + { + seg_.clear(); + index_ = 0; + val_ = std::numeric_limits::min(); + } + // Transforms list of segments so that they are sorted ascending // order of starting value and intersecting and touching segments // are all merged into one. This function resets current position. @@ -107,6 +114,19 @@ public: } return res; } + + // Returns next value. Current position is not advanced. If + // this fuction is called when hasNext() returns false, returns 0. + T peek() const + { + T res; + if(index_ < seg_.size()) { + res = val_; + } else { + res = 0; + } + return res; + } private: std::vector > seg_; size_t index_; diff --git a/src/download_helper.cc b/src/download_helper.cc index 1581aca8..8a71d1a7 100644 --- a/src/download_helper.cc +++ b/src/download_helper.cc @@ -62,6 +62,7 @@ #include "ByteArrayDiskWriterFactory.h" #include "MetadataInfo.h" #include "OptionParser.h" +#include "SegList.h" #ifdef ENABLE_BITTORRENT # include "bittorrent_helper.h" # include "BtConstants.h" @@ -183,7 +184,9 @@ createBtRequestGroup(const std::string& torrentFilePath, if(adjustAnnounceUri) { bittorrent::adjustAnnounceUri(bittorrent::getTorrentAttrs(dctx), option); } - dctx->setFileFilter(util::parseIntRange(option->get(PREF_SELECT_FILE))); + SegList sgl; + util::parseIntSegments(sgl, option->get(PREF_SELECT_FILE)); + dctx->setFileFilter(sgl); std::istringstream indexOutIn(option->get(PREF_INDEX_OUT)); std::map indexPathMap = util::createIndexPathMap(indexOutIn); diff --git a/src/util.cc b/src/util.cc index b8e8782c..52588932 100644 --- a/src/util.cc +++ b/src/util.cc @@ -774,6 +774,38 @@ IntSequence parseIntRange(const std::string& src) return values; } +void parseIntSegments(SegList& sgl, const std::string& src) +{ + for(std::string::const_iterator i = src.begin(), eoi = src.end(); i != eoi;) { + std::string::const_iterator j = i; + while(j != eoi && *j != ',') { + ++j; + } + if(j == i) { + ++i; + continue; + } + std::string::const_iterator p = i; + while(p != j && *p != '-') { + ++p; + } + if(p == j) { + int a = parseInt(std::string(i, j)); + sgl.add(a, a+1); + } else if(p == i || p+1 == j) { + throw DL_ABORT_EX(fmt(MSG_INCOMPLETE_RANGE, std::string(i, j).c_str())); + } else { + int a = parseInt(std::string(i, p)); + int b = parseInt(std::string(p+1, j)); + sgl.add(a, b+1); + } + if(j == eoi) { + break; + } + i = j+1; + } +} + namespace { void computeHeadPieces (std::vector& indexes, diff --git a/src/util.h b/src/util.h index 1fffd10a..c3aec6b2 100644 --- a/src/util.h +++ b/src/util.h @@ -55,6 +55,7 @@ #include "a2time.h" #include "a2netcompat.h" #include "a2functional.h" +#include "SegList.h" namespace aria2 { @@ -220,6 +221,8 @@ uint64_t parseULLInt(const std::string& s, int base = 10); IntSequence parseIntRange(const std::string& src); +void parseIntSegments(SegList& sgl, const std::string& src); + // Parses string which specifies the range of piece index for higher // priority and appends those indexes into result. The input string // src can contain 2 keywords "head" and "tail". To include both diff --git a/test/BittorrentHelperTest.cc b/test/BittorrentHelperTest.cc index 832e1997..c2ce9c9e 100644 --- a/test/BittorrentHelperTest.cc +++ b/test/BittorrentHelperTest.cc @@ -645,16 +645,20 @@ void BittorrentHelperTest::testSetFileFilter_single() load(A2_TEST_DIR"/single.torrent", dctx, option_); CPPUNIT_ASSERT(dctx->getFirstFileEntry()->isRequested()); - - dctx->setFileFilter(util::parseIntRange("")); + SegList sgl; + dctx->setFileFilter(sgl); CPPUNIT_ASSERT(dctx->getFirstFileEntry()->isRequested()); - dctx->setFileFilter(util::parseIntRange("1")); + sgl.clear(); + sgl.add(1, 2); + dctx->setFileFilter(sgl); CPPUNIT_ASSERT(dctx->getFirstFileEntry()->isRequested()); // For single file torrent, file is always selected whatever range // is passed. - dctx->setFileFilter(util::parseIntRange("2")); + sgl.clear(); + sgl.add(2, 3); + dctx->setFileFilter(sgl); CPPUNIT_ASSERT(dctx->getFirstFileEntry()->isRequested()); } @@ -666,19 +670,25 @@ void BittorrentHelperTest::testSetFileFilter_multi() CPPUNIT_ASSERT(dctx->getFileEntries()[0]->isRequested()); CPPUNIT_ASSERT(dctx->getFileEntries()[1]->isRequested()); - dctx->setFileFilter(util::parseIntRange("")); + SegList sgl; + dctx->setFileFilter(sgl); CPPUNIT_ASSERT(dctx->getFileEntries()[0]->isRequested()); CPPUNIT_ASSERT(dctx->getFileEntries()[1]->isRequested()); - dctx->setFileFilter(util::parseIntRange("2")); + sgl.add(2, 3); + dctx->setFileFilter(sgl); CPPUNIT_ASSERT(!dctx->getFileEntries()[0]->isRequested()); CPPUNIT_ASSERT(dctx->getFileEntries()[1]->isRequested()); - dctx->setFileFilter(util::parseIntRange("3")); + sgl.clear(); + sgl.add(3, 4); + dctx->setFileFilter(sgl); CPPUNIT_ASSERT(!dctx->getFileEntries()[0]->isRequested()); CPPUNIT_ASSERT(!dctx->getFileEntries()[1]->isRequested()); - dctx->setFileFilter(util::parseIntRange("1,2")); + sgl.clear(); + util::parseIntSegments(sgl, "1,2"); + dctx->setFileFilter(sgl); CPPUNIT_ASSERT(dctx->getFileEntries()[0]->isRequested()); CPPUNIT_ASSERT(dctx->getFileEntries()[1]->isRequested()); } diff --git a/test/DownloadContextTest.cc b/test/DownloadContextTest.cc index 384264a3..92bc74e9 100644 --- a/test/DownloadContextTest.cc +++ b/test/DownloadContextTest.cc @@ -14,12 +14,14 @@ class DownloadContextTest:public CppUnit::TestFixture { CPPUNIT_TEST(testGetPieceHash); CPPUNIT_TEST(testGetNumPieces); CPPUNIT_TEST(testGetBasePath); + CPPUNIT_TEST(testSetFileFilter); CPPUNIT_TEST_SUITE_END(); public: void testFindFileEntryByOffset(); void testGetPieceHash(); void testGetNumPieces(); void testGetBasePath(); + void testSetFileFilter(); }; @@ -73,4 +75,28 @@ void DownloadContextTest::testGetBasePath() CPPUNIT_ASSERT_EQUAL(std::string("aria2.tar.bz2"), ctx.getBasePath()); } +void DownloadContextTest::testSetFileFilter() +{ + DownloadContext ctx; + std::vector > files; + for(int i = 0; i < 10; ++i) { + files.push_back(SharedHandle(new FileEntry("file", 1, i))); + } + ctx.setFileEntries(files.begin(), files.end()); + SegList sgl; + util::parseIntSegments(sgl, "2-4,6-8"); + ctx.setFileFilter(sgl); + const std::vector >& res = ctx.getFileEntries(); + CPPUNIT_ASSERT(!res[0]->isRequested()); + CPPUNIT_ASSERT(res[1]->isRequested()); + CPPUNIT_ASSERT(res[2]->isRequested()); + CPPUNIT_ASSERT(res[3]->isRequested()); + CPPUNIT_ASSERT(!res[4]->isRequested()); + CPPUNIT_ASSERT(res[5]->isRequested()); + CPPUNIT_ASSERT(res[6]->isRequested()); + CPPUNIT_ASSERT(res[7]->isRequested()); + CPPUNIT_ASSERT(!res[8]->isRequested()); + CPPUNIT_ASSERT(!res[9]->isRequested()); +} + } // namespace aria2 diff --git a/test/SegListTest.cc b/test/SegListTest.cc index d5d3819b..aff3ee9e 100644 --- a/test/SegListTest.cc +++ b/test/SegListTest.cc @@ -8,10 +8,14 @@ class SegListTest:public CppUnit::TestFixture { CPPUNIT_TEST_SUITE(SegListTest); CPPUNIT_TEST(testNext); + CPPUNIT_TEST(testPeek); + CPPUNIT_TEST(testClear); CPPUNIT_TEST(testNormalize); CPPUNIT_TEST_SUITE_END(); public: void testNext(); + void testPeek(); + void testClear(); void testNormalize(); }; @@ -40,6 +44,32 @@ void SegListTest::testNext() CPPUNIT_ASSERT_EQUAL(0, sgl.next()); } +void SegListTest::testPeek() +{ + SegList sgl; + sgl.add(1, 3); + sgl.add(4, 5); + CPPUNIT_ASSERT_EQUAL(1, sgl.peek()); + CPPUNIT_ASSERT_EQUAL(1, sgl.peek()); + CPPUNIT_ASSERT_EQUAL(1, sgl.next()); + CPPUNIT_ASSERT_EQUAL(2, sgl.peek()); + CPPUNIT_ASSERT_EQUAL(2, sgl.next()); + CPPUNIT_ASSERT_EQUAL(4, sgl.peek()); + CPPUNIT_ASSERT_EQUAL(4, sgl.next()); + CPPUNIT_ASSERT(!sgl.hasNext()); +} + +void SegListTest::testClear() +{ + SegList sgl; + sgl.add(1, 3); + CPPUNIT_ASSERT_EQUAL(1, sgl.next()); + sgl.clear(); + CPPUNIT_ASSERT(!sgl.hasNext()); + sgl.add(2, 3); + CPPUNIT_ASSERT_EQUAL(2, sgl.next()); +} + void SegListTest::testNormalize() { SegList sgl; diff --git a/test/UtilTest.cc b/test/UtilTest.cc index 141435c8..8a4fec75 100644 --- a/test/UtilTest.cc +++ b/test/UtilTest.cc @@ -47,6 +47,8 @@ class UtilTest:public CppUnit::TestFixture { CPPUNIT_TEST(testConvertBitfield); CPPUNIT_TEST(testParseIntRange); CPPUNIT_TEST(testParseIntRange_invalidRange); + CPPUNIT_TEST(testParseIntSegments); + CPPUNIT_TEST(testParseIntSegments_invalidRange); CPPUNIT_TEST(testParseInt); CPPUNIT_TEST(testParseUInt); CPPUNIT_TEST(testParseLLInt); @@ -107,6 +109,8 @@ public: void testConvertBitfield(); void testParseIntRange(); void testParseIntRange_invalidRange(); + void testParseIntSegments(); + void testParseIntSegments_invalidRange(); void testParseInt(); void testParseUInt(); void testParseLLInt(); @@ -742,6 +746,77 @@ void UtilTest::testParseIntRange_invalidRange() } } +void UtilTest::testParseIntSegments() +{ + SegList sgl; + util::parseIntSegments(sgl, "1,3-8,10"); + + CPPUNIT_ASSERT(sgl.hasNext()); + CPPUNIT_ASSERT_EQUAL(1, sgl.next()); + CPPUNIT_ASSERT(sgl.hasNext()); + CPPUNIT_ASSERT_EQUAL(3, sgl.next()); + CPPUNIT_ASSERT(sgl.hasNext()); + CPPUNIT_ASSERT_EQUAL(4, sgl.next()); + CPPUNIT_ASSERT(sgl.hasNext()); + CPPUNIT_ASSERT_EQUAL(5, sgl.next()); + CPPUNIT_ASSERT(sgl.hasNext()); + CPPUNIT_ASSERT_EQUAL(6, sgl.next()); + CPPUNIT_ASSERT(sgl.hasNext()); + CPPUNIT_ASSERT_EQUAL(7, sgl.next()); + CPPUNIT_ASSERT(sgl.hasNext()); + CPPUNIT_ASSERT_EQUAL(8, sgl.next()); + CPPUNIT_ASSERT(sgl.hasNext()); + CPPUNIT_ASSERT_EQUAL(10, sgl.next()); + CPPUNIT_ASSERT(!sgl.hasNext()); + CPPUNIT_ASSERT_EQUAL(0, sgl.next()); + + sgl.clear(); + util::parseIntSegments(sgl, ",,,1,,,3,,,"); + CPPUNIT_ASSERT_EQUAL(1, sgl.next()); + CPPUNIT_ASSERT_EQUAL(3, sgl.next()); + CPPUNIT_ASSERT(!sgl.hasNext()); +} + +void UtilTest::testParseIntSegments_invalidRange() +{ + try { + SegList sgl; + util::parseIntSegments(sgl, "-1"); + CPPUNIT_FAIL("exception must be thrown."); + } catch(Exception& e) { + } + try { + SegList sgl; + util::parseIntSegments(sgl, "1-"); + CPPUNIT_FAIL("exception must be thrown."); + } catch(Exception& e) { + } + try { + SegList sgl; + util::parseIntSegments(sgl, "2147483648"); + CPPUNIT_FAIL("exception must be thrown."); + } catch(Exception& e) { + } + try { + SegList sgl; + util::parseIntSegments(sgl, "2147483647-2147483648"); + CPPUNIT_FAIL("exception must be thrown."); + } catch(Exception& e) { + } + try { + SegList sgl; + util::parseIntSegments(sgl, "1-2x"); + CPPUNIT_FAIL("exception must be thrown."); + } catch(Exception& e) { + } + try { + SegList sgl; + util::parseIntSegments(sgl, "3x-4"); + CPPUNIT_FAIL("exception must be thrown."); + } catch(Exception& e) { + } +} + void UtilTest::testParseInt() { CPPUNIT_ASSERT_EQUAL(-1, util::parseInt(" -1 "));