Model capabilities 734 (#842)

internal/config,server: implement model capabilities

- define the capabilities of a model using a simple config block on the
model
- v1/models renders out capabilities to be compatible with openrouter,
huggingface chat, and mistral formats for broader compatibility
- add support for capabilities in UI

Fixes #734
This commit is contained in:
Benson Wong
2026-06-13 23:23:19 -07:00
committed by GitHub
parent 62aea0e83d
commit 92b90447e8
16 changed files with 868 additions and 35 deletions
+54 -1
View File
@@ -378,6 +378,59 @@
}, },
"timeouts": { "timeouts": {
"$ref": "#/definitions/timeouts" "$ref": "#/definitions/timeouts"
},
"capabilities": {
"type": "object",
"properties": {
"in": {
"type": "array",
"minItems": 1,
"uniqueItems": true,
"default": [],
"items": {
"type": "string",
"enum": [
"text",
"audio",
"image"
]
},
"description": "List of input modalities understood by the model."
},
"out": {
"type": "array",
"minItems": 1,
"uniqueItems": true,
"default": [],
"items": {
"type": "string",
"enum": [
"text",
"audio",
"image"
]
},
"description": "List of output modalities generated by the model."
},
"tools": {
"type": "boolean",
"default": false,
"description": "Whether the model supports function calling."
},
"reranker": {
"type": "boolean",
"default": false,
"description": "Whether the model supports the /v1/rerank endpoint."
},
"context": {
"type": "integer",
"minimum": 0,
"default": 0,
"description": "Maximum token context length supported by the model."
}
},
"additionalProperties": false,
"description": "Defines what the model accepts for input, output and other metadata. Used in v1/models to inform clients what the model can do. An empty capabilities block (all zero values) is treated as not configured."
} }
} }
} }
@@ -619,4 +672,4 @@
} }
} }
] ]
} }
+31
View File
@@ -312,6 +312,37 @@ models:
tlsHandshake: 10 tlsHandshake: 10
idleConn: 90 idleConn: 90
# capabilities: defines what the model accepts for input, output and other metadata
# - optional; omitted or all-zero means no capabilities
# - used in v1/models to inform clients what the model can do
capabilities:
# in: list of modalities understood by the model
# - default: []
# - valid: text, audio, image
in:
- text
- audio
- image
# out: list of modalities generated by the model
# - default: []
# - valid: text, audio, image
out:
- text
- audio
- image
# tools: the model supports function calling
# - default: false
tools: true
# reranker: the model supports the /v1/rerank endpoint
# - default: false
reranker: false
# context: the maximum token context length supported
# - default: 0
# - must be an integer > 0
context: 32000
# Unlisted model example: # Unlisted model example:
"qwen-unlisted": "qwen-unlisted":
# unlisted: boolean, true or false # unlisted: boolean, true or false
+4
View File
@@ -447,6 +447,10 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
} }
} }
if err = modelConfig.Capabilities.Validate(); err != nil {
return Config{}, fmt.Errorf("model %s: %w", modelId, err)
}
// Validate SetParamsByID keys and values // Validate SetParamsByID keys and values
for key, paramMap := range modelConfig.Filters.SetParamsByID { for key, paramMap := range modelConfig.Filters.SetParamsByID {
if matches := macroPatternRegex.FindAllStringSubmatch(key, -1); len(matches) > 0 { if matches := macroPatternRegex.FindAllStringSubmatch(key, -1); len(matches) > 0 {
+45
View File
@@ -2,6 +2,7 @@ package config
import ( import (
"errors" "errors"
"fmt"
"runtime" "runtime"
) )
@@ -9,6 +10,47 @@ const (
MODEL_CONFIG_DEFAULT_TTL = -1 MODEL_CONFIG_DEFAULT_TTL = -1
) )
var validModalities = map[string]struct{}{
"text": {},
"audio": {},
"image": {},
}
// ModelCapConfig defines what modalities and features a model supports.
// Used in /v1/models to inform clients. An empty block (all zero values) is
// treated as not configured.
type ModelCapConfig struct {
In []string `yaml:"in"`
Out []string `yaml:"out"`
Tools bool `yaml:"tools"`
Reranker bool `yaml:"reranker"`
Context int `yaml:"context"`
}
// Empty returns true when all fields are at their zero values.
func (c ModelCapConfig) Empty() bool {
return len(c.In) == 0 && len(c.Out) == 0 && !c.Tools && !c.Reranker && c.Context == 0
}
// Validate checks that all modality values are recognized and context is
// non-negative. Returns an error if any value is invalid.
func (c ModelCapConfig) Validate() error {
for _, m := range c.In {
if _, ok := validModalities[m]; !ok {
return fmt.Errorf("capabilities.in: invalid modality %q, must be one of: text, audio, image", m)
}
}
for _, m := range c.Out {
if _, ok := validModalities[m]; !ok {
return fmt.Errorf("capabilities.out: invalid modality %q, must be one of: text, audio, image", m)
}
}
if c.Context < 0 {
return errors.New("capabilities.context: must be >= 0")
}
return nil
}
// TimeoutsConfig holds timeout settings for proxy connections // TimeoutsConfig holds timeout settings for proxy connections
// 0 = no timeout // 0 = no timeout
type TimeoutsConfig struct { type TimeoutsConfig struct {
@@ -55,6 +97,9 @@ type ModelConfig struct {
// Timeout settings for proxy connections // Timeout settings for proxy connections
Timeouts TimeoutsConfig `yaml:"timeouts"` Timeouts TimeoutsConfig `yaml:"timeouts"`
// Capabilities defines what modalities and features the model supports.
Capabilities ModelCapConfig `yaml:"capabilities"`
// Copy of HealthCheckTimeout from global config // Copy of HealthCheckTimeout from global config
HealthCheckTimeout int `yaml:"healthCheckTimeout"` HealthCheckTimeout int `yaml:"healthCheckTimeout"`
} }
+165 -1
View File
@@ -152,7 +152,7 @@ models:
stop: stop:
- "<|end|>" - "<|end|>"
- "<|stop|>" - "<|stop|>"
` `
config, err := LoadConfigFromReader(strings.NewReader(content)) config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err) assert.NoError(t, err)
@@ -170,3 +170,167 @@ models:
assert.Equal(t, 0.7, setParams["temperature"]) assert.Equal(t, 0.7, setParams["temperature"])
assert.Equal(t, 0.9, setParams["top_p"]) assert.Equal(t, 0.9, setParams["top_p"])
} }
func TestConfig_ModelCapabilities(t *testing.T) {
t.Run("all fields", func(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --port ${PORT}
capabilities:
in:
- text
- audio
- image
out:
- text
- audio
- image
tools: true
context: 32000
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
mc := config.Models["model1"]
assert.False(t, mc.Capabilities.Empty())
assert.Equal(t, []string{"text", "audio", "image"}, mc.Capabilities.In)
assert.Equal(t, []string{"text", "audio", "image"}, mc.Capabilities.Out)
assert.True(t, mc.Capabilities.Tools)
assert.Equal(t, 32000, mc.Capabilities.Context)
})
t.Run("partial fields", func(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --port ${PORT}
capabilities:
tools: true
context: 8192
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
mc := config.Models["model1"]
assert.False(t, mc.Capabilities.Empty())
assert.Nil(t, mc.Capabilities.In)
assert.Nil(t, mc.Capabilities.Out)
assert.True(t, mc.Capabilities.Tools)
assert.Equal(t, 8192, mc.Capabilities.Context)
})
t.Run("not set", func(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --port ${PORT}
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
mc := config.Models["model1"]
assert.True(t, mc.Capabilities.Empty())
})
t.Run("tools false is empty", func(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --port ${PORT}
capabilities:
tools: false
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
mc := config.Models["model1"]
assert.True(t, mc.Capabilities.Empty())
})
t.Run("reranker true is not empty", func(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --port ${PORT}
capabilities:
reranker: true
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
mc := config.Models["model1"]
assert.False(t, mc.Capabilities.Empty())
assert.True(t, mc.Capabilities.Reranker)
})
t.Run("reranker false is empty", func(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --port ${PORT}
capabilities:
reranker: false
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)
mc := config.Models["model1"]
assert.True(t, mc.Capabilities.Empty())
})
}
func TestConfig_ModelCapabilities_Validate(t *testing.T) {
t.Run("valid_modalities", func(t *testing.T) {
caps := ModelCapConfig{
In: []string{"text", "image"},
Out: []string{"text", "audio"},
Tools: true,
Context: 100000,
}
assert.NoError(t, caps.Validate())
})
t.Run("empty_is_valid", func(t *testing.T) {
caps := ModelCapConfig{}
assert.NoError(t, caps.Validate())
})
t.Run("invalid_in_modality", func(t *testing.T) {
caps := ModelCapConfig{In: []string{"video"}}
err := caps.Validate()
assert.Error(t, err)
assert.Contains(t, err.Error(), "capabilities.in")
assert.Contains(t, err.Error(), "video")
})
t.Run("invalid_out_modality", func(t *testing.T) {
caps := ModelCapConfig{Out: []string{"video"}}
err := caps.Validate()
assert.Error(t, err)
assert.Contains(t, err.Error(), "capabilities.out")
assert.Contains(t, err.Error(), "video")
})
t.Run("negative_context", func(t *testing.T) {
caps := ModelCapConfig{Context: -1}
err := caps.Validate()
assert.Error(t, err)
assert.Contains(t, err.Error(), "capabilities.context")
})
t.Run("rejects_invalid_at_load", func(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --port ${PORT}
capabilities:
in:
- text
- video
`
_, err := LoadConfigFromReader(strings.NewReader(content))
assert.Error(t, err)
assert.Contains(t, err.Error(), "video")
})
}
+120 -11
View File
@@ -17,13 +17,118 @@ const apiUnloadTimeout = 10 * time.Second
// modelRecord is one entry in the OpenAI-compatible /v1/models listing. // modelRecord is one entry in the OpenAI-compatible /v1/models listing.
type modelRecord struct { type modelRecord struct {
ID string `json:"id"` ID string `json:"id"`
Object string `json:"object"` Object string `json:"object"`
Created int64 `json:"created"` Created int64 `json:"created"`
OwnedBy string `json:"owned_by"` OwnedBy string `json:"owned_by"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"` Description string `json:"description,omitempty"`
Meta map[string]any `json:"meta,omitempty"` Architecture map[string]any `json:"architecture,omitempty"`
Capabilities map[string]any `json:"capabilities,omitempty"`
SupportedParameters []string `json:"supported_parameters,omitempty"`
ContextLength int `json:"context_length,omitempty"`
Meta map[string]any `json:"meta,omitempty"`
}
// cappedMetadataKeys are top-level /v1/models fields produced by the
// capabilities renderer. If a model's metadata block defines any of these
// keys, the renderer's values win and the metadata keys are dropped.
var cappedMetadataKeys = map[string]struct{}{
"architecture": {},
"capabilities": {},
"supported_parameters": {},
"context_length": {},
}
// renderCapabilities converts a model's capabilities config into additional
// /v1/models fields. Returns zero values when caps.Empty() is true.
func renderCapabilities(caps config.ModelCapConfig) (arch map[string]any, capsMap map[string]any, params []string, ctxLen int) {
if caps.Empty() {
return
}
hasIn := len(caps.In) > 0
hasOut := len(caps.Out) > 0
if hasIn || hasOut {
arch = make(map[string]any)
}
if hasIn {
arch["input_modalities"] = caps.In
}
if hasOut {
arch["output_modalities"] = caps.Out
}
if hasIn && hasOut {
arch["modality"] = strings.Join(caps.In, "+") + "->" + strings.Join(caps.Out, "+")
}
// Build capabilities map only if there's something to put in it.
if hasIn || hasOut || caps.Tools || caps.Reranker {
capsMap = make(map[string]any)
}
if hasIn {
if contains(caps.In, "image") {
capsMap["vision"] = true
}
}
if hasIn && hasOut {
if contains(caps.In, "audio") && contains(caps.Out, "text") {
capsMap["audio_transcriptions"] = true
}
if contains(caps.In, "text") && contains(caps.Out, "audio") {
capsMap["audio_speech"] = true
}
if contains(caps.In, "text") && contains(caps.Out, "image") {
capsMap["image_generation"] = true
}
if contains(caps.In, "image") && contains(caps.Out, "image") {
capsMap["image_to_image"] = true
}
}
if caps.Tools {
capsMap["function_calling"] = true
params = []string{"tools", "tool_choice"}
}
if caps.Reranker {
capsMap["reranker"] = true
}
if caps.Context > 0 {
ctxLen = caps.Context
}
return
}
// contains reports whether s is present in ss.
func contains(ss []string, s string) bool {
for _, v := range ss {
if v == s {
return true
}
}
return false
}
// filterCappedMetadata returns metadata with renderer-owned keys removed.
func filterCappedMetadata(md map[string]any) map[string]any {
if len(md) == 0 {
return nil
}
filtered := make(map[string]any, len(md))
for k, v := range md {
if _, capped := cappedMetadataKeys[k]; !capped {
filtered[k] = v
}
}
if len(filtered) == 0 {
return nil
}
return filtered
} }
// handleListModels serves the OpenAI-compatible model listing: local models // handleListModels serves the OpenAI-compatible model listing: local models
@@ -32,7 +137,7 @@ func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) {
created := time.Now().Unix() created := time.Now().Unix()
data := make([]modelRecord, 0, len(s.cfg.Models)) data := make([]modelRecord, 0, len(s.cfg.Models))
newRecord := func(id, name, description string, metadata map[string]any) modelRecord { newRecord := func(id, name, description string, metadata map[string]any, caps config.ModelCapConfig) modelRecord {
rec := modelRecord{ rec := modelRecord{
ID: id, ID: id,
Object: "model", Object: "model",
@@ -41,6 +146,10 @@ func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) {
Name: strings.TrimSpace(name), Name: strings.TrimSpace(name),
Description: strings.TrimSpace(description), Description: strings.TrimSpace(description),
} }
rec.Architecture, rec.Capabilities, rec.SupportedParameters, rec.ContextLength = renderCapabilities(caps)
if !caps.Empty() {
metadata = filterCappedMetadata(metadata)
}
if len(metadata) > 0 { if len(metadata) > 0 {
rec.Meta = map[string]any{"llamaswap": metadata} rec.Meta = map[string]any{"llamaswap": metadata}
} }
@@ -51,12 +160,12 @@ func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) {
if mc.Unlisted { if mc.Unlisted {
continue continue
} }
data = append(data, newRecord(id, mc.Name, mc.Description, mc.Metadata)) data = append(data, newRecord(id, mc.Name, mc.Description, mc.Metadata, mc.Capabilities))
if s.cfg.IncludeAliasesInList { if s.cfg.IncludeAliasesInList {
for _, alias := range mc.Aliases { for _, alias := range mc.Aliases {
if alias := strings.TrimSpace(alias); alias != "" { if alias := strings.TrimSpace(alias); alias != "" {
data = append(data, newRecord(alias, mc.Name, mc.Description, mc.Metadata)) data = append(data, newRecord(alias, mc.Name, mc.Description, mc.Metadata, mc.Capabilities))
} }
} }
} }
@@ -64,7 +173,7 @@ func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) {
for peerID, peer := range s.cfg.Peers { for peerID, peer := range s.cfg.Peers {
for _, modelID := range peer.Models { for _, modelID := range peer.Models {
data = append(data, newRecord(modelID, peerID+": "+modelID, "", map[string]any{"peerID": peerID})) data = append(data, newRecord(modelID, peerID+": "+modelID, "", map[string]any{"peerID": peerID}, config.ModelCapConfig{}))
} }
} }
+259
View File
@@ -157,3 +157,262 @@ func TestServer_Redirects(t *testing.T) {
} }
} }
} }
func TestServer_HandleListModels_Capabilities(t *testing.T) {
newServer := func(mc config.ModelConfig) *Server {
s := newTestServer(newStubRouter(nil, ""), newStubRouter(nil, ""))
s.cfg = config.Config{Models: map[string]config.ModelConfig{"m": mc}}
return s
}
getModel := func(t *testing.T, s *Server) modelRecord {
t.Helper()
w := httptest.NewRecorder()
s.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/v1/models", nil))
var resp struct {
Data []modelRecord `json:"data"`
}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("decode: %v", err)
}
if len(resp.Data) != 1 {
t.Fatalf("expected 1 model, got %d", len(resp.Data))
}
return resp.Data[0]
}
t.Run("all_fields", func(t *testing.T) {
m := getModel(t, newServer(config.ModelConfig{
Capabilities: config.ModelCapConfig{
In: []string{"text", "image"},
Out: []string{"text", "audio"},
Tools: true,
Context: 100000,
},
}))
if m.Architecture == nil {
t.Fatal("architecture is nil")
}
if !anySliceStrEqual(m.Architecture["input_modalities"], []string{"text", "image"}) {
t.Errorf("input_modalities = %v", m.Architecture["input_modalities"])
}
if !anySliceStrEqual(m.Architecture["output_modalities"], []string{"text", "audio"}) {
t.Errorf("output_modalities = %v", m.Architecture["output_modalities"])
}
if m.Architecture["modality"] != "text+image->text+audio" {
t.Errorf("modality = %v", m.Architecture["modality"])
}
if m.Capabilities == nil || m.Capabilities["vision"] != true {
t.Errorf("vision = %v", m.Capabilities)
}
if m.Capabilities["audio_speech"] != true {
t.Errorf("audio_speech = %v", m.Capabilities["audio_speech"])
}
if m.Capabilities["function_calling"] != true {
t.Errorf("function_calling = %v", m.Capabilities["function_calling"])
}
if !stringSliceEqual(m.SupportedParameters, []string{"tools", "tool_choice"}) {
t.Errorf("supported_parameters = %v", m.SupportedParameters)
}
if m.ContextLength != 100000 {
t.Errorf("context_length = %d", m.ContextLength)
}
})
t.Run("in_only", func(t *testing.T) {
m := getModel(t, newServer(config.ModelConfig{
Capabilities: config.ModelCapConfig{In: []string{"text", "image"}},
}))
if m.Architecture == nil {
t.Fatal("architecture is nil")
}
if _, ok := m.Architecture["output_modalities"]; ok {
t.Error("should not have output_modalities")
}
if _, ok := m.Architecture["modality"]; ok {
t.Error("should not have modality")
}
if m.Capabilities == nil || m.Capabilities["vision"] != true {
t.Error("expected vision: true")
}
if m.SupportedParameters != nil {
t.Error("should not have supported_parameters")
}
if m.ContextLength != 0 {
t.Error("should not have context_length")
}
})
t.Run("out_only", func(t *testing.T) {
m := getModel(t, newServer(config.ModelConfig{
Capabilities: config.ModelCapConfig{Out: []string{"audio"}},
}))
if m.Architecture == nil {
t.Fatal("architecture is nil")
}
if _, ok := m.Architecture["input_modalities"]; ok {
t.Error("should not have input_modalities")
}
if len(m.Capabilities) > 0 {
t.Errorf("expected no capabilities, got %v", m.Capabilities)
}
})
t.Run("tools", func(t *testing.T) {
m := getModel(t, newServer(config.ModelConfig{
Capabilities: config.ModelCapConfig{Tools: true},
}))
if m.Capabilities == nil || m.Capabilities["function_calling"] != true {
t.Error("expected function_calling: true")
}
if !stringSliceEqual(m.SupportedParameters, []string{"tools", "tool_choice"}) {
t.Errorf("supported_parameters = %v", m.SupportedParameters)
}
if m.Architecture != nil {
t.Error("should not have architecture")
}
})
t.Run("reranker", func(t *testing.T) {
m := getModel(t, newServer(config.ModelConfig{
Capabilities: config.ModelCapConfig{Reranker: true},
}))
if m.Capabilities == nil || m.Capabilities["reranker"] != true {
t.Error("expected reranker: true")
}
if m.Architecture != nil {
t.Error("should not have architecture")
}
})
t.Run("context", func(t *testing.T) {
m := getModel(t, newServer(config.ModelConfig{
Capabilities: config.ModelCapConfig{Context: 32768},
}))
if m.ContextLength != 32768 {
t.Errorf("context_length = %d", m.ContextLength)
}
if m.Architecture != nil {
t.Error("should not have architecture")
}
})
t.Run("audio_transcriptions", func(t *testing.T) {
m := getModel(t, newServer(config.ModelConfig{
Capabilities: config.ModelCapConfig{In: []string{"audio"}, Out: []string{"text"}},
}))
if m.Capabilities == nil || m.Capabilities["audio_transcriptions"] != true {
t.Error("expected audio_transcriptions: true")
}
})
t.Run("image_generation", func(t *testing.T) {
m := getModel(t, newServer(config.ModelConfig{
Capabilities: config.ModelCapConfig{In: []string{"text"}, Out: []string{"image"}},
}))
if m.Capabilities == nil || m.Capabilities["image_generation"] != true {
t.Error("expected image_generation: true")
}
})
t.Run("image_to_image", func(t *testing.T) {
m := getModel(t, newServer(config.ModelConfig{
Capabilities: config.ModelCapConfig{In: []string{"image"}, Out: []string{"image"}},
}))
if m.Capabilities == nil || m.Capabilities["image_to_image"] != true {
t.Error("expected image_to_image: true")
}
})
t.Run("empty_skip", func(t *testing.T) {
m := getModel(t, newServer(config.ModelConfig{}))
if m.Architecture != nil {
t.Error("should not have architecture")
}
if m.Capabilities != nil {
t.Error("should not have capabilities")
}
if m.SupportedParameters != nil {
t.Error("should not have supported_parameters")
}
if m.ContextLength != 0 {
t.Error("should not have context_length")
}
})
t.Run("metadata_precedence", func(t *testing.T) {
m := getModel(t, newServer(config.ModelConfig{
Capabilities: config.ModelCapConfig{In: []string{"text"}},
Metadata: map[string]any{
"architecture": "should-be-dropped",
"custom_field": "should-remain",
"capabilities": "also-dropped",
"other_metadata": "also-remain",
},
}))
if m.Architecture == nil || m.Architecture["input_modalities"] == nil {
t.Fatal("architecture should be rendered, not from metadata")
}
if m.Meta == nil || m.Meta["llamaswap"] == nil {
t.Fatal("meta.llamaswap should exist")
}
meta := m.Meta["llamaswap"].(map[string]any)
if _, ok := meta["architecture"]; ok {
t.Error("architecture should be filtered from metadata")
}
if _, ok := meta["custom_field"]; !ok {
t.Error("custom_field should remain in metadata")
}
})
t.Run("metadata_passthrough_no_caps", func(t *testing.T) {
m := getModel(t, newServer(config.ModelConfig{
Metadata: map[string]any{
"architecture": "preserved",
"context_length": 4096,
"capabilities": "preserved",
"custom_field": "preserved",
},
}))
if m.Architecture != nil {
t.Error("should not have architecture when caps is empty")
}
if m.Meta == nil || m.Meta["llamaswap"] == nil {
t.Fatal("meta.llamaswap should exist")
}
meta := m.Meta["llamaswap"].(map[string]any)
if _, ok := meta["architecture"]; !ok {
t.Error("architecture should be preserved in metadata when caps is empty")
}
if _, ok := meta["context_length"]; !ok {
t.Error("context_length should be preserved in metadata when caps is empty")
}
})
}
func stringSliceEqual(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
func anySliceStrEqual(v any, want []string) bool {
arr, ok := v.([]any)
if !ok {
return false
}
if len(arr) != len(want) {
return false
}
for i := range arr {
if s, ok := arr[i].(string); !ok || s != want[i] {
return false
}
}
return true
}
+16 -13
View File
@@ -17,13 +17,14 @@ import (
// apiModel is one entry in the /api/events modelStatus payload. // apiModel is one entry in the /api/events modelStatus payload.
type apiModel struct { type apiModel struct {
Id string `json:"id"` Id string `json:"id"`
Name string `json:"name"` Name string `json:"name"`
Description string `json:"description"` Description string `json:"description"`
State string `json:"state"` State string `json:"state"`
Unlisted bool `json:"unlisted"` Unlisted bool `json:"unlisted"`
PeerID string `json:"peerID"` PeerID string `json:"peerID"`
Aliases []string `json:"aliases,omitempty"` Aliases []string `json:"aliases,omitempty"`
Capabilities map[string]any `json:"capabilities,omitempty"`
} }
// modelStatus returns every configured model joined with its current process // modelStatus returns every configured model joined with its current process
@@ -44,13 +45,15 @@ func (s *Server) modelStatus() []apiModel {
if st, ok := running[id]; ok { if st, ok := running[id]; ok {
state = string(st) state = string(st)
} }
_, capsMap, _, _ := renderCapabilities(mc.Capabilities)
models = append(models, apiModel{ models = append(models, apiModel{
Id: id, Id: id,
Name: mc.Name, Name: mc.Name,
Description: mc.Description, Description: mc.Description,
State: state, State: state,
Unlisted: mc.Unlisted, Unlisted: mc.Unlisted,
Aliases: mc.Aliases, Aliases: mc.Aliases,
Capabilities: capsMap,
}) })
} }
@@ -145,7 +145,7 @@
<div class="flex flex-col h-full"> <div class="flex flex-col h-full">
<!-- Model selector --> <!-- Model selector -->
<div class="shrink-0 flex flex-wrap gap-2 mb-4"> <div class="shrink-0 flex flex-wrap gap-2 mb-4">
<ModelSelector bind:value={$selectedModelStore} placeholder="Select an audio model..." disabled={isTranscribing} /> <ModelSelector bind:value={$selectedModelStore} placeholder="Select an audio model..." disabled={isTranscribing} capabilities={["audio_transcriptions"]} />
</div> </div>
<!-- Empty state for no models configured --> <!-- Empty state for no models configured -->
@@ -193,7 +193,7 @@
<div class="flex flex-col h-full"> <div class="flex flex-col h-full">
<!-- Model selector and mode toggle --> <!-- Model selector and mode toggle -->
<div class="shrink-0 flex flex-wrap gap-2 mb-4"> <div class="shrink-0 flex flex-wrap gap-2 mb-4">
<ModelSelector bind:value={$selectedModelStore} placeholder="Select an image model..." disabled={isGenerating} /> <ModelSelector bind:value={$selectedModelStore} placeholder="Select an image model..." disabled={isGenerating} capabilities={["image_generation", "image_to_image"]} matchAny={true} />
<select <select
class="px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary" class="px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
@@ -6,12 +6,15 @@
value: string; value: string;
placeholder?: string; placeholder?: string;
disabled?: boolean; disabled?: boolean;
capabilities?: string[];
matchAny?: boolean;
} }
let { value = $bindable(), placeholder = "Select a model...", disabled = false }: Props = $props(); let { value = $bindable(), placeholder = "Select a model...", disabled = false, capabilities, matchAny = false }: Props = $props();
let grouped = $derived(groupModels($models)); let grouped = $derived(groupModels($models, capabilities, matchAny));
let hasModels = $derived(grouped.local.length > 0 || Object.keys(grouped.peersByProvider).length > 0); let hasMatching = $derived(grouped.localMatching.length > 0);
let hasModels = $derived(hasMatching || grouped.local.length > 0 || Object.keys(grouped.peersByProvider).length > 0);
</script> </script>
{#if hasModels} {#if hasModels}
@@ -21,6 +24,18 @@
{disabled} {disabled}
> >
<option value="">{placeholder}</option> <option value="">{placeholder}</option>
{#if hasMatching}
<optgroup label="Matching Capabilities">
{#each grouped.localMatching as model (model.id)}
<option value={model.id}>{model.id}</option>
{#if model.aliases}
{#each model.aliases as alias (alias)}
<option value={alias}> {alias}</option>
{/each}
{/if}
{/each}
</optgroup>
{/if}
{#if grouped.local.length > 0} {#if grouped.local.length > 0}
<optgroup label="Local"> <optgroup label="Local">
{#each grouped.local as model (model.id)} {#each grouped.local as model (model.id)}
@@ -264,7 +264,7 @@
<div class="flex flex-col h-full"> <div class="flex flex-col h-full">
<!-- Top bar: model selector + query input (table mode) + mode toggle --> <!-- Top bar: model selector + query input (table mode) + mode toggle -->
<div class="shrink-0 flex flex-wrap gap-2 mb-4"> <div class="shrink-0 flex flex-wrap gap-2 mb-4">
<ModelSelector bind:value={$selectedModelStore} placeholder="Select a rerank model..." disabled={isLoading} /> <ModelSelector bind:value={$selectedModelStore} placeholder="Select a rerank model..." disabled={isLoading} capabilities={["reranker"]} />
{#if editorMode === "table"} {#if editorMode === "table"}
<input <input
type="text" type="text"
@@ -206,7 +206,7 @@
<div class="flex flex-col h-full"> <div class="flex flex-col h-full">
<!-- Model and voice selectors --> <!-- Model and voice selectors -->
<div class="shrink-0 flex gap-2 mb-4"> <div class="shrink-0 flex gap-2 mb-4">
<ModelSelector bind:value={$selectedModelStore} placeholder="Select a speech model..." disabled={isGenerating} /> <ModelSelector bind:value={$selectedModelStore} placeholder="Select a speech model..." disabled={isGenerating} capabilities={["audio_speech"]} />
<div class="flex gap-2"> <div class="flex gap-2">
<select <select
class="shrink-0 px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary" class="shrink-0 px-3 py-2 rounded border border-gray-200 dark:border-white/10 bg-surface focus:outline-none focus:ring-2 focus:ring-primary"
+113
View File
@@ -0,0 +1,113 @@
import { describe, it, expect } from "vitest";
import { matchesCapabilities, groupModels } from "./modelUtils";
import type { Model } from "./types";
function makeModel(overrides: Partial<Model> = {}): Model {
return {
id: "test-model",
state: "ready",
name: "Test Model",
description: "",
unlisted: false,
peerID: "",
...overrides,
};
}
describe("matchesCapabilities", () => {
it("returns true when required is empty", () => {
const model = makeModel();
expect(matchesCapabilities(model, [])).toBe(true);
});
it("returns false when model has no capabilities", () => {
const model = makeModel();
expect(matchesCapabilities(model, ["vision"])).toBe(false);
});
it("returns false when model has empty capabilities object", () => {
const model = makeModel({ capabilities: {} });
expect(matchesCapabilities(model, ["vision"])).toBe(false);
});
it("returns true when model has the single required capability", () => {
const model = makeModel({ capabilities: { vision: true } });
expect(matchesCapabilities(model, ["vision"])).toBe(true);
});
it("returns false when model lacks the required capability", () => {
const model = makeModel({ capabilities: { vision: true } });
expect(matchesCapabilities(model, ["audio_transcriptions"])).toBe(false);
});
it("AND semantics: returns true only when all required are present", () => {
const model = makeModel({ capabilities: { vision: true, audio_transcriptions: true } });
expect(matchesCapabilities(model, ["vision", "audio_transcriptions"])).toBe(true);
expect(matchesCapabilities(model, ["vision", "reranker"])).toBe(false);
});
it("matchAny=true: returns true when at least one required is present", () => {
const model = makeModel({ capabilities: { vision: true } });
expect(matchesCapabilities(model, ["vision", "reranker"], true)).toBe(true);
expect(matchesCapabilities(model, ["audio_transcriptions", "reranker"], true)).toBe(false);
});
it("matchAny=true with empty required returns true", () => {
const model = makeModel();
expect(matchesCapabilities(model, [], true)).toBe(true);
});
});
describe("groupModels", () => {
const models: Model[] = [
makeModel({ id: "chat-model", capabilities: { vision: true } }),
makeModel({ id: "audio-model", capabilities: { audio_transcriptions: true } }),
makeModel({ id: "no-caps-model" }),
makeModel({ id: "peer-model", peerID: "peer1" }),
makeModel({ id: "unlisted-model", unlisted: true, capabilities: { vision: true } }),
];
it("filters out unlisted models", () => {
const result = groupModels(models);
expect(result.localMatching.length + result.local.length).toBe(3);
expect([...result.localMatching, ...result.local].every((m) => !m.unlisted)).toBe(true);
});
it("separates peer models into peersByProvider", () => {
const result = groupModels(models);
expect(result.peersByProvider["peer1"]).toHaveLength(1);
expect(result.peersByProvider["peer1"][0].id).toBe("peer-model");
});
it("without capabilities, all local models go to local (non-matching)", () => {
const result = groupModels(models);
expect(result.localMatching).toHaveLength(0);
expect(result.local).toHaveLength(3);
});
it("with capabilities, matching models go to localMatching", () => {
const result = groupModels(models, ["vision"]);
expect(result.localMatching).toHaveLength(1);
expect(result.localMatching[0].id).toBe("chat-model");
expect(result.local).toHaveLength(2);
});
it("with capabilities, models without capabilities go to local", () => {
const result = groupModels(models, ["vision"]);
expect(result.local.find((m) => m.id === "no-caps-model")).toBeDefined();
});
it("with matchAny, matches models with any listed capability", () => {
const result = groupModels(models, ["vision", "audio_transcriptions"], true);
expect(result.localMatching).toHaveLength(2);
expect(result.localMatching.map((m) => m.id)).toContain("chat-model");
expect(result.localMatching.map((m) => m.id)).toContain("audio-model");
expect(result.local).toHaveLength(1);
});
it("with empty capabilities array, all local go to local (non-matching)", () => {
const result = groupModels(models, []);
expect(result.localMatching).toHaveLength(0);
expect(result.local).toHaveLength(3);
});
});
+28 -2
View File
@@ -2,14 +2,40 @@ import type { Model } from "./types";
export interface GroupedModels { export interface GroupedModels {
local: Model[]; local: Model[];
localMatching: Model[];
peersByProvider: Record<string, Model[]>; peersByProvider: Record<string, Model[]>;
} }
export function groupModels(models: Model[]): GroupedModels { export function matchesCapabilities(model: Model, required: string[], matchAny = false): boolean {
if (!required.length) return true;
if (!model.capabilities) return false;
const caps = model.capabilities as Record<string, boolean>;
if (matchAny) {
return required.some((cap) => caps[cap] === true);
}
return required.every((cap) => caps[cap] === true);
}
export function groupModels(models: Model[], capabilities?: string[], matchAny = false): GroupedModels {
const available = models.filter((m) => !m.unlisted); const available = models.filter((m) => !m.unlisted);
const local = available.filter((m) => !m.peerID); const local = available.filter((m) => !m.peerID);
const peerModels = available.filter((m) => m.peerID); const peerModels = available.filter((m) => m.peerID);
let localMatching: Model[] = [];
let localRest: Model[] = [];
if (capabilities && capabilities.length > 0) {
for (const model of local) {
if (matchesCapabilities(model, capabilities, matchAny)) {
localMatching.push(model);
} else {
localRest.push(model);
}
}
} else {
localRest = local;
}
const peersByProvider = peerModels.reduce( const peersByProvider = peerModels.reduce(
(acc, model) => { (acc, model) => {
const peerId = model.peerID || "unknown"; const peerId = model.peerID || "unknown";
@@ -20,5 +46,5 @@ export function groupModels(models: Model[]): GroupedModels {
{} as Record<string, Model[]> {} as Record<string, Model[]>
); );
return { local, peersByProvider }; return { local: localRest, localMatching, peersByProvider };
} }
+11
View File
@@ -2,6 +2,16 @@ export type ConnectionState = "connected" | "connecting" | "disconnected";
export type ModelStatus = "ready" | "starting" | "stopping" | "stopped" | "shutdown" | "unknown"; export type ModelStatus = "ready" | "starting" | "stopping" | "stopped" | "shutdown" | "unknown";
export interface ModelCapabilities {
vision?: boolean;
audio_transcriptions?: boolean;
audio_speech?: boolean;
image_generation?: boolean;
image_to_image?: boolean;
function_calling?: boolean;
reranker?: boolean;
}
export interface Model { export interface Model {
id: string; id: string;
state: ModelStatus; state: ModelStatus;
@@ -10,6 +20,7 @@ export interface Model {
unlisted: boolean; unlisted: boolean;
peerID: string; peerID: string;
aliases?: string[]; aliases?: string[];
capabilities?: ModelCapabilities;
} }
export interface TokenMetrics { export interface TokenMetrics {