diff --git a/postings.go b/postings.go index 180ac099d..950cd2e3c 100644 --- a/postings.go +++ b/postings.go @@ -33,7 +33,7 @@ func (p *memPostings) get(t term) Postings { if l == nil { return emptyPostings } - return &listPostings{list: l, idx: -1} + return newListPostings(l) } // add adds a document to the index. The caller has to ensure that no @@ -70,18 +70,13 @@ func (e errPostings) Seek(uint32) bool { return false } func (e errPostings) At() uint32 { return 0 } func (e errPostings) Err() error { return e.err } -func expandPostings(p Postings) (res []uint32, err error) { - for p.Next() { - res = append(res, p.At()) - } - return res, p.Err() -} +var emptyPostings = errPostings{} // Intersect returns a new postings list over the intersection of the // input postings. func Intersect(its ...Postings) Postings { if len(its) == 0 { - return errPostings{err: nil} + return emptyPostings } a := its[0] @@ -91,8 +86,6 @@ func Intersect(its ...Postings) Postings { return a } -var emptyPostings = errPostings{} - type intersectPostings struct { a, b Postings aok, bok bool @@ -100,41 +93,44 @@ type intersectPostings struct { } func newIntersectPostings(a, b Postings) *intersectPostings { - it := &intersectPostings{a: a, b: b} - it.aok = it.a.Next() - it.bok = it.b.Next() - - return it + return &intersectPostings{a: a, b: b} } func (it *intersectPostings) At() uint32 { return it.cur } -func (it *intersectPostings) Next() bool { +func (it *intersectPostings) doNext(id uint32) bool { for { - if !it.aok || !it.bok { + if !it.b.Seek(id) { return false } - av, bv := it.a.At(), it.b.At() - - if av < bv { - it.aok = it.a.Seek(bv) - } else if bv < av { - it.bok = it.b.Seek(av) - } else { - it.cur = av - it.aok = it.a.Next() - it.bok = it.b.Next() - return true + if vb := it.b.At(); vb != id { + if !it.a.Seek(vb) { + return false + } + id = it.a.At() + if vb != id { + continue + } } + it.cur = id + return true + } +} + +func (it *intersectPostings) Next() bool { + if !it.a.Next() { + return false } + return it.doNext(it.a.At()) } func (it *intersectPostings) Seek(id uint32) bool { - it.aok = it.a.Seek(id) - it.bok = it.b.Seek(id) - return it.Next() + if !it.a.Seek(id) { + return false + } + return it.doNext(it.a.At()) } func (it *intersectPostings) Err() error { @@ -158,17 +154,14 @@ func Merge(its ...Postings) Postings { } type mergedPostings struct { - a, b Postings - aok, bok bool - cur uint32 + a, b Postings + initialized bool + aok, bok bool + cur uint32 } func newMergedPostings(a, b Postings) *mergedPostings { - it := &mergedPostings{a: a, b: b} - it.aok = it.a.Next() - it.bok = it.b.Next() - - return it + return &mergedPostings{a: a, b: b} } func (it *mergedPostings) At() uint32 { @@ -176,6 +169,12 @@ func (it *mergedPostings) At() uint32 { } func (it *mergedPostings) Next() bool { + if !it.initialized { + it.aok = it.a.Next() + it.bok = it.b.Next() + it.initialized = true + } + if !it.aok && !it.bok { return false } @@ -196,25 +195,31 @@ func (it *mergedPostings) Next() bool { if acur < bcur { it.cur = acur it.aok = it.a.Next() - return true - } - if bcur < acur { + } else if acur > bcur { it.cur = bcur it.bok = it.b.Next() - return true + } else { + it.cur = acur + it.aok = it.a.Next() + it.bok = it.b.Next() } - it.cur = acur - it.aok = it.a.Next() - it.bok = it.b.Next() - return true } func (it *mergedPostings) Seek(id uint32) bool { it.aok = it.a.Seek(id) it.bok = it.b.Seek(id) - - return it.Next() + it.initialized = true + acur, bcur := it.a.At(), it.b.At() + if acur < bcur { + it.cur = acur + } else if acur > bcur { + it.cur = bcur + } else { + it.cur = acur + it.bok = it.b.Next() + } + return it.aok && it.bok } func (it *mergedPostings) Err() error { @@ -227,28 +232,38 @@ func (it *mergedPostings) Err() error { // listPostings implements the Postings interface over a plain list. type listPostings struct { list []uint32 - idx int + cur uint32 } func newListPostings(list []uint32) *listPostings { - return &listPostings{list: list, idx: -1} + return &listPostings{list: list} } func (it *listPostings) At() uint32 { - return it.list[it.idx] + return it.cur } func (it *listPostings) Next() bool { - it.idx++ - return it.idx < len(it.list) + if len(it.list) > 0 { + it.cur = it.list[0] + it.list = it.list[1:] + return true + } + return false } func (it *listPostings) Seek(x uint32) bool { // Do binary search between current position and end. - it.idx += sort.Search(len(it.list)-it.idx, func(i int) bool { - return it.list[i+it.idx] >= x + i := sort.Search(len(it.list), func(i int) bool { + return it.list[i] >= x }) - return it.idx < len(it.list) + if i < len(it.list) { + it.cur = it.list[i] + it.list = it.list[i+1:] + return true + } + it.list = nil + return false } func (it *listPostings) Err() error { @@ -259,32 +274,40 @@ func (it *listPostings) Err() error { // big endian numbers. type bigEndianPostings struct { list []byte - idx int + cur uint32 } func newBigEndianPostings(list []byte) *bigEndianPostings { - return &bigEndianPostings{list: list, idx: -1} + return &bigEndianPostings{list: list} } func (it *bigEndianPostings) At() uint32 { - idx := 4 * it.idx - return binary.BigEndian.Uint32(it.list[idx : idx+4]) + return it.cur } func (it *bigEndianPostings) Next() bool { - it.idx++ - return it.idx*4 < len(it.list) + if len(it.list) >= 4 { + it.cur = binary.BigEndian.Uint32(it.list) + it.list = it.list[4:] + return true + } + return false } func (it *bigEndianPostings) Seek(x uint32) bool { num := len(it.list) / 4 // Do binary search between current position and end. - it.idx += sort.Search(num-it.idx, func(i int) bool { - idx := 4 * (it.idx + i) - val := binary.BigEndian.Uint32(it.list[idx : idx+4]) - return val >= x + i := sort.Search(num, func(i int) bool { + return binary.BigEndian.Uint32(it.list[i*4:]) >= x }) - return it.idx*4 < len(it.list) + if i < num { + j := i * 4 + it.cur = binary.BigEndian.Uint32(it.list[j:]) + it.list = it.list[j+4:] + return true + } + it.list = nil + return false } func (it *bigEndianPostings) Err() error { diff --git a/postings_test.go b/postings_test.go index d9154cab6..61db9ad44 100644 --- a/postings_test.go +++ b/postings_test.go @@ -33,6 +33,13 @@ func (m *mockPostings) Seek(v uint32) bool { return m.seek(v) } func (m *mockPostings) Value() uint32 { return m.value() } func (m *mockPostings) Err() error { return m.err() } +func expandPostings(p Postings) (res []uint32, err error) { + for p.Next() { + res = append(res, p.At()) + } + return res, p.Err() +} + func TestIntersect(t *testing.T) { var cases = []struct { a, b []uint32 @@ -233,19 +240,9 @@ func TestMergedPostingsSeek(t *testing.T) { p := newMergedPostings(a, b) require.Equal(t, c.success, p.Seek(c.seek)) - - if c.success { - // check the current element and then proceed to check the rest. - i := 0 - require.Equal(t, c.res[i], p.At()) - - for p.Next() { - i++ - require.Equal(t, int(c.res[i]), int(p.At())) - } - - require.Equal(t, len(c.res)-1, i) - } + lst, err := expandPostings(p) + require.NoError(t, err) + require.Equal(t, c.res, lst) } return @@ -296,16 +293,16 @@ func TestBigEndian(t *testing.T) { ls[600] + 1, ls[601], true, }, { - ls[600] + 1, ls[601], true, + ls[600] + 1, ls[602], true, }, { - ls[600] + 1, ls[601], true, + ls[600] + 1, ls[603], true, }, { - ls[0], ls[601], true, + ls[0], ls[604], true, }, { - ls[600], ls[601], true, + ls[600], ls[605], true, }, { ls[999], ls[999], true, @@ -316,15 +313,11 @@ func TestBigEndian(t *testing.T) { } bep := newBigEndianPostings(beLst) - bep.Next() for _, v := range table { require.Equal(t, v.found, bep.Seek(v.seek)) - // Once you seek beyond, At() will panic. - if v.found { - require.Equal(t, v.val, bep.At()) - require.Nil(t, bep.Err()) - } + require.Equal(t, v.val, bep.At()) + require.Nil(t, bep.Err()) } }) }