Rewritten DHTRoutingTableDeserializer using stdio instead of stream.

pull/1/head
Tatsuhiro Tsujikawa 2011-08-05 20:17:19 +09:00
parent 5eb338ad87
commit f141cd4228
4 changed files with 49 additions and 65 deletions

View File

@ -36,7 +36,7 @@
#include <cstring>
#include <cassert>
#include <istream>
#include <cstdio>
#include <utility>
#include "DHTNode.h"
@ -56,27 +56,27 @@ DHTRoutingTableDeserializer::DHTRoutingTableDeserializer(int family):
DHTRoutingTableDeserializer::~DHTRoutingTableDeserializer() {}
namespace {
void readBytes(unsigned char* buf, size_t buflen,
std::istream& in, size_t readlen)
{
assert(readlen <= buflen);
in.read(reinterpret_cast<char*>(buf), readlen);
}
} // namespace
#define CHECK_STREAM(in, length) \
if(in.gcount() != length) { \
throw DL_ABORT_EX \
(fmt("Failed to load DHT routing table. cause:%s", \
"Unexpected EOF")); \
} \
if(!in) { \
#define FREAD_CHECK(ptr, count, fp) \
if(fread((ptr), 1, (count), (fp)) != (count)) { \
throw DL_ABORT_EX("Failed to load DHT routing table."); \
}
void DHTRoutingTableDeserializer::deserialize(std::istream& in)
namespace {
void readBytes(unsigned char* buf, size_t buflen,
FILE* fp, size_t readlen)
{
assert(readlen <= buflen);
FREAD_CHECK(buf, readlen, fp);
}
} // namespace
void DHTRoutingTableDeserializer::deserialize(const std::string& filename)
{
FILE* fp = a2fopen(utf8ToWChar(filename).c_str(), "rb");
if(!fp) {
throw DL_ABORT_EX("Failed to load DHT routing table.");
}
auto_delete_r<FILE*, int> deleter(fp, fclose);
char header[8];
memset(header, 0, sizeof(header));
// magic
@ -109,8 +109,7 @@ void DHTRoutingTableDeserializer::deserialize(std::istream& in)
array_wrapper<unsigned char, 255> buf;
// header
readBytes(buf, buf.size(), in, 8);
CHECK_STREAM(in, 8);
readBytes(buf, buf.size(), fp, 8);
if(memcmp(header, buf, 8) == 0) {
version = 3;
} else if(memcmp(headerCompat, buf, 8) == 0) {
@ -125,37 +124,29 @@ void DHTRoutingTableDeserializer::deserialize(std::istream& in)
uint64_t temp64;
// time
if(version == 2) {
in.read(reinterpret_cast<char*>(&temp32), sizeof(temp32));
CHECK_STREAM(in, sizeof(temp32));
FREAD_CHECK(&temp32, sizeof(temp32), fp);
serializedTime_.setTimeInSec(ntohl(temp32));
// 4bytes reserved
readBytes(buf, buf.size(), in, 4);
CHECK_STREAM(in, 4);
readBytes(buf, buf.size(), fp, 4);
} else {
in.read(reinterpret_cast<char*>(&temp64), sizeof(temp64));
CHECK_STREAM(in, sizeof(temp64));
FREAD_CHECK(&temp64, sizeof(temp64), fp);
serializedTime_.setTimeInSec(ntoh64(temp64));
}
// localnode
// 8bytes reserved
readBytes(buf, buf.size(), in, 8);
CHECK_STREAM(in, 8);
readBytes(buf, buf.size(), fp, 8);
// localnode ID
readBytes(buf, buf.size(), in, DHT_ID_LENGTH);
CHECK_STREAM(in, DHT_ID_LENGTH);
readBytes(buf, buf.size(), fp, DHT_ID_LENGTH);
SharedHandle<DHTNode> localNode(new DHTNode(buf));
// 4bytes reserved
readBytes(buf, buf.size(), in, 4);
CHECK_STREAM(in, 4);
readBytes(buf, buf.size(), fp, 4);
// number of nodes
in.read(reinterpret_cast<char*>(&temp32), sizeof(temp32));
CHECK_STREAM(in, sizeof(temp32));
FREAD_CHECK(&temp32, sizeof(temp32), fp);
uint32_t numNodes = ntohl(temp32);
// 4bytes reserved
readBytes(buf, buf.size(), in, 4);
CHECK_STREAM(in, 4);
readBytes(buf, buf.size(), fp, 4);
std::vector<SharedHandle<DHTNode> > nodes;
// nodes
@ -163,45 +154,38 @@ void DHTRoutingTableDeserializer::deserialize(std::istream& in)
for(size_t i = 0; i < numNodes; ++i) {
// 1byte compact peer info length
uint8_t peerInfoLen;
in >> peerInfoLen;
FREAD_CHECK(&peerInfoLen, sizeof(peerInfoLen), fp);
if(peerInfoLen != compactlen) {
// skip this entry
readBytes(buf, buf.size(), in, 7+48);
CHECK_STREAM(in, 7+48);
readBytes(buf, buf.size(), fp, 7+48);
continue;
}
// 7bytes reserved
readBytes(buf, buf.size(), in, 7);
CHECK_STREAM(in, 7);
readBytes(buf, buf.size(), fp, 7);
// compactlen bytes compact peer info
readBytes(buf, buf.size(), in, compactlen);
CHECK_STREAM(in, compactlen);
readBytes(buf, buf.size(), fp, compactlen);
if(memcmp(zero, buf, compactlen) == 0) {
// skip this entry
readBytes(buf, buf.size(), in, 48-compactlen);
CHECK_STREAM(in, 48-compactlen);
readBytes(buf, buf.size(), fp, 48-compactlen);
continue;
}
std::pair<std::string, uint16_t> peer =
bittorrent::unpackcompact(buf, family_);
if(peer.first.empty()) {
// skip this entry
readBytes(buf, buf.size(), in, 48-compactlen);
CHECK_STREAM(in, 48-compactlen);
readBytes(buf, buf.size(), fp, 48-compactlen);
continue;
}
// 24-compactlen bytes reserved
readBytes(buf, buf.size(), in, 24-compactlen);
readBytes(buf, buf.size(), fp, 24-compactlen);
// node ID
readBytes(buf, buf.size(), in, DHT_ID_LENGTH);
CHECK_STREAM(in, DHT_ID_LENGTH);
readBytes(buf, buf.size(), fp, DHT_ID_LENGTH);
SharedHandle<DHTNode> node(new DHTNode(buf));
node->setIPAddress(peer.first);
node->setPort(peer.second);
// 4bytes reserved
readBytes(buf, buf.size(), in, 4);
CHECK_STREAM(in, 4);
readBytes(buf, buf.size(), fp, 4);
nodes.push_back(node);
}

View File

@ -38,7 +38,7 @@
#include "common.h"
#include <vector>
#include <iosfwd>
#include <string>
#include "SharedHandle.h"
#include "TimeA2.h"
@ -76,7 +76,7 @@ public:
return serializedTime_;
}
void deserialize(std::istream& in);
void deserialize(const std::string& filename);
};
} // namespace aria2

