/* */ #include "DHTBucketTree.h" #include #include #include "DHTBucket.h" #include "DHTNode.h" #include "a2functional.h" namespace aria2 { DHTBucketTreeNode::DHTBucketTreeNode (std::unique_ptr left, std::unique_ptr right) : parent_(nullptr), left_(std::move(left)), right_(std::move(right)) { resetRelation(); } DHTBucketTreeNode::DHTBucketTreeNode(std::shared_ptr bucket) : parent_(nullptr), bucket_(std::move(bucket)) { memcpy(minId_, bucket_->getMinID(), DHT_ID_LENGTH); memcpy(maxId_, bucket_->getMaxID(), DHT_ID_LENGTH); } DHTBucketTreeNode::~DHTBucketTreeNode() {} void DHTBucketTreeNode::resetRelation() { left_->setParent(this); right_->setParent(this); memcpy(minId_, left_->getMinId(), DHT_ID_LENGTH); memcpy(maxId_, right_->getMaxId(), DHT_ID_LENGTH); } DHTBucketTreeNode* DHTBucketTreeNode::dig(const unsigned char* key) { if(leaf()) { return nullptr; } if(left_->isInRange(key)) { return left_.get(); } else { return right_.get(); } } bool DHTBucketTreeNode::isInRange(const unsigned char* key) const { return !std::lexicographical_compare(&key[0], &key[DHT_ID_LENGTH], &minId_[0], &minId_[DHT_ID_LENGTH]) && !std::lexicographical_compare(&maxId_[0], &maxId_[DHT_ID_LENGTH], &key[0], &key[DHT_ID_LENGTH]); } void DHTBucketTreeNode::split() { left_ = make_unique(bucket_->split()); right_ = make_unique(bucket_); bucket_.reset(); resetRelation(); } namespace dht { DHTBucketTreeNode* findTreeNodeFor (DHTBucketTreeNode* root, const unsigned char* key) { if(root->leaf()) { return root; } else { return findTreeNodeFor(root->dig(key), key); } } std::shared_ptr findBucketFor (DHTBucketTreeNode* root, const unsigned char* key) { DHTBucketTreeNode* leaf = findTreeNodeFor(root, key); return leaf->getBucket(); } namespace { void collectNodes (std::vector >& nodes, const std::shared_ptr& bucket) { std::vector > goodNodes; bucket->getGoodNodes(goodNodes); nodes.insert(nodes.end(), goodNodes.begin(), goodNodes.end()); } } // namespace namespace { void collectDownwardLeftFirst (std::vector >& nodes, DHTBucketTreeNode* tnode) { if(tnode->leaf()) { collectNodes(nodes, tnode->getBucket()); } else { collectDownwardLeftFirst(nodes, tnode->getLeft()); if(nodes.size() < DHTBucket::K) { collectDownwardLeftFirst(nodes, tnode->getRight()); } } } } //namespace namespace { void collectDownwardRightFirst (std::vector >& nodes, DHTBucketTreeNode* tnode) { if(tnode->leaf()) { collectNodes(nodes, tnode->getBucket()); } else { collectDownwardRightFirst(nodes, tnode->getRight()); if(nodes.size() < DHTBucket::K) { collectDownwardRightFirst(nodes, tnode->getLeft()); } } } } //namespace namespace { void collectUpward (std::vector >& nodes, DHTBucketTreeNode* from) { while(1) { DHTBucketTreeNode* parent = from->getParent(); if(!parent) { break; } if(parent->getLeft() == from) { collectNodes(nodes, parent->getRight()->getBucket()); } else { collectNodes(nodes, parent->getLeft()->getBucket()); } from = parent; if(DHTBucket::K <= nodes.size()) { break; } } } } // namespace void findClosestKNodes (std::vector >& nodes, DHTBucketTreeNode* root, const unsigned char* key) { size_t nodesSize = nodes.size(); if(DHTBucket::K <= nodesSize) { return; } DHTBucketTreeNode* leaf = findTreeNodeFor(root, key); if(leaf == root) { collectNodes(nodes, leaf->getBucket()); } else { DHTBucketTreeNode* parent = leaf->getParent(); if(parent->getLeft() == leaf) { collectDownwardLeftFirst(nodes, parent); } else { collectDownwardRightFirst(nodes, parent); } if(nodes.size() < DHTBucket::K) { collectUpward(nodes, parent); } } if(DHTBucket::K < nodes.size()) { nodes.erase(nodes.begin()+DHTBucket::K, nodes.end()); } } void enumerateBucket (std::vector >& buckets, DHTBucketTreeNode* root) { if(root->leaf()) { buckets.push_back(root->getBucket()); } else { enumerateBucket(buckets, root->getLeft()); enumerateBucket(buckets, root->getRight()); } } } // namespace dht } // namespace aria2