Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| fd50932dbc | |||
| 8c693e7fcf |
@@ -18,7 +18,7 @@ Written in golang, it is very easy to install (single binary with no dependencie
|
||||
- `v1/completions`
|
||||
- `v1/chat/completions`
|
||||
- `v1/embeddings`
|
||||
- `v1/rerank`
|
||||
- `v1/rerank`, `v1/reranking`, `rerank`
|
||||
- `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36))
|
||||
- `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867))
|
||||
- ✅ llama-swap custom API endpoints
|
||||
|
||||
@@ -17,6 +17,7 @@ func MetricsMiddleware(pm *ProxyManager) gin.HandlerFunc {
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
@@ -24,15 +25,16 @@ func MetricsMiddleware(pm *ProxyManager) gin.HandlerFunc {
|
||||
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
||||
if requestedModel == "" {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||
if !found {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Set("ls-real-model-name", realModelName)
|
||||
|
||||
writer := &MetricsResponseWriter{
|
||||
ResponseWriter: c.Writer,
|
||||
|
||||
+14
-5
@@ -14,6 +14,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
@@ -160,8 +161,10 @@ func (pm *ProxyManager) setupGinEngine() {
|
||||
pm.ginEngine.POST("/v1/completions", mm, pm.proxyOAIHandler)
|
||||
|
||||
// Support embeddings
|
||||
pm.ginEngine.POST("/v1/embeddings", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/rerank", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/embeddings", mm, pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/rerank", mm, pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/reranking", mm, pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/rerank", mm, pm.proxyOAIHandler)
|
||||
|
||||
// Support audio/speech endpoint
|
||||
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
|
||||
@@ -365,9 +368,15 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
realModelName := c.GetString("ls-real-model-name") // Should be set in MetricsMiddleware
|
||||
if realModelName == "" {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, "ls-real-model-name not set")
|
||||
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
||||
if requestedModel == "" {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
||||
return
|
||||
}
|
||||
|
||||
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||
if !found {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user