pull/432/head
Darien Raymond 2017-04-13 22:17:58 +02:00
parent 94c6acea43
commit f57260c358
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
8 changed files with 110 additions and 110 deletions

View File

@ -95,13 +95,13 @@ func NewHeaderWriter(header *buf.Buffer) *HeaderWriter {
} }
} }
func (v *HeaderWriter) Write(writer io.Writer) error { func (w *HeaderWriter) Write(writer io.Writer) error {
if v.header == nil { if w.header == nil {
return nil return nil
} }
_, err := writer.Write(v.header.Bytes()) _, err := writer.Write(w.header.Bytes())
v.header.Release() w.header.Release()
v.header = nil w.header = nil
return err return err
} }
@ -123,49 +123,49 @@ func NewHttpConn(conn net.Conn, reader Reader, writer Writer, errorWriter Writer
} }
} }
func (v *HttpConn) Read(b []byte) (int, error) { func (c *HttpConn) Read(b []byte) (int, error) {
if v.oneTimeReader != nil { if c.oneTimeReader != nil {
buffer, err := v.oneTimeReader.Read(v.Conn) buffer, err := c.oneTimeReader.Read(c.Conn)
if err != nil { if err != nil {
return 0, err return 0, err
} }
v.readBuffer = buffer c.readBuffer = buffer
v.oneTimeReader = nil c.oneTimeReader = nil
} }
if v.readBuffer.Len() > 0 { if c.readBuffer.Len() > 0 {
nBytes, err := v.readBuffer.Read(b) nBytes, err := c.readBuffer.Read(b)
if nBytes == v.readBuffer.Len() { if nBytes == c.readBuffer.Len() {
v.readBuffer.Release() c.readBuffer.Release()
v.readBuffer = nil c.readBuffer = nil
} }
return nBytes, err return nBytes, err
} }
return v.Conn.Read(b) return c.Conn.Read(b)
} }
func (v *HttpConn) Write(b []byte) (int, error) { func (c *HttpConn) Write(b []byte) (int, error) {
if v.oneTimeWriter != nil { if c.oneTimeWriter != nil {
err := v.oneTimeWriter.Write(v.Conn) err := c.oneTimeWriter.Write(c.Conn)
v.oneTimeWriter = nil c.oneTimeWriter = nil
if err != nil { if err != nil {
return 0, err return 0, err
} }
} }
return v.Conn.Write(b) return c.Conn.Write(b)
} }
// Close implements net.Conn.Close(). // Close implements net.Conn.Close().
func (v *HttpConn) Close() error { func (c *HttpConn) Close() error {
if v.oneTimeWriter != nil && v.errorWriter != nil { if c.oneTimeWriter != nil && c.errorWriter != nil {
// Connection is being closed but header wasn't sent. This means the client request // Connection is being closed but header wasn't sent. This means the client request
// is probably not valid. Sending back a server error header in this case. // is probably not valid. Sending back a server error header in this case.
v.errorWriter.Write(v.Conn) c.errorWriter.Write(c.Conn)
} }
return v.Conn.Close() return c.Conn.Close()
} }
func formResponseHeader(config *ResponseConfig) *HeaderWriter { func formResponseHeader(config *ResponseConfig) *HeaderWriter {
@ -193,9 +193,9 @@ type HttpAuthenticator struct {
config *Config config *Config
} }
func (v HttpAuthenticator) GetClientWriter() *HeaderWriter { func (a HttpAuthenticator) GetClientWriter() *HeaderWriter {
header := buf.NewSmall() header := buf.NewSmall()
config := v.config.Request config := a.config.Request
header.AppendSupplier(serial.WriteString(strings.Join([]string{config.GetMethodValue(), config.PickUri(), config.GetFullVersion()}, " "))) header.AppendSupplier(serial.WriteString(strings.Join([]string{config.GetMethodValue(), config.PickUri(), config.GetFullVersion()}, " ")))
header.AppendSupplier(writeCRLF) header.AppendSupplier(writeCRLF)
@ -210,31 +210,31 @@ func (v HttpAuthenticator) GetClientWriter() *HeaderWriter {
} }
} }
func (v HttpAuthenticator) GetServerWriter() *HeaderWriter { func (a HttpAuthenticator) GetServerWriter() *HeaderWriter {
return formResponseHeader(v.config.Response) return formResponseHeader(a.config.Response)
} }
func (v HttpAuthenticator) Client(conn net.Conn) net.Conn { func (a HttpAuthenticator) Client(conn net.Conn) net.Conn {
if v.config.Request == nil && v.config.Response == nil { if a.config.Request == nil && a.config.Response == nil {
return conn return conn
} }
var reader Reader = new(NoOpReader) var reader Reader = new(NoOpReader)
if v.config.Request != nil { if a.config.Request != nil {
reader = new(HeaderReader) reader = new(HeaderReader)
} }
var writer Writer = new(NoOpWriter) var writer Writer = new(NoOpWriter)
if v.config.Response != nil { if a.config.Response != nil {
writer = v.GetClientWriter() writer = a.GetClientWriter()
} }
return NewHttpConn(conn, reader, writer, new(NoOpWriter)) return NewHttpConn(conn, reader, writer, new(NoOpWriter))
} }
func (v HttpAuthenticator) Server(conn net.Conn) net.Conn { func (a HttpAuthenticator) Server(conn net.Conn) net.Conn {
if v.config.Request == nil && v.config.Response == nil { if a.config.Request == nil && a.config.Response == nil {
return conn return conn
} }
return NewHttpConn(conn, new(HeaderReader), v.GetServerWriter(), formResponseHeader(&ResponseConfig{ return NewHttpConn(conn, new(HeaderReader), a.GetServerWriter(), formResponseHeader(&ResponseConfig{
Version: &Version{ Version: &Version{
Value: "1.1", Value: "1.1",
}, },

View File

@ -9,10 +9,10 @@ import (
type NoOpHeader struct{} type NoOpHeader struct{}
func (v NoOpHeader) Size() int { func (NoOpHeader) Size() int {
return 0 return 0
} }
func (v NoOpHeader) Write([]byte) (int, error) { func (NoOpHeader) Write([]byte) (int, error) {
return 0, nil return 0, nil
} }

View File

@ -13,14 +13,14 @@ type SRTP struct {
number uint16 number uint16
} }
func (v *SRTP) Size() int { func (*SRTP) Size() int {
return 4 return 4
} }
func (v *SRTP) Write(b []byte) (int, error) { func (s *SRTP) Write(b []byte) (int, error) {
v.number++ s.number++
serial.Uint16ToBytes(v.number, b[:0]) serial.Uint16ToBytes(s.number, b[:0])
serial.Uint16ToBytes(v.number, b[:2]) serial.Uint16ToBytes(s.number, b[:2])
return 4, nil return 4, nil
} }

View File

@ -14,17 +14,18 @@ type UTP struct {
connectionId uint16 connectionId uint16
} }
func (v *UTP) Size() int { func (*UTP) Size() int {
return 4 return 4
} }
func (v *UTP) Write(b []byte) (int, error) { func (u *UTP) Write(b []byte) (int, error) {
serial.Uint16ToBytes(v.connectionId, b[:0]) serial.Uint16ToBytes(u.connectionId, b[:0])
b[2] = v.header b[2] = u.header
b[3] = v.extension b[3] = u.extension
return 4, nil return 4, nil
} }
// NewUTP creates a new UTP header for the given config.
func NewUTP(ctx context.Context, config interface{}) (interface{}, error) { func NewUTP(ctx context.Context, config interface{}) (interface{}, error) {
return &UTP{ return &UTP{
header: 1, header: 1,

View File

@ -16,17 +16,17 @@ func NewSimpleAuthenticator() cipher.AEAD {
} }
// NonceSize implements cipher.AEAD.NonceSize(). // NonceSize implements cipher.AEAD.NonceSize().
func (v *SimpleAuthenticator) NonceSize() int { func (*SimpleAuthenticator) NonceSize() int {
return 0 return 0
} }
// Overhead implements cipher.AEAD.NonceSize(). // Overhead implements cipher.AEAD.NonceSize().
func (v *SimpleAuthenticator) Overhead() int { func (*SimpleAuthenticator) Overhead() int {
return 6 return 6
} }
// Seal implements cipher.AEAD.Seal(). // Seal implements cipher.AEAD.Seal().
func (v *SimpleAuthenticator) Seal(dst, nonce, plain, extra []byte) []byte { func (a *SimpleAuthenticator) Seal(dst, nonce, plain, extra []byte) []byte {
dst = append(dst, 0, 0, 0, 0) dst = append(dst, 0, 0, 0, 0)
dst = serial.Uint16ToBytes(uint16(len(plain)), dst) dst = serial.Uint16ToBytes(uint16(len(plain)), dst)
dst = append(dst, plain...) dst = append(dst, plain...)
@ -48,7 +48,7 @@ func (v *SimpleAuthenticator) Seal(dst, nonce, plain, extra []byte) []byte {
} }
// Open implements cipher.AEAD.Open(). // Open implements cipher.AEAD.Open().
func (v *SimpleAuthenticator) Open(dst, nonce, cipherText, extra []byte) ([]byte, error) { func (a *SimpleAuthenticator) Open(dst, nonce, cipherText, extra []byte) ([]byte, error) {
dst = append(dst, cipherText...) dst = append(dst, cipherText...)
dstLen := len(dst) dstLen := len(dst)
xtra := 4 - dstLen%4 xtra := 4 - dstLen%4

View File

@ -29,75 +29,75 @@ type ClientConnection struct {
writer PacketWriter writer PacketWriter
} }
func (o *ClientConnection) Overhead() int { func (c *ClientConnection) Overhead() int {
o.RLock() c.RLock()
defer o.RUnlock() defer c.RUnlock()
if o.writer == nil { if c.writer == nil {
return 0 return 0
} }
return o.writer.Overhead() return c.writer.Overhead()
} }
func (o *ClientConnection) Write(b []byte) (int, error) { func (c *ClientConnection) Write(b []byte) (int, error) {
o.RLock() c.RLock()
defer o.RUnlock() defer c.RUnlock()
if o.writer == nil { if c.writer == nil {
return len(b), nil return len(b), nil
} }
return o.writer.Write(b) return c.writer.Write(b)
} }
func (o *ClientConnection) Read([]byte) (int, error) { func (*ClientConnection) Read([]byte) (int, error) {
panic("KCP|ClientConnection: Read should not be called.") panic("KCP|ClientConnection: Read should not be called.")
} }
func (o *ClientConnection) Close() error { func (c *ClientConnection) Close() error {
return o.Conn.Close() return c.Conn.Close()
} }
func (o *ClientConnection) Reset(inputCallback func([]Segment)) { func (c *ClientConnection) Reset(inputCallback func([]Segment)) {
o.Lock() c.Lock()
o.input = inputCallback c.input = inputCallback
o.Unlock() c.Unlock()
} }
func (o *ClientConnection) ResetSecurity(header internet.PacketHeader, security cipher.AEAD) { func (c *ClientConnection) ResetSecurity(header internet.PacketHeader, security cipher.AEAD) {
o.Lock() c.Lock()
if o.reader == nil { if c.reader == nil {
o.reader = new(KCPPacketReader) c.reader = new(KCPPacketReader)
} }
o.reader.(*KCPPacketReader).Header = header c.reader.(*KCPPacketReader).Header = header
o.reader.(*KCPPacketReader).Security = security c.reader.(*KCPPacketReader).Security = security
if o.writer == nil { if c.writer == nil {
o.writer = new(KCPPacketWriter) c.writer = new(KCPPacketWriter)
} }
o.writer.(*KCPPacketWriter).Header = header c.writer.(*KCPPacketWriter).Header = header
o.writer.(*KCPPacketWriter).Security = security c.writer.(*KCPPacketWriter).Security = security
o.writer.(*KCPPacketWriter).Writer = o.Conn c.writer.(*KCPPacketWriter).Writer = c.Conn
o.Unlock() c.Unlock()
} }
func (o *ClientConnection) Run() { func (c *ClientConnection) Run() {
payload := buf.NewSmall() payload := buf.NewSmall()
defer payload.Release() defer payload.Release()
for { for {
err := payload.Reset(buf.ReadFrom(o.Conn)) err := payload.Reset(buf.ReadFrom(c.Conn))
if err != nil { if err != nil {
payload.Release() payload.Release()
return return
} }
o.RLock() c.RLock()
if o.input != nil { if c.input != nil {
segments := o.reader.Read(payload.Bytes()) segments := c.reader.Read(payload.Bytes())
if len(segments) > 0 { if len(segments) > 0 {
o.input(segments) c.input(segments)
} }
} }
o.RUnlock() c.RUnlock()
} }
} }

View File

@ -22,13 +22,13 @@ type KCPPacketReader struct {
Header internet.PacketHeader Header internet.PacketHeader
} }
func (v *KCPPacketReader) Read(b []byte) []Segment { func (r *KCPPacketReader) Read(b []byte) []Segment {
if v.Header != nil { if r.Header != nil {
b = b[v.Header.Size():] b = b[r.Header.Size():]
} }
if v.Security != nil { if r.Security != nil {
nonceSize := v.Security.NonceSize() nonceSize := r.Security.NonceSize()
out, err := v.Security.Open(b[nonceSize:nonceSize], b[:nonceSize], b[nonceSize:], nil) out, err := r.Security.Open(b[nonceSize:nonceSize], b[:nonceSize], b[nonceSize:], nil)
if err != nil { if err != nil {
return nil return nil
} }
@ -54,39 +54,39 @@ type KCPPacketWriter struct {
buffer [32 * 1024]byte buffer [32 * 1024]byte
} }
func (v *KCPPacketWriter) Overhead() int { func (w *KCPPacketWriter) Overhead() int {
overhead := 0 overhead := 0
if v.Header != nil { if w.Header != nil {
overhead += v.Header.Size() overhead += w.Header.Size()
} }
if v.Security != nil { if w.Security != nil {
overhead += v.Security.Overhead() overhead += w.Security.Overhead()
} }
return overhead return overhead
} }
func (v *KCPPacketWriter) Write(b []byte) (int, error) { func (w *KCPPacketWriter) Write(b []byte) (int, error) {
x := v.buffer[:] x := w.buffer[:]
size := 0 size := 0
if v.Header != nil { if w.Header != nil {
nBytes, _ := v.Header.Write(x) nBytes, _ := w.Header.Write(x)
size += nBytes size += nBytes
x = x[nBytes:] x = x[nBytes:]
} }
if v.Security != nil { if w.Security != nil {
nonceSize := v.Security.NonceSize() nonceSize := w.Security.NonceSize()
var nonce []byte var nonce []byte
if nonceSize > 0 { if nonceSize > 0 {
nonce = x[:nonceSize] nonce = x[:nonceSize]
rand.Read(nonce) rand.Read(nonce)
x = x[nonceSize:] x = x[nonceSize:]
} }
x = v.Security.Seal(x[:0], nonce, b, nil) x = w.Security.Seal(x[:0], nonce, b, nil)
size += nonceSize + len(x) size += nonceSize + len(x)
} else { } else {
size += copy(x, b) size += copy(x, b)
} }
_, err := v.Writer.Write(v.buffer[:size]) _, err := w.Writer.Write(w.buffer[:size])
return len(b), err return len(b), err
} }

View File

@ -1,11 +1,10 @@
package ray_test package ray_test
import ( import (
"context"
"io" "io"
"testing" "testing"
"context"
"v2ray.com/core/common/buf" "v2ray.com/core/common/buf"
"v2ray.com/core/testing/assert" "v2ray.com/core/testing/assert"
. "v2ray.com/core/transport/ray" . "v2ray.com/core/transport/ray"