Files
executus/tools/classify.go
T
steve 1e201550b3 P3: meta + primitive tool group (think/now/cite + classify/extract/summarize)
Grow executus/tools into a real generic tool library:

- Register(reg): the always-available, zero-config tools — think, now (UTC
  unless a CurrentTimeProvider is wired), cite (inert unless a CitationStorage
  is wired). All nil-safe; a light host calls Register and is useful.
- RegisterMeta(reg, MetaDeps): the LLM-backed meta tools — classify,
  extract_entities, summarize — over the llmmeta helper. Budget defaults to the
  shipped in-memory per-run cap; Files optional; caps default.
- Seams moved (interface/type-only, no host coupling): research_providers.go
  (CurrentTimeProvider/CitationStorage/SearchBudget/PageExtractor/PDFFetcher/…)
  and file_storage.go (FileStorage + FileDomainMeta). Plus the in-memory budget
  default (research_defaults.go) and scope_validate.go.

calculate deferred (drags github.com/Krognol/go-wolfram + a module-path replace
— not worth it in the lean core for one tool). Core go.sum still free of
gorm/redis/discordgo/sqlite/wolfram.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-27 00:11:54 -04:00

320 lines
10 KiB
Go

// Package tools — v12 classify.
//
// Classification primitive: text + categories → labels + per-category
// scores. Single-label mode (default) returns the top-1 category;
// multi-label mode returns every category whose score crosses the
// threshold.
//
// Why a dedicated tool (vs reusing extract_entities for one-of-N
// classification): classification has a typed result (labels[] +
// scores{}) that downstream agents consume programmatically. Folding
// it into extract_entities would force every author to re-spec the
// scoring schema.
//
// Score normalisation: the LLM's reply is normalised so each score
// lands in [0, 1]. The single-label result returns scores for ALL
// categories so the author can read the distribution; multi-label
// returns labels[] of categories above 0.5.
//
// Test: classify_test.go covers single-label, multi-label, score
// normalisation, > 20 categories rejected, unknown category in the
// reply silently dropped.
package tools
import (
"context"
"encoding/json"
"fmt"
"strings"
"gitea.stevedudenhoeffer.com/steve/executus/llmmeta"
"gitea.stevedudenhoeffer.com/steve/executus/tool"
)
// classifyMaxInputBytes is the input cap.
const classifyMaxInputBytes = 16 * 1024
// classifyMaxCategories is the hard cap on category count.
const classifyMaxCategories = 20
// classifyMultiLabelThreshold is the score threshold above which a
// category appears in the labels[] array in multi-label mode.
const classifyMultiLabelThreshold = 0.5
// classifyFallbackMaxPerRun is the per-run cap when ClassifyConfig is
// nil.
const classifyFallbackMaxPerRun = 20
// ClassifyConfig is the narrow per-deployment config surface.
type ClassifyConfig interface {
MaxPerRun(ctx context.Context) int
}
// classifyArgs is the LLM-facing param struct.
type classifyArgs struct {
Text string `json:"text" description:"The text to classify. Required. Capped at 16KB."`
Categories []string `json:"categories" description:"List of categories to score the text against. Required. Max 20."`
MultiLabel bool `json:"multi_label,omitempty" description:"When true, returns every category scoring above 0.5. Default false → single-label (top-1) result."`
}
type classifyResult struct {
Labels []string `json:"labels,omitempty"`
Scores map[string]float64 `json:"scores,omitempty"`
ModelUsed string `json:"model_used,omitempty"`
RawReply string `json:"raw_reply,omitempty"`
Error string `json:"error,omitempty"`
BudgetMsg string `json:"budget_message,omitempty"`
}
// NewClassify constructs the classify tool.
func NewClassify(helper *llmmeta.Helper, cfg ClassifyConfig, budget SearchBudget) tool.Tool {
return tool.NewGatedTool[classifyArgs](
"classify",
"Classify text into one of N categories (or multiple via multi_label=true). Returns labels[] (top-1 by default) + scores{category: 0..1}. Counts against per-run and 7-day cost budgets.",
tool.Permission{
AuthoringRequirement: tool.RequirementAnyone,
OperatesOn: tool.ScopeCaller,
SafeForShare: true,
Categories: []string{"llm-meta", "cost-bearing"},
},
func(ctx context.Context, inv tool.Invocation, args classifyArgs) (string, error) {
if helper == nil {
return "", fmt.Errorf("classify: not configured")
}
text := args.Text
if strings.TrimSpace(text) == "" {
return marshalClassifyResult(classifyResult{Error: "text is empty"}), nil
}
if len(args.Categories) == 0 {
return marshalClassifyResult(classifyResult{Error: "categories is empty"}), nil
}
if len(args.Categories) > classifyMaxCategories {
return marshalClassifyResult(classifyResult{
Error: fmt.Sprintf("too many categories (%d > %d)", len(args.Categories), classifyMaxCategories),
}), nil
}
// Trim + dedupe categories so the LLM sees a clean
// schema. Order is preserved for the prompt; the result
// map is order-agnostic.
categories := make([]string, 0, len(args.Categories))
seen := make(map[string]bool, len(args.Categories))
for _, c := range args.Categories {
c = strings.TrimSpace(c)
if c == "" || seen[c] {
continue
}
seen[c] = true
categories = append(categories, c)
}
if len(categories) == 0 {
return marshalClassifyResult(classifyResult{Error: "categories has no non-empty entries"}), nil
}
if len(text) > classifyMaxInputBytes {
text = text[:classifyMaxInputBytes]
}
// Per-run budget gate.
if budget == nil {
maxPerRun := classifyFallbackMaxPerRun
if cfg != nil {
maxPerRun = cfg.MaxPerRun(ctx)
}
budget = NewInMemorySearchBudget(map[string]int{
"classify": maxPerRun,
})
}
count, max, exceeded := budget.CheckAndIncrement(ctx, inv.RunID, "classify")
if exceeded {
return marshalClassifyResult(classifyResult{
Error: "classify_budget_exceeded",
BudgetMsg: fmt.Sprintf("per-run classify budget exceeded (%d/%d). Ask an admin to raise skills.classify.max_per_run.", count, max),
}), nil
}
systemPrompt := "You classify text into a fixed set of categories. Return ONLY JSON. Score each category in [0,1] (1 = perfect fit). Sum of all scores does NOT need to be 1 — high overlap across categories is allowed."
userPrompt := buildClassifyPrompt(text, categories, args.MultiLabel)
res, callErr := helper.Call(ctx, llmmeta.CallSpec{
Tier: "fast",
SystemPrompt: systemPrompt,
UserPrompt: userPrompt,
MaxOutputTokens: 2048,
ResponseFormat: "json",
RetryOnMalformedJSON: true,
ToolName: "classify",
RunID: inv.RunID,
SkillID: inv.SkillID,
CallerID: inv.CallerID,
})
if callErr != nil {
return "", callErr
}
if !res.Success {
kind := res.ErrorKind
if kind == "" {
kind = "llm_unavailable"
}
return marshalClassifyResult(classifyResult{Error: kind}), nil
}
if res.ErrorKind == llmmeta.ErrorKindMalformedJSON || res.Parsed == nil {
return marshalClassifyResult(classifyResult{
Error: "classification_failed",
RawReply: res.Text,
ModelUsed: res.ModelUsed,
}), nil
}
parsedMap, ok := res.Parsed.(map[string]any)
if !ok {
return marshalClassifyResult(classifyResult{
Error: "classification_failed_not_object",
RawReply: res.Text,
ModelUsed: res.ModelUsed,
}), nil
}
scores := normaliseClassifyScores(parsedMap, categories)
labels := selectClassifyLabels(scores, categories, args.MultiLabel)
return marshalClassifyResult(classifyResult{
Labels: labels,
Scores: scores,
ModelUsed: res.ModelUsed,
}), nil
},
)
}
// buildClassifyPrompt composes the user message.
func buildClassifyPrompt(text string, categories []string, multiLabel bool) string {
var sb strings.Builder
sb.WriteString("Classify the text below.\n\nCategories:\n")
for _, c := range categories {
sb.WriteString("- ")
sb.WriteString(c)
sb.WriteString("\n")
}
sb.WriteString("\nText:\n")
sb.WriteString(text)
sb.WriteString("\n\nReturn ONLY a JSON object: {\"scores\": {\"<category>\": <0..1 float>, ...}}.")
if multiLabel {
sb.WriteString(" The same text may score high in MULTIPLE categories — score each independently.")
} else {
sb.WriteString(" Score each category; the highest-scoring one will be the chosen label.")
}
return sb.String()
}
// normaliseClassifyScores extracts the scores map from the LLM's
// reply and clamps each value into [0, 1]. Categories absent from the
// reply default to 0.
//
// Why we accept either {"scores": {...}} or {...}: some models reply
// with the inner object directly, dropping the wrapping key. Both
// shapes are valid as long as the keys match the requested category
// names.
func normaliseClassifyScores(parsed map[string]any, categories []string) map[string]float64 {
scoresIn, ok := parsed["scores"].(map[string]any)
if !ok {
// Accept the bare-map shape too.
scoresIn = parsed
}
out := make(map[string]float64, len(categories))
for _, c := range categories {
v, has := scoresIn[c]
if !has {
out[c] = 0
continue
}
f, ok := coerceClassifyScore(v)
if !ok {
out[c] = 0
continue
}
// Clamp into [0, 1].
if f < 0 {
f = 0
}
if f > 1 {
f = 1
}
out[c] = f
}
return out
}
// coerceClassifyScore reads a JSON value as a float in [0, 1]. Accepts
// floats, ints, and percent-strings ("85%" → 0.85).
func coerceClassifyScore(raw any) (float64, bool) {
switch v := raw.(type) {
case float64:
return v, true
case int:
return float64(v), true
case int64:
return float64(v), true
case string:
s := strings.TrimSuffix(strings.TrimSpace(v), "%")
var f float64
if _, err := fmt.Sscanf(s, "%f", &f); err == nil {
if strings.HasSuffix(strings.TrimSpace(v), "%") {
f = f / 100.0
}
return f, true
}
}
return 0, false
}
// selectClassifyLabels picks the labels to surface. Single-label mode
// returns the highest-scoring category. Multi-label returns every
// category above the threshold (sorted by score desc for stable
// rendering).
func selectClassifyLabels(scores map[string]float64, categories []string, multiLabel bool) []string {
if multiLabel {
var labels []string
for _, c := range categories {
if scores[c] >= classifyMultiLabelThreshold {
labels = append(labels, c)
}
}
// Sort labels by score desc, then category-list order for ties.
sortClassifyLabelsByScore(labels, scores)
return labels
}
// Single-label: top-1.
bestCat := ""
bestScore := -1.0
for _, c := range categories {
if scores[c] > bestScore {
bestScore = scores[c]
bestCat = c
}
}
if bestCat == "" {
return nil
}
return []string{bestCat}
}
// sortClassifyLabelsByScore sorts labels desc by score. Stable on
// ties (preserves category-list order).
func sortClassifyLabelsByScore(labels []string, scores map[string]float64) {
for i := 1; i < len(labels); i++ {
j := i
for j > 0 && scores[labels[j]] > scores[labels[j-1]] {
labels[j], labels[j-1] = labels[j-1], labels[j]
j--
}
}
}
func marshalClassifyResult(r classifyResult) string {
b, err := json.Marshal(r)
if err != nil {
return fmt.Sprintf(`{"error":"marshal_failed: %v"}`, err)
}
return string(b)
}