Collection class for sorted results

This commit is contained in:
Kovid Goyal
2025-06-28 11:39:18 +05:30
parent e4d17f9864
commit 9669aac55e
2 changed files with 158 additions and 8 deletions

View File

@@ -3,6 +3,8 @@ package choose_files
import (
"fmt"
"io/fs"
"slices"
"sync"
)
var _ = fmt.Print
@@ -18,6 +20,11 @@ func (c CollectionIndex) Compare(o CollectionIndex) int {
return c.Slice - o.Slice
}
func (c *CollectionIndex) NextSlice() {
c.Slice++
c.Pos = 0
}
type ResultCollection struct {
slices [][]ResultItem
append_idx CollectionIndex
@@ -42,12 +49,10 @@ func (c *ResultCollection) NextAppendPointer() (ans *ResultItem) {
if c.append_idx.Pos+1 < len(s) {
c.append_idx.Pos++
} else if c.append_idx.Slice+1 < len(c.slices) {
c.append_idx.Slice++
c.append_idx.Pos = 0
c.append_idx.NextSlice()
} else {
c.slices = append(c.slices, make([]ResultItem, 4096))
c.append_idx.Slice++
c.append_idx.Pos = 0
c.append_idx.NextSlice()
}
return
}
@@ -60,8 +65,7 @@ func (c *ResultCollection) Batch(offset *CollectionIndex) (ans []ResultItem) {
}
} else if offset.Slice < c.append_idx.Slice {
ans = c.slices[offset.Slice][offset.Pos:]
offset.Slice++
offset.Pos = 0
offset.NextSlice()
}
return
}
@@ -73,9 +77,121 @@ func (c *ResultCollection) NextDir(offset *CollectionIndex) (ans string) {
}
offset.Pos++
if offset.Pos >= len(c.slices[offset.Slice]) {
offset.Slice++
offset.Pos = 0
offset.NextSlice()
}
}
return
}
type SortedResults struct {
slices [][]*ResultItem
mutex sync.Mutex
len int
}
func NewSortedResults() *SortedResults { return &SortedResults{} }
func (s *SortedResults) lock() { s.mutex.Lock() }
func (s *SortedResults) unlock() { s.mutex.Unlock() }
func (s *SortedResults) Len() int {
s.lock()
defer s.unlock()
return s.len
}
func (s *SortedResults) At(pos CollectionIndex) (ans *ResultItem) {
s.lock()
defer s.unlock()
if pos.Slice < len(s.slices) {
s := s.slices[pos.Slice]
if pos.Pos < len(s) {
ans = s[pos.Pos]
}
}
return
}
func (s *SortedResults) RenderedMatches(pos CollectionIndex, max_num int) (ans []*ResultItem) {
s.lock()
defer s.unlock()
if pos.Slice >= len(s.slices) {
return
}
ans = make([]*ResultItem, 0, max_num)
for ; pos.Slice < len(s.slices) && max_num > 0; pos.NextSlice() {
sl := s.slices[pos.Slice]
if pos.Pos >= len(sl) {
continue
}
sl = sl[pos.Pos:min(len(sl), pos.Pos+max_num)]
ans = append(ans, sl...)
max_num -= len(sl)
}
return
}
func (s *SortedResults) merge_slice(idx int, sl []*ResultItem) {
sz := len(s.slices[idx])
maxs := sl[len(sl)-1].score
limit := idx + 1
for limit < len(s.slices) {
q := s.slices[limit]
if q[0].score > maxs {
break
}
sz += len(q)
limit++
}
ans := make([]*ResultItem, 0, sz)
a := 0
b := CollectionIndex{Slice: idx}
ss := s.slices[b.Slice]
for a < len(sl) {
if sl[a].score <= ss[b.Pos].score {
ans = append(ans, sl[a])
a++
} else {
ans = append(ans, ss[b.Pos])
b.Pos++
if b.Pos >= len(ss) {
b.NextSlice()
if b.Slice >= limit {
break
}
ss = s.slices[b.Slice]
}
}
}
ans = append(ans, sl[a:]...)
for ; b.Slice < limit; b.NextSlice() {
ans = append(ans, s.slices[b.Slice][b.Pos:]...)
}
s.slices = slices.Replace(s.slices, idx, limit, ans)
}
func (s *SortedResults) AddSortedSlice(sl []*ResultItem) {
if len(sl) == 0 {
return
}
s.lock()
defer s.unlock()
s.len += len(sl)
if len(s.slices) == 0 {
s.slices = append(s.slices, sl)
return
}
sl_min, sl_max := sl[0].score, sl[len(sl)-1].score
for i, q := range s.slices {
switch {
case sl_max <= q[0].score:
s.slices = slices.Insert(s.slices, i, sl)
return
case sl_min >= q[len(q)-1].score:
continue
default:
s.merge_slice(i, sl)
return
}
}
s.slices = append(s.slices, sl)
}

View File

@@ -5,6 +5,7 @@ import (
"io/fs"
"math/rand"
"os"
"strconv"
"strings"
"sync"
"testing"
@@ -156,6 +157,39 @@ func TestChooseFilesScoring(t *testing.T) {
ae("sn", "x/s/n")
}
func TestSortedResults(t *testing.T) {
r := NewSortedResults()
m := func(items ...int) []*ResultItem {
ans := make([]*ResultItem, len(items))
for i, x := range items {
ans[i] = &ResultItem{text: strconv.Itoa(x), score: CombinedScore(x)}
}
return ans
}
v := func(slice, pos, num int) []int {
if num == 0 {
num = r.Len()
}
return utils.Map(func(r *ResultItem) int { return int(r.score) }, r.RenderedMatches(CollectionIndex{slice, pos}, num))
}
tv := func(slice, pos, num int, expected ...int) {
if diff := cmp.Diff(expected, v(slice, pos, num)); diff != "" {
t.Fatalf("view failed for %v num:%d\n%s", CollectionIndex{slice, pos}, num, diff)
}
}
r.AddSortedSlice(m(10, 20, 30))
r.AddSortedSlice(m(40, 50, 60))
r.AddSortedSlice(m(70, 80, 90))
tv(0, 0, 0, 10, 20, 30, 40, 50, 60, 70, 80, 90)
tv(0, 2, 3, 30, 40, 50)
tv(0, 3, 3, 40, 50, 60)
tv(1, 0, 4, 40, 50, 60, 70)
r.AddSortedSlice(m(100, 110, 120))
r.AddSortedSlice(m(41, 61, 71, 99))
tv(0, 0, 0, 10, 20, 30, 40, 41, 50, 60, 61, 70, 71, 80, 90, 99, 100, 110, 120)
}
func run_scoring(b *testing.B, depth, breadth int, query string) {
b.StopTimer()
root := node{name: string(os.PathSeparator)}