View File

@ -99,11 +99,7 @@ void DHTSetup::setup
e->getOption()->get(family == AF_INET?PREF_DHT_FILE_PATH:
PREF_DHT_FILE_PATH6);
try {
std::ifstream in(dhtFile.c_str(), std::ios::binary);
if(!in) {
throw DL_ABORT_EX("Could not open file");
}
deserializer.deserialize(in);
deserializer.deserialize(dhtFile);
localNode = deserializer.getLocalNode();
} catch(RecoverableException& e) {
A2_LOG_ERROR_EX

View File

@ -54,11 +54,13 @@ void DHTRoutingTableDeserializerTest::testDeserialize()
s.setLocalNode(localNode);
s.setNodes(nodes);
std::stringstream ss;
s.serialize(ss);
std::string filename = A2_TEST_OUT_DIR"/aria2_DHTRoutingTableDeserializerTest_testDeserialize";
std::ofstream outfile(filename.c_str(), std::ios::binary);
s.serialize(outfile);
outfile.close();
DHTRoutingTableDeserializer d(AF_INET);
d.deserialize(ss);
d.deserialize(filename);
CPPUNIT_ASSERT(memcmp(localNode->getID(), d.getLocalNode()->getID(),
DHT_ID_LENGTH) == 0);
@ -93,11 +95,13 @@ void DHTRoutingTableDeserializerTest::testDeserialize6()
s.setLocalNode(localNode);
s.setNodes(nodes);
std::stringstream ss;
s.serialize(ss);
std::string filename = A2_TEST_OUT_DIR"/aria2_DHTRoutingTableDeserializerTest_testDeserialize6";
std::ofstream outfile(filename.c_str(), std::ios::binary);
s.serialize(outfile);
outfile.close();
DHTRoutingTableDeserializer d(AF_INET6);
d.deserialize(ss);
d.deserialize(filename);
CPPUNIT_ASSERT(memcmp(localNode->getID(), d.getLocalNode()->getID(),
DHT_ID_LENGTH) == 0);