62aea0e83d
- refactor shared http functionality into internal/shared/http.go - remove stripping of Authorization and x-api-key - add Request Context middleware to internal/server - add /ui and /metrics behind auth middleware, fixes #717 Fix #717 Updates: #834
189 lines
4.7 KiB
Go
189 lines
4.7 KiB
Go
package router
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httputil"
|
|
"runtime"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/mostlygeek/llama-swap/internal/config"
|
|
"github.com/mostlygeek/llama-swap/internal/logmon"
|
|
"github.com/mostlygeek/llama-swap/internal/shared"
|
|
)
|
|
|
|
type peerMember struct {
|
|
peerID string
|
|
reverseProxy *httputil.ReverseProxy
|
|
apiKey string
|
|
}
|
|
|
|
type Peer struct {
|
|
cfg config.Config
|
|
logger *logmon.Monitor
|
|
peers map[string]*peerMember
|
|
|
|
shutdownCtx context.Context
|
|
shutdownFn context.CancelFunc
|
|
shuttingDown atomic.Bool
|
|
inflight sync.WaitGroup
|
|
}
|
|
|
|
func NewPeer(cfg config.Config, logger *logmon.Monitor) (*Peer, error) {
|
|
peers := cfg.Peers
|
|
modelMap := make(map[string]*peerMember)
|
|
|
|
peerIDs := make([]string, 0, len(peers))
|
|
for peerID := range peers {
|
|
peerIDs = append(peerIDs, peerID)
|
|
}
|
|
sort.Strings(peerIDs)
|
|
|
|
for _, peerID := range peerIDs {
|
|
peer := peers[peerID]
|
|
|
|
peerTransport := &http.Transport{
|
|
Proxy: http.ProxyFromEnvironment,
|
|
DialContext: (&net.Dialer{
|
|
Timeout: time.Duration(peer.Timeouts.Connect) * time.Second,
|
|
KeepAlive: time.Duration(peer.Timeouts.KeepAlive) * time.Second,
|
|
}).DialContext,
|
|
TLSHandshakeTimeout: time.Duration(peer.Timeouts.TLSHandshake) * time.Second,
|
|
ResponseHeaderTimeout: time.Duration(peer.Timeouts.ResponseHeader) * time.Second,
|
|
ExpectContinueTimeout: time.Duration(peer.Timeouts.ExpectContinue) * time.Second,
|
|
ForceAttemptHTTP2: true,
|
|
MaxIdleConns: 100,
|
|
MaxIdleConnsPerHost: 10,
|
|
IdleConnTimeout: time.Duration(peer.Timeouts.IdleConn) * time.Second,
|
|
}
|
|
|
|
reverseProxy := &httputil.ReverseProxy{
|
|
Transport: peerTransport,
|
|
Rewrite: func(r *httputil.ProxyRequest) {
|
|
r.SetURL(peer.ProxyURL)
|
|
r.Out.Host = r.Out.URL.Host
|
|
},
|
|
}
|
|
|
|
reverseProxy.ModifyResponse = func(resp *http.Response) error {
|
|
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
|
|
resp.Header.Set("X-Accel-Buffering", "no")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
reverseProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
|
logger.Warnf("peer %s: proxy error: %v", peerID, err)
|
|
errMsg := fmt.Sprintf("peer proxy error: %v", err)
|
|
if runtime.GOOS == "darwin" && strings.Contains(err.Error(), "connect: no route to host") {
|
|
errMsg += " (hint: on macOS, check System Settings > Privacy & Security > Local Network permissions)"
|
|
}
|
|
http.Error(w, errMsg, http.StatusBadGateway)
|
|
}
|
|
|
|
pp := &peerMember{
|
|
peerID: peerID,
|
|
reverseProxy: reverseProxy,
|
|
apiKey: peer.ApiKey,
|
|
}
|
|
|
|
for _, modelID := range peer.Models {
|
|
if _, found := modelMap[modelID]; found {
|
|
logger.Warnf("peer %s: model %s already mapped to another peer, skipping", peerID, modelID)
|
|
continue
|
|
}
|
|
modelMap[modelID] = pp
|
|
}
|
|
}
|
|
|
|
shutdownCtx, shutdownFn := context.WithCancel(context.Background())
|
|
|
|
return &Peer{
|
|
cfg: cfg,
|
|
logger: logger,
|
|
peers: modelMap,
|
|
shutdownCtx: shutdownCtx,
|
|
shutdownFn: shutdownFn,
|
|
}, nil
|
|
}
|
|
|
|
func (r *Peer) Handles(model string) bool {
|
|
_, ok := r.peers[model]
|
|
return ok
|
|
}
|
|
|
|
func (r *Peer) Shutdown(timeout time.Duration) error {
|
|
if !r.shuttingDown.CompareAndSwap(false, true) {
|
|
return fmt.Errorf("shutdown already in progress")
|
|
}
|
|
|
|
if timeout == 0 {
|
|
r.shutdownFn()
|
|
r.inflight.Wait()
|
|
return nil
|
|
}
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
r.inflight.Wait()
|
|
close(done)
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
return nil
|
|
case <-time.After(timeout):
|
|
r.shutdownFn()
|
|
r.inflight.Wait()
|
|
return fmt.Errorf("peer shutdown timed out after %v", timeout)
|
|
}
|
|
}
|
|
|
|
func (r *Peer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
|
if r.shuttingDown.Load() {
|
|
shared.SendError(w, req, fmt.Errorf("peer proxy is shutting down"))
|
|
return
|
|
}
|
|
r.inflight.Add(1)
|
|
defer r.inflight.Done()
|
|
|
|
data, err := shared.FetchContext(req, r.cfg)
|
|
if err != nil {
|
|
shared.SendError(w, req, err)
|
|
return
|
|
}
|
|
|
|
pp, found := r.peers[data.ModelID]
|
|
if !found {
|
|
r.logger.Warnf("peer model not found: %s", data.ModelID)
|
|
shared.SendError(w, req, ErrNoPeerModelFound)
|
|
return
|
|
}
|
|
|
|
r.logger.Debugf("peer: routing model %s to peer %s", data.ModelID, pp.peerID)
|
|
|
|
if pp.apiKey != "" {
|
|
req.Header.Set("Authorization", "Bearer "+pp.apiKey)
|
|
req.Header.Set("x-api-key", pp.apiKey)
|
|
}
|
|
|
|
// Cancel the proxy request when the client disconnects or shutdown times out.
|
|
// AfterFunc links both parent contexts to our child without a goroutine leak.
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
stopReq := context.AfterFunc(req.Context(), cancel)
|
|
stopShutdown := context.AfterFunc(r.shutdownCtx, cancel)
|
|
req = req.WithContext(ctx)
|
|
|
|
pp.reverseProxy.ServeHTTP(w, req)
|
|
|
|
stopShutdown()
|
|
stopReq()
|
|
cancel()
|
|
}
|