diff --git a/promql/engine_test.go b/promql/engine_test.go index 6c792f5f3..b7435d473 100644 --- a/promql/engine_test.go +++ b/promql/engine_test.go @@ -21,6 +21,7 @@ import ( "os" "sort" "strconv" + "sync" "testing" "time" @@ -58,7 +59,9 @@ func TestQueryConcurrency(t *testing.T) { require.NoError(t, err) defer os.RemoveAll(dir) queryTracker := promql.NewActiveQueryTracker(dir, maxConcurrency, nil) - t.Cleanup(queryTracker.Close) + t.Cleanup(func() { + require.NoError(t, queryTracker.Close()) + }) opts := promql.EngineOpts{ Logger: nil, @@ -90,9 +93,14 @@ func TestQueryConcurrency(t *testing.T) { return nil } + var wg sync.WaitGroup for i := 0; i < maxConcurrency; i++ { q := engine.NewTestQuery(f) - go q.Exec(ctx) + wg.Add(1) + go func() { + q.Exec(ctx) + wg.Done() + }() select { case <-processing: // Expected. @@ -102,7 +110,11 @@ func TestQueryConcurrency(t *testing.T) { } q := engine.NewTestQuery(f) - go q.Exec(ctx) + wg.Add(1) + go func() { + q.Exec(ctx) + wg.Done() + }() select { case <-processing: @@ -125,6 +137,8 @@ func TestQueryConcurrency(t *testing.T) { for i := 0; i < maxConcurrency; i++ { block <- struct{}{} } + + wg.Wait() } // contextDone returns an error if the context was canceled or timed out. diff --git a/promql/query_logger.go b/promql/query_logger.go index 7ddd8c2d5..7e06ebb97 100644 --- a/promql/query_logger.go +++ b/promql/query_logger.go @@ -16,6 +16,8 @@ package promql import ( "context" "encoding/json" + "errors" + "fmt" "io" "os" "path/filepath" @@ -36,6 +38,8 @@ type ActiveQueryTracker struct { maxConcurrent int } +var _ io.Closer = &ActiveQueryTracker{} + type Entry struct { Query string `json:"query"` Timestamp int64 `json:"timestamp_sec"` @@ -83,6 +87,23 @@ func logUnfinishedQueries(filename string, filesize int, logger log.Logger) { } } +type mmapedFile struct { + f io.Closer + m mmap.MMap +} + +func (f *mmapedFile) Close() error { + err := f.m.Unmap() + if err != nil { + err = fmt.Errorf("mmapedFile: unmapping: %w", err) + } + if fErr := f.f.Close(); fErr != nil { + return errors.Join(fmt.Errorf("close mmapedFile.f: %w", fErr), err) + } + + return err +} + func getMMapedFile(filename string, filesize int, logger log.Logger) ([]byte, io.Closer, error) { file, err := os.OpenFile(filename, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0o666) if err != nil { @@ -108,7 +129,7 @@ func getMMapedFile(filename string, filesize int, logger log.Logger) ([]byte, io return nil, nil, err } - return fileAsBytes, file, err + return fileAsBytes, &mmapedFile{f: file, m: fileAsBytes}, err } func NewActiveQueryTracker(localStoragePath string, maxConcurrent int, logger log.Logger) *ActiveQueryTracker { @@ -204,9 +225,13 @@ func (tracker ActiveQueryTracker) Insert(ctx context.Context, query string) (int } } -func (tracker *ActiveQueryTracker) Close() { +// Close closes tracker. +func (tracker *ActiveQueryTracker) Close() error { if tracker == nil || tracker.closer == nil { - return + return nil + } + if err := tracker.closer.Close(); err != nil { + return fmt.Errorf("close ActiveQueryTracker.closer: %w", err) } - tracker.closer.Close() + return nil } diff --git a/promql/query_logger_test.go b/promql/query_logger_test.go index 376d61b64..7bd93781e 100644 --- a/promql/query_logger_test.go +++ b/promql/query_logger_test.go @@ -16,6 +16,7 @@ package promql import ( "context" "os" + "path/filepath" "testing" "github.com/grafana/regexp" @@ -104,26 +105,26 @@ func TestIndexReuse(t *testing.T) { } func TestMMapFile(t *testing.T) { - file, err := os.CreateTemp("", "mmapedFile") - require.NoError(t, err) - - filename := file.Name() - defer os.Remove(filename) - - fileAsBytes, _, err := getMMapedFile(filename, 2, nil) + dir := t.TempDir() + fpath := filepath.Join(dir, "mmapedFile") + const data = "ab" + fileAsBytes, closer, err := getMMapedFile(fpath, 2, nil) require.NoError(t, err) - copy(fileAsBytes, "ab") + copy(fileAsBytes, data) + require.NoError(t, closer.Close()) - f, err := os.Open(filename) + f, err := os.Open(fpath) require.NoError(t, err) + t.Cleanup(func() { + _ = f.Close() + }) bytes := make([]byte, 4) n, err := f.Read(bytes) - require.Equal(t, 2, n) require.NoError(t, err, "Unexpected error while reading file.") - - require.Equal(t, fileAsBytes, bytes[:2], "Mmap failed") + require.Equal(t, 2, n) + require.Equal(t, []byte(data), bytes[:2], "Mmap failed") } func TestParseBrokenJSON(t *testing.T) {