From 50ef0dc954592666a13ff92ef20811f0127c3b49 Mon Sep 17 00:00:00 2001 From: Oleg Zaytsev Date: Fri, 11 Oct 2024 15:21:15 +0200 Subject: [PATCH] Fix `MemPostings.Add` and `MemPostings.Get` data race (#15141) * Tests for Mempostings.{Add,Get} data race * Fix MemPostings.{Add,Get} data race We can't modify the postings list that are held in MemPostings as they might already be in use by some readers. * Modify BenchmarkHeadStripeSeriesCreate to have common labels If there are no common labels on the series, we don't excercise the ordering part of MemSeries, as we're just creating slices of one element for each label value. --------- Signed-off-by: Oleg Zaytsev --- tsdb/head_bench_test.go | 8 +++--- tsdb/index/postings.go | 27 ++++++++++++------ tsdb/index/postings_test.go | 55 +++++++++++++++++++++++++++++++++++++ 3 files changed, 78 insertions(+), 12 deletions(-) diff --git a/tsdb/head_bench_test.go b/tsdb/head_bench_test.go index a03794810..51de50ec2 100644 --- a/tsdb/head_bench_test.go +++ b/tsdb/head_bench_test.go @@ -36,7 +36,7 @@ func BenchmarkHeadStripeSeriesCreate(b *testing.B) { defer h.Close() for i := 0; i < b.N; i++ { - h.getOrCreate(uint64(i), labels.FromStrings("a", strconv.Itoa(i))) + h.getOrCreate(uint64(i), labels.FromStrings(labels.MetricName, "test", "a", strconv.Itoa(i), "b", strconv.Itoa(i%10), "c", strconv.Itoa(i%100), "d", strconv.Itoa(i/2), "e", strconv.Itoa(i/4))) } } @@ -54,8 +54,8 @@ func BenchmarkHeadStripeSeriesCreateParallel(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { - i := count.Inc() - h.getOrCreate(uint64(i), labels.FromStrings("a", strconv.Itoa(int(i)))) + i := int(count.Inc()) + h.getOrCreate(uint64(i), labels.FromStrings(labels.MetricName, "test", "a", strconv.Itoa(i), "b", strconv.Itoa(i%10), "c", strconv.Itoa(i%100), "d", strconv.Itoa(i/2), "e", strconv.Itoa(i/4))) } }) } @@ -75,7 +75,7 @@ func BenchmarkHeadStripeSeriesCreate_PreCreationFailure(b *testing.B) { defer h.Close() for i := 0; i < b.N; i++ { - h.getOrCreate(uint64(i), labels.FromStrings("a", strconv.Itoa(i))) + h.getOrCreate(uint64(i), labels.FromStrings(labels.MetricName, "test", "a", strconv.Itoa(i), "b", strconv.Itoa(i%10), "c", strconv.Itoa(i%100), "d", strconv.Itoa(i/2), "e", strconv.Itoa(i/4))) } } diff --git a/tsdb/index/postings.go b/tsdb/index/postings.go index e909a3717..7bc5629ac 100644 --- a/tsdb/index/postings.go +++ b/tsdb/index/postings.go @@ -392,13 +392,14 @@ func (p *MemPostings) Add(id storage.SeriesRef, lset labels.Labels) { p.mtx.Unlock() } -func appendWithExponentialGrowth[T any](a []T, v T) []T { +func appendWithExponentialGrowth[T any](a []T, v T) (_ []T, copied bool) { if cap(a) < len(a)+1 { newList := make([]T, len(a), len(a)*2+1) copy(newList, a) a = newList + copied = true } - return append(a, v) + return append(a, v), copied } func (p *MemPostings) addFor(id storage.SeriesRef, l labels.Label) { @@ -407,16 +408,26 @@ func (p *MemPostings) addFor(id storage.SeriesRef, l labels.Label) { nm = map[string][]storage.SeriesRef{} p.m[l.Name] = nm } - list := appendWithExponentialGrowth(nm[l.Value], id) + list, copied := appendWithExponentialGrowth(nm[l.Value], id) nm[l.Value] = list - if !p.ordered { + // Return if it shouldn't be ordered, if it only has one element or if it's already ordered. + // The invariant is that the first n-1 items in the list are already sorted. + if !p.ordered || len(list) == 1 || list[len(list)-1] >= list[len(list)-2] { return } - // There is no guarantee that no higher ID was inserted before as they may - // be generated independently before adding them to postings. - // We repair order violations on insert. The invariant is that the first n-1 - // items in the list are already sorted. + + if !copied { + // We have appended to the existing slice, + // and readers may already have a copy of this postings slice, + // so we need to copy it before sorting. + old := list + list = make([]storage.SeriesRef, len(old), cap(old)) + copy(list, old) + nm[l.Value] = list + } + + // Repair order violations. for i := len(list) - 1; i >= 1; i-- { if list[i] >= list[i-1] { break diff --git a/tsdb/index/postings_test.go b/tsdb/index/postings_test.go index b41fb54e6..8ee9b9943 100644 --- a/tsdb/index/postings_test.go +++ b/tsdb/index/postings_test.go @@ -1507,3 +1507,58 @@ func TestMemPostings_PostingsForLabelMatchingHonorsContextCancel(t *testing.T) { require.Error(t, p.Err()) require.Equal(t, failAfter+1, ctx.Count()) // Plus one for the Err() call that puts the error in the result. } + +func TestMemPostings_Unordered_Add_Get(t *testing.T) { + mp := NewMemPostings() + for ref := storage.SeriesRef(1); ref < 8; ref += 2 { + // First, add next series. + next := ref + 1 + mp.Add(next, labels.FromStrings(labels.MetricName, "test", "series", strconv.Itoa(int(next)))) + nextPostings := mp.Get(labels.MetricName, "test") + + // Now add current ref. + mp.Add(ref, labels.FromStrings(labels.MetricName, "test", "series", strconv.Itoa(int(ref)))) + + // Next postings should still reference the next series. + nextExpanded, err := ExpandPostings(nextPostings) + require.NoError(t, err) + require.Len(t, nextExpanded, int(ref)) + require.Equal(t, next, nextExpanded[len(nextExpanded)-1]) + } +} + +func TestMemPostings_Concurrent_Add_Get(t *testing.T) { + refs := make(chan storage.SeriesRef) + wg := sync.WaitGroup{} + wg.Add(1) + t.Cleanup(wg.Wait) + t.Cleanup(func() { close(refs) }) + + mp := NewMemPostings() + go func() { + defer wg.Done() + for ref := range refs { + mp.Add(ref, labels.FromStrings(labels.MetricName, "test", "series", strconv.Itoa(int(ref)))) + p := mp.Get(labels.MetricName, "test") + + _, err := ExpandPostings(p) + if err != nil { + t.Errorf("unexpected error: %s", err) + } + } + }() + + for ref := storage.SeriesRef(1); ref < 8; ref += 2 { + // Add next ref in another goroutine so they would race. + refs <- ref + 1 + // Add current ref here + mp.Add(ref, labels.FromStrings(labels.MetricName, "test", "series", strconv.Itoa(int(ref)))) + + // We don't read the value of the postings here, + // this is tested in TestMemPostings_Unordered_Add_Get where it's easier to achieve the determinism. + // This test just checks that there's no data race. + p := mp.Get(labels.MetricName, "test") + _, err := ExpandPostings(p) + require.NoError(t, err) + } +}