package valkey import ( "bufio" "context" "crypto/tls" "errors" "fmt" "io" "net" "os" "regexp" "runtime" "strconv" "strings" "sync" "sync/atomic" "time" "github.com/valkey-io/valkey-go/internal/cmds" ) const LibName = "valkey" const LibVer = "1.0.55" var noHello = regexp.MustCompile("unknown command .?(HELLO|hello).?") // See https://github.com/redis/rueidis/pull/691 func isUnsubReply(msg *ValkeyMessage) bool { // ex. NOPERM User limiteduser has no permissions to run the 'ping' command // ex. LOADING server is loading the dataset in memory // ex. BUSY if msg.typ == '-' && (strings.HasPrefix(msg.string, "LOADING") || strings.HasPrefix(msg.string, "BUSY") || strings.Contains(msg.string, "'ping'")) { msg.typ = '+' msg.string = "PONG" return true } return msg.string == "PONG" || (len(msg.values) != 0 && msg.values[0].string == "pong") } type wire interface { Do(ctx context.Context, cmd Completed) ValkeyResult DoCache(ctx context.Context, cmd Cacheable, ttl time.Duration) ValkeyResult DoMulti(ctx context.Context, multi ...Completed) *valkeyresults DoMultiCache(ctx context.Context, multi ...CacheableTTL) *valkeyresults Receive(ctx context.Context, subscribe Completed, fn func(message PubSubMessage)) error DoStream(ctx context.Context, pool *pool, cmd Completed) ValkeyResultStream DoMultiStream(ctx context.Context, pool *pool, multi ...Completed) MultiValkeyResultStream Info() map[string]ValkeyMessage Version() int AZ() string Error() error Close() CleanSubscriptions() SetPubSubHooks(hooks PubSubHooks) <-chan error SetOnCloseHook(fn func(error)) } var _ wire = (*pipe)(nil) type pipe struct { conn net.Conn error atomic.Value clhks atomic.Value // closed hook, invoked after the conn is closed pshks atomic.Value // pubsub hook, registered by the SetPubSubHooks queue queue cache CacheStore r *bufio.Reader w *bufio.Writer close chan struct{} onInvalidations func([]ValkeyMessage) r2psFn func() (p *pipe, err error) // func to build pipe for resp2 pubsub r2pipe *pipe // internal pipe for resp2 pubsub only ssubs *subs // pubsub smessage subscriptions nsubs *subs // pubsub message subscriptions psubs *subs // pubsub pmessage subscriptions info map[string]ValkeyMessage timeout time.Duration pinggap time.Duration maxFlushDelay time.Duration once sync.Once r2mu sync.Mutex version int32 _ [10]int32 blcksig int32 state int32 waits int32 recvs int32 r2ps bool // identify this pipe is used for resp2 pubsub or not noNoDelay bool } type pipeFn func(connFn func() (net.Conn, error), option *ClientOption) (p *pipe, err error) func newPipe(connFn func() (net.Conn, error), option *ClientOption) (p *pipe, err error) { return _newPipe(connFn, option, false, false) } func newPipeNoBg(connFn func() (net.Conn, error), option *ClientOption) (p *pipe, err error) { return _newPipe(connFn, option, false, true) } func _newPipe(connFn func() (net.Conn, error), option *ClientOption, r2ps, nobg bool) (p *pipe, err error) { conn, err := connFn() if err != nil { return nil, err } p = &pipe{ conn: conn, r: bufio.NewReaderSize(conn, option.ReadBufferEachConn), w: bufio.NewWriterSize(conn, option.WriteBufferEachConn), timeout: option.ConnWriteTimeout, pinggap: option.Dialer.KeepAlive, maxFlushDelay: option.MaxFlushDelay, noNoDelay: option.DisableTCPNoDelay, r2ps: r2ps, } if !nobg { p.queue = newRing(option.RingScaleEachConn) p.nsubs = newSubs() p.psubs = newSubs() p.ssubs = newSubs() p.close = make(chan struct{}) } if !r2ps { p.r2psFn = func() (p *pipe, err error) { return _newPipe(connFn, option, true, nobg) } } if !nobg && !option.DisableCache { cacheStoreFn := option.NewCacheStoreFn if cacheStoreFn == nil { cacheStoreFn = newLRU } p.cache = cacheStoreFn(CacheStoreOption{CacheSizeEachConn: option.CacheSizeEachConn}) } p.pshks.Store(emptypshks) p.clhks.Store(emptyclhks) username := option.Username password := option.Password if option.AuthCredentialsFn != nil { authCredentialsContext := AuthCredentialsContext{ Address: conn.RemoteAddr(), } authCredentials, err := option.AuthCredentialsFn(authCredentialsContext) if err != nil { p.Close() return nil, err } username = authCredentials.Username password = authCredentials.Password } helloCmd := []string{"HELLO", "3"} if password != "" && username == "" { helloCmd = append(helloCmd, "AUTH", "default", password) } else if username != "" { helloCmd = append(helloCmd, "AUTH", username, password) } if option.ClientName != "" { helloCmd = append(helloCmd, "SETNAME", option.ClientName) } init := make([][]string, 0, 5) if option.ClientTrackingOptions == nil { init = append(init, helloCmd, []string{"CLIENT", "TRACKING", "ON", "OPTIN"}) } else { init = append(init, helloCmd, append([]string{"CLIENT", "TRACKING", "ON"}, option.ClientTrackingOptions...)) } if option.DisableCache { init = init[:1] } if option.SelectDB != 0 { init = append(init, []string{"SELECT", strconv.Itoa(option.SelectDB)}) } if option.ReplicaOnly && option.Sentinel.MasterSet == "" { init = append(init, []string{"READONLY"}) } if option.ClientNoTouch { init = append(init, []string{"CLIENT", "NO-TOUCH", "ON"}) } if option.ClientNoEvict { init = append(init, []string{"CLIENT", "NO-EVICT", "ON"}) } addClientSetInfoCmds := true if len(option.ClientSetInfo) == 2 { init = append(init, []string{"CLIENT", "SETINFO", "LIB-NAME", option.ClientSetInfo[0]}, []string{"CLIENT", "SETINFO", "LIB-VER", option.ClientSetInfo[1]}) } else if option.ClientSetInfo == nil { init = append(init, []string{"CLIENT", "SETINFO", "LIB-NAME", LibName}, []string{"CLIENT", "SETINFO", "LIB-VER", LibVer}) } else { addClientSetInfoCmds = false } timeout := option.Dialer.Timeout if timeout <= 0 { timeout = DefaultDialTimeout } ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() r2 := option.AlwaysRESP2 if !r2 && !r2ps { resp := p.DoMulti(ctx, cmds.NewMultiCompleted(init)...) defer resultsp.Put(resp) count := len(resp.s) if addClientSetInfoCmds { // skip error checking on the last CLIENT SETINFO count -= 2 } for i, r := range resp.s[:count] { if i == 0 { p.info, err = r.AsMap() } else { err = r.Error() } if err != nil { if init[i][0] == "READONLY" { // ignore READONLY command error continue } if re, ok := err.(*ValkeyError); ok { if !r2 && noHello.MatchString(re.string) { r2 = true continue } else if init[i][0] == "CLIENT" { err = fmt.Errorf("%s: %v\n%w", re.string, init[i], ErrNoCache) } else if r2 { continue } } p.Close() return nil, err } } } if proto := p.info["proto"]; proto.integer < 3 { r2 = true } if !r2 && !r2ps { if ver, ok := p.info["version"]; ok { if v := strings.Split(ver.string, "."); len(v) != 0 { vv, _ := strconv.ParseInt(v[0], 10, 32) p.version = int32(vv) } } p.onInvalidations = option.OnInvalidations } else { if !option.DisableCache { p.Close() return nil, ErrNoCache } init = init[:0] init = append(init, []string{"HELLO", "2"}) if password != "" && username == "" { init = append(init, []string{"AUTH", password}) } else if username != "" { init = append(init, []string{"AUTH", username, password}) } if option.ClientName != "" { init = append(init, []string{"CLIENT", "SETNAME", option.ClientName}) } if option.SelectDB != 0 { init = append(init, []string{"SELECT", strconv.Itoa(option.SelectDB)}) } if option.ReplicaOnly && option.Sentinel.MasterSet == "" { init = append(init, []string{"READONLY"}) } if option.ClientNoTouch { init = append(init, []string{"CLIENT", "NO-TOUCH", "ON"}) } if option.ClientNoEvict { init = append(init, []string{"CLIENT", "NO-EVICT", "ON"}) } addClientSetInfoCmds := true if len(option.ClientSetInfo) == 2 { init = append(init, []string{"CLIENT", "SETINFO", "LIB-NAME", option.ClientSetInfo[0]}, []string{"CLIENT", "SETINFO", "LIB-VER", option.ClientSetInfo[1]}) } else if option.ClientSetInfo == nil { init = append(init, []string{"CLIENT", "SETINFO", "LIB-NAME", LibName}, []string{"CLIENT", "SETINFO", "LIB-VER", LibVer}) } else { addClientSetInfoCmds = false } p.version = 5 if len(init) != 0 { resp := p.DoMulti(ctx, cmds.NewMultiCompleted(init)...) defer resultsp.Put(resp) count := len(resp.s) if addClientSetInfoCmds { // skip error checking on the last CLIENT SETINFO count -= 2 } for i, r := range resp.s[:count] { if init[i][0] == "READONLY" { // ignore READONLY command error continue } if err = r.Error(); err != nil { if re, ok := err.(*ValkeyError); ok && noHello.MatchString(re.string) { continue } p.Close() return nil, err } if i == 0 { p.info, err = r.AsMap() } } } } if !nobg { if p.onInvalidations != nil || option.AlwaysPipelining { p.background() } if p.timeout > 0 && p.pinggap > 0 { go p.backgroundPing() } } return p, nil } func (p *pipe) background() { if p.queue != nil { atomic.CompareAndSwapInt32(&p.state, 0, 1) p.once.Do(func() { go p._background() }) } } func (p *pipe) _exit(err error) { p.error.CompareAndSwap(nil, &errs{error: err}) atomic.CompareAndSwapInt32(&p.state, 1, 2) // stop accepting new requests _ = p.conn.Close() // force both read & write goroutine to exit p.clhks.Load().(func(error))(err) } func disableNoDelay(conn net.Conn) { if c, ok := conn.(*tls.Conn); ok { conn = c.NetConn() } if c, ok := conn.(*net.TCPConn); ok { c.SetNoDelay(false) } } func (p *pipe) _background() { p.conn.SetDeadline(time.Time{}) if p.noNoDelay { disableNoDelay(p.conn) } go func() { p._exit(p._backgroundWrite()) close(p.close) }() { p._exit(p._backgroundRead()) select { case <-p.close: default: atomic.AddInt32(&p.waits, 1) go func() { <-p.queue.PutOne(cmds.PingCmd) // avoid _backgroundWrite hanging at p.queue.WaitForWrite() atomic.AddInt32(&p.waits, -1) }() } } err := p.Error() p.nsubs.Close() p.psubs.Close() p.ssubs.Close() if old := p.pshks.Swap(emptypshks).(*pshks); old.close != nil { old.close <- err close(old.close) } var ( resps []ValkeyResult ch chan ValkeyResult cond *sync.Cond ) // clean up cache and free pending calls if p.cache != nil { p.cache.Close(ErrDoCacheAborted) } if p.onInvalidations != nil { p.onInvalidations(nil) } resp := newErrResult(err) for atomic.LoadInt32(&p.waits) != 0 { select { case <-p.close: // p.queue.NextWriteCmd() can only be called after _backgroundWrite _, _, _ = p.queue.NextWriteCmd() default: } if _, _, ch, resps, cond = p.queue.NextResultCh(); ch != nil { for i := range resps { resps[i] = resp } ch <- resp cond.L.Unlock() cond.Signal() } else { cond.L.Unlock() cond.Signal() runtime.Gosched() } } <-p.close atomic.StoreInt32(&p.state, 4) } func (p *pipe) _backgroundWrite() (err error) { var ( ones = make([]Completed, 1) multi []Completed ch chan ValkeyResult flushDelay = p.maxFlushDelay flushStart = time.Time{} ) for err == nil { if ones[0], multi, ch = p.queue.NextWriteCmd(); ch == nil { if flushDelay != 0 { flushStart = time.Now() } if p.w.Buffered() != 0 { if err = p.w.Flush(); err != nil { break } } ones[0], multi, ch = p.queue.WaitForWrite() if flushDelay != 0 && atomic.LoadInt32(&p.waits) > 1 { // do not delay for sequential usage // Blocking commands are executed in dedicated client which is acquired from pool. // So, there is no sense to wait other commands to be written. // https://github.com/redis/rueidis/issues/379 var blocked bool for i := 0; i < len(multi) && !blocked; i++ { blocked = multi[i].IsBlock() } if !blocked { time.Sleep(flushDelay - time.Since(flushStart)) // ref: https://github.com/redis/rueidis/issues/156 } } } if ch != nil && multi == nil { multi = ones } for _, cmd := range multi { err = writeCmd(p.w, cmd.Commands()) if cmd.IsUnsub() { // See https://github.com/redis/rueidis/pull/691 err = writeCmd(p.w, cmds.PingCmd.Commands()) } } } return } func (p *pipe) _backgroundRead() (err error) { var ( msg ValkeyMessage cond *sync.Cond ones = make([]Completed, 1) multi []Completed resps []ValkeyResult ch chan ValkeyResult ff int // fulfilled count skip int // skip rest push messages ver = p.version prply bool // push reply unsub bool // unsubscribe notification skipUnsubReply bool // if unsubscribe is replied r2ps = p.r2ps ) defer func() { resp := newErrResult(err) if err != nil && ff < len(multi) { for ; ff < len(resps); ff++ { resps[ff] = resp } ch <- resp cond.L.Unlock() cond.Signal() } }() for { if msg, err = readNextMessage(p.r); err != nil { return } if msg.typ == '>' || (r2ps && len(msg.values) != 0 && msg.values[0].string != "pong") { if prply, unsub = p.handlePush(msg.values); !prply { continue } if skip > 0 { skip-- prply = false unsub = false continue } } else if ver == 6 && len(msg.values) != 0 { // This is a workaround for Valkey 6's broken invalidation protocol: https://github.com/redis/redis/issues/8935 // When Valkey 6 handles MULTI, MGET, or other multi-keys command, // it will send invalidation message immediately if it finds the keys are expired, thus causing the multi-keys command response to be broken. // We fix this by fetching the next message and patch it back to the response. i := 0 for j, v := range msg.values { if v.typ == '>' { p.handlePush(v.values) } else { if i != j { msg.values[i] = v } i++ } } for ; i < len(msg.values); i++ { if msg.values[i], err = readNextMessage(p.r); err != nil { return } } } if ff == len(multi) { ff = 0 ones[0], multi, ch, resps, cond = p.queue.NextResultCh() // ch should not be nil, otherwise it must be a protocol bug if ch == nil { cond.L.Unlock() // Valkey will send sunsubscribe notification proactively in the event of slot migration. // We should ignore them and go fetch next message. // We also treat all the other unsubscribe notifications just like sunsubscribe, // so that we don't need to track how many channels we have subscribed to deal with wildcard unsubscribe command // See https://github.com/redis/rueidis/pull/691 if unsub { prply = false unsub = false continue } if skipUnsubReply && isUnsubReply(&msg) { skipUnsubReply = false continue } panic(protocolbug) } if multi == nil { multi = ones } } else if ff >= 4 && len(msg.values) >= 2 && multi[0].IsOptIn() { // if unfulfilled multi commands are lead by opt-in and get success response now := time.Now() if cacheable := Cacheable(multi[ff-1]); cacheable.IsMGet() { cc := cmds.MGetCacheCmd(cacheable) msgs := msg.values[len(msg.values)-1].values for i, cp := range msgs { ck := cmds.MGetCacheKey(cacheable, i) cp.attrs = cacheMark if pttl := msg.values[i].integer; pttl >= 0 { cp.setExpireAt(now.Add(time.Duration(pttl) * time.Millisecond).UnixMilli()) } msgs[i].setExpireAt(p.cache.Update(ck, cc, cp)) } } else { ck, cc := cmds.CacheKey(cacheable) ci := len(msg.values) - 1 cp := msg.values[ci] cp.attrs = cacheMark if pttl := msg.values[ci-1].integer; pttl >= 0 { cp.setExpireAt(now.Add(time.Duration(pttl) * time.Millisecond).UnixMilli()) } msg.values[ci].setExpireAt(p.cache.Update(ck, cc, cp)) } } if prply { // Valkey will send sunsubscribe notification proactively in the event of slot migration. // We should ignore them and go fetch next message. // We also treat all the other unsubscribe notifications just like sunsubscribe, // so that we don't need to track how many channels we have subscribed to deal with wildcard unsubscribe command // See https://github.com/redis/rueidis/pull/691 if unsub { prply = false unsub = false continue } prply = false unsub = false if !multi[ff].NoReply() { panic(protocolbug) } skip = len(multi[ff].Commands()) - 2 msg = ValkeyMessage{} // override successful subscribe/unsubscribe response to empty } else if multi[ff].NoReply() && msg.string == "QUEUED" { panic(multiexecsub) } else if multi[ff].IsUnsub() && !isUnsubReply(&msg) { // See https://github.com/redis/rueidis/pull/691 skipUnsubReply = true } else if skipUnsubReply { // See https://github.com/redis/rueidis/pull/691 if !isUnsubReply(&msg) { panic(protocolbug) } skipUnsubReply = false continue } resp := newResult(msg, err) if resps != nil { resps[ff] = resp } if ff++; ff == len(multi) { ch <- resp cond.L.Unlock() cond.Signal() } } } func (p *pipe) backgroundPing() { var err error var prev, recv int32 ticker := time.NewTicker(p.pinggap) defer ticker.Stop() for ; err == nil; prev = recv { select { case <-ticker.C: recv = atomic.LoadInt32(&p.recvs) if recv != prev || atomic.LoadInt32(&p.blcksig) != 0 || (atomic.LoadInt32(&p.state) == 0 && atomic.LoadInt32(&p.waits) != 0) { continue } ch := make(chan error, 1) tm := time.NewTimer(p.timeout) go func() { ch <- p.Do(context.Background(), cmds.PingCmd).NonValkeyError() }() select { case <-tm.C: err = os.ErrDeadlineExceeded case err = <-ch: tm.Stop() } if err != nil && atomic.LoadInt32(&p.blcksig) != 0 { err = nil } case <-p.close: return } } if err != ErrClosing { p._exit(err) } } func (p *pipe) handlePush(values []ValkeyMessage) (reply bool, unsubscribe bool) { if len(values) < 2 { return } // TODO: handle other push data // tracking-redir-broken // server-cpu-usage switch values[0].string { case "invalidate": if p.cache != nil { if values[1].IsNil() { p.cache.Delete(nil) } else { p.cache.Delete(values[1].values) } } if p.onInvalidations != nil { if values[1].IsNil() { p.onInvalidations(nil) } else { p.onInvalidations(values[1].values) } } case "message": if len(values) >= 3 { m := PubSubMessage{Channel: values[1].string, Message: values[2].string} p.nsubs.Publish(values[1].string, m) p.pshks.Load().(*pshks).hooks.OnMessage(m) } case "pmessage": if len(values) >= 4 { m := PubSubMessage{Pattern: values[1].string, Channel: values[2].string, Message: values[3].string} p.psubs.Publish(values[1].string, m) p.pshks.Load().(*pshks).hooks.OnMessage(m) } case "smessage": if len(values) >= 3 { m := PubSubMessage{Channel: values[1].string, Message: values[2].string} p.ssubs.Publish(values[1].string, m) p.pshks.Load().(*pshks).hooks.OnMessage(m) } case "unsubscribe": p.nsubs.Unsubscribe(values[1].string) if len(values) >= 3 { p.pshks.Load().(*pshks).hooks.OnSubscription(PubSubSubscription{Kind: values[0].string, Channel: values[1].string, Count: values[2].integer}) } return true, true case "punsubscribe": p.psubs.Unsubscribe(values[1].string) if len(values) >= 3 { p.pshks.Load().(*pshks).hooks.OnSubscription(PubSubSubscription{Kind: values[0].string, Channel: values[1].string, Count: values[2].integer}) } return true, true case "sunsubscribe": p.ssubs.Unsubscribe(values[1].string) if len(values) >= 3 { p.pshks.Load().(*pshks).hooks.OnSubscription(PubSubSubscription{Kind: values[0].string, Channel: values[1].string, Count: values[2].integer}) } return true, true case "subscribe", "psubscribe", "ssubscribe": if len(values) >= 3 { p.pshks.Load().(*pshks).hooks.OnSubscription(PubSubSubscription{Kind: values[0].string, Channel: values[1].string, Count: values[2].integer}) } return true, false } return false, false } func (p *pipe) _r2pipe() (r2p *pipe) { p.r2mu.Lock() if p.r2pipe != nil { r2p = p.r2pipe } else { var err error if r2p, err = p.r2psFn(); err != nil { r2p = epipeFn(err) } else { p.r2pipe = r2p } } p.r2mu.Unlock() return r2p } func (p *pipe) Receive(ctx context.Context, subscribe Completed, fn func(message PubSubMessage)) error { if p.nsubs == nil || p.psubs == nil || p.ssubs == nil { return p.Error() } if p.version < 6 && p.r2psFn != nil { return p._r2pipe().Receive(ctx, subscribe, fn) } cmds.CompletedCS(subscribe).Verify() var sb *subs cmd, args := subscribe.Commands()[0], subscribe.Commands()[1:] switch cmd { case "SUBSCRIBE": sb = p.nsubs case "PSUBSCRIBE": sb = p.psubs case "SSUBSCRIBE": sb = p.ssubs default: panic(wrongreceive) } if ch, cancel := sb.Subscribe(args); ch != nil { defer cancel() if err := p.Do(ctx, subscribe).Error(); err != nil { return err } if ctxCh := ctx.Done(); ctxCh == nil { for msg := range ch { fn(msg) } } else { next: select { case msg, ok := <-ch: if ok { fn(msg) goto next } case <-ctx.Done(): return ctx.Err() } } } return p.Error() } func (p *pipe) CleanSubscriptions() { if atomic.LoadInt32(&p.blcksig) != 0 { p.Close() } else if atomic.LoadInt32(&p.state) == 1 { if p.version >= 7 { p.DoMulti(context.Background(), cmds.UnsubscribeCmd, cmds.PUnsubscribeCmd, cmds.SUnsubscribeCmd, cmds.DiscardCmd) } else { p.DoMulti(context.Background(), cmds.UnsubscribeCmd, cmds.PUnsubscribeCmd, cmds.DiscardCmd) } } } func (p *pipe) SetPubSubHooks(hooks PubSubHooks) <-chan error { if p.version < 6 && p.r2psFn != nil { return p._r2pipe().SetPubSubHooks(hooks) } if hooks.isZero() { if old := p.pshks.Swap(emptypshks).(*pshks); old.close != nil { close(old.close) } return nil } if hooks.OnMessage == nil { hooks.OnMessage = func(m PubSubMessage) {} } if hooks.OnSubscription == nil { hooks.OnSubscription = func(s PubSubSubscription) {} } ch := make(chan error, 1) if old := p.pshks.Swap(&pshks{hooks: hooks, close: ch}).(*pshks); old.close != nil { close(old.close) } if err := p.Error(); err != nil { if old := p.pshks.Swap(emptypshks).(*pshks); old.close != nil { old.close <- err close(old.close) } } if atomic.AddInt32(&p.waits, 1) == 1 && atomic.LoadInt32(&p.state) == 0 { p.background() } atomic.AddInt32(&p.waits, -1) return ch } func (p *pipe) SetOnCloseHook(fn func(error)) { p.clhks.Store(fn) } func (p *pipe) Info() map[string]ValkeyMessage { return p.info } func (p *pipe) Version() int { return int(p.version) } func (p *pipe) AZ() string { return p.info["availability_zone"].string } func (p *pipe) Do(ctx context.Context, cmd Completed) (resp ValkeyResult) { if err := ctx.Err(); err != nil { return newErrResult(err) } cmds.CompletedCS(cmd).Verify() if cmd.IsBlock() { atomic.AddInt32(&p.blcksig, 1) defer func() { if resp.err == nil { atomic.AddInt32(&p.blcksig, -1) } }() } if cmd.NoReply() { if p.version < 6 && p.r2psFn != nil { return p._r2pipe().Do(ctx, cmd) } } waits := atomic.AddInt32(&p.waits, 1) // if this is 1, and background worker is not started, no need to queue state := atomic.LoadInt32(&p.state) if state == 1 { goto queue } if state == 0 { if waits != 1 { goto queue } if cmd.NoReply() { p.background() goto queue } dl, ok := ctx.Deadline() if p.queue != nil && !ok && ctx.Done() != nil { p.background() goto queue } resp = p.syncDo(dl, ok, cmd) } else { resp = newErrResult(p.Error()) } if left := atomic.AddInt32(&p.waits, -1); state == 0 && left != 0 { p.background() } atomic.AddInt32(&p.recvs, 1) return resp queue: ch := p.queue.PutOne(cmd) if ctxCh := ctx.Done(); ctxCh == nil { resp = <-ch } else { select { case resp = <-ch: case <-ctxCh: goto abort } } atomic.AddInt32(&p.waits, -1) atomic.AddInt32(&p.recvs, 1) return resp abort: go func(ch chan ValkeyResult) { <-ch atomic.AddInt32(&p.waits, -1) atomic.AddInt32(&p.recvs, 1) }(ch) return newErrResult(ctx.Err()) } func (p *pipe) DoMulti(ctx context.Context, multi ...Completed) *valkeyresults { resp := resultsp.Get(len(multi), len(multi)) if err := ctx.Err(); err != nil { for i := 0; i < len(resp.s); i++ { resp.s[i] = newErrResult(err) } return resp } cmds.CompletedCS(multi[0]).Verify() isOptIn := multi[0].IsOptIn() // len(multi) > 0 should have already been checked by upper layer noReply := 0 for _, cmd := range multi { if cmd.NoReply() { noReply++ } } if p.version < 6 && noReply != 0 { if noReply != len(multi) { for i := 0; i < len(resp.s); i++ { resp.s[i] = newErrResult(ErrRESP2PubSubMixed) } return resp } else if p.r2psFn != nil { resultsp.Put(resp) return p._r2pipe().DoMulti(ctx, multi...) } } for _, cmd := range multi { if cmd.IsBlock() { if noReply != 0 { for i := 0; i < len(resp.s); i++ { resp.s[i] = newErrResult(ErrBlockingPubSubMixed) } return resp } atomic.AddInt32(&p.blcksig, 1) defer func() { for _, r := range resp.s { if r.err != nil { return } } atomic.AddInt32(&p.blcksig, -1) }() break } } waits := atomic.AddInt32(&p.waits, 1) // if this is 1, and background worker is not started, no need to queue state := atomic.LoadInt32(&p.state) if state == 1 { goto queue } if state == 0 { if waits != 1 { goto queue } if isOptIn || noReply != 0 { p.background() goto queue } dl, ok := ctx.Deadline() if p.queue != nil && !ok && ctx.Done() != nil { p.background() goto queue } p.syncDoMulti(dl, ok, resp.s, multi) } else { err := newErrResult(p.Error()) for i := 0; i < len(resp.s); i++ { resp.s[i] = err } } if left := atomic.AddInt32(&p.waits, -1); state == 0 && left != 0 { p.background() } atomic.AddInt32(&p.recvs, 1) return resp queue: ch := p.queue.PutMulti(multi, resp.s) if ctxCh := ctx.Done(); ctxCh == nil { <-ch } else { select { case <-ch: case <-ctxCh: goto abort } } atomic.AddInt32(&p.waits, -1) atomic.AddInt32(&p.recvs, 1) return resp abort: go func(resp *valkeyresults, ch chan ValkeyResult) { <-ch resultsp.Put(resp) atomic.AddInt32(&p.waits, -1) atomic.AddInt32(&p.recvs, 1) }(resp, ch) resp = resultsp.Get(len(multi), len(multi)) err := newErrResult(ctx.Err()) for i := 0; i < len(resp.s); i++ { resp.s[i] = err } return resp } type MultiValkeyResultStream = ValkeyResultStream type ValkeyResultStream struct { p *pool w *pipe e error n int } // HasNext can be used in a for loop condition to check if a further WriteTo call is needed. func (s *ValkeyResultStream) HasNext() bool { return s.n > 0 && s.e == nil } // Error returns the error happened when sending commands to valkey or reading response from valkey. // Usually a user is not required to use this function because the error is also reported by the WriteTo. func (s *ValkeyResultStream) Error() error { return s.e } // WriteTo reads a valkey response from valkey and then write it to the given writer. // This function is not thread safe and should be called sequentially to read multiple responses. // An io.EOF error will be reported if all responses are read. func (s *ValkeyResultStream) WriteTo(w io.Writer) (n int64, err error) { if err = s.e; err == nil && s.n > 0 { var clean bool if n, err, clean = streamTo(s.w.r, w); !clean { s.e = err // err must not be nil in case of !clean s.n = 1 } if s.n--; s.n == 0 { atomic.AddInt32(&s.w.blcksig, -1) atomic.AddInt32(&s.w.waits, -1) if s.e == nil { s.e = io.EOF } else { s.w.Close() } s.p.Store(s.w) } } return n, err } func (p *pipe) DoStream(ctx context.Context, pool *pool, cmd Completed) ValkeyResultStream { cmds.CompletedCS(cmd).Verify() if err := ctx.Err(); err != nil { return ValkeyResultStream{e: err} } state := atomic.LoadInt32(&p.state) if state == 1 { panic("DoStream with auto pipelining is a bug") } if state == 0 { atomic.AddInt32(&p.blcksig, 1) waits := atomic.AddInt32(&p.waits, 1) if waits != 1 { panic("DoStream with racing is a bug") } dl, ok := ctx.Deadline() if ok { if p.timeout > 0 && !cmd.IsBlock() { defaultDeadline := time.Now().Add(p.timeout) if dl.After(defaultDeadline) { dl = defaultDeadline } } p.conn.SetDeadline(dl) } else if p.timeout > 0 && !cmd.IsBlock() { p.conn.SetDeadline(time.Now().Add(p.timeout)) } else { p.conn.SetDeadline(time.Time{}) } _ = writeCmd(p.w, cmd.Commands()) if err := p.w.Flush(); err != nil { p.error.CompareAndSwap(nil, &errs{error: err}) p.conn.Close() p.background() // start the background worker to clean up goroutines } else { return ValkeyResultStream{p: pool, w: p, n: 1} } } atomic.AddInt32(&p.blcksig, -1) atomic.AddInt32(&p.waits, -1) pool.Store(p) return ValkeyResultStream{e: p.Error()} } func (p *pipe) DoMultiStream(ctx context.Context, pool *pool, multi ...Completed) MultiValkeyResultStream { for _, cmd := range multi { cmds.CompletedCS(cmd).Verify() } if err := ctx.Err(); err != nil { return ValkeyResultStream{e: err} } state := atomic.LoadInt32(&p.state) if state == 1 { panic("DoMultiStream with auto pipelining is a bug") } if state == 0 { atomic.AddInt32(&p.blcksig, 1) waits := atomic.AddInt32(&p.waits, 1) if waits != 1 { panic("DoMultiStream with racing is a bug") } dl, ok := ctx.Deadline() if ok { if p.timeout > 0 { for _, cmd := range multi { if cmd.IsBlock() { p.conn.SetDeadline(dl) goto process } } defaultDeadline := time.Now().Add(p.timeout) if dl.After(defaultDeadline) { dl = defaultDeadline } } p.conn.SetDeadline(dl) } else if p.timeout > 0 { for _, cmd := range multi { if cmd.IsBlock() { p.conn.SetDeadline(time.Time{}) goto process } } p.conn.SetDeadline(time.Now().Add(p.timeout)) } else { p.conn.SetDeadline(time.Time{}) } process: for _, cmd := range multi { _ = writeCmd(p.w, cmd.Commands()) } if err := p.w.Flush(); err != nil { p.error.CompareAndSwap(nil, &errs{error: err}) p.conn.Close() p.background() // start the background worker to clean up goroutines } else { return ValkeyResultStream{p: pool, w: p, n: len(multi)} } } atomic.AddInt32(&p.blcksig, -1) atomic.AddInt32(&p.waits, -1) pool.Store(p) return ValkeyResultStream{e: p.Error()} } func (p *pipe) syncDo(dl time.Time, dlOk bool, cmd Completed) (resp ValkeyResult) { if dlOk { if p.timeout > 0 && !cmd.IsBlock() { defaultDeadline := time.Now().Add(p.timeout) if dl.After(defaultDeadline) { dl = defaultDeadline dlOk = false } } p.conn.SetDeadline(dl) } else if p.timeout > 0 && !cmd.IsBlock() { p.conn.SetDeadline(time.Now().Add(p.timeout)) } else { p.conn.SetDeadline(time.Time{}) } var msg ValkeyMessage err := flushCmd(p.w, cmd.Commands()) if err == nil { msg, err = syncRead(p.r) } if err != nil { if dlOk && errors.Is(err, os.ErrDeadlineExceeded) { err = context.DeadlineExceeded } p.error.CompareAndSwap(nil, &errs{error: err}) p.conn.Close() p.background() // start the background worker to clean up goroutines } return newResult(msg, err) } func (p *pipe) syncDoMulti(dl time.Time, dlOk bool, resp []ValkeyResult, multi []Completed) { if dlOk { if p.timeout > 0 { for _, cmd := range multi { if cmd.IsBlock() { p.conn.SetDeadline(dl) goto process } } defaultDeadline := time.Now().Add(p.timeout) if dl.After(defaultDeadline) { dl = defaultDeadline dlOk = false } } p.conn.SetDeadline(dl) } else if p.timeout > 0 { for _, cmd := range multi { if cmd.IsBlock() { p.conn.SetDeadline(time.Time{}) goto process } } p.conn.SetDeadline(time.Now().Add(p.timeout)) } else { p.conn.SetDeadline(time.Time{}) } process: var err error var msg ValkeyMessage for _, cmd := range multi { _ = writeCmd(p.w, cmd.Commands()) } if err = p.w.Flush(); err != nil { goto abort } for i := 0; i < len(resp); i++ { if msg, err = syncRead(p.r); err != nil { goto abort } resp[i] = newResult(msg, err) } return abort: if dlOk && errors.Is(err, os.ErrDeadlineExceeded) { err = context.DeadlineExceeded } p.error.CompareAndSwap(nil, &errs{error: err}) p.conn.Close() p.background() // start the background worker to clean up goroutines for i := 0; i < len(resp); i++ { resp[i] = newErrResult(err) } } func syncRead(r *bufio.Reader) (m ValkeyMessage, err error) { next: if m, err = readNextMessage(r); err != nil { return m, err } if m.typ == '>' { goto next } return m, nil } func (p *pipe) DoCache(ctx context.Context, cmd Cacheable, ttl time.Duration) ValkeyResult { if p.cache == nil { return p.Do(ctx, Completed(cmd)) } cmds.CacheableCS(cmd).Verify() if cmd.IsMGet() { return p.doCacheMGet(ctx, cmd, ttl) } ck, cc := cmds.CacheKey(cmd) now := time.Now() if v, entry := p.cache.Flight(ck, cc, ttl, now); v.typ != 0 { return newResult(v, nil) } else if entry != nil { return newResult(entry.Wait(ctx)) } resp := p.DoMulti( ctx, cmds.OptInCmd, cmds.MultiCmd, cmds.NewCompleted([]string{"PTTL", ck}), Completed(cmd), cmds.ExecCmd, ) defer resultsp.Put(resp) exec, err := resp.s[4].ToArray() if err != nil { if _, ok := err.(*ValkeyError); ok { err = ErrDoCacheAborted if preErr := resp.s[3].Error(); preErr != nil { // if {cmd} get a ValkeyError if _, ok := preErr.(*ValkeyError); ok { err = preErr } } } p.cache.Cancel(ck, cc, err) return newErrResult(err) } return newResult(exec[1], nil) } func (p *pipe) doCacheMGet(ctx context.Context, cmd Cacheable, ttl time.Duration) ValkeyResult { commands := cmd.Commands() keys := len(commands) - 1 builder := cmds.NewBuilder(cmds.InitSlot) result := ValkeyResult{val: ValkeyMessage{typ: '*', values: nil}} mgetcc := cmds.MGetCacheCmd(cmd) if mgetcc[0] == 'J' { keys-- // the last one of JSON.MGET is a path, not a key } entries := entriesp.Get(keys, keys) defer entriesp.Put(entries) var now = time.Now() var rewrite cmds.Arbitrary for i, key := range commands[1 : keys+1] { v, entry := p.cache.Flight(key, mgetcc, ttl, now) if v.typ != 0 { // cache hit for one key if len(result.val.values) == 0 { result.val.values = make([]ValkeyMessage, keys) } result.val.values[i] = v continue } if entry != nil { entries.e[i] = entry // store entries for later entry.Wait() to avoid MGET deadlock each others. continue } if rewrite.IsZero() { rewrite = builder.Arbitrary(commands[0]) } rewrite = rewrite.Args(key) } var partial []ValkeyMessage if !rewrite.IsZero() { var rewritten Completed var keys int if mgetcc[0] == 'J' { // rewrite JSON.MGET path rewritten = rewrite.Args(commands[len(commands)-1]).MultiGet() keys = len(rewritten.Commands()) - 2 } else { rewritten = rewrite.MultiGet() keys = len(rewritten.Commands()) - 1 } multi := make([]Completed, 0, keys+4) multi = append(multi, cmds.OptInCmd, cmds.MultiCmd) for _, key := range rewritten.Commands()[1 : keys+1] { multi = append(multi, builder.Pttl().Key(key).Build()) } multi = append(multi, rewritten, cmds.ExecCmd) resp := p.DoMulti(ctx, multi...) defer resultsp.Put(resp) exec, err := resp.s[len(multi)-1].ToArray() if err != nil { if _, ok := err.(*ValkeyError); ok { err = ErrDoCacheAborted if preErr := resp.s[len(multi)-2].Error(); preErr != nil { // if {rewritten} get a ValkeyError if _, ok := preErr.(*ValkeyError); ok { err = preErr } } } for _, key := range rewritten.Commands()[1 : keys+1] { p.cache.Cancel(key, mgetcc, err) } return newErrResult(err) } defer func() { for _, cmd := range multi[2 : len(multi)-1] { cmds.PutCompleted(cmd) } }() last := len(exec) - 1 if len(rewritten.Commands()) == len(commands) { // all cache miss return newResult(exec[last], nil) } partial = exec[last].values } else { // all cache hit result.val.attrs = cacheMark } if len(result.val.values) == 0 { result.val.values = make([]ValkeyMessage, keys) } for i, entry := range entries.e { v, err := entry.Wait(ctx) if err != nil { return newErrResult(err) } result.val.values[i] = v } j := 0 for _, ret := range partial { for ; j < len(result.val.values); j++ { if result.val.values[j].typ == 0 { result.val.values[j] = ret break } } } return result } func (p *pipe) DoMultiCache(ctx context.Context, multi ...CacheableTTL) *valkeyresults { if p.cache == nil { commands := make([]Completed, len(multi)) for i, ct := range multi { commands[i] = Completed(ct.Cmd) } return p.DoMulti(ctx, commands...) } cmds.CacheableCS(multi[0].Cmd).Verify() results := resultsp.Get(len(multi), len(multi)) entries := entriesp.Get(len(multi), len(multi)) defer entriesp.Put(entries) var missing []Completed now := time.Now() for _, ct := range multi { if ct.Cmd.IsMGet() { panic(panicmgetcsc) } } if cache, ok := p.cache.(*lru); ok { missed := cache.Flights(now, multi, results.s, entries.e) for _, i := range missed { ct := multi[i] ck, _ := cmds.CacheKey(ct.Cmd) missing = append(missing, cmds.OptInCmd, cmds.MultiCmd, cmds.NewCompleted([]string{"PTTL", ck}), Completed(ct.Cmd), cmds.ExecCmd) } } else { for i, ct := range multi { ck, cc := cmds.CacheKey(ct.Cmd) v, entry := p.cache.Flight(ck, cc, ct.TTL, now) if v.typ != 0 { // cache hit for one key results.s[i] = newResult(v, nil) continue } if entry != nil { entries.e[i] = entry // store entries for later entry.Wait() to avoid MGET deadlock each others. continue } missing = append(missing, cmds.OptInCmd, cmds.MultiCmd, cmds.NewCompleted([]string{"PTTL", ck}), Completed(ct.Cmd), cmds.ExecCmd) } } var resp *valkeyresults if len(missing) > 0 { resp = p.DoMulti(ctx, missing...) defer resultsp.Put(resp) for i := 4; i < len(resp.s); i += 5 { if err := resp.s[i].Error(); err != nil { if _, ok := err.(*ValkeyError); ok { err = ErrDoCacheAborted if preErr := resp.s[i-1].Error(); preErr != nil { // if {cmd} get a ValkeyError if _, ok := preErr.(*ValkeyError); ok { err = preErr } } } ck, cc := cmds.CacheKey(Cacheable(missing[i-1])) p.cache.Cancel(ck, cc, err) } } } for i, entry := range entries.e { results.s[i] = newResult(entry.Wait(ctx)) } if len(missing) == 0 { return results } j := 0 for i := 4; i < len(resp.s); i += 5 { for ; j < len(results.s); j++ { if results.s[j].val.typ == 0 && results.s[j].err == nil { exec, err := resp.s[i].ToArray() if err != nil { if _, ok := err.(*ValkeyError); ok { err = ErrDoCacheAborted if preErr := resp.s[i-1].Error(); preErr != nil { // if {cmd} get a ValkeyError if _, ok := preErr.(*ValkeyError); ok { err = preErr } } } results.s[j] = newErrResult(err) } else { results.s[j] = newResult(exec[len(exec)-1], nil) } break } } } return results } func (p *pipe) Error() error { if err, ok := p.error.Load().(*errs); ok { return err.error } return nil } func (p *pipe) Close() { p.error.CompareAndSwap(nil, errClosing) block := atomic.AddInt32(&p.blcksig, 1) waits := atomic.AddInt32(&p.waits, 1) stopping1 := atomic.CompareAndSwapInt32(&p.state, 0, 2) stopping2 := atomic.CompareAndSwapInt32(&p.state, 1, 2) if p.queue != nil { if stopping1 && waits == 1 { // make sure there is no sync read p.background() } if block == 1 && (stopping1 || stopping2) { // make sure there is no block cmd atomic.AddInt32(&p.waits, 1) ch := p.queue.PutOne(cmds.PingCmd) select { case <-ch: atomic.AddInt32(&p.waits, -1) case <-time.After(time.Second): go func(ch chan ValkeyResult) { <-ch atomic.AddInt32(&p.waits, -1) }(ch) } } } atomic.AddInt32(&p.waits, -1) atomic.AddInt32(&p.blcksig, -1) if p.conn != nil { p.conn.Close() } p.r2mu.Lock() if p.r2pipe != nil { p.r2pipe.Close() } p.r2mu.Unlock() } type pshks struct { hooks PubSubHooks close chan error } var emptypshks = &pshks{ hooks: PubSubHooks{ OnMessage: func(m PubSubMessage) {}, OnSubscription: func(s PubSubSubscription) {}, }, close: nil, } var emptyclhks = func(error) {} func deadFn() *pipe { dead := &pipe{state: 3} dead.error.Store(errClosing) dead.pshks.Store(emptypshks) dead.clhks.Store(emptyclhks) return dead } func epipeFn(err error) *pipe { dead := &pipe{state: 3} dead.error.Store(&errs{error: err}) dead.pshks.Store(emptypshks) dead.clhks.Store(emptyclhks) return dead } const ( protocolbug = "protocol bug, message handled out of order" wrongreceive = "only SUBSCRIBE, SSUBSCRIBE, or PSUBSCRIBE command are allowed in Receive" multiexecsub = "SUBSCRIBE/UNSUBSCRIBE are not allowed in MULTI/EXEC block" panicmgetcsc = "MGET and JSON.MGET in DoMultiCache are not implemented, use DoCache instead" ) var cacheMark = &(ValkeyMessage{}) var errClosing = &errs{error: ErrClosing} type errs struct{ error }