/* */ #include "PeerMessageUtil.h" #include "DlAbortEx.h" #include "Util.h" #include PeerMessage* PeerMessageUtil::createPeerMessage(const char* msg, int len) { PeerMessage* peerMessage; if(len == 0) { // keep-alive peerMessage = new PeerMessage(); peerMessage->setId(PeerMessage::KEEP_ALIVE); return peerMessage; } int id = getId(msg); switch(id) { case PeerMessage::CHOKE: case PeerMessage::UNCHOKE: case PeerMessage::INTERESTED: case PeerMessage::NOT_INTERESTED: peerMessage = createBasicMessage(id, msg, len); break; case PeerMessage::HAVE: peerMessage = createHaveMessage(id, msg, len); break; case PeerMessage::BITFIELD: peerMessage = createBitfieldMessage(id, msg, len); break; case PeerMessage::REQUEST: case PeerMessage::CANCEL: peerMessage = createRequestCancelMessage(id, msg, len); break; case PeerMessage::PIECE: peerMessage = createPieceMessage(id, msg, len); break; default: throw new DlAbortEx("invalid message id. id = %d", id); } return peerMessage; } PeerMessage* PeerMessageUtil::createBasicMessage(int id, const char* msg, int len) { if(len != 1) { throw new DlAbortEx("invalid payload size for ID%d, size = %d. It should be %d", id, len, 1); } PeerMessage* peerMessage = new PeerMessage(); peerMessage->setId(id); return peerMessage; } PeerMessage* PeerMessageUtil::createHaveMessage(int id, const char* msg, int len) { if(len != 5) { throw new DlAbortEx("invalid payload size for ID%d, size = %d. It should be %d", id, len, 5); } PeerMessage* peerMessage = new PeerMessage(); peerMessage->setId(id); peerMessage->setIndex(getIntParam(msg, 1)); return peerMessage; } PeerMessage* PeerMessageUtil::createBitfieldMessage(int id, const char* msg, int len) { if(len <= 1) { throw new DlAbortEx("invalid payload size for ID%d, size = %d. It should be greater than %d", id, len, 1); } PeerMessage* peerMessage = new PeerMessage(); peerMessage->setId(id); peerMessage->setBitfield((unsigned char*)msg+1, len-1); return peerMessage; } PeerMessage* PeerMessageUtil::createRequestCancelMessage(int id, const char* msg, int len) { if(len != 13) { throw new DlAbortEx("invalid payload size for ID%d, size = %d. It should be %d", id, len, 13); } PeerMessage* peerMessage = new PeerMessage(); peerMessage->setId(id); peerMessage->setIndex(getIntParam(msg, 1)); peerMessage->setBegin(getIntParam(msg, 5)); peerMessage->setLength(getIntParam(msg, 9)); return peerMessage; } PeerMessage* PeerMessageUtil::createPieceMessage(int id, const char* msg, int len) { if(len <= 9) { throw new DlAbortEx("invalid payload size for ID%d, size = %d. It should be greater than %d", id, len, 9); } PeerMessage* peerMessage = new PeerMessage(); peerMessage->setId(id); peerMessage->setIndex(getIntParam(msg, 1)); peerMessage->setBegin(getIntParam(msg, 5)); peerMessage->setBlock(msg+9, len-9); return peerMessage; } int PeerMessageUtil::getId(const char* msg) { return (int)msg[0]; } int PeerMessageUtil::getIntParam(const char* msg, int offset) { int nParam; memcpy(&nParam, msg+offset, 4); return ntohl(nParam); } void PeerMessageUtil::checkIndex(const PeerMessage* message, int pieces) { if(!(0 <= message->getIndex() && message->getIndex() < pieces)) { throw new DlAbortEx("invalid index = %d", message->getIndex()); } } void PeerMessageUtil::checkBegin(const PeerMessage* message, int pieceLength) { if(!(0 <= message->getBegin() && message->getBegin() < pieceLength)) { throw new DlAbortEx("invalid begin = %d", message->getBegin()); } } void PeerMessageUtil::checkPieceOffset(const PeerMessage* message, int pieceLength, int pieces, long long int totalSize) { if(!(0 <= message->getBegin() && 0 < message->getLength())) { throw new DlAbortEx("invalid offset, begin = %d, length = %d", message->getBegin(), message->getLength()); } int offset = message->getBegin()+message->getLength(); int currentPieceLength; if(message->getIndex()+1 == pieces) { currentPieceLength = pieceLength-(pieces*pieceLength-totalSize); } else { currentPieceLength = pieceLength; } if(!(0 < offset && offset <= currentPieceLength)) { throw new DlAbortEx("invalid offset, begin = %d, length = %d", message->getBegin(), message->getLength()); } } void PeerMessageUtil::checkLength(const PeerMessage* message) { if(message->getLength() > 128*1024) { throw new DlAbortEx("too large length %d > 128KB", message->getLength()); } } void PeerMessageUtil::checkBitfield(const PeerMessage* message, int pieces) { if(!(message->getBitfieldLength() == BITFIELD_LEN_FROM_PIECES(pieces))) { throw new DlAbortEx("invalid bitfield length = %d", message->getBitfieldLength()); } char lastbyte = message->getBitfield()[message->getBitfieldLength()-1]; for(int i = 0; i < 8-pieces%8 && pieces%8 != 0; i++) { if(!(((lastbyte >> i) & 1) == 0)) { throw new DlAbortEx("invalid bitfield"); } } } void PeerMessageUtil::checkIntegrity(const PeerMessage* message, int pieceLength, int pieces, long long int totalSize) { // 0 <= index < pieces // 0 <= begin < pieceLength // 0 < begin+length <= pieceLength // len of bitfield == pieces/8+(pieces%8 ? 1 : 0) // for(int i = 0; i < 8-pieces%8; i++) { ((lastbyteofbitfield >> i) & 1) == 0 } switch(message->getId()) { case PeerMessage::KEEP_ALIVE: case PeerMessage::CHOKE: case PeerMessage::UNCHOKE: case PeerMessage::INTERESTED: case PeerMessage::NOT_INTERESTED: break; case PeerMessage::HAVE: checkIndex(message, pieces); break; case PeerMessage::BITFIELD: checkBitfield(message, pieces); break; case PeerMessage::REQUEST: case PeerMessage::CANCEL: checkIndex(message, pieces); checkBegin(message, pieceLength); checkLength(message); checkPieceOffset(message, pieceLength, pieces, totalSize); break; case PeerMessage::PIECE: checkIndex(message, pieces); checkBegin(message, pieceLength); break; default: throw new DlAbortEx("invalid message id. id = %d", message->getId()); } } HandshakeMessage* PeerMessageUtil::createHandshakeMessage(const char* msg) { HandshakeMessage* message = new HandshakeMessage(); message->pstrlen = msg[0]; char pstr[20]; memcpy(pstr, &msg[1], sizeof(pstr)-1); pstr[sizeof(pstr)-1] = '\0'; message->pstr = pstr; memcpy(message->infoHash, &msg[28], 20); memcpy(message->peerId, &msg[48], 20); return message; } void PeerMessageUtil::checkHandshake(const HandshakeMessage* message, const unsigned char* infoHash) { if(message->pstrlen != 19) { throw new DlAbortEx("invalid handshake pstrlen = %d", (int)message->pstrlen); } if(message->pstr != PSTR) { throw new DlAbortEx("invalid handshake pstr"); } string myInfoHash = Util::toHex(infoHash, 20); string peerInfoHash = Util::toHex(message->infoHash, 20); if(myInfoHash != peerInfoHash) { throw new DlAbortEx("invalid handshake info hash: expected:%s, actual:%s", myInfoHash.c_str(), peerInfoHash.c_str()); } }