Files
llama-swap/internal/router/peer.go
T
Benson Wong 0cfe5a6639 Makefile,internal: fix websocket regression and other small things (#830)
- fix websocket regression and add test to prevent in the future
- fix staticheck errors
- remove proxy package remnants from Makefile 

fix #829
2026-06-09 21:37:53 -07:00

188 lines
4.6 KiB
Go

package router
import (
"context"
"fmt"
"net"
"net/http"
"net/http/httputil"
"runtime"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/mostlygeek/llama-swap/internal/config"
"github.com/mostlygeek/llama-swap/internal/logmon"
)
type peerMember struct {
peerID string
reverseProxy *httputil.ReverseProxy
apiKey string
}
type Peer struct {
cfg config.Config
logger *logmon.Monitor
peers map[string]*peerMember
shutdownCtx context.Context
shutdownFn context.CancelFunc
shuttingDown atomic.Bool
inflight sync.WaitGroup
}
func NewPeer(cfg config.Config, logger *logmon.Monitor) (*Peer, error) {
peers := cfg.Peers
modelMap := make(map[string]*peerMember)
peerIDs := make([]string, 0, len(peers))
for peerID := range peers {
peerIDs = append(peerIDs, peerID)
}
sort.Strings(peerIDs)
for _, peerID := range peerIDs {
peer := peers[peerID]
peerTransport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: time.Duration(peer.Timeouts.Connect) * time.Second,
KeepAlive: time.Duration(peer.Timeouts.KeepAlive) * time.Second,
}).DialContext,
TLSHandshakeTimeout: time.Duration(peer.Timeouts.TLSHandshake) * time.Second,
ResponseHeaderTimeout: time.Duration(peer.Timeouts.ResponseHeader) * time.Second,
ExpectContinueTimeout: time.Duration(peer.Timeouts.ExpectContinue) * time.Second,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: time.Duration(peer.Timeouts.IdleConn) * time.Second,
}
reverseProxy := &httputil.ReverseProxy{
Transport: peerTransport,
Rewrite: func(r *httputil.ProxyRequest) {
r.SetURL(peer.ProxyURL)
r.Out.Host = r.Out.URL.Host
},
}
reverseProxy.ModifyResponse = func(resp *http.Response) error {
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
resp.Header.Set("X-Accel-Buffering", "no")
}
return nil
}
reverseProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
logger.Warnf("peer %s: proxy error: %v", peerID, err)
errMsg := fmt.Sprintf("peer proxy error: %v", err)
if runtime.GOOS == "darwin" && strings.Contains(err.Error(), "connect: no route to host") {
errMsg += " (hint: on macOS, check System Settings > Privacy & Security > Local Network permissions)"
}
http.Error(w, errMsg, http.StatusBadGateway)
}
pp := &peerMember{
peerID: peerID,
reverseProxy: reverseProxy,
apiKey: peer.ApiKey,
}
for _, modelID := range peer.Models {
if _, found := modelMap[modelID]; found {
logger.Warnf("peer %s: model %s already mapped to another peer, skipping", peerID, modelID)
continue
}
modelMap[modelID] = pp
}
}
shutdownCtx, shutdownFn := context.WithCancel(context.Background())
return &Peer{
cfg: cfg,
logger: logger,
peers: modelMap,
shutdownCtx: shutdownCtx,
shutdownFn: shutdownFn,
}, nil
}
func (r *Peer) Handles(model string) bool {
_, ok := r.peers[model]
return ok
}
func (r *Peer) Shutdown(timeout time.Duration) error {
if !r.shuttingDown.CompareAndSwap(false, true) {
return fmt.Errorf("shutdown already in progress")
}
if timeout == 0 {
r.shutdownFn()
r.inflight.Wait()
return nil
}
done := make(chan struct{})
go func() {
r.inflight.Wait()
close(done)
}()
select {
case <-done:
return nil
case <-time.After(timeout):
r.shutdownFn()
r.inflight.Wait()
return fmt.Errorf("peer shutdown timed out after %v", timeout)
}
}
func (r *Peer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if r.shuttingDown.Load() {
SendError(w, req, fmt.Errorf("peer proxy is shutting down"))
return
}
r.inflight.Add(1)
defer r.inflight.Done()
data, err := FetchContext(req, r.cfg)
if err != nil {
SendError(w, req, err)
return
}
pp, found := r.peers[data.ModelID]
if !found {
r.logger.Warnf("peer model not found: %s", data.ModelID)
SendError(w, req, ErrNoPeerModelFound)
return
}
r.logger.Debugf("peer: routing model %s to peer %s", data.ModelID, pp.peerID)
if pp.apiKey != "" {
req.Header.Set("Authorization", "Bearer "+pp.apiKey)
req.Header.Set("x-api-key", pp.apiKey)
}
// Cancel the proxy request when the client disconnects or shutdown times out.
// AfterFunc links both parent contexts to our child without a goroutine leak.
ctx, cancel := context.WithCancel(context.Background())
stopReq := context.AfterFunc(req.Context(), cancel)
stopShutdown := context.AfterFunc(r.shutdownCtx, cancel)
req = req.WithContext(ctx)
pp.reverseProxy.ServeHTTP(w, req)
stopShutdown()
stopReq()
cancel()
}