diff --git a/src/DHTRoutingTableDeserializer.cc b/src/DHTRoutingTableDeserializer.cc index 640e5cb1..30c0c3c3 100644 --- a/src/DHTRoutingTableDeserializer.cc +++ b/src/DHTRoutingTableDeserializer.cc @@ -36,7 +36,7 @@ #include #include -#include +#include #include #include "DHTNode.h" @@ -56,27 +56,27 @@ DHTRoutingTableDeserializer::DHTRoutingTableDeserializer(int family): DHTRoutingTableDeserializer::~DHTRoutingTableDeserializer() {} -namespace { -void readBytes(unsigned char* buf, size_t buflen, - std::istream& in, size_t readlen) -{ - assert(readlen <= buflen); - in.read(reinterpret_cast(buf), readlen); -} -} // namespace - -#define CHECK_STREAM(in, length) \ - if(in.gcount() != length) { \ - throw DL_ABORT_EX \ - (fmt("Failed to load DHT routing table. cause:%s", \ - "Unexpected EOF")); \ - } \ - if(!in) { \ +#define FREAD_CHECK(ptr, count, fp) \ + if(fread((ptr), 1, (count), (fp)) != (count)) { \ throw DL_ABORT_EX("Failed to load DHT routing table."); \ } -void DHTRoutingTableDeserializer::deserialize(std::istream& in) +namespace { +void readBytes(unsigned char* buf, size_t buflen, + FILE* fp, size_t readlen) { + assert(readlen <= buflen); + FREAD_CHECK(buf, readlen, fp); +} +} // namespace + +void DHTRoutingTableDeserializer::deserialize(const std::string& filename) +{ + FILE* fp = a2fopen(utf8ToWChar(filename).c_str(), "rb"); + if(!fp) { + throw DL_ABORT_EX("Failed to load DHT routing table."); + } + auto_delete_r deleter(fp, fclose); char header[8]; memset(header, 0, sizeof(header)); // magic @@ -109,8 +109,7 @@ void DHTRoutingTableDeserializer::deserialize(std::istream& in) array_wrapper buf; // header - readBytes(buf, buf.size(), in, 8); - CHECK_STREAM(in, 8); + readBytes(buf, buf.size(), fp, 8); if(memcmp(header, buf, 8) == 0) { version = 3; } else if(memcmp(headerCompat, buf, 8) == 0) { @@ -125,37 +124,29 @@ void DHTRoutingTableDeserializer::deserialize(std::istream& in) uint64_t temp64; // time if(version == 2) { - in.read(reinterpret_cast(&temp32), sizeof(temp32)); - CHECK_STREAM(in, sizeof(temp32)); + FREAD_CHECK(&temp32, sizeof(temp32), fp); serializedTime_.setTimeInSec(ntohl(temp32)); // 4bytes reserved - readBytes(buf, buf.size(), in, 4); - CHECK_STREAM(in, 4); + readBytes(buf, buf.size(), fp, 4); } else { - in.read(reinterpret_cast(&temp64), sizeof(temp64)); - CHECK_STREAM(in, sizeof(temp64)); + FREAD_CHECK(&temp64, sizeof(temp64), fp); serializedTime_.setTimeInSec(ntoh64(temp64)); } // localnode // 8bytes reserved - readBytes(buf, buf.size(), in, 8); - CHECK_STREAM(in, 8); + readBytes(buf, buf.size(), fp, 8); // localnode ID - readBytes(buf, buf.size(), in, DHT_ID_LENGTH); - CHECK_STREAM(in, DHT_ID_LENGTH); + readBytes(buf, buf.size(), fp, DHT_ID_LENGTH); SharedHandle localNode(new DHTNode(buf)); // 4bytes reserved - readBytes(buf, buf.size(), in, 4); - CHECK_STREAM(in, 4); + readBytes(buf, buf.size(), fp, 4); // number of nodes - in.read(reinterpret_cast(&temp32), sizeof(temp32)); - CHECK_STREAM(in, sizeof(temp32)); + FREAD_CHECK(&temp32, sizeof(temp32), fp); uint32_t numNodes = ntohl(temp32); // 4bytes reserved - readBytes(buf, buf.size(), in, 4); - CHECK_STREAM(in, 4); + readBytes(buf, buf.size(), fp, 4); std::vector > nodes; // nodes @@ -163,45 +154,38 @@ void DHTRoutingTableDeserializer::deserialize(std::istream& in) for(size_t i = 0; i < numNodes; ++i) { // 1byte compact peer info length uint8_t peerInfoLen; - in >> peerInfoLen; + FREAD_CHECK(&peerInfoLen, sizeof(peerInfoLen), fp); if(peerInfoLen != compactlen) { // skip this entry - readBytes(buf, buf.size(), in, 7+48); - CHECK_STREAM(in, 7+48); + readBytes(buf, buf.size(), fp, 7+48); continue; } // 7bytes reserved - readBytes(buf, buf.size(), in, 7); - CHECK_STREAM(in, 7); + readBytes(buf, buf.size(), fp, 7); // compactlen bytes compact peer info - readBytes(buf, buf.size(), in, compactlen); - CHECK_STREAM(in, compactlen); + readBytes(buf, buf.size(), fp, compactlen); if(memcmp(zero, buf, compactlen) == 0) { // skip this entry - readBytes(buf, buf.size(), in, 48-compactlen); - CHECK_STREAM(in, 48-compactlen); + readBytes(buf, buf.size(), fp, 48-compactlen); continue; } std::pair peer = bittorrent::unpackcompact(buf, family_); if(peer.first.empty()) { // skip this entry - readBytes(buf, buf.size(), in, 48-compactlen); - CHECK_STREAM(in, 48-compactlen); + readBytes(buf, buf.size(), fp, 48-compactlen); continue; } // 24-compactlen bytes reserved - readBytes(buf, buf.size(), in, 24-compactlen); + readBytes(buf, buf.size(), fp, 24-compactlen); // node ID - readBytes(buf, buf.size(), in, DHT_ID_LENGTH); - CHECK_STREAM(in, DHT_ID_LENGTH); + readBytes(buf, buf.size(), fp, DHT_ID_LENGTH); SharedHandle node(new DHTNode(buf)); node->setIPAddress(peer.first); node->setPort(peer.second); // 4bytes reserved - readBytes(buf, buf.size(), in, 4); - CHECK_STREAM(in, 4); + readBytes(buf, buf.size(), fp, 4); nodes.push_back(node); } diff --git a/src/DHTRoutingTableDeserializer.h b/src/DHTRoutingTableDeserializer.h index a64b573e..a94acb94 100644 --- a/src/DHTRoutingTableDeserializer.h +++ b/src/DHTRoutingTableDeserializer.h @@ -38,7 +38,7 @@ #include "common.h" #include -#include +#include #include "SharedHandle.h" #include "TimeA2.h" @@ -76,7 +76,7 @@ public: return serializedTime_; } - void deserialize(std::istream& in); + void deserialize(const std::string& filename); }; } // namespace aria2 diff --git a/src/DHTSetup.cc b/src/DHTSetup.cc index 3d4ec816..01102095 100644 --- a/src/DHTSetup.cc +++ b/src/DHTSetup.cc @@ -99,11 +99,7 @@ void DHTSetup::setup e->getOption()->get(family == AF_INET?PREF_DHT_FILE_PATH: PREF_DHT_FILE_PATH6); try { - std::ifstream in(dhtFile.c_str(), std::ios::binary); - if(!in) { - throw DL_ABORT_EX("Could not open file"); - } - deserializer.deserialize(in); + deserializer.deserialize(dhtFile); localNode = deserializer.getLocalNode(); } catch(RecoverableException& e) { A2_LOG_ERROR_EX diff --git a/test/DHTRoutingTableDeserializerTest.cc b/test/DHTRoutingTableDeserializerTest.cc index 6cf74fdf..9fa5073a 100644 --- a/test/DHTRoutingTableDeserializerTest.cc +++ b/test/DHTRoutingTableDeserializerTest.cc @@ -54,11 +54,13 @@ void DHTRoutingTableDeserializerTest::testDeserialize() s.setLocalNode(localNode); s.setNodes(nodes); - std::stringstream ss; - s.serialize(ss); + std::string filename = A2_TEST_OUT_DIR"/aria2_DHTRoutingTableDeserializerTest_testDeserialize"; + std::ofstream outfile(filename.c_str(), std::ios::binary); + s.serialize(outfile); + outfile.close(); DHTRoutingTableDeserializer d(AF_INET); - d.deserialize(ss); + d.deserialize(filename); CPPUNIT_ASSERT(memcmp(localNode->getID(), d.getLocalNode()->getID(), DHT_ID_LENGTH) == 0); @@ -93,11 +95,13 @@ void DHTRoutingTableDeserializerTest::testDeserialize6() s.setLocalNode(localNode); s.setNodes(nodes); - std::stringstream ss; - s.serialize(ss); + std::string filename = A2_TEST_OUT_DIR"/aria2_DHTRoutingTableDeserializerTest_testDeserialize6"; + std::ofstream outfile(filename.c_str(), std::ios::binary); + s.serialize(outfile); + outfile.close(); DHTRoutingTableDeserializer d(AF_INET6); - d.deserialize(ss); + d.deserialize(filename); CPPUNIT_ASSERT(memcmp(localNode->getID(), d.getLocalNode()->getID(), DHT_ID_LENGTH) == 0);