From a0c6dbfe2aca40874fbc77dd62e5d4e448100491 Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Mon, 1 Sep 2014 10:42:35 -0700 Subject: [PATCH] agent: testing remote exec writer --- command/agent/remote_exec.go | 55 ++++++++++++---------- command/agent/remote_exec_test.go | 77 +++++++++++++++++++++++++++++++ 2 files changed, 108 insertions(+), 24 deletions(-) create mode 100644 command/agent/remote_exec_test.go diff --git a/command/agent/remote_exec.go b/command/agent/remote_exec.go index d7d6986221..d8c8915d42 100644 --- a/command/agent/remote_exec.go +++ b/command/agent/remote_exec.go @@ -53,12 +53,15 @@ type remoteExecSpec struct { } type rexecWriter struct { - bufCh chan []byte - buf []byte - bufLen int - bufLock sync.Mutex - cancelCh chan struct{} - flush *time.Timer + BufCh chan []byte + BufSize int + BufIdle time.Duration + CancelCh chan struct{} + + buf []byte + bufLen int + bufLock sync.Mutex + flush *time.Timer } func (r *rexecWriter) Write(b []byte) (int, error) { @@ -69,10 +72,13 @@ func (r *rexecWriter) Write(b []byte) (int, error) { r.flush = nil } inpLen := len(b) + if r.buf == nil { + r.buf = make([]byte, r.BufSize) + } COPY: remain := len(r.buf) - r.bufLen - if remain >= len(b) { + if remain > len(b) { copy(r.buf[r.bufLen:], b) r.bufLen += len(b) } else { @@ -80,31 +86,30 @@ COPY: b = b[remain:] r.bufLen += remain r.bufLock.Unlock() - r.flushBuf() + r.Flush() r.bufLock.Lock() goto COPY } - r.flush = time.AfterFunc(remoteExecOutputDeadline, r.flushBuf) + r.flush = time.AfterFunc(r.BufIdle, r.Flush) return inpLen, nil } -func (r *rexecWriter) Close() { - r.flushBuf() - close(r.bufCh) -} - -func (r *rexecWriter) flushBuf() { +func (r *rexecWriter) Flush() { r.bufLock.Lock() defer r.bufLock.Unlock() + if r.flush != nil { + r.flush.Stop() + r.flush = nil + } if r.bufLen == 0 { return } select { - case r.bufCh <- r.buf: - r.buf = make([]byte, remoteExecOutputSize) + case r.BufCh <- r.buf[:r.bufLen]: + r.buf = make([]byte, r.BufSize) r.bufLen = 0 - case <-r.cancelCh: + case <-r.CancelCh: r.bufLen = 0 } } @@ -162,9 +167,10 @@ func (a *Agent) handleRemoteExec(msg *UserEvent) { // Setup the output streaming writer := &rexecWriter{ - bufCh: make(chan []byte, 16), - buf: make([]byte, remoteExecOutputSize), - cancelCh: make(chan struct{}), + BufCh: make(chan []byte, 16), + BufSize: remoteExecOutputSize, + BufIdle: remoteExecOutputDeadline, + CancelCh: make(chan struct{}), } cmd.Stdout = writer cmd.Stderr = writer @@ -181,7 +187,8 @@ func (a *Agent) handleRemoteExec(msg *UserEvent) { exitCh := make(chan int, 1) go func() { err := cmd.Wait() - writer.Close() + writer.Flush() + close(writer.BufCh) if err != nil { exitCh <- 0 return @@ -201,12 +208,12 @@ func (a *Agent) handleRemoteExec(msg *UserEvent) { WAIT: for num := 0; ; num++ { select { - case out := <-writer.bufCh: + case out := <-writer.BufCh: if out == nil { break WAIT } if !a.remoteExecWriteOutput(&event, num, out) { - close(writer.cancelCh) + close(writer.CancelCh) exitCode = 255 return } diff --git a/command/agent/remote_exec_test.go b/command/agent/remote_exec_test.go new file mode 100644 index 0000000000..af65e5427b --- /dev/null +++ b/command/agent/remote_exec_test.go @@ -0,0 +1,77 @@ +package agent + +import ( + "testing" + "time" +) + +func TestRexecWriter(t *testing.T) { + writer := &rexecWriter{ + BufCh: make(chan []byte, 16), + BufSize: 16, + BufIdle: 10 * time.Millisecond, + CancelCh: make(chan struct{}), + } + + // Write short, wait for idle + start := time.Now() + n, err := writer.Write([]byte("test")) + if err != nil { + t.Fatalf("err: %v", err) + } + if n != 4 { + t.Fatalf("bad: %v", n) + } + + select { + case b := <-writer.BufCh: + if len(b) != 4 { + t.Fatalf("Bad: %v", b) + } + if time.Now().Sub(start) < writer.BufIdle { + t.Fatalf("too early") + } + case <-time.After(2 * writer.BufIdle): + t.Fatalf("timeout") + } + + // Write in succession to prevent the timeout + writer.Write([]byte("test")) + time.Sleep(writer.BufIdle / 2) + writer.Write([]byte("test")) + time.Sleep(writer.BufIdle / 2) + start = time.Now() + writer.Write([]byte("test")) + + select { + case b := <-writer.BufCh: + if len(b) != 12 { + t.Fatalf("Bad: %v", b) + } + if time.Now().Sub(start) < writer.BufIdle { + t.Fatalf("too early") + } + case <-time.After(2 * writer.BufIdle): + t.Fatalf("timeout") + } + + // Write large values, multiple flushes required + writer.Write([]byte("01234567890123456789012345678901")) + + select { + case b := <-writer.BufCh: + if string(b) != "0123456789012345" { + t.Fatalf("bad: %s", b) + } + default: + t.Fatalf("should have buf") + } + select { + case b := <-writer.BufCh: + if string(b) != "6789012345678901" { + t.Fatalf("bad: %s", b) + } + default: + t.Fatalf("should have buf") + } +}