mirror of https://github.com/XTLS/Xray-core
				
				
				
			Disable VMess drain when not pure connection
							parent
							
								
									ff9bb2d8df
								
							
						
					
					
						commit
						f390047b37
					
				| 
						 | 
				
			
			@ -57,14 +57,14 @@ func TestRequestSerialization(t *testing.T) {
 | 
			
		|||
	defer common.Close(userValidator)
 | 
			
		||||
 | 
			
		||||
	server := NewServerSession(userValidator, sessionHistory)
 | 
			
		||||
	actualRequest, err := server.DecodeRequestHeader(buffer)
 | 
			
		||||
	actualRequest, err := server.DecodeRequestHeader(buffer, false)
 | 
			
		||||
	common.Must(err)
 | 
			
		||||
 | 
			
		||||
	if r := cmp.Diff(actualRequest, expectedRequest, cmp.AllowUnexported(protocol.ID{})); r != "" {
 | 
			
		||||
		t.Error(r)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_, err = server.DecodeRequestHeader(buffer2)
 | 
			
		||||
	_, err = server.DecodeRequestHeader(buffer2, false)
 | 
			
		||||
	// anti replay attack
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		t.Error("nil error")
 | 
			
		||||
| 
						 | 
				
			
			@ -107,7 +107,7 @@ func TestInvalidRequest(t *testing.T) {
 | 
			
		|||
	defer common.Close(userValidator)
 | 
			
		||||
 | 
			
		||||
	server := NewServerSession(userValidator, sessionHistory)
 | 
			
		||||
	_, err := server.DecodeRequestHeader(buffer)
 | 
			
		||||
	_, err := server.DecodeRequestHeader(buffer, false)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		t.Error("nil error")
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -148,7 +148,7 @@ func TestMuxRequest(t *testing.T) {
 | 
			
		|||
	defer common.Close(userValidator)
 | 
			
		||||
 | 
			
		||||
	server := NewServerSession(userValidator, sessionHistory)
 | 
			
		||||
	actualRequest, err := server.DecodeRequestHeader(buffer)
 | 
			
		||||
	actualRequest, err := server.DecodeRequestHeader(buffer, false)
 | 
			
		||||
	common.Must(err)
 | 
			
		||||
 | 
			
		||||
	if r := cmp.Diff(actualRequest, expectedRequest, cmp.AllowUnexported(protocol.ID{})); r != "" {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -131,7 +131,7 @@ func parseSecurityType(b byte) protocol.SecurityType {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
// DecodeRequestHeader decodes and returns (if successful) a RequestHeader from an input stream.
 | 
			
		||||
func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.RequestHeader, error) {
 | 
			
		||||
func (s *ServerSession) DecodeRequestHeader(reader io.Reader, isDrain bool) (*protocol.RequestHeader, error) {
 | 
			
		||||
	buffer := buf.New()
 | 
			
		||||
	behaviorRand := dice.NewDeterministicDice(int64(s.userValidator.GetBehaviorSeed()))
 | 
			
		||||
	BaseDrainSize := behaviorRand.Roll(3266)
 | 
			
		||||
| 
						 | 
				
			
			@ -143,7 +143,7 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
 | 
			
		|||
	drainConnection := func(e error) error {
 | 
			
		||||
		// We read a deterministic generated length of data before closing the connection to offset padding read pattern
 | 
			
		||||
		readSizeRemain -= int(buffer.Len())
 | 
			
		||||
		if readSizeRemain > 0 {
 | 
			
		||||
		if readSizeRemain > 0 && isDrain {
 | 
			
		||||
			err := s.DrainConnN(reader, readSizeRemain)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return newError("failed to drain connection DrainSize = ", BaseDrainSize, " ", RandDrainMax, " ", RandDrainRolled).Base(err).Base(e)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -220,9 +220,18 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i
 | 
			
		|||
		return newError("unable to set read deadline").Base(err).AtWarning()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	iConn := connection
 | 
			
		||||
	if statConn, ok := iConn.(*internet.StatCouterConnection); ok {
 | 
			
		||||
		iConn = statConn.Connection
 | 
			
		||||
	}
 | 
			
		||||
	_, isDrain := iConn.(*net.TCPConn)
 | 
			
		||||
	if !isDrain {
 | 
			
		||||
		_, isDrain = iConn.(*net.UnixConn)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	reader := &buf.BufferedReader{Reader: buf.NewReader(connection)}
 | 
			
		||||
	svrSession := encoding.NewServerSession(h.clients, h.sessionHistory)
 | 
			
		||||
	request, err := svrSession.DecodeRequestHeader(reader)
 | 
			
		||||
	request, err := svrSession.DecodeRequestHeader(reader, isDrain)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if errors.Cause(err) != io.EOF {
 | 
			
		||||
			log.Record(&log.AccessMessage{
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue