P1 (part 1): move skilltools core -> tool/ (clean, verbatim)
executus CI / test (push) Successful in 36s
executus CI / test (push) Successful in 36s
The tool registry core (registry, permission model, Invocation, gated-tool wrapper, ssrf guard, hmac, encryption, argcoerce, helpers, rootrun, session_tools, webhook_rate_limit) had zero mort coupling — it imports only majordomo/llm + x/crypto/hkdf — so it moves verbatim with a package rename (skilltools -> tool). All same-package tests came along and pass; the SSRF, gated-wrapper, encryption and output-pattern invariants are re-anchored here. majordomo re-enters the module graph (now pinned to the latest, incl. the front-loaded-output fix). model/ + llmmeta + structured follow next. Docs: CLAUDE.md now requires README/examples to stay in sync with changes in the same commit; CI skips docs/example-only pushes via paths-ignore. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -13,8 +13,24 @@ on:
|
||||
push:
|
||||
branches: [main]
|
||||
tags: ["v*"]
|
||||
# Docs/example/meta-only changes don't affect the build — skip the run.
|
||||
# (Path filters are not applied to tag pushes, so v* releases always run.)
|
||||
paths-ignore:
|
||||
- "**.md"
|
||||
- "LICENSE"
|
||||
- ".gitignore"
|
||||
- ".dockerignore"
|
||||
- "examples/**"
|
||||
- "docs/**"
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
paths-ignore:
|
||||
- "**.md"
|
||||
- "LICENSE"
|
||||
- ".gitignore"
|
||||
- ".dockerignore"
|
||||
- "examples/**"
|
||||
- "docs/**"
|
||||
workflow_dispatch: {}
|
||||
|
||||
concurrency:
|
||||
|
||||
@@ -47,8 +47,9 @@ CORE (majordomo + stdlib):
|
||||
nil-safe Ports + RunnableAgent later [P2]
|
||||
dispatchguard/ loop/depth/fan-out caps [P0 ✓]
|
||||
pendingattach/ attachment dedupe [P0 ✓]
|
||||
tool/ registry + 3-stage permissions + ssrf/llmmeta [P1]
|
||||
tool/ registry + 3-stage permissions + ssrf [P1 ✓]
|
||||
model/ config-driven tier resolution over majordomo [P1]
|
||||
llmmeta/ shared meta-LLM helper (moves with model/) [P1]
|
||||
compact/ context compactor (WithCompactor hook) [P2]
|
||||
tools/{web,net,store,compose,meta,comms} generic tools [P3]
|
||||
structured/ Generate[T] convenience over majordomo [P1]
|
||||
@@ -91,6 +92,11 @@ rewire mort + tag v0.1.0. The mort-side rewrite reuses mort's existing
|
||||
|
||||
## Conventions
|
||||
|
||||
- **Keep `README.md`, this `CLAUDE.md`, and `examples/` in sync with every change,
|
||||
in the SAME commit.** No aspirational docs: when you add/rename a package, change
|
||||
a seam or a default, or alter the public API, update the docs and the relevant
|
||||
example so they always reflect reality (mirrors majordomo's house rule). The
|
||||
status markers in the tier map above must track what's actually landed.
|
||||
- Mirror majordomo's house style: gofmt; check errors immediately and wrap with
|
||||
`fmt.Errorf("...: %w", err)`; `// Why:` comments where rationale isn't obvious;
|
||||
hermetic tests (majordomo's fake provider; no network in the default suite).
|
||||
|
||||
@@ -1,3 +1,8 @@
|
||||
module gitea.stevedudenhoeffer.com/steve/executus
|
||||
|
||||
go 1.26.2
|
||||
|
||||
require (
|
||||
gitea.stevedudenhoeffer.com/steve/majordomo v0.0.0-20260626223738-1fd7109a42f3
|
||||
golang.org/x/crypto v0.53.0
|
||||
)
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
gitea.stevedudenhoeffer.com/steve/majordomo v0.0.0-20260626223738-1fd7109a42f3 h1:KYKIFFRsXzbbBJVDa99+Fhy0zxl9G0xV/MCrLipsLL4=
|
||||
gitea.stevedudenhoeffer.com/steve/majordomo v0.0.0-20260626223738-1fd7109a42f3/go.mod h1:UZLveG17SmENt4sne2RSLIbioix30RZbRIQUzBAnOyY=
|
||||
golang.org/x/crypto v0.53.0 h1:QZ4Muo8THX6CizN2vPPd5fBGHyogrdK9fG4wLPFUsto=
|
||||
golang.org/x/crypto v0.53.0/go.mod h1:DNLU434OwVakk9PzuwV8w62mAJpRJL3vsgcfp4Qnsio=
|
||||
@@ -0,0 +1,161 @@
|
||||
package tool
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// unmarshalArgsLenient decodes the raw JSON arguments the model supplied
|
||||
// into the tool's typed Args struct, tolerating the classic LLM tool-call
|
||||
// bug of emitting numbers and booleans as strings ("3" where the schema
|
||||
// said integer, "true" where it said boolean).
|
||||
//
|
||||
// Why mort-side (vs relying on the library): legacy gollm's Define performed
|
||||
// this coercion internally (tool_coerce.go), and several years of tool
|
||||
// traffic depend on the tolerance. majordomo's DefineTool decodes
|
||||
// strictly by design, so the gated wrappers re-create the leniency here
|
||||
// — the strict path is tried first and coercion only runs on failure,
|
||||
// which makes the happy path zero-cost.
|
||||
func unmarshalArgsLenient(raw json.RawMessage, target any) error {
|
||||
if len(raw) == 0 {
|
||||
return nil
|
||||
}
|
||||
strictErr := json.Unmarshal(raw, target)
|
||||
if strictErr == nil {
|
||||
return nil
|
||||
}
|
||||
coerced, err := coerceArgsToType(raw, reflect.TypeOf(target).Elem())
|
||||
if err != nil {
|
||||
// Malformed JSON: surface the original strict error, which
|
||||
// names the real problem.
|
||||
return strictErr
|
||||
}
|
||||
if err := json.Unmarshal(coerced, target); err != nil {
|
||||
return strictErr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// coerceArgsToType reparses argsJSON with leniency: where the target
|
||||
// struct expects a numeric or boolean field but the JSON value is a
|
||||
// string, it converts the string to the target kind. Recurses into
|
||||
// nested structs, slices, maps, and pointer fields. Returns a freshly
|
||||
// marshaled JSON byte slice that unmarshals strictly into the target.
|
||||
func coerceArgsToType(argsJSON []byte, target reflect.Type) ([]byte, error) {
|
||||
var raw any
|
||||
if err := json.Unmarshal(argsJSON, &raw); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
raw = coerceValue(raw, target)
|
||||
return json.Marshal(raw)
|
||||
}
|
||||
|
||||
func coerceValue(v any, t reflect.Type) any {
|
||||
if t == nil {
|
||||
return v
|
||||
}
|
||||
for t.Kind() == reflect.Pointer {
|
||||
t = t.Elem()
|
||||
}
|
||||
|
||||
switch t.Kind() {
|
||||
case reflect.Struct:
|
||||
m, ok := v.(map[string]any)
|
||||
if !ok {
|
||||
return v
|
||||
}
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
f := t.Field(i)
|
||||
if !f.IsExported() {
|
||||
continue
|
||||
}
|
||||
name := jsonFieldName(f)
|
||||
if name == "-" {
|
||||
continue
|
||||
}
|
||||
if val, present := m[name]; present {
|
||||
m[name] = coerceValue(val, f.Type)
|
||||
}
|
||||
}
|
||||
return m
|
||||
|
||||
case reflect.Slice, reflect.Array:
|
||||
arr, ok := v.([]any)
|
||||
if !ok {
|
||||
return v
|
||||
}
|
||||
elemType := t.Elem()
|
||||
for i := range arr {
|
||||
arr[i] = coerceValue(arr[i], elemType)
|
||||
}
|
||||
return arr
|
||||
|
||||
case reflect.Map:
|
||||
m, ok := v.(map[string]any)
|
||||
if !ok {
|
||||
return v
|
||||
}
|
||||
valType := t.Elem()
|
||||
for k := range m {
|
||||
m[k] = coerceValue(m[k], valType)
|
||||
}
|
||||
return m
|
||||
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
if s, ok := v.(string); ok {
|
||||
s = strings.TrimSpace(s)
|
||||
s = strings.TrimPrefix(s, "+")
|
||||
if n, err := strconv.ParseInt(s, 10, 64); err == nil {
|
||||
return n
|
||||
}
|
||||
if f, err := strconv.ParseFloat(s, 64); err == nil {
|
||||
return int64(f)
|
||||
}
|
||||
}
|
||||
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
if s, ok := v.(string); ok {
|
||||
s = strings.TrimSpace(s)
|
||||
s = strings.TrimPrefix(s, "+")
|
||||
if n, err := strconv.ParseUint(s, 10, 64); err == nil {
|
||||
return n
|
||||
}
|
||||
if f, err := strconv.ParseFloat(s, 64); err == nil && f >= 0 {
|
||||
return uint64(f)
|
||||
}
|
||||
}
|
||||
|
||||
case reflect.Float32, reflect.Float64:
|
||||
if s, ok := v.(string); ok {
|
||||
s = strings.TrimSpace(s)
|
||||
if f, err := strconv.ParseFloat(s, 64); err == nil {
|
||||
return f
|
||||
}
|
||||
}
|
||||
|
||||
case reflect.Bool:
|
||||
if s, ok := v.(string); ok {
|
||||
if b, err := strconv.ParseBool(strings.TrimSpace(s)); err == nil {
|
||||
return b
|
||||
}
|
||||
}
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// jsonFieldName returns the effective JSON key for a struct field:
|
||||
// the json tag's name part when present, the Go field name otherwise,
|
||||
// or "-" when the field is excluded.
|
||||
func jsonFieldName(f reflect.StructField) string {
|
||||
tag, ok := f.Tag.Lookup("json")
|
||||
if !ok {
|
||||
return f.Name
|
||||
}
|
||||
name, _, _ := strings.Cut(tag, ",")
|
||||
if name == "" {
|
||||
return f.Name
|
||||
}
|
||||
return name
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
package tool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// coerceParams mirrors the field kinds legacy gollm's coercion supported.
|
||||
type coerceParams struct {
|
||||
Count int `json:"count"`
|
||||
Ratio float64 `json:"ratio"`
|
||||
Flag bool `json:"flag"`
|
||||
Limit *int `json:"limit"`
|
||||
Tags []int `json:"tags"`
|
||||
Nested coerceIn `json:"nested"`
|
||||
Verbose string `json:"verbose"`
|
||||
}
|
||||
|
||||
type coerceIn struct {
|
||||
Depth uint `json:"depth"`
|
||||
}
|
||||
|
||||
// TestGatedTool_LenientArgCoercion anchors the legacy gollm-era tolerance the
|
||||
// conversion preserved: numeric and boolean fields supplied as strings
|
||||
// by the model ("3", "true") decode into the typed Args, recursing into
|
||||
// pointers, slices, and nested structs. Models emit this shape
|
||||
// constantly; losing the tolerance would break live tool traffic.
|
||||
func TestGatedTool_LenientArgCoercion(t *testing.T) {
|
||||
var seen coerceParams
|
||||
tool := NewGatedTool[coerceParams](
|
||||
"coerce_tool", "coercion test",
|
||||
Permission{AuthoringRequirement: RequirementAnyone, SafeForShare: true},
|
||||
func(ctx context.Context, inv Invocation, args coerceParams) (string, error) {
|
||||
seen = args
|
||||
return "ok", nil
|
||||
},
|
||||
)
|
||||
|
||||
out, err := buildAndExecute(t, tool, Invocation{SkillName: "x"}, VisibilityPrivate, nil,
|
||||
`{"count":"3","ratio":" 2.5 ","flag":"true","limit":"7","tags":["1","2"],"nested":{"depth":"4"},"verbose":"yes"}`)
|
||||
if err != nil || out != "ok" {
|
||||
t.Fatalf("execute: out=%q err=%v", out, err)
|
||||
}
|
||||
if seen.Count != 3 || seen.Ratio != 2.5 || seen.Flag != true {
|
||||
t.Fatalf("scalar coercion failed: %+v", seen)
|
||||
}
|
||||
if seen.Limit == nil || *seen.Limit != 7 {
|
||||
t.Fatalf("pointer coercion failed: %+v", seen.Limit)
|
||||
}
|
||||
if len(seen.Tags) != 2 || seen.Tags[0] != 1 || seen.Tags[1] != 2 {
|
||||
t.Fatalf("slice coercion failed: %+v", seen.Tags)
|
||||
}
|
||||
if seen.Nested.Depth != 4 {
|
||||
t.Fatalf("nested coercion failed: %+v", seen.Nested)
|
||||
}
|
||||
if seen.Verbose != "yes" {
|
||||
t.Fatalf("string field mangled: %q", seen.Verbose)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGatedTool_StrictPathUnaffected confirms well-typed args take the
|
||||
// zero-cost strict path and uncoercible strings still fail loudly.
|
||||
func TestGatedTool_StrictPathUnaffected(t *testing.T) {
|
||||
tool := NewGatedTool[coerceParams](
|
||||
"coerce_strict_tool", "coercion test",
|
||||
Permission{AuthoringRequirement: RequirementAnyone, SafeForShare: true},
|
||||
func(ctx context.Context, inv Invocation, args coerceParams) (string, error) {
|
||||
return "ok", nil
|
||||
},
|
||||
)
|
||||
|
||||
if out, err := buildAndExecute(t, tool, Invocation{SkillName: "x"}, VisibilityPrivate, nil,
|
||||
`{"count":3,"ratio":2.5,"flag":true}`); err != nil || out != "ok" {
|
||||
t.Fatalf("strict path: out=%q err=%v", out, err)
|
||||
}
|
||||
|
||||
_, err := buildAndExecute(t, tool, Invocation{SkillName: "x"}, VisibilityPrivate, nil,
|
||||
`{"count":"not-a-number"}`)
|
||||
if err == nil || !strings.Contains(err.Error(), "invalid arguments") {
|
||||
t.Fatalf("expected invalid-arguments error for uncoercible string, got %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
package tool
|
||||
|
||||
import "fmt"
|
||||
|
||||
// CheckAuthoring verifies that the saving user is permitted to author a
|
||||
// skill that uses the given tool list. Called from the save path
|
||||
// (skills.System.SaveUserSkill); the builtin loader bypasses this check.
|
||||
//
|
||||
// Why: the AuthoringRequirement gate is the primary admin trust boundary
|
||||
// for tools that can read sensitive data (db_select, repo_*) or perform
|
||||
// privileged Discord queries. Failing closed at save time prevents the
|
||||
// situation where a skill is saved-then-rejected at execute time.
|
||||
//
|
||||
// What: returns nil if all tools clear; otherwise returns the spec's
|
||||
// exact rejection message for the first offending tool.
|
||||
//
|
||||
// Test: see checks_test.go.
|
||||
func CheckAuthoring(reg Registry, tools []string, isAdmin bool) error {
|
||||
for _, name := range tools {
|
||||
t, ok := reg.Get(name)
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown tool %q", name)
|
||||
}
|
||||
if t.Permission().AuthoringRequirement == RequirementAdmin && !isAdmin {
|
||||
return fmt.Errorf("The tool `%s` requires admin authoring. Ask an admin to create or publish a skill that uses this tool.", name)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckShareSafety verifies that none of the listed tools is unsafe for
|
||||
// sharing. Called when a skill's visibility is being set to shared or
|
||||
// public.
|
||||
//
|
||||
// Why: tools that operate on caller-private data (mortbux_get_balance,
|
||||
// chatbot_get_memories) leak when invoked by non-owners through a
|
||||
// shared/public skill — the executor would compute "the caller is
|
||||
// whoever ran the skill", whose data would then surface to the skill
|
||||
// authoring user.
|
||||
//
|
||||
// What: returns nil if all tools clear; otherwise returns the spec's
|
||||
// exact rejection message for the first offending tool.
|
||||
//
|
||||
// Test: see checks_test.go.
|
||||
func CheckShareSafety(reg Registry, tools []string) error {
|
||||
for _, name := range tools {
|
||||
t, ok := reg.Get(name)
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown tool %q", name)
|
||||
}
|
||||
if !t.Permission().SafeForShare {
|
||||
return fmt.Errorf("The tool `%s` cannot appear in a shared skill because it operates on the caller's own data.", name)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
package tool
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCheckAuthoring_AllowsAnyone(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
_ = r.Register(&fakeTool{name: "calc", perm: Permission{AuthoringRequirement: RequirementAnyone}})
|
||||
if err := CheckAuthoring(r, []string{"calc"}, false); err != nil {
|
||||
t.Fatalf("expected anyone to pass, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckAuthoring_BlocksNonAdminFromAdminTool(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
_ = r.Register(&fakeTool{name: "db_select", perm: Permission{AuthoringRequirement: RequirementAdmin}})
|
||||
err := CheckAuthoring(r, []string{"db_select"}, false)
|
||||
if err == nil || !strings.Contains(err.Error(), "requires admin authoring") {
|
||||
t.Fatalf("expected admin-required error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckAuthoring_AllowsAdminWithAdminTool(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
_ = r.Register(&fakeTool{name: "db_select", perm: Permission{AuthoringRequirement: RequirementAdmin}})
|
||||
if err := CheckAuthoring(r, []string{"db_select"}, true); err != nil {
|
||||
t.Fatalf("expected admin to pass, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckAuthoring_UnknownTool(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
err := CheckAuthoring(r, []string{"missing"}, true)
|
||||
if err == nil || !strings.Contains(err.Error(), "unknown tool") {
|
||||
t.Fatalf("expected unknown-tool error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckShareSafety_Pass(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
_ = r.Register(&fakeTool{name: "search", perm: Permission{SafeForShare: true}})
|
||||
if err := CheckShareSafety(r, []string{"search"}); err != nil {
|
||||
t.Fatalf("expected safe tool to pass, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckShareSafety_BlocksUnsafe(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
_ = r.Register(&fakeTool{name: "balance", perm: Permission{SafeForShare: false}})
|
||||
err := CheckShareSafety(r, []string{"balance"})
|
||||
if err == nil || !strings.Contains(err.Error(), "operates on the caller's own data") {
|
||||
t.Fatalf("expected share-safety error, got %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,242 @@
|
||||
// Package skilltools — encryption.go: per-skill envelope encryption for
|
||||
// KV values and file blobs. AES-256-GCM with a per-skill key derived
|
||||
// from a single master key (env var SKILLS_ENCRYPTION_MASTER_KEY) via
|
||||
// HKDF using the skill ID as the salt.
|
||||
//
|
||||
// !!!!! CRITICAL OPERATIONAL WARNING !!!!!
|
||||
//
|
||||
// SKILLS_ENCRYPTION_MASTER_KEY MUST BE BACKED UP SEPARATELY FROM THE
|
||||
// DATABASE. Losing the master key = losing every byte of encrypted
|
||||
// KV value and every encrypted file blob, with no recovery path. The
|
||||
// key is the ONLY thing that can decrypt rows whose
|
||||
// encryption_key_version > 0.
|
||||
//
|
||||
// Operational rules:
|
||||
// - Store the master key in a secrets manager (Vault, 1Password,
|
||||
// KMS export) — NEVER in the same backup as the database dump.
|
||||
// - Rotating the master key without a versioned re-encrypt
|
||||
// migration WILL render existing encrypted rows unreadable. The
|
||||
// encryption_key_version column was added so a future rotation
|
||||
// migration can re-encrypt under a new (master, version)
|
||||
// pair; do not bump the version without that migration.
|
||||
// - When the env var is empty, encryption is OFF for the whole
|
||||
// instance. Skills with encryption_enabled=true still write
|
||||
// plaintext (with a logged WARNING). This is intentional — the
|
||||
// alternative is to refuse to start, which would break
|
||||
// deployment for everyone the moment the secret leaks during
|
||||
// rotation. Loud logging + the boot-time warning in mort.go is
|
||||
// the correct trade-off.
|
||||
//
|
||||
// Why HKDF-derived per-skill keys (vs one global key): a future
|
||||
// "wipe this skill's data" admin action can be made auditable by
|
||||
// recording the skill_id in the operation log without exposing the
|
||||
// master key. Per-skill keys also cap blast radius if one key
|
||||
// somehow leaks via a side channel — only that one skill's data is
|
||||
// compromised, not the whole platform.
|
||||
//
|
||||
// Why AES-256-GCM: authenticated encryption catches tampered
|
||||
// ciphertext at decrypt time. The GCM nonce is 12 random bytes per
|
||||
// row; the auth tag is 16 bytes. Both are stored inline with the
|
||||
// ciphertext so the storage layer's value/content column holds the
|
||||
// full envelope (no separate nonce column).
|
||||
//
|
||||
// Wire format of an encrypted blob:
|
||||
//
|
||||
// +-- 1 byte: format version (0x01)
|
||||
// +-- 12 bytes: GCM nonce
|
||||
// +-- N bytes: ciphertext + 16-byte GCM tag
|
||||
//
|
||||
// The format-version byte lets a future change to nonce length or
|
||||
// auth tag handling be detected loudly rather than corrupting reads.
|
||||
// Encrypt always writes 0x01; Decrypt rejects any other version with
|
||||
// ErrEncryptionUnknownVersion.
|
||||
package tool
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
// EncryptionMasterKeyEnv is the environment variable that holds the
|
||||
// 32-byte (or longer, hashed down) master key for skill envelope
|
||||
// encryption.
|
||||
//
|
||||
// !!!!! LOSING THIS KEY = LOSING ALL ENCRYPTED DATA !!!!!
|
||||
//
|
||||
// Back it up separately from database backups. Never commit it.
|
||||
// Empty value = encryption OFF (with WARNING logged at boot).
|
||||
const EncryptionMasterKeyEnv = "SKILLS_ENCRYPTION_MASTER_KEY"
|
||||
|
||||
// CurrentKeyVersion is the version stamped on every newly-encrypted
|
||||
// row. Version 0 is reserved for plaintext (legacy / encryption-off).
|
||||
// Version 1 is "AES-256-GCM with HKDF(master, skill_id) per-skill key,
|
||||
// envelope format 0x01". Bumping this requires a migration that
|
||||
// re-encrypts existing rows under the new (master, version) pair.
|
||||
const CurrentKeyVersion = 1
|
||||
|
||||
// envelopeFormatV1 is the first byte of every Encrypt output. Decrypt
|
||||
// rejects any other value with ErrEncryptionUnknownVersion.
|
||||
const envelopeFormatV1 = byte(0x01)
|
||||
|
||||
// gcmNonceSize is fixed at 12 bytes for AES-GCM (NIST SP 800-38D
|
||||
// recommended).
|
||||
const gcmNonceSize = 12
|
||||
|
||||
// Encryption sentinel errors. Callers compare with errors.Is so storage
|
||||
// adapters can branch on "tampered" vs "unknown version" vs "no master
|
||||
// key".
|
||||
var (
|
||||
// ErrEncryptionDisabled is returned when an encryption operation
|
||||
// is attempted but SKILLS_ENCRYPTION_MASTER_KEY is empty. Storage
|
||||
// adapters interpret this as "fall through to plaintext" — they
|
||||
// MUST log loudly when this branch is taken.
|
||||
ErrEncryptionDisabled = errors.New("skilltools: encryption disabled (master key empty)")
|
||||
|
||||
// ErrEncryptionUnknownVersion is returned by Decrypt when the
|
||||
// envelope's format-version byte is not envelopeFormatV1. A read
|
||||
// that hits this error is corruption — surface to the operator,
|
||||
// do NOT silently fall back to plaintext.
|
||||
ErrEncryptionUnknownVersion = errors.New("skilltools: encryption envelope has unknown format version")
|
||||
|
||||
// ErrEncryptionTampered is returned by Decrypt when the GCM auth
|
||||
// tag check fails. The ciphertext or nonce was modified after
|
||||
// encryption. Surface as "data corruption" — the row is unreadable.
|
||||
ErrEncryptionTampered = errors.New("skilltools: encryption auth tag mismatch (data corruption or wrong key)")
|
||||
|
||||
// ErrEncryptionShortInput is returned by Decrypt when the input
|
||||
// is too short to contain even the version byte + nonce. Bug or
|
||||
// malformed write.
|
||||
ErrEncryptionShortInput = errors.New("skilltools: encryption input too short")
|
||||
)
|
||||
|
||||
// MasterKeyFromEnv returns the master key bytes (raw, NOT
|
||||
// HKDF-derived) from the SKILLS_ENCRYPTION_MASTER_KEY env var.
|
||||
//
|
||||
// Why hash + truncate to 32 bytes vs require 32 raw bytes: operators
|
||||
// commonly paste a generated random hex/base64 string of varying
|
||||
// length. SHA-256-truncate accepts any non-empty input and produces
|
||||
// a fixed-length key, which is then fed into HKDF for per-skill
|
||||
// derivation. The hash step is purely "normalize length"; HKDF still
|
||||
// does the per-skill diversification.
|
||||
//
|
||||
// Returns nil bytes (and false) if the env var is empty.
|
||||
func MasterKeyFromEnv() (key []byte, present bool) {
|
||||
raw := os.Getenv(EncryptionMasterKeyEnv)
|
||||
if raw == "" {
|
||||
return nil, false
|
||||
}
|
||||
sum := sha256.Sum256([]byte(raw))
|
||||
return sum[:], true
|
||||
}
|
||||
|
||||
// DeriveSkillKey returns the per-skill 32-byte AES-256 key for the
|
||||
// given (master, skillID) pair via HKDF-SHA256.
|
||||
//
|
||||
// Why skillID as HKDF salt: each skill gets a distinct subkey so a
|
||||
// single master breach is necessary to decrypt any one skill, but
|
||||
// a skill_id leak (which is normal — IDs appear in logs) does NOT
|
||||
// help an attacker. The HKDF info parameter is fixed to a constant
|
||||
// label so different uses of the same master+skillID pair (e.g. a
|
||||
// future per-skill HMAC key) can be derived with a different label
|
||||
// without colliding.
|
||||
//
|
||||
// master must be the 32-byte output of MasterKeyFromEnv (or
|
||||
// equivalent length-normalized input). skillID must be non-empty —
|
||||
// caller is responsible.
|
||||
func DeriveSkillKey(master []byte, skillID string) ([]byte, error) {
|
||||
if len(master) == 0 {
|
||||
return nil, ErrEncryptionDisabled
|
||||
}
|
||||
if skillID == "" {
|
||||
return nil, errors.New("skilltools: DeriveSkillKey requires non-empty skillID")
|
||||
}
|
||||
r := hkdf.New(sha256.New, master, []byte(skillID), []byte("mort/skills/v1/aead"))
|
||||
out := make([]byte, 32)
|
||||
if _, err := io.ReadFull(r, out); err != nil {
|
||||
return nil, fmt.Errorf("skilltools: HKDF derive: %w", err)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// Encrypt seals plaintext under skillKey using AES-256-GCM and returns
|
||||
// the wire envelope (version byte || nonce || ciphertext || tag).
|
||||
//
|
||||
// Caller is responsible for stamping the encryption_key_version column
|
||||
// to CurrentKeyVersion AFTER a successful Encrypt — Encrypt itself
|
||||
// only produces bytes; persisting them is the storage layer's job.
|
||||
//
|
||||
// Why a fresh random nonce per call (vs deterministic): nonce reuse
|
||||
// under GCM is catastrophic (allows recovering the keystream); fresh
|
||||
// 96-bit random nonces have a negligible collision probability under
|
||||
// any realistic write rate.
|
||||
func Encrypt(skillKey, plaintext []byte) ([]byte, error) {
|
||||
if len(skillKey) != 32 {
|
||||
return nil, fmt.Errorf("skilltools: Encrypt requires 32-byte key, got %d", len(skillKey))
|
||||
}
|
||||
block, err := aes.NewCipher(skillKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("skilltools: aes.NewCipher: %w", err)
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("skilltools: cipher.NewGCM: %w", err)
|
||||
}
|
||||
nonce := make([]byte, gcmNonceSize)
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return nil, fmt.Errorf("skilltools: rand.Read: %w", err)
|
||||
}
|
||||
// Pre-allocate the envelope: 1 (version) + 12 (nonce) + len(plaintext) + 16 (tag).
|
||||
out := make([]byte, 0, 1+gcmNonceSize+len(plaintext)+gcm.Overhead())
|
||||
out = append(out, envelopeFormatV1)
|
||||
out = append(out, nonce...)
|
||||
out = gcm.Seal(out, nonce, plaintext, nil)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// Decrypt opens an envelope produced by Encrypt under the same
|
||||
// skillKey. Returns the plaintext or one of the sentinel errors.
|
||||
//
|
||||
// Caller MUST inspect the storage row's encryption_key_version BEFORE
|
||||
// calling Decrypt. Version 0 means plaintext — Decrypt SHOULD NOT be
|
||||
// called for version-0 rows (callers branch on the column value).
|
||||
// This function does NOT inspect any version column; it only looks at
|
||||
// the in-band envelope-format byte.
|
||||
func Decrypt(skillKey, envelope []byte) ([]byte, error) {
|
||||
if len(skillKey) != 32 {
|
||||
return nil, fmt.Errorf("skilltools: Decrypt requires 32-byte key, got %d", len(skillKey))
|
||||
}
|
||||
if len(envelope) < 1+gcmNonceSize {
|
||||
return nil, ErrEncryptionShortInput
|
||||
}
|
||||
if envelope[0] != envelopeFormatV1 {
|
||||
return nil, ErrEncryptionUnknownVersion
|
||||
}
|
||||
nonce := envelope[1 : 1+gcmNonceSize]
|
||||
ciphertext := envelope[1+gcmNonceSize:]
|
||||
block, err := aes.NewCipher(skillKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("skilltools: aes.NewCipher: %w", err)
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("skilltools: cipher.NewGCM: %w", err)
|
||||
}
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
// Distinguish auth-tag mismatch from other crypto errors so
|
||||
// callers can surface "data corruption" specifically. The
|
||||
// stdlib wraps the failure as a generic error; we map any
|
||||
// failure here to ErrEncryptionTampered (the most likely
|
||||
// cause is wrong key / tampered bytes).
|
||||
return nil, ErrEncryptionTampered
|
||||
}
|
||||
return plaintext, nil
|
||||
}
|
||||
@@ -0,0 +1,205 @@
|
||||
package tool
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Why: round-trip is the bedrock — without it, every other test is
|
||||
// meaningless. What: encrypt then decrypt; assert plaintext returns.
|
||||
// Test: write a non-trivial plaintext and confirm exact byte equality.
|
||||
func TestEncryption_RoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
master := masterTestKey()
|
||||
key, err := DeriveSkillKey(master, "skill-abc")
|
||||
if err != nil {
|
||||
t.Fatalf("DeriveSkillKey: %v", err)
|
||||
}
|
||||
plaintext := []byte(`{"hello":"world","n":42,"arr":[1,2,3]}`)
|
||||
envelope, err := Encrypt(key, plaintext)
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt: %v", err)
|
||||
}
|
||||
if envelope[0] != envelopeFormatV1 {
|
||||
t.Fatalf("envelope[0] = %d, want %d", envelope[0], envelopeFormatV1)
|
||||
}
|
||||
got, err := Decrypt(key, envelope)
|
||||
if err != nil {
|
||||
t.Fatalf("Decrypt: %v", err)
|
||||
}
|
||||
if !bytes.Equal(got, plaintext) {
|
||||
t.Fatalf("round-trip mismatch:\n got: %q\nwant: %q", got, plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
// Why: GCM is authenticated encryption — flipping any bit MUST be
|
||||
// detected. What: tamper with the ciphertext; assert ErrEncryptionTampered.
|
||||
// Test: flip one byte of the ciphertext suffix, decrypt, expect tamper error.
|
||||
func TestEncryption_TamperDetected(t *testing.T) {
|
||||
t.Parallel()
|
||||
master := masterTestKey()
|
||||
key, err := DeriveSkillKey(master, "skill-tamper")
|
||||
if err != nil {
|
||||
t.Fatalf("DeriveSkillKey: %v", err)
|
||||
}
|
||||
envelope, err := Encrypt(key, []byte("sensitive data"))
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt: %v", err)
|
||||
}
|
||||
// Flip a byte in the ciphertext (after version + nonce).
|
||||
envelope[1+gcmNonceSize] ^= 0x01
|
||||
_, err = Decrypt(key, envelope)
|
||||
if !errors.Is(err, ErrEncryptionTampered) {
|
||||
t.Fatalf("Decrypt after tamper = %v, want ErrEncryptionTampered", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Why: nonce reuse under GCM is catastrophic — verify the impl uses
|
||||
// fresh randomness on every call. What: encrypt the same plaintext twice;
|
||||
// the envelopes must differ.
|
||||
func TestEncryption_FreshNoncePerCall(t *testing.T) {
|
||||
t.Parallel()
|
||||
master := masterTestKey()
|
||||
key, err := DeriveSkillKey(master, "skill-nonce")
|
||||
if err != nil {
|
||||
t.Fatalf("DeriveSkillKey: %v", err)
|
||||
}
|
||||
plaintext := []byte("fixed payload")
|
||||
a, err := Encrypt(key, plaintext)
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt #1: %v", err)
|
||||
}
|
||||
b, err := Encrypt(key, plaintext)
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt #2: %v", err)
|
||||
}
|
||||
if bytes.Equal(a, b) {
|
||||
t.Fatalf("two encryptions of the same plaintext produced identical envelopes (nonce not random)")
|
||||
}
|
||||
// Both must still decrypt to the same plaintext.
|
||||
for i, env := range [][]byte{a, b} {
|
||||
got, err := Decrypt(key, env)
|
||||
if err != nil {
|
||||
t.Fatalf("Decrypt #%d: %v", i, err)
|
||||
}
|
||||
if !bytes.Equal(got, plaintext) {
|
||||
t.Fatalf("Decrypt #%d mismatch", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Why: per-skill key derivation MUST give different skills different
|
||||
// keys so a leaked skillkey doesn't cross-decrypt. What: derive keys
|
||||
// for skill A and skill B; encrypt under A; decrypt under B; expect
|
||||
// tamper error.
|
||||
func TestEncryption_PerSkillIsolation(t *testing.T) {
|
||||
t.Parallel()
|
||||
master := masterTestKey()
|
||||
keyA, _ := DeriveSkillKey(master, "skill-a")
|
||||
keyB, _ := DeriveSkillKey(master, "skill-b")
|
||||
if bytes.Equal(keyA, keyB) {
|
||||
t.Fatalf("derived keys for distinct skills are identical (HKDF salt not effective)")
|
||||
}
|
||||
envelope, err := Encrypt(keyA, []byte("only skill A may read"))
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt: %v", err)
|
||||
}
|
||||
_, err = Decrypt(keyB, envelope)
|
||||
if !errors.Is(err, ErrEncryptionTampered) {
|
||||
t.Fatalf("Decrypt under wrong skill key = %v, want ErrEncryptionTampered", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Why: a future format change must be detectable, not silently
|
||||
// corrupting reads. What: hand-craft an envelope with version byte
|
||||
// 0xFF; assert ErrEncryptionUnknownVersion.
|
||||
func TestEncryption_UnknownVersionRejected(t *testing.T) {
|
||||
t.Parallel()
|
||||
key := make([]byte, 32)
|
||||
envelope := make([]byte, 1+gcmNonceSize+16)
|
||||
envelope[0] = 0xFF
|
||||
_, err := Decrypt(key, envelope)
|
||||
if !errors.Is(err, ErrEncryptionUnknownVersion) {
|
||||
t.Fatalf("Decrypt with bad version = %v, want ErrEncryptionUnknownVersion", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Why: short inputs must not panic. What: feed a 5-byte envelope to
|
||||
// Decrypt; assert ErrEncryptionShortInput.
|
||||
func TestEncryption_ShortInputRejected(t *testing.T) {
|
||||
t.Parallel()
|
||||
key := make([]byte, 32)
|
||||
_, err := Decrypt(key, []byte{1, 2, 3, 4, 5})
|
||||
if !errors.Is(err, ErrEncryptionShortInput) {
|
||||
t.Fatalf("Decrypt with short input = %v, want ErrEncryptionShortInput", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Why: empty master = encryption disabled. What: DeriveSkillKey with
|
||||
// empty master returns ErrEncryptionDisabled.
|
||||
func TestEncryption_EmptyMasterDisabled(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, err := DeriveSkillKey(nil, "skill-x")
|
||||
if !errors.Is(err, ErrEncryptionDisabled) {
|
||||
t.Fatalf("DeriveSkillKey(nil) = %v, want ErrEncryptionDisabled", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Why: callers commonly paste random hex/base64 of varying length;
|
||||
// MasterKeyFromEnv should normalize to 32 bytes via SHA-256. What:
|
||||
// set the env var; assert returned bytes match SHA-256 of input.
|
||||
func TestEncryption_MasterKeyFromEnvNormalizesLength(t *testing.T) {
|
||||
// not parallel — we mutate process env
|
||||
const raw = "this-is-a-fake-master-key-for-testing-only-totally-not-secure"
|
||||
t.Setenv(EncryptionMasterKeyEnv, raw)
|
||||
got, present := MasterKeyFromEnv()
|
||||
if !present {
|
||||
t.Fatalf("MasterKeyFromEnv reported absent for non-empty env var")
|
||||
}
|
||||
if len(got) != 32 {
|
||||
t.Fatalf("len(masterKey) = %d, want 32", len(got))
|
||||
}
|
||||
want := sha256.Sum256([]byte(raw))
|
||||
if !bytes.Equal(got, want[:]) {
|
||||
t.Fatalf("masterKey does not match SHA-256 of env var")
|
||||
}
|
||||
}
|
||||
|
||||
// Why: empty env var = encryption off (instance-wide). What:
|
||||
// MasterKeyFromEnv returns nil + present=false.
|
||||
func TestEncryption_MasterKeyFromEnvEmpty(t *testing.T) {
|
||||
t.Setenv(EncryptionMasterKeyEnv, "")
|
||||
got, present := MasterKeyFromEnv()
|
||||
if present {
|
||||
t.Fatalf("MasterKeyFromEnv reported present for empty env var")
|
||||
}
|
||||
if got != nil {
|
||||
t.Fatalf("MasterKeyFromEnv returned non-nil bytes for empty env var")
|
||||
}
|
||||
}
|
||||
|
||||
// Why: defence in depth — explicit non-32 key sizes should error
|
||||
// rather than panic.
|
||||
func TestEncryption_BadKeySize(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, err := Encrypt(make([]byte, 16), []byte("x"))
|
||||
if err == nil {
|
||||
t.Fatalf("Encrypt with 16-byte key did not error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "32-byte key") {
|
||||
t.Fatalf("error did not mention key size: %v", err)
|
||||
}
|
||||
_, err = Decrypt(make([]byte, 16), make([]byte, 32))
|
||||
if err == nil {
|
||||
t.Fatalf("Decrypt with 16-byte key did not error")
|
||||
}
|
||||
}
|
||||
|
||||
func masterTestKey() []byte {
|
||||
// fixed deterministic master so test runs are reproducible.
|
||||
sum := sha256.Sum256([]byte("test-master-do-not-use-in-prod"))
|
||||
return sum[:]
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
package tool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
llm "gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
// toolCall mirrors the legacy gollm-era test shape (string arguments) so the
|
||||
// pre-conversion test call sites keep their literal syntax. majordomo's
|
||||
// llm.ToolCall carries json.RawMessage arguments; execBox adapts.
|
||||
type toolCall struct {
|
||||
Name string
|
||||
Arguments string
|
||||
}
|
||||
|
||||
// execBox executes one call through a toolbox and adapts majordomo's
|
||||
// ToolResult to the (result, error) pair these tests assert against:
|
||||
// IsError results come back as a Go error carrying the result content
|
||||
// (which is how the agent-facing error text read in the legacy gollm era).
|
||||
func execBox(box *llm.Toolbox, call toolCall) (string, error) {
|
||||
res := box.Execute(context.Background(), llm.ToolCall{
|
||||
ID: "test-call",
|
||||
Name: call.Name,
|
||||
Arguments: json.RawMessage(call.Arguments),
|
||||
})
|
||||
if res.IsError {
|
||||
return "", errors.New(res.Content)
|
||||
}
|
||||
return res.Content, nil
|
||||
}
|
||||
|
||||
// execTool runs a single built llm.Tool's handler and serializes the
|
||||
// result the way llm.ExecuteTool does. Replaces the legacy gollm
|
||||
// Tool.Execute(ctx, argsJSON) method the original tests called.
|
||||
//
|
||||
// Why the handler directly (vs llm.ExecuteTool): ExecuteTool flattens
|
||||
// handler errors into IsError result text, but several tests assert
|
||||
// error IDENTITY (errors.Is against sentinel errors the handlers
|
||||
// wrap). Calling the handler preserves the error value, matching the
|
||||
// legacy gollm Execute contract these tests were written against.
|
||||
func execTool(ctx context.Context, t llm.Tool, args string) (string, error) {
|
||||
raw := json.RawMessage(args)
|
||||
if len(raw) == 0 {
|
||||
raw = json.RawMessage("{}")
|
||||
}
|
||||
out, err := t.Handler(ctx, raw)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
switch v := out.(type) {
|
||||
case nil:
|
||||
return "null", nil
|
||||
case string:
|
||||
return v, nil
|
||||
case json.RawMessage:
|
||||
return string(v), nil
|
||||
default:
|
||||
enc, mErr := json.Marshal(v)
|
||||
if mErr != nil {
|
||||
return "", mErr
|
||||
}
|
||||
return string(enc), nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,272 @@
|
||||
package tool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
llm "gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
// gatedToolMarker is the unexported interface implemented by every Tool
|
||||
// constructed via NewGatedTool. The IsGatedTool helper performs a type
|
||||
// assertion against this marker so the meta-test in default_test.go
|
||||
// (and the wizardtools meta-test) can enforce that every registered
|
||||
// production tool uses the wrapper.
|
||||
//
|
||||
// Why an unexported method (vs a public marker): the goal is to make it
|
||||
// IMPOSSIBLE for an external caller to lie about being gated. Only the
|
||||
// implementation in this file can satisfy the interface, so the
|
||||
// type-assertion in IsGatedTool is a real proof of provenance, not a
|
||||
// pinky-swear from a struct that opts in.
|
||||
type gatedToolMarker interface {
|
||||
isGatedTool()
|
||||
}
|
||||
|
||||
// gatedTool is the concrete Tool returned by NewGatedTool. It carries
|
||||
// the per-tool metadata (Name/Description/Permission) and the typed
|
||||
// handler closure; BuildLLM wraps the handler with CheckGate +
|
||||
// EmitAudit so tool authors literally cannot forget either call.
|
||||
//
|
||||
// Why generic on Args (vs accepting any-shaped JSON): each tool's
|
||||
// handler is typed against its own param struct. defineTypedTool
|
||||
// derives a JSON schema for the LLM from Args (llm.SchemaFor) and
|
||||
// parses the args before invoking the handler. We re-marshal args to
|
||||
// JSON once for the audit row so the captured shape matches exactly
|
||||
// what the handler ran with (post-coercion).
|
||||
type gatedTool[Args any] struct {
|
||||
name string
|
||||
description string
|
||||
permission Permission
|
||||
fn func(ctx context.Context, inv Invocation, args Args) (string, error)
|
||||
}
|
||||
|
||||
// isGatedTool implements gatedToolMarker for the meta-test.
|
||||
func (g *gatedTool[Args]) isGatedTool() {}
|
||||
|
||||
// defineTypedTool builds the majordomo llm.Tool for a typed handler:
|
||||
// schema derived from Args, arguments decoded leniently (string→number/
|
||||
// boolean coercion preserved from the legacy gollm era — see argcoerce.go)
|
||||
// before the handler runs. An unparseable arguments object returns the
|
||||
// decode error WITHOUT running fn, framing arg-parse-error as a
|
||||
// tool-call wiring failure rather than a tool-handler failure.
|
||||
//
|
||||
// Why not majordomo's llm.DefineTool: its decode is strict by design;
|
||||
// mort's tool catalog keeps the lenient dialect for parity with years
|
||||
// of model traffic that emits "3" where the schema says integer.
|
||||
func defineTypedTool[Args any](name, description string, fn func(ctx context.Context, args Args) (string, error)) llm.Tool {
|
||||
schema, err := llm.SchemaFor[Args]()
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("skilltools: defineTypedTool(%q): %v", name, err))
|
||||
}
|
||||
return llm.Tool{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Parameters: schema,
|
||||
Handler: func(ctx context.Context, raw json.RawMessage) (any, error) {
|
||||
var args Args
|
||||
if err := unmarshalArgsLenient(raw, &args); err != nil {
|
||||
return nil, fmt.Errorf("invalid arguments for %s: %w", name, err)
|
||||
}
|
||||
return fn(ctx, args)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewGatedTool wraps a typed handler so it automatically:
|
||||
// 1. Calls CheckGate(inv) before the handler runs. On gate rejection
|
||||
// emits EmitAudit(inv, "{}", "", err) and returns the gate error.
|
||||
// 2. Calls fn(ctx, inv, args) once gate passes.
|
||||
// 3. Re-marshals args to JSON for the audit row (so the captured args
|
||||
// reflect any coercion performed during deserialisation), then
|
||||
// emits EmitAudit(inv, argsJSON, result, err) once the handler
|
||||
// returns.
|
||||
//
|
||||
// Production tools SHOULD use NewGatedTool unless they have a strong
|
||||
// reason to handle gating manually. The wrapper exists because the
|
||||
// previous per-tool pattern repeated four lines of boilerplate
|
||||
// (CheckGate at the top, EmitAudit on every return path), and that
|
||||
// boilerplate is easy to forget — wizard tools in v1 hotfix #4 had to
|
||||
// be retrofitted because the author overlooked CheckGate. Centralising
|
||||
// the calls makes them impossible to skip and the meta-test in
|
||||
// tools/default_test.go enforces the discipline.
|
||||
//
|
||||
// The typed define layer handles JSON parsing and arg coercion before
|
||||
// fn runs; if the args JSON is unparseable, the decode error is
|
||||
// returned directly (the wrapper's audit emission does NOT fire on
|
||||
// parse error — arg-parse-error is a tool-call wiring failure rather
|
||||
// than a tool-handler failure).
|
||||
//
|
||||
// Test: pkg/skilltools/gated_tool_test.go covers gate rejection,
|
||||
// happy path, fn-returned error, and the IsGatedTool assertion. The
|
||||
// meta-test in pkg/skilltools/tools/default_test.go walks the registry
|
||||
// and asserts every production tool implements gatedToolMarker.
|
||||
func NewGatedTool[Args any](
|
||||
name, description string,
|
||||
permission Permission,
|
||||
fn func(ctx context.Context, inv Invocation, args Args) (string, error),
|
||||
) Tool {
|
||||
return &gatedTool[Args]{
|
||||
name: name,
|
||||
description: description,
|
||||
permission: permission,
|
||||
fn: fn,
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the tool's registry key.
|
||||
func (g *gatedTool[Args]) Name() string { return g.name }
|
||||
|
||||
// Description is shown to the LLM.
|
||||
func (g *gatedTool[Args]) Description() string { return g.description }
|
||||
|
||||
// Permission classifies the tool for save-time / share-time gating.
|
||||
func (g *gatedTool[Args]) Permission() Permission { return g.permission }
|
||||
|
||||
// BuildLLM produces the per-invocation llm.Tool. The returned tool's
|
||||
// handler:
|
||||
// - Runs CheckGate(inv) FIRST (before any handler logic). On gate
|
||||
// rejection emits the audit row and returns the gate error.
|
||||
// - Calls the user-supplied fn with the typed args. fn never sees a
|
||||
// gate-rejected invocation.
|
||||
// - Re-marshals args to JSON and emits the audit row exactly once,
|
||||
// regardless of fn's return value (success or error).
|
||||
//
|
||||
// Why re-marshal vs using the raw LLM JSON: the lenient decode performs
|
||||
// numeric/boolean coercion (e.g. "3" → 3) before invoking the handler;
|
||||
// the audit row should reflect what fn actually received, not the
|
||||
// pre-coercion text the LLM emitted.
|
||||
func (g *gatedTool[Args]) BuildLLM(inv Invocation) llm.Tool {
|
||||
return defineTypedTool[Args](
|
||||
g.name,
|
||||
g.description,
|
||||
func(ctx context.Context, args Args) (string, error) {
|
||||
if err := CheckGate(inv); err != nil {
|
||||
EmitAudit(inv, "{}", "", err)
|
||||
return "", err
|
||||
}
|
||||
argsJSON, mErr := json.Marshal(args)
|
||||
if mErr != nil {
|
||||
// Vanishingly rare for the typed param structs in use;
|
||||
// fall back to "{}" so the audit row never carries a
|
||||
// half-formed args field.
|
||||
argsJSON = []byte("{}")
|
||||
}
|
||||
result, err := g.fn(ctx, inv, args)
|
||||
EmitAudit(inv, string(argsJSON), result, err)
|
||||
return result, err
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// IsGatedTool reports whether t was constructed via NewGatedTool /
|
||||
// NewGatedToolWithAudit. Used by the meta-test in
|
||||
// tools/default_test.go to enforce that every registered production
|
||||
// tool uses the wrapper. The check is a type assertion against the
|
||||
// unexported gatedToolMarker interface, so only the gatedTool variants
|
||||
// from this package can satisfy it — there is no way for an external
|
||||
// Tool to pretend to be gated.
|
||||
func IsGatedTool(t Tool) bool {
|
||||
_, ok := t.(gatedToolMarker)
|
||||
return ok
|
||||
}
|
||||
|
||||
// AuditedResult is what a NewGatedToolWithAudit handler returns:
|
||||
// LLMResult is the string surfaced to the LLM (the tool-call result
|
||||
// the model sees in its conversation); AuditArgs and AuditResult are
|
||||
// what the wrapper logs to the audit row INSTEAD of the auto-derived
|
||||
// values.
|
||||
//
|
||||
// Why a separate variant: a small number of tools (paste_create being
|
||||
// the canonical example) need to return a sensitive value to the LLM
|
||||
// (a URL containing an encryption-key fragment) but MUST redact that
|
||||
// value from the audit row, since the audit row is rendered to admins
|
||||
// in the webui run-trace view. The default wrapper auto-logs args +
|
||||
// result, which would leak the key. NewGatedToolWithAudit lets the
|
||||
// handler explicitly separate the LLM-visible output from the
|
||||
// audit-visible output, while still benefitting from auto-injected
|
||||
// CheckGate.
|
||||
type AuditedResult struct {
|
||||
// LLMResult is the string returned to the LLM as the tool result.
|
||||
LLMResult string
|
||||
// AuditArgs is the args string written to the audit row. If empty,
|
||||
// the wrapper falls back to the JSON-marshaled typed args (same
|
||||
// behaviour as NewGatedTool).
|
||||
AuditArgs string
|
||||
// AuditResult is the result string written to the audit row. May
|
||||
// be empty (logged as "") to suppress sensitive fragments.
|
||||
AuditResult string
|
||||
}
|
||||
|
||||
// gatedToolWithAudit is the variant of gatedTool whose handler returns
|
||||
// an AuditedResult so it can override what the audit row captures.
|
||||
type gatedToolWithAudit[Args any] struct {
|
||||
name string
|
||||
description string
|
||||
permission Permission
|
||||
fn func(ctx context.Context, inv Invocation, args Args) (AuditedResult, error)
|
||||
}
|
||||
|
||||
// isGatedTool implements gatedToolMarker for the meta-test.
|
||||
func (g *gatedToolWithAudit[Args]) isGatedTool() {}
|
||||
|
||||
func (g *gatedToolWithAudit[Args]) Name() string { return g.name }
|
||||
func (g *gatedToolWithAudit[Args]) Description() string { return g.description }
|
||||
func (g *gatedToolWithAudit[Args]) Permission() Permission { return g.permission }
|
||||
|
||||
// NewGatedToolWithAudit is the redaction-aware variant of NewGatedTool.
|
||||
// Use it ONLY when the LLM-facing result must differ from the audit
|
||||
// row (e.g. the result contains an encryption key that the audit must
|
||||
// NOT capture). Most tools should use NewGatedTool.
|
||||
//
|
||||
// Behaviour matches NewGatedTool exactly except:
|
||||
// - The handler returns AuditedResult; the wrapper passes
|
||||
// AuditedResult.LLMResult to the LLM and writes
|
||||
// AuditedResult.AuditArgs / AuditedResult.AuditResult to the
|
||||
// audit row (falling back to the JSON-marshaled args if
|
||||
// AuditArgs is empty).
|
||||
// - Gate rejection still emits an audit row with empty Result and
|
||||
// args="{}" before returning the gate error.
|
||||
//
|
||||
// Test: covered alongside NewGatedTool in pkg/skilltools/
|
||||
// gated_tool_test.go.
|
||||
func NewGatedToolWithAudit[Args any](
|
||||
name, description string,
|
||||
permission Permission,
|
||||
fn func(ctx context.Context, inv Invocation, args Args) (AuditedResult, error),
|
||||
) Tool {
|
||||
return &gatedToolWithAudit[Args]{
|
||||
name: name,
|
||||
description: description,
|
||||
permission: permission,
|
||||
fn: fn,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildLLM produces the per-invocation llm.Tool. Same gate-injection
|
||||
// semantics as gatedTool[Args].BuildLLM; the audit row uses the
|
||||
// handler-supplied AuditArgs / AuditResult so a sensitive LLM-visible
|
||||
// result string never leaks into the audit log.
|
||||
func (g *gatedToolWithAudit[Args]) BuildLLM(inv Invocation) llm.Tool {
|
||||
return defineTypedTool[Args](
|
||||
g.name,
|
||||
g.description,
|
||||
func(ctx context.Context, args Args) (string, error) {
|
||||
if err := CheckGate(inv); err != nil {
|
||||
EmitAudit(inv, "{}", "", err)
|
||||
return "", err
|
||||
}
|
||||
res, err := g.fn(ctx, inv, args)
|
||||
auditArgs := res.AuditArgs
|
||||
if auditArgs == "" {
|
||||
if b, mErr := json.Marshal(args); mErr == nil {
|
||||
auditArgs = string(b)
|
||||
} else {
|
||||
auditArgs = "{}"
|
||||
}
|
||||
}
|
||||
EmitAudit(inv, auditArgs, res.AuditResult, err)
|
||||
return res.LLMResult, err
|
||||
},
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,401 @@
|
||||
package tool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
llm "gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
// gatedTestParams is a typed param struct used by the gated_tool tests.
|
||||
// Mirrors a real production tool: a couple of strings the LLM supplies.
|
||||
type gatedTestParams struct {
|
||||
Question string `json:"question" description:"The question to answer."`
|
||||
Detail string `json:"detail,omitempty" description:"Optional detail level."`
|
||||
}
|
||||
|
||||
// recordingAudit captures every AuditCall the wrapper emits so tests
|
||||
// can assert exactly what the wrapper logged. Concurrent-safe in case a
|
||||
// future test parallelises across goroutines.
|
||||
type recordingAudit struct {
|
||||
mu sync.Mutex
|
||||
calls []AuditCall
|
||||
}
|
||||
|
||||
func (r *recordingAudit) hook() AuditHook {
|
||||
return func(call AuditCall) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.calls = append(r.calls, call)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *recordingAudit) snapshot() []AuditCall {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
out := make([]AuditCall, len(r.calls))
|
||||
copy(out, r.calls)
|
||||
return out
|
||||
}
|
||||
|
||||
// buildAndExecute is the test-only convenience for going from a
|
||||
// constructed Tool to an llm.Tool result. Mirrors how the production
|
||||
// registry's Build call wires inv.gate / inv.audit.
|
||||
func buildAndExecute(t *testing.T, tool Tool, inv Invocation, vis Visibility, audit AuditHook, args string) (string, error) {
|
||||
t.Helper()
|
||||
r := NewRegistry()
|
||||
if err := r.Register(tool); err != nil {
|
||||
t.Fatalf("register: %v", err)
|
||||
}
|
||||
box, err := r.Build([]string{tool.Name()}, inv, vis, audit)
|
||||
if err != nil {
|
||||
t.Fatalf("build: %v", err)
|
||||
}
|
||||
return execBox(box, toolCall{Name: tool.Name(), Arguments: args})
|
||||
}
|
||||
|
||||
// TestNewGatedTool_GateRejection verifies that the wrapper auto-injects
|
||||
// CheckGate: if the invocation's SkillName doesn't match the tool's
|
||||
// SkillNameGate, fn never runs and the audit row is emitted with the
|
||||
// gate error. This is the core contract that v1 hotfix #4 had to
|
||||
// retrofit by hand.
|
||||
func TestNewGatedTool_GateRejection(t *testing.T) {
|
||||
called := false
|
||||
tool := NewGatedTool[gatedTestParams](
|
||||
"gated_test_tool",
|
||||
"A test tool gated to my-skill.",
|
||||
Permission{
|
||||
AuthoringRequirement: RequirementAnyone,
|
||||
OperatesOn: ScopeGlobal,
|
||||
SafeForShare: true,
|
||||
SkillNameGate: "my-skill",
|
||||
},
|
||||
func(ctx context.Context, inv Invocation, args gatedTestParams) (string, error) {
|
||||
called = true
|
||||
return "should not be reached", nil
|
||||
},
|
||||
)
|
||||
|
||||
rec := &recordingAudit{}
|
||||
out, err := buildAndExecute(t, tool,
|
||||
Invocation{SkillName: "other-skill"},
|
||||
VisibilityPrivate, rec.hook(),
|
||||
`{"question":"hi"}`)
|
||||
|
||||
if err == nil {
|
||||
t.Fatalf("expected gate-rejection error, got out=%q err=nil", out)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "restricted to") {
|
||||
t.Fatalf("expected error containing 'restricted to', got %v", err)
|
||||
}
|
||||
if called {
|
||||
t.Errorf("fn was called despite gate rejection — wrapper failed to inject CheckGate")
|
||||
}
|
||||
|
||||
calls := rec.snapshot()
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected exactly 1 audit call, got %d: %+v", len(calls), calls)
|
||||
}
|
||||
if calls[0].Err == nil {
|
||||
t.Errorf("audit call.Err was nil; expected the gate error")
|
||||
}
|
||||
if calls[0].Args != "{}" {
|
||||
t.Errorf("audit call.Args=%q, want \"{}\" (no args parsed pre-gate)", calls[0].Args)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewGatedTool_HappyPath verifies the wrapper passes args to fn,
|
||||
// returns fn's result, and emits a successful audit row with the
|
||||
// re-marshaled args.
|
||||
func TestNewGatedTool_HappyPath(t *testing.T) {
|
||||
var seen gatedTestParams
|
||||
var seenInv Invocation
|
||||
|
||||
tool := NewGatedTool[gatedTestParams](
|
||||
"gated_happy_tool",
|
||||
"A test tool with no gate.",
|
||||
Permission{
|
||||
AuthoringRequirement: RequirementAnyone,
|
||||
OperatesOn: ScopeGlobal,
|
||||
SafeForShare: true,
|
||||
},
|
||||
func(ctx context.Context, inv Invocation, args gatedTestParams) (string, error) {
|
||||
seen = args
|
||||
seenInv = inv
|
||||
return "answered: " + args.Question, nil
|
||||
},
|
||||
)
|
||||
|
||||
rec := &recordingAudit{}
|
||||
out, err := buildAndExecute(t, tool,
|
||||
Invocation{SkillName: "any-skill", CallerID: "user-7"},
|
||||
VisibilityPrivate, rec.hook(),
|
||||
`{"question":"what is the time?","detail":"verbose"}`)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("execute: %v", err)
|
||||
}
|
||||
if out != "answered: what is the time?" {
|
||||
t.Errorf("unexpected output: %q", out)
|
||||
}
|
||||
if seen.Question != "what is the time?" || seen.Detail != "verbose" {
|
||||
t.Errorf("fn received %+v, want question/detail populated", seen)
|
||||
}
|
||||
if seenInv.CallerID != "user-7" {
|
||||
t.Errorf("fn saw CallerID=%q, want user-7", seenInv.CallerID)
|
||||
}
|
||||
|
||||
calls := rec.snapshot()
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected exactly 1 audit call, got %d", len(calls))
|
||||
}
|
||||
if calls[0].Err != nil {
|
||||
t.Errorf("audit call.Err=%v, want nil", calls[0].Err)
|
||||
}
|
||||
if calls[0].Result != "answered: what is the time?" {
|
||||
t.Errorf("audit call.Result=%q, want match output", calls[0].Result)
|
||||
}
|
||||
// The wrapper re-marshals the args — verify the JSON is well-formed
|
||||
// and contains the expected fields.
|
||||
var argsBack gatedTestParams
|
||||
if err := json.Unmarshal([]byte(calls[0].Args), &argsBack); err != nil {
|
||||
t.Fatalf("audit args not valid JSON: %q (%v)", calls[0].Args, err)
|
||||
}
|
||||
if argsBack.Question != "what is the time?" || argsBack.Detail != "verbose" {
|
||||
t.Errorf("audit args round-trip mismatch: %+v", argsBack)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewGatedTool_FnError verifies the wrapper surfaces fn's error
|
||||
// AND captures the partial result + error in the audit row.
|
||||
func TestNewGatedTool_FnError(t *testing.T) {
|
||||
tool := NewGatedTool[gatedTestParams](
|
||||
"gated_fn_err_tool",
|
||||
"A test tool whose handler always errors.",
|
||||
Permission{
|
||||
AuthoringRequirement: RequirementAnyone,
|
||||
OperatesOn: ScopeGlobal,
|
||||
SafeForShare: true,
|
||||
},
|
||||
func(ctx context.Context, inv Invocation, args gatedTestParams) (string, error) {
|
||||
return "partial output", errors.New("boom")
|
||||
},
|
||||
)
|
||||
|
||||
rec := &recordingAudit{}
|
||||
out, err := buildAndExecute(t, tool,
|
||||
Invocation{SkillName: "any-skill"},
|
||||
VisibilityPrivate, rec.hook(),
|
||||
`{"question":"x"}`)
|
||||
|
||||
// llm.Define's Execute returns ("", err) when the handler returns a
|
||||
// non-nil error — out is dropped on the LLM side. But the wrapper's
|
||||
// audit row should still capture both partial result + error.
|
||||
if err == nil || !strings.Contains(err.Error(), "boom") {
|
||||
t.Fatalf("expected boom error, got out=%q err=%v", out, err)
|
||||
}
|
||||
|
||||
calls := rec.snapshot()
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected exactly 1 audit call, got %d", len(calls))
|
||||
}
|
||||
if calls[0].Err == nil || !strings.Contains(calls[0].Err.Error(), "boom") {
|
||||
t.Errorf("audit call.Err=%v, want boom", calls[0].Err)
|
||||
}
|
||||
if calls[0].Result != "partial output" {
|
||||
t.Errorf("audit call.Result=%q, want 'partial output' (partial captured)", calls[0].Result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewGatedTool_ArgsParseHandledByLLM_NoAuditEmitted documents the
|
||||
// behaviour at the wrapper boundary: when the LLM sends malformed JSON
|
||||
// args, llm.Define's Execute fails BEFORE the wrapper's inner closure
|
||||
// runs. The wrapper does NOT emit an audit row in that case — it never
|
||||
// got the chance. This is intentional: arg-parse failure is a
|
||||
// tool-call wiring problem, not a tool-handler problem; the audit log
|
||||
// reflects what the handler did, and on parse failure no handler ran.
|
||||
//
|
||||
// The test exists so future readers see this invariant documented in
|
||||
// code and don't re-introduce a "log everything" path that breaks the
|
||||
// wrapper's contract with the audit storage layer.
|
||||
func TestNewGatedTool_ArgsParseHandledByLLM_NoAuditEmitted(t *testing.T) {
|
||||
tool := NewGatedTool[gatedTestParams](
|
||||
"gated_parse_err_tool",
|
||||
"A test tool that should never receive bad JSON.",
|
||||
Permission{
|
||||
AuthoringRequirement: RequirementAnyone,
|
||||
OperatesOn: ScopeGlobal,
|
||||
SafeForShare: true,
|
||||
},
|
||||
func(ctx context.Context, inv Invocation, args gatedTestParams) (string, error) {
|
||||
t.Fatalf("fn ran despite malformed JSON — should never happen")
|
||||
return "", nil
|
||||
},
|
||||
)
|
||||
|
||||
rec := &recordingAudit{}
|
||||
_, err := buildAndExecute(t, tool,
|
||||
Invocation{SkillName: "any-skill"},
|
||||
VisibilityPrivate, rec.hook(),
|
||||
`{"question":not-quoted}`) // intentionally malformed
|
||||
|
||||
if err == nil {
|
||||
t.Fatalf("expected JSON parse error, got nil")
|
||||
}
|
||||
if calls := rec.snapshot(); len(calls) != 0 {
|
||||
t.Errorf("audit emitted %d calls on parse error; expected 0 (parse-fail is pre-handler)", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsGatedTool_DetectsWrapped confirms that NewGatedTool's return
|
||||
// value satisfies the gatedToolMarker interface so the meta-test can
|
||||
// distinguish wrapped from unwrapped tools.
|
||||
func TestIsGatedTool_DetectsWrapped(t *testing.T) {
|
||||
tool := NewGatedTool[gatedTestParams](
|
||||
"gated_marker_tool", "marker test",
|
||||
Permission{AuthoringRequirement: RequirementAnyone},
|
||||
func(ctx context.Context, inv Invocation, args gatedTestParams) (string, error) {
|
||||
return "", nil
|
||||
},
|
||||
)
|
||||
if !IsGatedTool(tool) {
|
||||
t.Fatalf("IsGatedTool returned false for a NewGatedTool result")
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsGatedTool_DetectsNonWrapped is the negative half of the
|
||||
// detection test: a hand-rolled Tool that does NOT go through
|
||||
// NewGatedTool must fail IsGatedTool. This guards the meta-test
|
||||
// against trivially passing for everything.
|
||||
func TestIsGatedTool_DetectsNonWrapped(t *testing.T) {
|
||||
stub := manualToolStub{}
|
||||
if IsGatedTool(stub) {
|
||||
t.Fatalf("IsGatedTool returned true for a non-wrapped Tool — detection broken")
|
||||
}
|
||||
}
|
||||
|
||||
// manualToolStub satisfies skilltools.Tool by hand without going
|
||||
// through NewGatedTool. Used only to prove IsGatedTool rejects
|
||||
// non-wrapped implementations.
|
||||
type manualToolStub struct{}
|
||||
|
||||
func (manualToolStub) Name() string { return "manual_stub" }
|
||||
func (manualToolStub) Description() string { return "manual stub" }
|
||||
func (manualToolStub) Permission() Permission { return Permission{} }
|
||||
func (manualToolStub) BuildLLM(Invocation) llm.Tool {
|
||||
type p struct{}
|
||||
return llm.DefineTool("manual_stub", "manual stub",
|
||||
func(ctx context.Context, _ p) (any, error) { return "", nil })
|
||||
}
|
||||
|
||||
// TestNewGatedToolWithAudit_RedactsAuditResult covers the variant used
|
||||
// by paste_create: the LLM receives a sensitive string (e.g. URL with
|
||||
// fragment-encoded key) but the audit row records only a redacted
|
||||
// summary. Confirms LLMResult ↔ AuditResult separation works.
|
||||
func TestNewGatedToolWithAudit_RedactsAuditResult(t *testing.T) {
|
||||
tool := NewGatedToolWithAudit[gatedTestParams](
|
||||
"audited_tool",
|
||||
"A tool whose audit result is redacted from its LLM result.",
|
||||
Permission{AuthoringRequirement: RequirementAnyone, SafeForShare: true},
|
||||
func(ctx context.Context, inv Invocation, args gatedTestParams) (AuditedResult, error) {
|
||||
return AuditedResult{
|
||||
LLMResult: "secret-fragment-12345",
|
||||
AuditArgs: "redacted",
|
||||
AuditResult: "[redacted]",
|
||||
}, nil
|
||||
},
|
||||
)
|
||||
if !IsGatedTool(tool) {
|
||||
t.Fatalf("audited variant must satisfy IsGatedTool")
|
||||
}
|
||||
|
||||
rec := &recordingAudit{}
|
||||
out, err := buildAndExecute(t, tool,
|
||||
Invocation{SkillName: "any"},
|
||||
VisibilityPrivate, rec.hook(),
|
||||
`{"question":"x"}`)
|
||||
if err != nil {
|
||||
t.Fatalf("execute: %v", err)
|
||||
}
|
||||
if out != "secret-fragment-12345" {
|
||||
t.Errorf("LLM saw %q, want secret-fragment-12345", out)
|
||||
}
|
||||
calls := rec.snapshot()
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 audit call, got %d", len(calls))
|
||||
}
|
||||
if calls[0].Args != "redacted" {
|
||||
t.Errorf("audit args=%q, want redacted", calls[0].Args)
|
||||
}
|
||||
if calls[0].Result != "[redacted]" {
|
||||
t.Errorf("audit result=%q, want [redacted]", calls[0].Result)
|
||||
}
|
||||
if strings.Contains(calls[0].Result, "secret-fragment-12345") {
|
||||
t.Fatalf("audit leaked LLM result into Result field: %q", calls[0].Result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewGatedToolWithAudit_GateRejection mirrors the gate-rejection
|
||||
// test for the default wrapper to anchor the same contract for the
|
||||
// audited variant.
|
||||
func TestNewGatedToolWithAudit_GateRejection(t *testing.T) {
|
||||
tool := NewGatedToolWithAudit[gatedTestParams](
|
||||
"audited_gated_tool", "gated tool",
|
||||
Permission{
|
||||
AuthoringRequirement: RequirementAnyone,
|
||||
SkillNameGate: "my-skill",
|
||||
},
|
||||
func(ctx context.Context, inv Invocation, args gatedTestParams) (AuditedResult, error) {
|
||||
t.Fatalf("fn should not run on gate rejection")
|
||||
return AuditedResult{}, nil
|
||||
},
|
||||
)
|
||||
rec := &recordingAudit{}
|
||||
_, err := buildAndExecute(t, tool,
|
||||
Invocation{SkillName: "other"},
|
||||
VisibilityPrivate, rec.hook(),
|
||||
`{}`)
|
||||
if err == nil || !strings.Contains(err.Error(), "restricted to") {
|
||||
t.Fatalf("expected gate rejection, got %v", err)
|
||||
}
|
||||
calls := rec.snapshot()
|
||||
if len(calls) != 1 || calls[0].Err == nil {
|
||||
t.Fatalf("expected gate-rejection audit row, got %+v", calls)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewGatedToolWithAudit_FallbackArgs verifies that an empty
|
||||
// AuditArgs falls back to the JSON-marshaled typed args (matching the
|
||||
// default wrapper's behaviour).
|
||||
func TestNewGatedToolWithAudit_FallbackArgs(t *testing.T) {
|
||||
tool := NewGatedToolWithAudit[gatedTestParams](
|
||||
"audited_fallback_tool", "fallback args test",
|
||||
Permission{AuthoringRequirement: RequirementAnyone},
|
||||
func(ctx context.Context, inv Invocation, args gatedTestParams) (AuditedResult, error) {
|
||||
return AuditedResult{
|
||||
LLMResult: "ok",
|
||||
AuditResult: "ok",
|
||||
// AuditArgs intentionally empty
|
||||
}, nil
|
||||
},
|
||||
)
|
||||
rec := &recordingAudit{}
|
||||
_, err := buildAndExecute(t, tool,
|
||||
Invocation{SkillName: "x"},
|
||||
VisibilityPrivate, rec.hook(),
|
||||
`{"question":"hi"}`)
|
||||
if err != nil {
|
||||
t.Fatalf("execute: %v", err)
|
||||
}
|
||||
calls := rec.snapshot()
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 audit call, got %d", len(calls))
|
||||
}
|
||||
if !strings.Contains(calls[0].Args, "hi") {
|
||||
t.Errorf("expected fallback to JSON args containing 'hi', got %q", calls[0].Args)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package tool
|
||||
|
||||
import "fmt"
|
||||
|
||||
// CheckGate returns an error if the invocation context's SkillName does
|
||||
// not match the tool's gate. Tools should call this at the top of their
|
||||
// handler when their Permission has a non-empty SkillNameGate.
|
||||
//
|
||||
// Why: the gate is enforced per-call (not per-build) because the same
|
||||
// Tool may be referenced by multiple skills, only one of which is
|
||||
// gate-eligible. Build cannot know in advance which skill will call it
|
||||
// — that's per-Invocation.
|
||||
//
|
||||
// What: returns nil if no gate, or the names match. Returns an error
|
||||
// suitable for surfacing to the LLM as the tool's failure result.
|
||||
func CheckGate(inv Invocation) error {
|
||||
if inv.gate == "" {
|
||||
return nil
|
||||
}
|
||||
if inv.currentSkill == inv.gate {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("tool %q is restricted to the %q skill", inv.toolName, inv.gate)
|
||||
}
|
||||
|
||||
// EmitAudit forwards a tool's call+result to the audit hook, if one is
|
||||
// wired. Tools should call this once per Execute, after the underlying
|
||||
// work has completed (regardless of error). Pass the original args
|
||||
// JSON, the result string, and any error.
|
||||
//
|
||||
// Why: keeping the audit emission inside the tool ensures the captured
|
||||
// args are exactly what the tool ran with (after coercion / defaults),
|
||||
// not the raw LLM JSON which can drift.
|
||||
func EmitAudit(inv Invocation, args, result string, err error) {
|
||||
if inv.audit == nil {
|
||||
return
|
||||
}
|
||||
inv.audit(AuditCall{
|
||||
Tool: inv.toolName,
|
||||
Args: args,
|
||||
Result: result,
|
||||
Err: err,
|
||||
})
|
||||
}
|
||||
+121
@@ -0,0 +1,121 @@
|
||||
// Package skilltools — hmac.go: HMAC-SHA256 signature verification
|
||||
// for the v7 inbound webhook subsystem.
|
||||
//
|
||||
// Why a small util in pkg/skilltools (vs inline in skillsui): the
|
||||
// signature format is part of the skill platform's public contract —
|
||||
// callers (GitHub, monitoring, Stripe, etc) compute it client-side,
|
||||
// and other parts of mort may eventually verify the same shape (e.g.
|
||||
// outbound retry verification). A shared util means we test the
|
||||
// verifier once and the format stays consistent.
|
||||
//
|
||||
// Format:
|
||||
//
|
||||
// X-Mort-Signature: sha256=<hex(HMAC-SHA256(secret, body))>
|
||||
// X-Mort-Timestamp: <unix-seconds>
|
||||
//
|
||||
// The timestamp is included so a stolen payload+signature pair can't
|
||||
// be replayed indefinitely. Default skew window is 5 min via the
|
||||
// caller-supplied maxSkew. The body is verified verbatim — callers
|
||||
// must NOT canonicalise (the LLM-supplied JSON shape is usually
|
||||
// unstable on round-trip).
|
||||
package tool
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HMAC-related sentinel errors. Callers compare with errors.Is so
|
||||
// handler code can surface the right HTTP status (401 vs 400).
|
||||
var (
|
||||
// ErrHMACBadFormat is returned when the signature header is not
|
||||
// the expected "sha256=<hex>" form.
|
||||
ErrHMACBadFormat = errors.New("hmac: bad signature format")
|
||||
|
||||
// ErrHMACBadSignature is returned when the computed HMAC does
|
||||
// not match the supplied signature (constant-time compare).
|
||||
ErrHMACBadSignature = errors.New("hmac: signature mismatch")
|
||||
|
||||
// ErrHMACBadTimestamp is returned when the timestamp header is
|
||||
// missing, malformed, or outside the maxSkew window.
|
||||
ErrHMACBadTimestamp = errors.New("hmac: bad or stale timestamp")
|
||||
|
||||
// ErrHMACEmptySecret is returned when verification is requested
|
||||
// but the secret is empty — a programmer error (caller should
|
||||
// have rejected the request earlier).
|
||||
ErrHMACEmptySecret = errors.New("hmac: empty secret")
|
||||
)
|
||||
|
||||
// SignBody returns the canonical signature value for the given body
|
||||
// + secret. Used by the test-payload sender on the management page.
|
||||
//
|
||||
// Why exported: the management page's "send test payload" button needs
|
||||
// to compute the signature server-side before POSTing; reusing the
|
||||
// same function ensures the verifier and signer stay in lock-step.
|
||||
func SignBody(body []byte, secret string) string {
|
||||
mac := hmac.New(sha256.New, []byte(secret))
|
||||
mac.Write(body)
|
||||
return "sha256=" + hex.EncodeToString(mac.Sum(nil))
|
||||
}
|
||||
|
||||
// VerifyHMAC checks the signature + timestamp against the body using
|
||||
// the supplied secret. Returns nil on success or one of the sentinels.
|
||||
//
|
||||
// Why hmac.Equal (constant-time): a naive == leaks signature length
|
||||
// information through timing — VerifyHMAC must be safe against
|
||||
// length-extension and timing oracle attacks.
|
||||
//
|
||||
// Why the timestamp is part of the verification (not the body): the
|
||||
// signature does NOT cover the timestamp itself (callers may rotate
|
||||
// timestamps without re-signing). The timestamp is a separate
|
||||
// freshness check; if you wanted timestamp-bound replay protection
|
||||
// you'd include it in the signed payload — but that complicates the
|
||||
// signing API for callers and the per-skill rate limiter is the
|
||||
// real defence against rapid replay.
|
||||
func VerifyHMAC(body []byte, signature, timestamp, secret string, maxSkew time.Duration) error {
|
||||
if secret == "" {
|
||||
return ErrHMACEmptySecret
|
||||
}
|
||||
// Timestamp first — cheap reject before the HMAC compute.
|
||||
if timestamp != "" {
|
||||
ts, err := strconv.ParseInt(strings.TrimSpace(timestamp), 10, 64)
|
||||
if err != nil {
|
||||
return ErrHMACBadTimestamp
|
||||
}
|
||||
if maxSkew > 0 {
|
||||
now := time.Now().Unix()
|
||||
if abs(now-ts) > int64(maxSkew.Seconds()) {
|
||||
return ErrHMACBadTimestamp
|
||||
}
|
||||
}
|
||||
}
|
||||
// Signature format: "sha256=<hex>"
|
||||
const prefix = "sha256="
|
||||
if !strings.HasPrefix(signature, prefix) {
|
||||
return ErrHMACBadFormat
|
||||
}
|
||||
provided, err := hex.DecodeString(signature[len(prefix):])
|
||||
if err != nil {
|
||||
return ErrHMACBadFormat
|
||||
}
|
||||
mac := hmac.New(sha256.New, []byte(secret))
|
||||
mac.Write(body)
|
||||
expected := mac.Sum(nil)
|
||||
if !hmac.Equal(provided, expected) {
|
||||
return ErrHMACBadSignature
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// abs returns |x| for int64. Avoids importing math for one call.
|
||||
func abs(x int64) int64 {
|
||||
if x < 0 {
|
||||
return -x
|
||||
}
|
||||
return x
|
||||
}
|
||||
@@ -0,0 +1,167 @@
|
||||
package tool
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
llm "gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
// TestOutputPatternMetaTest enforces the V10 byte-vs-reference
|
||||
// principle: any tool's typed return shape MUST NOT contain a raw
|
||||
// []byte field that lacks a documented cap. Inline byte fields blow
|
||||
// the agent's context window — the right pattern is to return a
|
||||
// file_id reference.
|
||||
//
|
||||
// What this catches:
|
||||
// - A future tool author returning {"data": []byte("...")} inline.
|
||||
// - A reflective walk that sees `[]byte` or `Bytes` named fields
|
||||
// with no annotation flagging them as size-capped.
|
||||
//
|
||||
// What this DOES NOT catch (acceptable trade-off):
|
||||
// - Base64-encoded byte fields hidden as `string` (e.g. file_get's
|
||||
// content_base64). The agent author can still misuse those, but
|
||||
// existing code is grandfathered — the new pattern is to use
|
||||
// file_get_metadata + file_get_text + send_attachments instead.
|
||||
// - Tools whose outputs are JSON-marshalled at the LLM boundary; the
|
||||
// check operates on the GO RETURN TYPES, not the wire JSON. That's
|
||||
// fine because Go authors can't accidentally introduce []byte at
|
||||
// marshal time.
|
||||
//
|
||||
// The test walks llm.Tool's exposed result type (where available)
|
||||
// and Permission.Categories so future binary tools must label
|
||||
// themselves with "binary" + return file_id-shaped envelopes.
|
||||
//
|
||||
// Currently this is a forward-looking contract — the existing tools
|
||||
// emit JSON-string results from the typed gated wrappers, and the result
|
||||
// type is opaque. We assert here that no STARTER tool registers a
|
||||
// `[]byte`-shaped public Args (which is the foot-gun for input), and
|
||||
// document the principle for new authors.
|
||||
func TestOutputPatternMetaTest_NoRawByteArgs(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
// We don't have access to deps here; a tool author wishing to
|
||||
// enforce can run the same walk on their concrete Registry. The
|
||||
// test asserts NewRegistry() produces an empty registry that the
|
||||
// production wiring populates via tools.RegisterDefaults — and we
|
||||
// re-enforce the principle in pkg/skilltools/tools/default_test.go
|
||||
// where the live tools are registered.
|
||||
for _, tool := range r.List() {
|
||||
assertNoRawByteArgs(t, tool)
|
||||
}
|
||||
}
|
||||
|
||||
// assertNoRawByteArgs reflects on the tool's BuildLLM result and walks
|
||||
// its declared Args struct to fail when a public field is a raw []byte.
|
||||
//
|
||||
// Why public-fields-only: private fields can't be set by the LLM, so
|
||||
// they're not a concern.
|
||||
func assertNoRawByteArgs(t *testing.T, tool Tool) {
|
||||
t.Helper()
|
||||
llmTool := tool.BuildLLM(Invocation{})
|
||||
// Use reflection on the tool's call signature. The built llm.Tool
|
||||
// exposes only a JSON schema derived from a Go type —
|
||||
// we don't need to deconstruct it here; the existing meta-tests
|
||||
// in pkg/skilltools/tools/default_test.go already enforce
|
||||
// IsGatedTool(tool), and the gated wrappers are typed via
|
||||
// generics. New authors should use NewGatedTool[ArgsStruct] which
|
||||
// makes raw []byte impossible to declare without compile-time
|
||||
// awareness.
|
||||
_ = llmTool
|
||||
}
|
||||
|
||||
// TestBinaryContentTypeRecognition ensures the content-type
|
||||
// classifier (used by http_get's V10 binary persistence path) picks
|
||||
// up the content types that motivated the v10 change. Adding a new
|
||||
// MIME to the binary list requires updating this test alongside the
|
||||
// classifier so the meta-test stays load-bearing.
|
||||
func TestBinaryContentTypeRecognition(t *testing.T) {
|
||||
tests := []struct {
|
||||
ct string
|
||||
want bool
|
||||
comment string
|
||||
}{
|
||||
{"image/png", true, "image"},
|
||||
{"image/jpeg; charset=binary", true, "image with parameter"},
|
||||
{"audio/mpeg", true, "audio"},
|
||||
{"video/mp4", true, "video"},
|
||||
{"application/pdf", true, "pdf"},
|
||||
{"application/octet-stream", true, "octet-stream"},
|
||||
{"application/zip", true, "zip"},
|
||||
{"text/plain", false, "text"},
|
||||
{"text/html; charset=utf-8", false, "html"},
|
||||
{"application/json", false, "json"},
|
||||
{"application/xml", false, "xml"},
|
||||
{"", false, "empty"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.comment, func(t *testing.T) {
|
||||
got := isBinaryContentTypeForTest(tt.ct)
|
||||
if got != tt.want {
|
||||
t.Fatalf("ct=%q got=%v want=%v", tt.ct, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// isBinaryContentTypeForTest mirrors tools.isBinaryContentType but is
|
||||
// duplicated here so the package-level meta-test doesn't import
|
||||
// tools/. The two MUST stay in sync — the test in pkg/skilltools/tools
|
||||
// covers the production helper directly via end-to-end http_get tests.
|
||||
func isBinaryContentTypeForTest(ct string) bool {
|
||||
ct = strings.ToLower(strings.TrimSpace(ct))
|
||||
if i := strings.Index(ct, ";"); i >= 0 {
|
||||
ct = strings.TrimSpace(ct[:i])
|
||||
}
|
||||
if ct == "" {
|
||||
return false
|
||||
}
|
||||
if strings.HasPrefix(ct, "image/") ||
|
||||
strings.HasPrefix(ct, "audio/") ||
|
||||
strings.HasPrefix(ct, "video/") {
|
||||
return true
|
||||
}
|
||||
switch ct {
|
||||
case "application/octet-stream", "application/pdf", "application/zip",
|
||||
"application/x-tar", "application/x-gzip", "application/x-bzip2",
|
||||
"application/x-7z-compressed", "application/msword",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/vnd.ms-excel",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
"application/vnd.ms-powerpoint",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// TestArgsStructsHaveNoRawBytes is the static-typed check the v10
|
||||
// principle relies on: NewGatedTool[Args] is the public surface, and
|
||||
// Args structs MUST NOT carry a `[]byte`. We can't enumerate all
|
||||
// Args types from outside (they're per-tool generics), so the
|
||||
// production check is in pkg/skilltools/tools/default_test.go which
|
||||
// reflects on every registered tool's BuildLLM args via the
|
||||
// schema generator (llm.SchemaFor).
|
||||
//
|
||||
// This test asserts the documented principle compiles by referencing
|
||||
// it: the `bytesForbiddenSentinel` type below intentionally contains
|
||||
// `[]byte` and the test marks it as a known antipattern.
|
||||
func TestArgsStructsHaveNoRawBytes(t *testing.T) {
|
||||
tp := reflect.TypeOf(bytesForbiddenSentinel{})
|
||||
if tp.NumField() != 1 || tp.Field(0).Type.Kind() != reflect.Slice {
|
||||
t.Fatalf("sentinel shape unexpected")
|
||||
}
|
||||
// Documenting: this is the SHAPE we forbid. Future authors who
|
||||
// see a CodeReview comment pointing at this test can read the
|
||||
// principle here and the doc in CLAUDE.md. (majordomo's SchemaFor
|
||||
// encodes []byte as a base64 string on the wire, which is exactly
|
||||
// the inline-bytes foot-gun the v10 principle bans.)
|
||||
_, _ = llm.SchemaFor[bytesForbiddenSentinel]()
|
||||
}
|
||||
|
||||
// bytesForbiddenSentinel is the antipattern shape for tool Args. The
|
||||
// meta-test references this so a developer searching for "[]byte" in
|
||||
// the codebase finds the explanation immediately.
|
||||
type bytesForbiddenSentinel struct {
|
||||
Data []byte `json:"data" description:"DO NOT USE: raw bytes in Args blow the LLM's context window. Use file_id references via file_save / file_get_text / file_get_metadata / send_attachments instead."`
|
||||
}
|
||||
@@ -0,0 +1,701 @@
|
||||
// Package skilltools is the tool registry for the agentic skills platform.
|
||||
// Tools registered here can be referenced by name from a Skill's Tools
|
||||
// list and are surfaced to the underlying majordomo agent loop via Build().
|
||||
//
|
||||
// Independent of pkg/logic/chatbot/tool_provider.go: the chatbot's
|
||||
// ToolProvider supplies tools per-channel during a chatbot turn; skill
|
||||
// tools are scoped to one skill execution. Bridging happens once, in
|
||||
// pkg/logic/skills/chatbot_provider.go, which exposes whole agent skills
|
||||
// as chatbot tools (not individual skill tools).
|
||||
//
|
||||
// Permission model is documented in
|
||||
// docs/superpowers/specs/2026-05-02-agentic-skills-design.md, "Tool
|
||||
// registry" section. Three orthogonal checks:
|
||||
//
|
||||
// 1. Save-time: AuthoringRequirement vs caller's admin status.
|
||||
// 2. Share-time: SafeForShare for visibility != private.
|
||||
// 3. Execute-time: SkillNameGate.
|
||||
package tool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
llm "gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
// Visibility is the spec's visibility enum mirrored here as a typed
|
||||
// string. It's redeclared (vs imported from pkg/logic/skills) to break
|
||||
// the import cycle that would otherwise form: skills → skilltools →
|
||||
// skills. The string values match Visibility one-to-one so a
|
||||
// caller can pass `string(VisibilityPublic)` and it just works.
|
||||
type Visibility string
|
||||
|
||||
const (
|
||||
VisibilityPrivate Visibility = "private"
|
||||
VisibilityShared Visibility = "shared"
|
||||
VisibilityPublic Visibility = "public"
|
||||
)
|
||||
|
||||
// Tool is what a registry entry implements. Concrete tools wrap an
|
||||
// underlying mort subsystem (e.g. wolfram, weather, paste) and produce
|
||||
// an llm.Tool on demand for a given Invocation.
|
||||
//
|
||||
// Why an interface (vs majordomo's concrete llm.Tool): we need richer
|
||||
// metadata (Permission, Categories, SkillNameGate) for the platform's
|
||||
// gating logic before we hand the tool to majordomo. BuildLLM converts
|
||||
// to llm.Tool for one execution, closing over the Invocation so the
|
||||
// per-tool handler can read CallerID/ChannelID without further plumbing.
|
||||
//
|
||||
// Why BuildLLM-per-call (vs static llm.Tool): per-user tools must close
|
||||
// over inv.CallerID — the LLM-supplied args are intentionally ignored
|
||||
// for those. Constructing the llm.Tool inside BuildLLM lets each tool
|
||||
// craft its own typed Define call while reading the invocation context.
|
||||
//
|
||||
// Test: each tool under pkg/skilltools/tools/ has its own *_test.go.
|
||||
type Tool interface {
|
||||
Name() string
|
||||
Description() string
|
||||
Permission() Permission
|
||||
// BuildLLM produces the llm.Tool for one invocation. The returned
|
||||
// tool's name MUST equal Name(); the registry's Build() relies on
|
||||
// this when wiring multiple tools into a Toolbox.
|
||||
BuildLLM(inv Invocation) llm.Tool
|
||||
}
|
||||
|
||||
// Permission summarises the three lifecycle gates plus UI metadata.
|
||||
type Permission struct {
|
||||
// AuthoringRequirement governs who may SAVE a skill that uses
|
||||
// this tool: anyone or admin-only.
|
||||
AuthoringRequirement Requirement
|
||||
|
||||
// OperatesOn classifies whose data the tool reads: global
|
||||
// (channel-wide, public sources) or caller (the invoking user's
|
||||
// own data).
|
||||
OperatesOn Scope
|
||||
|
||||
// SafeForShare reports whether the tool may appear in a shared or
|
||||
// public skill. Tools that operate on caller data are typically
|
||||
// not safe for share — the executing skill becomes a vector for
|
||||
// reading other users' data.
|
||||
SafeForShare bool
|
||||
|
||||
// Categories are free-form labels used for UI grouping (read,
|
||||
// write, network, code, data, social). Code does NOT branch on
|
||||
// these strings.
|
||||
Categories []string
|
||||
|
||||
// SkillNameGate, if non-empty, restricts execution to the named
|
||||
// skill. Used for wizard-only tools in v2; SkillNameGate=="" means
|
||||
// any skill may use the tool.
|
||||
SkillNameGate string
|
||||
}
|
||||
|
||||
// Requirement is who is allowed to author a skill using this tool.
|
||||
type Requirement string
|
||||
|
||||
const (
|
||||
RequirementAnyone Requirement = "anyone"
|
||||
RequirementAdmin Requirement = "admin"
|
||||
)
|
||||
|
||||
// Scope classifies the data domain a tool acts on.
|
||||
type Scope string
|
||||
|
||||
const (
|
||||
ScopeGlobal Scope = "global"
|
||||
ScopeCaller Scope = "caller"
|
||||
)
|
||||
|
||||
// ContinuationContext describes a V10 reply continuation. When set on
|
||||
// an Invocation, the skill executor reuses the parent run's KV scope,
|
||||
// renders a continuation prompt, and bumps ChainDepth for cap
|
||||
// enforcement.
|
||||
//
|
||||
// The executor reads ParentRunID to set the new run's parent_run_id
|
||||
// column (for call-tree reconstruction); ParentOutput to render the
|
||||
// "previous output you sent" line in the agent prompt; ReplyText to
|
||||
// render the "user replied with" line; ReplyMessageID for diagnostic
|
||||
// logging; and ChainDepth to compare against
|
||||
// skills.reply.max_chain_depth.
|
||||
//
|
||||
// Why ChainDepth (vs walking parent_run_id at execution time): a fresh
|
||||
// query per turn would add a DB roundtrip on every reply hop. Carrying
|
||||
// the count in the invocation is cheap and authoritative.
|
||||
type ContinuationContext struct {
|
||||
// ParentRunID is the run that produced the message the user
|
||||
// replied to. The new run inherits its KV scope (run:<ParentRunID>).
|
||||
ParentRunID string
|
||||
|
||||
// ParentOutput is the text the parent run delivered to Discord —
|
||||
// stored on the run row so it survives even if the parent's
|
||||
// run-scope KV has been auto-purged (24h after parent finished).
|
||||
ParentOutput string
|
||||
|
||||
// ReplyText is what the user said when they replied (the new
|
||||
// turn's user input). May be empty if the reply was an attachment-
|
||||
// only message (handle gracefully — agent should handle empty
|
||||
// input as a "noop continuation").
|
||||
ReplyText string
|
||||
|
||||
// ReplyMessageID is the Discord message ID of the user's reply.
|
||||
// Used for audit + log breadcrumbs; not currently consumed by the
|
||||
// agent prompt.
|
||||
ReplyMessageID string
|
||||
|
||||
// ChainDepth is how many continuation hops have happened in the
|
||||
// chain rooted at the original invocation. The router should set
|
||||
// this to (parent's chain depth + 1). The executor rejects when
|
||||
// it exceeds skills.reply.max_chain_depth.
|
||||
ChainDepth int
|
||||
}
|
||||
|
||||
// InputFile is a non-image file the user supplied with a run (audio,
|
||||
// etc.). The executor stages it into the file store under run scope and
|
||||
// surfaces its file_id to the agent. Name is a safe base name (no path
|
||||
// separators) suitable for /workspace/<name>; MimeType is the resolved
|
||||
// content type; Data is the raw bytes.
|
||||
type InputFile struct {
|
||||
Name string
|
||||
MimeType string
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// Invocation is the runtime context passed to Tool.BuildLLM. The executor
|
||||
// builds it once per skill run and the same struct is closed over by
|
||||
// every tool's handler, so each tool sees the caller / channel identity.
|
||||
type Invocation struct {
|
||||
SkillID string
|
||||
SkillName string
|
||||
RunID string
|
||||
CallerID string
|
||||
ChannelID string
|
||||
GuildID string
|
||||
// CallerIsAdmin is true when the caller is a mort admin (Member.Admin).
|
||||
// Populated by the executor at run dispatch via Bot.GetMember; defaults
|
||||
// to false on any lookup failure (member not found, DB error, empty
|
||||
// CallerID for system-invoked runs). Read by tools that gate behaviour
|
||||
// on admin status — currently `code_exec` for the v15 admin-only WAN
|
||||
// network mode.
|
||||
//
|
||||
// Why a precomputed bool on Invocation (vs an AdminChecker dep on
|
||||
// every tool): the admin lookup is read-once-per-run; every tool
|
||||
// would otherwise have to redo the work. The executor knows the
|
||||
// caller's admin status by the time it builds Invocation, so it
|
||||
// stamps the field once and every tool reads it for free.
|
||||
CallerIsAdmin bool
|
||||
// SkillInputs is the parsed input map for the enclosing skill —
|
||||
// available so a tool can reference values the user supplied at
|
||||
// invocation time. Tools may read this to specialise behaviour but
|
||||
// MUST NOT use it as a substitute for inv.CallerID-based isolation.
|
||||
SkillInputs map[string]any
|
||||
// ParentRunID is set when the skill was invoked via skill_invoke
|
||||
// from a parent skill run. Empty for top-level invocations
|
||||
// (Discord, chatbot, scheduler). Used by the loop guard in
|
||||
// skill_invoke and by the audit log for call-tree reconstruction.
|
||||
//
|
||||
// Why threaded through Invocation (vs context.Value): the loop
|
||||
// guard runs at tool-handler time, where the only context the
|
||||
// handler sees is inv. Stuffing it into context would force a
|
||||
// helper for unwrap on every read; an explicit field is easier to
|
||||
// audit and impossible to forget.
|
||||
ParentRunID string
|
||||
|
||||
// RootRunID is the audit run id at the ROOT of the dispatch tree
|
||||
// this run belongs to — for a top-level run, its own RunID; for a
|
||||
// delegated run (skill_invoke / agent_invoke / agent_spawn /
|
||||
// palette wrappers), the outermost ancestor's. Stamped by both
|
||||
// executors from the dispatchguard ancestor chain right after
|
||||
// guard entry. Backs the shared `root_run:<id>` KV scope that lets
|
||||
// parallel sibling workers coordinate (see tools/scope_validate.go
|
||||
// + RootRunKVPartition).
|
||||
RootRunID string
|
||||
|
||||
// ToolsSubset, when non-empty, narrows an AGENT run's low-level tools
|
||||
// to the named subset of the agent's configured LowLevelTools. Set by
|
||||
// agent_invoke's `tools_subset` arg for ephemeral fan-out — spawning a
|
||||
// focused worker from a template (e.g. a `coder` template with only
|
||||
// code_exec + read_page). Names outside the agent's tool menu are
|
||||
// rejected upstream (in the invoke adapter), so by the time the
|
||||
// executor reads this the intersection is safe. Empty = full palette.
|
||||
// Skill runs ignore this field.
|
||||
ToolsSubset []string
|
||||
|
||||
// SystemPromptPrepend, when non-empty, is prepended to an AGENT's
|
||||
// system prompt for this invocation only — the fan-out "customized
|
||||
// system prompt" lever (agent_invoke's `prompt_prepend` arg). It
|
||||
// specializes a template persona to a task without mutating the
|
||||
// persisted agent row. Skill runs ignore this field.
|
||||
SystemPromptPrepend string
|
||||
|
||||
// SuppressDelivery, when true, instructs the skill executor to
|
||||
// SKIP its OutputTarget Delivery (Deliver / DeliverError) entirely.
|
||||
// The run still produces an output string (returned from Run) and
|
||||
// still writes to the audit log — only the side-channel delivery
|
||||
// (Discord channel/DM/thread post) is suppressed.
|
||||
//
|
||||
// Why: when the chatbot exposure adapter invokes a skill, the skill's
|
||||
// output is already going to be consumed by the chatbot as a tool
|
||||
// result; ALSO posting it to Discord via OutputTarget produces double
|
||||
// output and (worse) primes the chatbot to call the tool again on
|
||||
// the next turn after seeing its own output as a "human message",
|
||||
// kicking off a tool-loop. The chatbot adapter sets this to true on
|
||||
// every invocation it constructs.
|
||||
SuppressDelivery bool
|
||||
|
||||
// HandlerOwnsDelivery, when true, tells the executor that the caller
|
||||
// (typically a Discord command handler) will assemble the final
|
||||
// user-visible reply itself — folding any deferred attachments
|
||||
// (rows queued by send_attachments to skill_run_pending_attachments)
|
||||
// into the same message as the text output. The executor's
|
||||
// post-run AttachmentDrainer is skipped so the handler can drain +
|
||||
// classify + chain-overflow + post in one place.
|
||||
//
|
||||
// Why an explicit flag (vs reusing SuppressDelivery): SuppressDelivery
|
||||
// also short-circuits the OutputTarget Delivery layer (channel/dm/
|
||||
// thread post), which is the right shape for chatbot exposure but
|
||||
// the WRONG shape for `.agent run` — the handler still wants the
|
||||
// audit row to land and the executor's drainer to NOT post a
|
||||
// separate "here's an image" follow-up message after the handler's
|
||||
// own text reply. HandlerOwnsDelivery is the narrow "the caller is
|
||||
// taking over post-run delivery" signal that does NOT change any
|
||||
// other executor behaviour.
|
||||
//
|
||||
// SuppressDelivery and HandlerOwnsDelivery are independent. The
|
||||
// drainer is skipped when EITHER is set (the chatbot path doesn't
|
||||
// want stray posts either; agent-run sets HandlerOwnsDelivery
|
||||
// because it owns delivery; sub-agent dispatches set SuppressDelivery
|
||||
// because they surface output as a tool result).
|
||||
HandlerOwnsDelivery bool
|
||||
|
||||
// Priority is the v9 per-invocation priority override for the lane
|
||||
// scheduler. When non-zero, the executor uses this value when
|
||||
// constructing the lane Job; zero falls back to the skill's
|
||||
// Skill.DefaultPriority. Owners are capped by convar
|
||||
// `skills.priority_max_per_user` (default 5); admins may exceed it.
|
||||
//
|
||||
// Why a non-pointer (vs *int): zero means "use the default", which
|
||||
// matches the convention everywhere else in this struct. Skills
|
||||
// that need an explicit zero priority can store
|
||||
// DefaultPriority=0 — the result is identical.
|
||||
Priority int
|
||||
|
||||
// LaneWaitMaxSeconds is the v9 per-invocation lane backoff cap. When
|
||||
// >0, the executor calls SubmitWithMaxWait so the run is rejected
|
||||
// with ErrLaneBusy (surfaced as `lane_busy`) when the estimated
|
||||
// queue wait would exceed this many seconds. 0 (default) preserves
|
||||
// the legacy block-forever Submit semantics.
|
||||
LaneWaitMaxSeconds int
|
||||
|
||||
// LaneOverride forces the run onto the named lane regardless of
|
||||
// Skill.ExecutionLane. Used by the v9 inbound webhook handler to
|
||||
// route webhook-triggered runs to the dedicated webhook-default
|
||||
// lane. Empty preserves the per-skill ExecutionLane.
|
||||
LaneOverride string
|
||||
|
||||
// Continuation, when non-nil, signals that this Invocation is a
|
||||
// V10 reply continuation: a Discord user replied to a message the
|
||||
// originating skill posted, and mort is re-invoking the skill to
|
||||
// produce the next turn. The executor reads this field to:
|
||||
//
|
||||
// - Reuse the parent run's `run:<parent_run_id>` KV scope (so any
|
||||
// state the prior turn saved is still readable).
|
||||
// - Render a continuation block at the top of the agent's user
|
||||
// prompt that includes the parent output + reply text.
|
||||
// - Enforce the per-deployment chain-depth cap
|
||||
// (skills.reply.max_chain_depth, default 20).
|
||||
// - Stamp parent_run_id on the new run for call-tree
|
||||
// reconstruction in audit + UI.
|
||||
//
|
||||
// Why a pointer struct (vs flat fields): all five fields are
|
||||
// meaningful only together — splitting them would invite
|
||||
// half-populated states. nil = "this is a fresh invocation, not a
|
||||
// continuation".
|
||||
Continuation *ContinuationContext
|
||||
|
||||
// SourceWebhookSecretMatched is set true by the inbound webhook
|
||||
// handler AFTER it has validated both the URL secret AND the HMAC
|
||||
// signature for the named skill. It signals to System.Run that the
|
||||
// caller is authenticated by a per-skill secret (not by Discord
|
||||
// identity), so the visibility / owner gate in CanInvoke should be
|
||||
// bypassed for THIS skill (matching SkillID). All other gates —
|
||||
// pinned_version, budget caps, lane caps — still apply.
|
||||
//
|
||||
// Hotfix-5 Bug 1: pre-fix the webhook handler built an Invocation
|
||||
// with CallerID=`<webhook>:<source-IP>` and dispatched through
|
||||
// System.Run. CanInvoke saw a non-owner non-admin caller against a
|
||||
// private skill and rejected with HTTP 500 ("caller is not
|
||||
// permitted to invoke skill"). The cure isn't to weaken
|
||||
// CanInvoke's general-purpose policy — it's to recognise that a
|
||||
// matched secret IS the auth gate for the named skill.
|
||||
//
|
||||
// Why per-Invocation (vs a separate gate path): the executor uses
|
||||
// Run as the single canonical dispatch point — adding a second
|
||||
// "authenticated dispatch" entry would split run-recording, lane
|
||||
// dispatch, and audit emission into two parallel implementations.
|
||||
SourceWebhookSecretMatched bool
|
||||
|
||||
// OnEvent, when non-nil, is called by the executor at run
|
||||
// boundaries and by the agent loop on each tool dispatch. The
|
||||
// bot's command handler closes over the invoking message and
|
||||
// reacts an emoji from the skill's StateReactEmoji map. Nil-safe.
|
||||
//
|
||||
// Event names:
|
||||
// "__start__" — right before agent.Run starts
|
||||
// "__end__" — on successful completion
|
||||
// "__error__" — on terminal error
|
||||
// <tool_name> — when a tool dispatches (any registered tool)
|
||||
//
|
||||
// The executor passes the resolved emoji as `emoji` so callers
|
||||
// don't have to look it up themselves; emoji=="" means "no react
|
||||
// for this event" and callers should skip the react entirely.
|
||||
//
|
||||
// Why a callback (vs a state-react map carried in the Invocation):
|
||||
// the lookup table lives on the Skill, not the Invocation, but the
|
||||
// caller-supplied side effect (a Discord react) lives on the bot
|
||||
// command surface. A callback bridges the two without forcing the
|
||||
// executor to import discord types and without forcing the bot
|
||||
// command surface to know about the Skill's emoji map shape.
|
||||
OnEvent func(ctx context.Context, event string, emoji string)
|
||||
|
||||
// OnToolEvent, when non-nil, is called by the executor on each tool
|
||||
// dispatch with phase "start" (before the tool runs) then "end" or
|
||||
// "error" (after it completes, with the result text in detail). Distinct
|
||||
// from OnEvent (which is the emoji state-react hook): this carries the
|
||||
// tool name + args/result so an out-of-band caller — e.g. the mortise
|
||||
// chat API streaming SSE tool.start/tool.end frames — can surface live
|
||||
// tool-progress. Nil-safe; the callback MUST be fast and non-blocking
|
||||
// (it runs on the agent-loop goroutine).
|
||||
OnToolEvent func(ctx context.Context, toolName, phase, detail string)
|
||||
|
||||
// OnStep, when non-nil, is called by the executor as the agent loop
|
||||
// makes progress — currently once per tool call: phase "start" before
|
||||
// the tool runs, phase "end" after it completes (StepEvent.Step.Status
|
||||
// is "complete" or "error"). Correlate the two by StepEvent.Step.ID.
|
||||
// "delta" is reserved for progressive detail and is unused today.
|
||||
//
|
||||
// Distinct from OnToolEvent (the raw tool-name/result hook): OnStep
|
||||
// carries a richer, presentation-ready Step (kind + human present-tense
|
||||
// summary) so an out-of-band consumer — e.g. the mortise chat API
|
||||
// streaming SSE step.start/step.end frames — can render structured
|
||||
// progress without re-deriving it. The executor ALSO accumulates the
|
||||
// same Steps onto its run Result, so persistence does not depend on
|
||||
// this callback being set. Nil-safe; the callback MUST be fast and
|
||||
// non-blocking (it runs on the agent-loop goroutine).
|
||||
OnStep func(ctx context.Context, ev StepEvent)
|
||||
|
||||
// InvokingMessageID is the Discord message ID of the user's command
|
||||
// that triggered this run, when it was triggered by a Discord text
|
||||
// command. Used by delivery to thread the reply (Discord native
|
||||
// reply with the gray quote bar + jump link). Empty for chatbot
|
||||
// exposure, scheduled, or webhook invocations — delivery falls
|
||||
// back to a plain channel post for those.
|
||||
//
|
||||
// Why threaded through Invocation (vs a separate field on Skill or
|
||||
// a magic SkillInputs key): the message ID is per-invocation, not
|
||||
// per-skill, and the delivery layer is the natural reader. Direct
|
||||
// field on Invocation matches the existing ChannelID / GuildID
|
||||
// fields' shape.
|
||||
InvokingMessageID string
|
||||
|
||||
// Images carries multi-modal image content for the initial user
|
||||
// message. When non-empty, the executor builds the initial user
|
||||
// message with llm.UserParts(text + image parts) instead of plain
|
||||
// llm.UserText. Populated by callers that extract images from Discord
|
||||
// attachments or URLs in prompt text (pkg/imageutil downloads the
|
||||
// bytes — majordomo image parts are bytes-only). Nil = text-only.
|
||||
Images []llm.ImagePart
|
||||
|
||||
// InputFiles carries non-image attachments (audio, etc.) the user
|
||||
// supplied with the run. Unlike Images, these are NOT inlined into
|
||||
// the model's context — the LLM can't ingest raw mp3/wav/midi bytes.
|
||||
// Instead the executor stages each into the skill file store under
|
||||
// run scope and tells the agent the resulting file_ids (in the
|
||||
// prompt) so it can hand one to a worker tool (e.g. code_exec
|
||||
// files_in → /workspace/<name>) for processing. Nil = none.
|
||||
InputFiles []InputFile
|
||||
|
||||
// ExtraTools are additional llm.Tool instances injected for this
|
||||
// run only. They are appended to the palette after registry-built
|
||||
// tools, skill-palette wrappers, and sub-agent wrappers. Use this
|
||||
// for session-specific tools that cannot be pre-registered in the
|
||||
// catalog (e.g., scaddy's write_scad which needs per-session
|
||||
// workspace + renderer state).
|
||||
//
|
||||
// Why on Invocation (vs a dedicated Run parameter): the Invocation
|
||||
// is the per-run context carrier in mort's execution path. Adding
|
||||
// a separate ExtraTools arg to Executor.Run would fork the
|
||||
// signature for one use case; a field on the existing carrier
|
||||
// keeps the surface stable.
|
||||
ExtraTools []llm.Tool
|
||||
|
||||
// SessionToolFactory, if set, is called with the live AgentSession
|
||||
// after the executor constructs the agent but before it runs. It
|
||||
// returns a SessionTools struct carrying the tools to add, an
|
||||
// optional PostRun hook for post-processing (e.g., rendering final
|
||||
// artifacts from workspace state), and an optional Cleanup func for
|
||||
// resource teardown. Types are defined in session_tools.go.
|
||||
//
|
||||
// Why a factory (vs ExtraTools): ExtraTools are static — they
|
||||
// don't have access to the running agent. Tools that need to call
|
||||
// session.AttachImages (to show rendered previews to the model on
|
||||
// its next turn) require the live session handle that only exists
|
||||
// after construction. The factory receives that handle.
|
||||
SessionToolFactory SessionToolFactory
|
||||
|
||||
// PostRunDelivery, if set, is called by the agent command handler
|
||||
// (`.agent run`) INSTEAD of the default text + paste-fallback reply
|
||||
// when the executor's result carries a PostRunResult. The callback
|
||||
// receives the Discord message to reply to, the agent's text output,
|
||||
// and the PostRunResult. It returns the message ID of the primary
|
||||
// reply (for origin recording) and any error.
|
||||
//
|
||||
// Why a callback on Invocation (vs a handler method on the agent):
|
||||
// delivery needs services (paste, filetransfer, Discord session)
|
||||
// that live outside the agents package. A callback lets the adapter
|
||||
// (e.g., scaddy) close over the services at factory-build time
|
||||
// without adding service dependencies to the agents.System struct.
|
||||
//
|
||||
// When nil, `handleRun` falls through to the standard text-based
|
||||
// reply path (formatRunReply + postRunReply). When set, the
|
||||
// callback owns the ENTIRE reply — `handleRun` does NOT post a
|
||||
// text reply alongside it.
|
||||
PostRunDelivery func(ctx context.Context, channelID, replyToMsgID string, output string, prr *PostRunResult) (primaryMsgID string, err error)
|
||||
|
||||
// RunState, when set by the executor, lets a tool read the live
|
||||
// run's progress + budget snapshot (iteration vs cap, tool calls,
|
||||
// tokens, cost, elapsed). Nil on paths that do not provide it (e.g.
|
||||
// the no-tools direct path, or executors that predate the hook).
|
||||
// The skill_self_status tool reads this.
|
||||
RunState RunStateAccessor
|
||||
|
||||
// AttachImages, when set by the executor, queues a user-role message
|
||||
// (optional text + image parts) into the LIVE run so the model sees
|
||||
// the images on its next step — the same steer-mailbox mechanism the
|
||||
// SessionToolFactory's AgentSession exposes, but reachable from any
|
||||
// ordinary tool handler. A tool returns text; images cannot ride a
|
||||
// string result, so a tool that fetches images the model must SEE
|
||||
// (e.g. discord_list_recent_messages reading channel history) calls
|
||||
// this to feed the pixels in. Nil on paths that do not own a steer
|
||||
// mailbox (skillexec, the no-tools direct path); tools MUST nil-check
|
||||
// before calling and degrade to text-only when it is nil.
|
||||
AttachImages func(text string, images ...llm.ImagePart)
|
||||
|
||||
// gate / audit are populated by the registry's Build before
|
||||
// BuildLLM is called. Tools should call CheckGate(inv) at the top
|
||||
// of their handler and EmitAudit(inv, ...) when reporting tool
|
||||
// results. The fields are unexported in the public surface but
|
||||
// available to tools via the helpers in helpers.go.
|
||||
gate string
|
||||
currentSkill string
|
||||
audit AuditHook
|
||||
toolName string
|
||||
}
|
||||
|
||||
// RunState is a live, read-only snapshot of the current run's progress
|
||||
// and budget. Populated on demand by the executor's per-run accessor
|
||||
// (see Invocation.RunState).
|
||||
type RunState struct {
|
||||
Iteration int
|
||||
MaxIterations int
|
||||
ToolCalls int
|
||||
MaxToolCalls int
|
||||
InputTokens int64
|
||||
OutputTokens int64
|
||||
ThinkingTokens int64
|
||||
ElapsedSeconds int
|
||||
}
|
||||
|
||||
// RunStateAccessor returns the live RunState for the enclosing run. The
|
||||
// executor builds one per run and stamps it on Invocation.RunState
|
||||
// before the toolbox is built; tools read it via inv.RunState. Nil on
|
||||
// any path that does not provide it.
|
||||
type RunStateAccessor interface {
|
||||
RunState() RunState
|
||||
}
|
||||
|
||||
// Registry is the read interface to the tool catalog. Concrete impl is
|
||||
// the package-private *registry struct returned by NewRegistry.
|
||||
type Registry interface {
|
||||
Register(t Tool) error
|
||||
Get(name string) (Tool, bool)
|
||||
List() []Tool
|
||||
// Build returns an llm.Toolbox with each named tool prepared for
|
||||
// execution against the given invocation. Save-time authoring
|
||||
// checks happen elsewhere (CheckAuthoring in checks.go) — Build
|
||||
// trusts that the skill was already saved past those gates and
|
||||
// only re-checks runtime invariants:
|
||||
//
|
||||
// 1. Share-safety drift: rejects an unsafe tool when visibility
|
||||
// != private.
|
||||
// 2. SkillNameGate enforcement is delegated to the per-tool
|
||||
// handler via CheckGate, which reads invocation context.
|
||||
// 3. Audit emission via EmitAudit (also per-tool).
|
||||
//
|
||||
// The optional `trusted` variadic argument lets the caller declare
|
||||
// the skill as trusted infrastructure (a builtin loaded from disk
|
||||
// by the project's own loader) so the share-safety drift check is
|
||||
// skipped. Builtins legitimately ship with public visibility AND
|
||||
// not-safe-for-share tools (e.g. skill-wizard's wizard_* tools),
|
||||
// and the loader bypasses save-time gates by design — applying the
|
||||
// share-safety check at invocation would be inconsistent with the
|
||||
// rest of the trusted-builtin contract. Pass true ONLY for builtins
|
||||
// (Skill.Source == SourceBuiltin / OwnerID == ""). Variadic so the
|
||||
// existing call sites (and tests) compile unchanged.
|
||||
Build(names []string, inv Invocation, vis Visibility, audit AuditHook, trusted ...bool) (*llm.Toolbox, error)
|
||||
}
|
||||
|
||||
// AuditHook is invoked synchronously around each tool call. Implementations
|
||||
// typically forward to skillaudit.Writer. May be nil for tests.
|
||||
type AuditHook func(call AuditCall)
|
||||
|
||||
// AuditCall describes one tool invocation. Result is set on success;
|
||||
// Err is set on failure. Either may be present together (e.g. the tool
|
||||
// returned partial output then errored).
|
||||
type AuditCall struct {
|
||||
Tool string
|
||||
Args string
|
||||
Result string
|
||||
Err error
|
||||
}
|
||||
|
||||
// Step is one unit of agent progress surfaced to a consumer of OnStep
|
||||
// (and accumulated onto the executor's run Result). Today there is one
|
||||
// Step per tool call; the shape is deliberately open so future kinds
|
||||
// (a coalesced reasoning beat, a sub-agent delegation) slot in without a
|
||||
// wire change.
|
||||
//
|
||||
// This is a plain DTO — no HTTP/Discord/JSON-tag coupling beyond the
|
||||
// neutral snake_case tags a transport may reuse. The chat API converts
|
||||
// it to its own persisted/wire type; Discord/cron consumers read the
|
||||
// Result field directly.
|
||||
type Step struct {
|
||||
// ID is stable per-step and unique within one run; it is the
|
||||
// correlation key between the "start" and "end" emissions.
|
||||
ID string `json:"id"`
|
||||
// Kind is an open vocabulary (search, read, code, image, file,
|
||||
// memory, delegate, tool, …); consumers map known values to an icon
|
||||
// and fall back for unknown ones. Never drop a step for an
|
||||
// unrecognised kind.
|
||||
Kind string `json:"kind"`
|
||||
// Title is a short machine-ish label (typically the raw tool name).
|
||||
Title string `json:"title,omitempty"`
|
||||
// Summary is the human present-tense one-liner ("Searching the web
|
||||
// for …"); on end it may be replaced with a result phrase.
|
||||
Summary string `json:"summary"`
|
||||
// Status is "running" | "complete" | "error".
|
||||
Status string `json:"status"`
|
||||
// Detail is optional, user-safe, size-capped markdown. Never raw tool
|
||||
// output, credentials, or chain-of-thought.
|
||||
Detail string `json:"detail,omitempty"`
|
||||
// StartedAt is when the step began.
|
||||
StartedAt time.Time `json:"started_at"`
|
||||
// EndedAt is set on the terminal "end" emission.
|
||||
EndedAt *time.Time `json:"ended_at,omitempty"`
|
||||
}
|
||||
|
||||
// StepEvent is one live emission to OnStep. Phase is "start" or "end"
|
||||
// ("delta" is reserved for progressive detail and unused today). Step
|
||||
// carries the full current snapshot; Detail holds the delta text when
|
||||
// Phase == "delta".
|
||||
type StepEvent struct {
|
||||
Phase string
|
||||
Step Step
|
||||
Detail string
|
||||
}
|
||||
|
||||
// NewRegistry constructs an empty registry. Call Register for each tool;
|
||||
// see pkg/skilltools/default_registry.go for the v1 set.
|
||||
func NewRegistry() Registry {
|
||||
return ®istry{tools: make(map[string]Tool)}
|
||||
}
|
||||
|
||||
type registry struct {
|
||||
mu sync.RWMutex
|
||||
tools map[string]Tool
|
||||
}
|
||||
|
||||
func (r *registry) Register(t Tool) error {
|
||||
if t == nil {
|
||||
return fmt.Errorf("skilltools: nil tool")
|
||||
}
|
||||
name := t.Name()
|
||||
if name == "" {
|
||||
return fmt.Errorf("skilltools: tool with empty name")
|
||||
}
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if _, dup := r.tools[name]; dup {
|
||||
return fmt.Errorf("skilltools: duplicate tool name %q", name)
|
||||
}
|
||||
r.tools[name] = t
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *registry) Get(name string) (Tool, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
t, ok := r.tools[name]
|
||||
return t, ok
|
||||
}
|
||||
|
||||
func (r *registry) List() []Tool {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
out := make([]Tool, 0, len(r.tools))
|
||||
for _, t := range r.tools {
|
||||
out = append(out, t)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Build prepares an llm.Toolbox for one skill execution.
|
||||
//
|
||||
// Why: each tool needs to know the caller / channel / skill name plus
|
||||
// the audit hook. Stuffing them into Invocation lets each Tool.BuildLLM
|
||||
// produce a closure that has everything it needs without further
|
||||
// plumbing.
|
||||
//
|
||||
// Defence in depth: rejects an unsafe tool when visibility != private —
|
||||
// the share-time check should already have prevented this; this catches
|
||||
// drift (e.g. a tool's SafeForShare flag flipping after a skill saved).
|
||||
//
|
||||
// The trusted variadic flag lets a caller bypass the share-safety drift
|
||||
// check for builtin (trusted-infrastructure) skills. The mortventure /
|
||||
// skill-wizard builtins legitimately ship with public visibility AND
|
||||
// not-safe-for-share tools — the loader bypasses save-time gates and
|
||||
// the share-safety check at invocation would block them inconsistently.
|
||||
// Pass true ONLY for builtins.
|
||||
func (r *registry) Build(names []string, inv Invocation, vis Visibility, audit AuditHook, trusted ...bool) (*llm.Toolbox, error) {
|
||||
isTrusted := len(trusted) > 0 && trusted[0]
|
||||
box := llm.NewToolbox("skilltools")
|
||||
for _, name := range names {
|
||||
t, ok := r.Get(name)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("skilltools: unknown tool %q", name)
|
||||
}
|
||||
|
||||
if !isTrusted && vis != VisibilityPrivate && !t.Permission().SafeForShare {
|
||||
return nil, fmt.Errorf("skilltools: tool %q is not safe for share but skill visibility is %s", name, vis)
|
||||
}
|
||||
|
||||
// Populate the gate/audit fields on the Invocation so the tool
|
||||
// can call CheckGate / EmitAudit from its handler.
|
||||
toolInv := inv
|
||||
toolInv.gate = t.Permission().SkillNameGate
|
||||
toolInv.currentSkill = inv.SkillName
|
||||
toolInv.audit = audit
|
||||
toolInv.toolName = name
|
||||
|
||||
built := t.BuildLLM(toolInv)
|
||||
if built.Name == "" {
|
||||
return nil, fmt.Errorf("skilltools: tool %q built llm.Tool with empty name", name)
|
||||
}
|
||||
if err := box.Add(built); err != nil {
|
||||
return nil, fmt.Errorf("skilltools: adding tool %q: %w", name, err)
|
||||
}
|
||||
}
|
||||
return box, nil
|
||||
}
|
||||
@@ -0,0 +1,184 @@
|
||||
package tool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
llm "gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
// fakeTool is a minimal Tool used to exercise the registry's gating.
|
||||
type fakeTool struct {
|
||||
name string
|
||||
desc string
|
||||
perm Permission
|
||||
calledWith *Invocation
|
||||
returnText string
|
||||
returnError error
|
||||
}
|
||||
|
||||
func (f *fakeTool) Name() string { return f.name }
|
||||
func (f *fakeTool) Description() string { return f.desc }
|
||||
func (f *fakeTool) Permission() Permission { return f.perm }
|
||||
func (f *fakeTool) BuildLLM(inv Invocation) llm.Tool {
|
||||
type emptyParams struct{}
|
||||
return llm.DefineTool(
|
||||
f.name,
|
||||
f.desc,
|
||||
func(ctx context.Context, _ emptyParams) (any, error) {
|
||||
if err := CheckGate(inv); err != nil {
|
||||
EmitAudit(inv, "{}", "", err)
|
||||
return "", err
|
||||
}
|
||||
f.calledWith = &inv
|
||||
EmitAudit(inv, "{}", f.returnText, f.returnError)
|
||||
return f.returnText, f.returnError
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func TestRegister_DuplicateRejected(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
a := &fakeTool{name: "x", perm: Permission{AuthoringRequirement: RequirementAnyone, SafeForShare: true}}
|
||||
b := &fakeTool{name: "x", perm: Permission{AuthoringRequirement: RequirementAnyone, SafeForShare: true}}
|
||||
if err := r.Register(a); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err := r.Register(b)
|
||||
if err == nil || !strings.Contains(err.Error(), "duplicate") {
|
||||
t.Fatalf("expected duplicate-name error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegister_RejectsEmpty(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
if err := r.Register(&fakeTool{name: ""}); err == nil {
|
||||
t.Fatal("expected empty-name rejection")
|
||||
}
|
||||
if err := r.Register(nil); err == nil {
|
||||
t.Fatal("expected nil-tool rejection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuild_UnknownTool(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
_, err := r.Build([]string{"nope"}, Invocation{}, VisibilityPrivate, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "unknown tool") {
|
||||
t.Fatalf("expected unknown-tool error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuild_SharedRejectsUnsafeTool(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
_ = r.Register(&fakeTool{name: "balance", perm: Permission{SafeForShare: false}})
|
||||
_, err := r.Build([]string{"balance"}, Invocation{}, VisibilityShared, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "not safe for share") {
|
||||
t.Fatalf("expected share-safety error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuild_TrustedBuiltinBypassesShareSafety verifies the
|
||||
// trusted-flag escape hatch: a builtin (skill-wizard, mortventure)
|
||||
// legitimately ships with public visibility AND not-safe-for-share
|
||||
// tools. Build with trusted=true must not reject those.
|
||||
//
|
||||
// Why: pre-fix, invocation of skill-wizard (visibility=public, tools
|
||||
// include wizard_* with SafeForShare=false) was rejected at runtime
|
||||
// even though the loader had already bypassed save-time gates. The
|
||||
// trusted flag aligns the invocation-time gate with the loader's
|
||||
// trust model.
|
||||
func TestBuild_TrustedBuiltinBypassesShareSafety(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
_ = r.Register(&fakeTool{
|
||||
name: "wizard_list",
|
||||
perm: Permission{SafeForShare: false},
|
||||
returnText: "ok",
|
||||
})
|
||||
box, err := r.Build([]string{"wizard_list"}, Invocation{SkillName: "skill-wizard"}, VisibilityPublic, nil, true)
|
||||
if err != nil {
|
||||
t.Fatalf("trusted=true should bypass share-safety, got %v", err)
|
||||
}
|
||||
if box == nil {
|
||||
t.Fatal("trusted=true should produce a toolbox, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuild_NonTrustedSharedStillRejects confirms the bypass is
|
||||
// strictly opt-in: a non-builtin caller with the same shape (public
|
||||
// visibility + unsafe tool) still hits the rejection path.
|
||||
func TestBuild_NonTrustedSharedStillRejects(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
_ = r.Register(&fakeTool{name: "balance", perm: Permission{SafeForShare: false}})
|
||||
_, err := r.Build([]string{"balance"}, Invocation{}, VisibilityPublic, nil, false)
|
||||
if err == nil || !strings.Contains(err.Error(), "not safe for share") {
|
||||
t.Fatalf("trusted=false (non-builtin) must still reject unsafe tool at public visibility, got %v", err)
|
||||
}
|
||||
// Omitted variadic = trusted defaults to false → same rejection.
|
||||
_, err = r.Build([]string{"balance"}, Invocation{}, VisibilityPublic, nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "not safe for share") {
|
||||
t.Fatalf("omitted variadic must default to trusted=false, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuild_PublicAcceptsSafeTool(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
_ = r.Register(&fakeTool{name: "search", perm: Permission{SafeForShare: true}, returnText: "hits"})
|
||||
box, err := r.Build([]string{"search"}, Invocation{SkillName: "echo"}, VisibilityPublic, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
out, err := execBox(box, toolCall{Name: "search", Arguments: "{}"})
|
||||
if err != nil || out != "hits" {
|
||||
t.Fatalf("unexpected: %q %v", out, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuild_GateBlocksMismatchedSkill(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
tt := &fakeTool{
|
||||
name: "wizard_save",
|
||||
perm: Permission{SafeForShare: true, SkillNameGate: "skill-wizard"},
|
||||
returnText: "saved",
|
||||
}
|
||||
_ = r.Register(tt)
|
||||
box, err := r.Build([]string{"wizard_save"}, Invocation{SkillName: "echo"}, VisibilityPrivate, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("build: %v", err)
|
||||
}
|
||||
out, err := execBox(box, toolCall{Name: "wizard_save", Arguments: "{}"})
|
||||
if err == nil || !strings.Contains(err.Error(), "restricted to") {
|
||||
t.Fatalf("expected gate rejection, got out=%q err=%v", out, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuild_GateAllowsMatchingSkill(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
tt := &fakeTool{
|
||||
name: "wizard_save",
|
||||
perm: Permission{SafeForShare: true, SkillNameGate: "skill-wizard"},
|
||||
returnText: "saved",
|
||||
}
|
||||
_ = r.Register(tt)
|
||||
box, _ := r.Build([]string{"wizard_save"}, Invocation{SkillName: "skill-wizard"}, VisibilityPrivate, nil)
|
||||
out, err := execBox(box, toolCall{Name: "wizard_save", Arguments: "{}"})
|
||||
if err != nil || out != "saved" {
|
||||
t.Fatalf("unexpected: %q %v", out, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuild_EmitsAudit(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
tt := &fakeTool{name: "search", perm: Permission{SafeForShare: true}, returnText: "hits"}
|
||||
_ = r.Register(tt)
|
||||
|
||||
var calls []AuditCall
|
||||
hook := func(c AuditCall) { calls = append(calls, c) }
|
||||
|
||||
box, _ := r.Build([]string{"search"}, Invocation{SkillName: "echo"}, VisibilityPrivate, hook)
|
||||
_, _ = execBox(box, toolCall{Name: "search", Arguments: "{}"})
|
||||
|
||||
if len(calls) != 1 || calls[0].Tool != "search" || calls[0].Result != "hits" || calls[0].Err != nil {
|
||||
t.Fatalf("unexpected audit: %+v", calls)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package tool
|
||||
|
||||
// RootRunKVPartition is the sentinel skill_id partition under which all
|
||||
// `root_run:<id>` KV rows are stored.
|
||||
//
|
||||
// Why a sentinel: skill KV rows are keyed (skill_id, scope, key), so
|
||||
// two sibling workers with different IDs (e.g. agent_spawn ephemeral
|
||||
// workers under one fan-out) could never share state through a scope
|
||||
// string alone — each would read/write its own partition. Routing every
|
||||
// root_run scope into one shared partition makes the scope string the
|
||||
// real boundary: it embeds the root run id, which the validator checks
|
||||
// against Invocation.RootRunID, so per-tree isolation holds even though
|
||||
// the partition is global.
|
||||
//
|
||||
// Declared in the root skilltools package (not tools/) because both the
|
||||
// tool handlers (pkg/skilltools/tools) and the storage sweeper
|
||||
// (pkg/logic/skills) need it without importing each other.
|
||||
const RootRunKVPartition = "__root_run__"
|
||||
@@ -0,0 +1,18 @@
|
||||
package tool
|
||||
|
||||
import "testing"
|
||||
|
||||
type fakeAccessor struct{ s RunState }
|
||||
|
||||
func (f fakeAccessor) RunState() RunState { return f.s }
|
||||
|
||||
func TestInvocationRunState_NilSafe(t *testing.T) {
|
||||
var inv Invocation
|
||||
if inv.RunState != nil {
|
||||
t.Fatal("RunState should default nil")
|
||||
}
|
||||
inv.RunState = fakeAccessor{s: RunState{Iteration: 3, MaxIterations: 10}}
|
||||
if got := inv.RunState.RunState(); got.Iteration != 3 || got.MaxIterations != 10 {
|
||||
t.Fatalf("unexpected RunState: %+v", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
package tool
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
llm "gitea.stevedudenhoeffer.com/steve/majordomo/llm"
|
||||
)
|
||||
|
||||
// AgentSession is the live-run handle a SessionToolFactory receives.
|
||||
// It is implemented by the executors (agentexec / skillexec / scaddy's
|
||||
// adapter) on top of majordomo's agent loop and exposes the one mid-run
|
||||
// mutation session tools need: feeding content back into the running
|
||||
// conversation.
|
||||
//
|
||||
// Why an interface (vs the concrete agent type): legacy agentkit handed the
|
||||
// factory a *agentkit.Agent so tools could call agent.AttachImages.
|
||||
// majordomo's *agent.Agent is deliberately immutable mid-run — message
|
||||
// injection happens through the run-scoped steer mailbox
|
||||
// (agent.WithSteer). A narrow interface lets each executor implement
|
||||
// AttachImages over its own steer queue without skilltools importing
|
||||
// the agent package, and keeps session tools testable with a two-line
|
||||
// fake.
|
||||
type AgentSession interface {
|
||||
// AttachImages queues a user-role message (text plus image parts)
|
||||
// for injection into the conversation before the agent's next
|
||||
// step. Used by tools that produce visual feedback the model must
|
||||
// see on its following turn (e.g. scaddy's rendered OpenSCAD
|
||||
// previews). Safe to call from inside a tool handler; the message
|
||||
// lands after the current step's tool results.
|
||||
AttachImages(text string, images ...llm.ImagePart)
|
||||
}
|
||||
|
||||
// SessionToolFactory builds per-session tools that close over the live
|
||||
// agent session. Called by the executor after the agent is constructed
|
||||
// but before it runs. See Invocation.SessionToolFactory for the
|
||||
// rationale (static ExtraTools cannot reach the running agent).
|
||||
type SessionToolFactory func(session AgentSession) SessionTools
|
||||
|
||||
// SessionTools carries per-session tools plus optional post-run and
|
||||
// teardown hooks. It replaces legacy agentkit's SessionTools with the same
|
||||
// three-field shape, re-based on majordomo types.
|
||||
type SessionTools struct {
|
||||
// Tools to add to the agent's toolbox for this run only.
|
||||
Tools []llm.Tool
|
||||
|
||||
// PostRun, if set, is called after the agent run completes
|
||||
// (successfully or not). It receives the full run transcript (the
|
||||
// agent Result's Messages — also populated on partial results from
|
||||
// agent.ErrMaxSteps / agent.ErrToolLoop), the agent's text output,
|
||||
// and the run error, so the hook can decide whether to attempt
|
||||
// artifact production on partial success (e.g. scaddy ships the
|
||||
// latest SCAD even when the step budget ran out). The returned
|
||||
// PostRunResult is attached to the executor's run result. Errors
|
||||
// inside PostRun must be handled by the hook itself — the executor
|
||||
// logs a nil return but never fails the run over it; the agent's
|
||||
// output is the source of truth.
|
||||
//
|
||||
// Why a transcript slice (vs the live agent): the consumers only
|
||||
// ever read the message history (thought-chain transcripts); the
|
||||
// majordomo agent exposes that on Result, not on the Agent.
|
||||
PostRun func(ctx context.Context, transcript []llm.Message, output string, runErr error) *PostRunResult
|
||||
|
||||
// Cleanup, if set, is deferred by the executor immediately after
|
||||
// the factory returns. Called even if the run fails or PostRun
|
||||
// panics. Use for temp directory removal, closing file handles,
|
||||
// etc.
|
||||
Cleanup func()
|
||||
}
|
||||
|
||||
// PostRunResult carries artifacts produced by the PostRun hook.
|
||||
// Attached to the executor's run result so callers (Discord command
|
||||
// handlers, HTTP API responses) can inspect and deliver the artifacts.
|
||||
//
|
||||
// Why a separate struct (vs returning artifacts inline): post-
|
||||
// processing may produce multiple typed artifacts (PNGs, STLs, SCAD
|
||||
// source) that the delivery layer classifies and routes differently.
|
||||
// A flat []Artifact + arbitrary Metadata covers the known use cases
|
||||
// without over-specifying the shape.
|
||||
type PostRunResult struct {
|
||||
// Artifacts are files produced during post-processing
|
||||
// (e.g., rendered PNGs, STL files, SCAD source).
|
||||
Artifacts []Artifact
|
||||
|
||||
// Metadata is arbitrary key-value data the delivery layer can
|
||||
// use for formatting (e.g., iteration count, model name, notes).
|
||||
Metadata map[string]any
|
||||
}
|
||||
|
||||
// Artifact is a named binary blob produced by post-run processing.
|
||||
//
|
||||
// Why: the delivery layer needs name + type + bytes to classify
|
||||
// each artifact (PNG → embed image, STL → filetransfer upload,
|
||||
// SCAD → paste upload). A struct with these three fields is the
|
||||
// minimal viable description.
|
||||
type Artifact struct {
|
||||
Name string // e.g., "model.stl", "preview_iso.png"
|
||||
MimeType string // e.g., "model/stl", "image/png"
|
||||
Data []byte
|
||||
}
|
||||
@@ -0,0 +1,221 @@
|
||||
// Package skilltools — SSRF protection layer for skill HTTP tools.
|
||||
//
|
||||
// Why a dedicated layer (vs reusing pkg/utils.ValidateExternalURL):
|
||||
// the platform's HTTP tools enforce a per-deployment ALLOWLIST (not
|
||||
// just a "no private IPs" denylist) — admins must explicitly opt-in
|
||||
// to each domain a skill may call. Additionally, defeating DNS
|
||||
// rebinding requires capturing the resolved IP at validation time
|
||||
// and pinning the dialler so a hostile DNS resolver can't return a
|
||||
// public IP during the check and a private one at dial time.
|
||||
package tool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AllowlistConfig governs which hosts a skill HTTP tool may contact.
|
||||
//
|
||||
// Why a config struct (vs raw []string): forward-compatibility — we
|
||||
// expect to add per-tool overrides (e.g. "this skill may also reach
|
||||
// internal.example.com") and an explicit `AllowLoopback` opt-in for
|
||||
// development environments. Keeping the validation surface as a
|
||||
// struct lets new fields land without breaking call sites.
|
||||
type AllowlistConfig struct {
|
||||
// Domains is the list of allowed hostnames. Wildcards: "*.example.com"
|
||||
// matches "foo.example.com" and "bar.baz.example.com" but NOT
|
||||
// "example.com" itself (to allow both, list both entries).
|
||||
//
|
||||
// Comparison is case-insensitive; trailing dots are NOT trimmed
|
||||
// (DNS treats "example.com" and "example.com." as different).
|
||||
Domains []string
|
||||
}
|
||||
|
||||
// ResolveAndCheck validates urlStr against the allowlist and returns
|
||||
// the resolved IP. The IP is meant to be passed to the transport's
|
||||
// dial step (via PinnedDialTransport) to defeat DNS rebinding.
|
||||
//
|
||||
// Loopback / private / link-local rejection is bypassed when the
|
||||
// HOSTNAME (not the resolved IP) is itself an entry in the allowlist
|
||||
// OR the resolved IP literal appears in the allowlist. This lets an
|
||||
// admin opt-in to "127.0.0.1" or "localhost" for tests / debug
|
||||
// without a global allow-private flag, while keeping the default
|
||||
// (random hostname → resolved private IP) safe.
|
||||
//
|
||||
// Returns:
|
||||
// - resolvedIP if the URL is acceptable
|
||||
// - error explaining the rejection (host not allowlisted, scheme
|
||||
// unsupported, resolves to private IP, etc.)
|
||||
func ResolveAndCheck(ctx context.Context, urlStr string, allow AllowlistConfig) (net.IP, error) {
|
||||
u, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse url: %w", err)
|
||||
}
|
||||
if u.Scheme != "http" && u.Scheme != "https" {
|
||||
return nil, fmt.Errorf("scheme %q not supported (need http or https)", u.Scheme)
|
||||
}
|
||||
host := u.Hostname()
|
||||
if host == "" {
|
||||
return nil, fmt.Errorf("url has no host")
|
||||
}
|
||||
|
||||
if !matchesAllowlist(host, allow.Domains) {
|
||||
return nil, fmt.Errorf("host %q not in allowlist", host)
|
||||
}
|
||||
|
||||
// If the host is already a literal IP, skip the resolve step.
|
||||
if literal := net.ParseIP(host); literal != nil {
|
||||
// Even an explicitly allowlisted IP literal goes through the
|
||||
// privacy check UNLESS the literal is itself in the allowlist
|
||||
// (covers admin opt-in "127.0.0.1" for tests).
|
||||
if hostExplicitlyAllowed(host, allow.Domains) {
|
||||
return literal, nil
|
||||
}
|
||||
if err := rejectPrivateIP(host, literal); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return literal, nil
|
||||
}
|
||||
|
||||
// Resolve. context controls timeout.
|
||||
resolver := &net.Resolver{}
|
||||
addrs, err := resolver.LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolve %q: %w", host, err)
|
||||
}
|
||||
if len(addrs) == 0 {
|
||||
return nil, fmt.Errorf("resolve %q: no addresses", host)
|
||||
}
|
||||
|
||||
ip := addrs[0].IP
|
||||
|
||||
// Hostname explicitly in allowlist (e.g. "localhost" → opt-in by
|
||||
// admin) bypasses the private-IP check. The wildcard form does NOT
|
||||
// bypass — wildcards are for public domain families, not for
|
||||
// private space.
|
||||
if hostExplicitlyAllowed(host, allow.Domains) {
|
||||
return ip, nil
|
||||
}
|
||||
|
||||
if err := rejectPrivateIP(host, ip); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ip, nil
|
||||
}
|
||||
|
||||
// rejectPrivateIP returns an error if the IP is loopback / private /
|
||||
// link-local / unspecified, formatted with the original hostname so
|
||||
// the rejection message is informative.
|
||||
//
|
||||
// Why a helper: ResolveAndCheck calls it twice (literal-IP path and
|
||||
// resolved-host path) and the same checks apply.
|
||||
func rejectPrivateIP(host string, ip net.IP) error {
|
||||
// Cloud metadata endpoint check FIRST — it's a link-local IP, so
|
||||
// the more-specific metadata error message would otherwise be
|
||||
// shadowed by the link-local rejection.
|
||||
if ip.Equal(net.ParseIP("169.254.169.254")) {
|
||||
return fmt.Errorf("host %q resolves to cloud metadata IP %v", host, ip)
|
||||
}
|
||||
if ip.IsLoopback() {
|
||||
return fmt.Errorf("host %q resolves to loopback %v", host, ip)
|
||||
}
|
||||
if ip.IsPrivate() {
|
||||
return fmt.Errorf("host %q resolves to private IP %v", host, ip)
|
||||
}
|
||||
if ip.IsLinkLocalUnicast() {
|
||||
return fmt.Errorf("host %q resolves to link-local %v", host, ip)
|
||||
}
|
||||
if ip.IsUnspecified() {
|
||||
return fmt.Errorf("host %q resolves to unspecified %v", host, ip)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// hostExplicitlyAllowed reports whether host is in the allowlist as
|
||||
// an exact entry (NOT via a wildcard). Used to bypass the private-IP
|
||||
// check when an admin has explicitly named a host (e.g. "127.0.0.1"
|
||||
// or "localhost") to opt-in.
|
||||
func hostExplicitlyAllowed(host string, allow []string) bool {
|
||||
host = strings.ToLower(host)
|
||||
for _, pattern := range allow {
|
||||
pattern = strings.ToLower(strings.TrimSpace(pattern))
|
||||
if pattern == host {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// matchesAllowlist reports whether host matches any entry in allow,
|
||||
// either by exact match, by "*.example.com" wildcard, or by the
|
||||
// special bare "*" wildcard (allow every host).
|
||||
//
|
||||
// Wildcards match one-or-more subdomain levels: "*.example.com"
|
||||
// matches "foo.example.com" and "a.b.example.com" but NOT
|
||||
// "example.com" itself.
|
||||
//
|
||||
// Bare "*" matches any host. **Operators should use this only when
|
||||
// they understand the SSRF + iptables layers still defend against
|
||||
// private-IP traffic** (ResolveAndCheck blocks loopback / RFC1918 /
|
||||
// link-local UNLESS the IP literal is also in the allowlist; the v15
|
||||
// codeexec firewall sidecar adds host-level iptables drops). The
|
||||
// bare-"*" form is the v15.1 operator UX answer to "I just want to
|
||||
// let the agent reach the public internet" — without it, operators
|
||||
// had to enumerate TLDs (*.com, *.org, *.io, etc.) which never
|
||||
// covered the long tail.
|
||||
func matchesAllowlist(host string, allow []string) bool {
|
||||
host = strings.ToLower(host)
|
||||
for _, pattern := range allow {
|
||||
pattern = strings.ToLower(strings.TrimSpace(pattern))
|
||||
if pattern == "" {
|
||||
continue
|
||||
}
|
||||
// Bare "*" = allow-any. The SSRF + iptables layers still
|
||||
// enforce private-IP blocks; this only opens the hostname gate.
|
||||
if pattern == "*" {
|
||||
return true
|
||||
}
|
||||
if pattern == host {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(pattern, "*.") {
|
||||
suffix := pattern[1:] // ".example.com"
|
||||
if strings.HasSuffix(host, suffix) && len(host) > len(suffix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// PinnedDialTransport returns an http.RoundTripper that uses the given
|
||||
// IP for all Dial operations regardless of host (defeats DNS rebinding).
|
||||
// The Host header is preserved from the request — TLS SNI and HTTP
|
||||
// Host routing continue to work, only the network connection is
|
||||
// pinned to the pre-validated IP.
|
||||
//
|
||||
// Why pre-validated dial vs trusting the request: between the
|
||||
// ResolveAndCheck call and the http.Client.Do call, a hostile DNS
|
||||
// server can return a different IP. Pinning the dialler ensures the
|
||||
// connection lands on the exact address that passed the privacy
|
||||
// check.
|
||||
func PinnedDialTransport(ip net.IP, timeout time.Duration) http.RoundTripper {
|
||||
dialer := &net.Dialer{Timeout: timeout}
|
||||
return &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
// addr is "host:port" — replace host with pinned IP.
|
||||
_, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port))
|
||||
},
|
||||
ResponseHeaderTimeout: timeout,
|
||||
TLSHandshakeTimeout: timeout,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,168 @@
|
||||
package tool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestResolveAndCheck_AllowlistedPublic anchors the happy path: a
|
||||
// public domain in the allowlist resolves and returns its IP. Skips
|
||||
// when DNS isn't available so the suite still passes in offline CI.
|
||||
func TestResolveAndCheck_AllowlistedPublic(t *testing.T) {
|
||||
// Pre-flight: skip the test if the test environment has no DNS.
|
||||
if _, err := net.LookupHost("example.com"); err != nil {
|
||||
t.Skipf("no DNS in test environment: %v", err)
|
||||
}
|
||||
allow := AllowlistConfig{Domains: []string{"example.com"}}
|
||||
ip, err := ResolveAndCheck(context.Background(), "https://example.com/", allow)
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveAndCheck failed: %v", err)
|
||||
}
|
||||
if ip == nil {
|
||||
t.Fatal("expected non-nil IP")
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveAndCheck_NotAllowlisted ensures a domain outside the
|
||||
// allowlist is rejected before any DNS resolution.
|
||||
func TestResolveAndCheck_NotAllowlisted(t *testing.T) {
|
||||
allow := AllowlistConfig{Domains: []string{"example.com"}}
|
||||
_, err := ResolveAndCheck(context.Background(), "https://evil.test/", allow)
|
||||
if err == nil {
|
||||
t.Fatal("expected rejection for non-allowlisted host")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not in allowlist") {
|
||||
t.Errorf("expected allowlist error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveAndCheck_WildcardMatch confirms "*.example.com" matches
|
||||
// foo.example.com.
|
||||
func TestResolveAndCheck_WildcardMatch(t *testing.T) {
|
||||
if !matchesAllowlist("foo.example.com", []string{"*.example.com"}) {
|
||||
t.Error("expected *.example.com to match foo.example.com")
|
||||
}
|
||||
if !matchesAllowlist("a.b.example.com", []string{"*.example.com"}) {
|
||||
t.Error("expected *.example.com to match a.b.example.com")
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveAndCheck_WildcardDoesNotMatchBareDomain documents the
|
||||
// design choice: "*.example.com" does NOT match "example.com" itself.
|
||||
// Admins who want both must list both entries.
|
||||
func TestResolveAndCheck_WildcardDoesNotMatchBareDomain(t *testing.T) {
|
||||
if matchesAllowlist("example.com", []string{"*.example.com"}) {
|
||||
t.Error("expected *.example.com NOT to match bare example.com")
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveAndCheck_LocalhostRejected verifies that a hostname
|
||||
// resolving to 127.0.0.1 is rejected unless the admin explicitly
|
||||
// includes it in the allowlist.
|
||||
func TestResolveAndCheck_LocalhostRejected(t *testing.T) {
|
||||
if _, err := net.LookupHost("localhost"); err != nil {
|
||||
t.Skipf("no DNS in test environment: %v", err)
|
||||
}
|
||||
// "localhost" matches the allowlist by exact name match, but the
|
||||
// hostExplicitlyAllowed bypass kicks in only when the host is in
|
||||
// the allowlist as an exact entry. Here we use a DIFFERENT bare
|
||||
// allowlist entry so the host fails the allowlist match outright.
|
||||
allow := AllowlistConfig{Domains: []string{"example.com"}}
|
||||
_, err := ResolveAndCheck(context.Background(), "http://localhost/", allow)
|
||||
if err == nil {
|
||||
t.Fatal("expected rejection for localhost (not in allowlist)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveAndCheck_LocalhostAllowedExplicit confirms the
|
||||
// admin-opt-in escape hatch: when the hostname is itself in the
|
||||
// allowlist as an exact entry, the private-IP check is bypassed.
|
||||
// This is what test code uses to drive httptest.NewServer URLs.
|
||||
func TestResolveAndCheck_LocalhostAllowedExplicit(t *testing.T) {
|
||||
// Use 127.0.0.1 directly so this test doesn't depend on DNS for
|
||||
// "localhost".
|
||||
allow := AllowlistConfig{Domains: []string{"127.0.0.1"}}
|
||||
ip, err := ResolveAndCheck(context.Background(), "http://127.0.0.1/", allow)
|
||||
if err != nil {
|
||||
t.Fatalf("expected 127.0.0.1 with explicit allowlist to succeed; got: %v", err)
|
||||
}
|
||||
if !ip.Equal(net.ParseIP("127.0.0.1")) {
|
||||
t.Errorf("expected ip=127.0.0.1, got %v", ip)
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveAndCheck_FileSchemeRejected blocks file:// URLs.
|
||||
func TestResolveAndCheck_FileSchemeRejected(t *testing.T) {
|
||||
allow := AllowlistConfig{Domains: []string{"anything"}}
|
||||
_, err := ResolveAndCheck(context.Background(), "file:///etc/passwd", allow)
|
||||
if err == nil {
|
||||
t.Fatal("expected rejection for file:// scheme")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "scheme") {
|
||||
t.Errorf("expected scheme error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveAndCheck_EmptyHostRejected blocks malformed URLs with
|
||||
// no host component.
|
||||
func TestResolveAndCheck_EmptyHostRejected(t *testing.T) {
|
||||
allow := AllowlistConfig{Domains: []string{"anything"}}
|
||||
_, err := ResolveAndCheck(context.Background(), "http:///nohost", allow)
|
||||
if err == nil {
|
||||
t.Fatal("expected rejection for empty host")
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveAndCheck_PrivateIPLiteralRejected confirms that an IP
|
||||
// literal resolving to the private range is rejected even if the
|
||||
// allowlist matches by wildcard or other means. The private-IP gate
|
||||
// is the last line of defence.
|
||||
func TestResolveAndCheck_PrivateIPLiteralRejected(t *testing.T) {
|
||||
// Add a wildcard that would match anything (silly but plausible
|
||||
// admin error) and confirm a private IP literal is still blocked
|
||||
// because the literal isn't itself in the allowlist as exact.
|
||||
allow := AllowlistConfig{Domains: []string{"192.168.1.1"}}
|
||||
// The exact-IP-in-allowlist case bypasses the private check; flip
|
||||
// to a NEAR-but-different IP literal that's NOT in the allowlist.
|
||||
allow2 := AllowlistConfig{Domains: []string{"192.168.0.0"}}
|
||||
_, err := ResolveAndCheck(context.Background(), "http://192.168.1.1/", allow2)
|
||||
if err == nil {
|
||||
t.Fatal("expected rejection for private IP literal not in allowlist")
|
||||
}
|
||||
// Sanity: explicit allowlist entry bypasses.
|
||||
_, err = ResolveAndCheck(context.Background(), "http://192.168.1.1/", allow)
|
||||
if err != nil {
|
||||
t.Errorf("expected explicit allowlist entry to bypass; got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveAndCheck_CloudMetadataRejected blocks the well-known
|
||||
// cloud metadata IP via the link-local check. We use a wildcard that
|
||||
// matches the IP-as-hostname so the rejection comes from the
|
||||
// private/link-local layer (not the allowlist).
|
||||
func TestResolveAndCheck_CloudMetadataRejected(t *testing.T) {
|
||||
// "*.169.254.169.254" wildcard wouldn't match either; instead use
|
||||
// a wildcard that matches any IP literal under .254 — but
|
||||
// matchesAllowlist treats '.' as a literal so we just allowlist
|
||||
// the IP itself with a one-bit-different sibling that fails the
|
||||
// exact-allow check (so private check still runs).
|
||||
//
|
||||
// Easier: include a different exact IP entry so the IP literal
|
||||
// fails hostExplicitlyAllowed but passes the wildcard.
|
||||
allow := AllowlistConfig{Domains: []string{"*.169.254.169.254"}} // matches "x.169.254.169.254", not the bare IP
|
||||
// 169.254.169.254 won't match the wildcard pattern either —
|
||||
// switch to a strategy that lets the host pass allowlist but
|
||||
// fails the private check.
|
||||
_ = allow
|
||||
// Use an explicit non-IP-literal hostname (we'd need DNS to point
|
||||
// to 169.254.169.254 which is not feasible). Instead, exercise
|
||||
// the rejectPrivateIP helper directly for the metadata IP since
|
||||
// the public surface only enters that path through resolution.
|
||||
if err := rejectPrivateIP("metadata.test", net.ParseIP("169.254.169.254")); err == nil {
|
||||
t.Fatal("expected rejection for cloud metadata IP")
|
||||
} else if !strings.Contains(err.Error(), "metadata") {
|
||||
t.Errorf("expected metadata error, got: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,145 @@
|
||||
// Package skilltools — webhook_rate_limit.go: per-IP-per-skill
|
||||
// sliding-window rate limiter for the v7 inbound webhook handler.
|
||||
//
|
||||
// Why an in-memory limiter (vs Redis or DB-backed): rate limiting is
|
||||
// the cheap reject path BEFORE the HMAC compute and run-budget check,
|
||||
// and an extra round-trip per inbound webhook would be wasted. The
|
||||
// 6-person server's volume is well within a single-process limiter's
|
||||
// scale; if mort is ever multi-process the limiter becomes
|
||||
// approximate (still good enough to throttle abusive sources).
|
||||
//
|
||||
// Why per-IP-per-skill (vs per-IP global): one busy webhook (e.g.
|
||||
// GitHub PR opened) shouldn't shadow another (Stripe charge). The
|
||||
// composite key keeps a noisy source from pushing other skill's
|
||||
// callers off the lane.
|
||||
//
|
||||
// Test: webhook_rate_limit_test.go covers admit + reject paths.
|
||||
package tool
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// WebhookRateLimiter is a sliding-window per-(skillID, sourceIP)
|
||||
// counter. Configure once at construction; concurrent-safe.
|
||||
type WebhookRateLimiter struct {
|
||||
limit int
|
||||
window time.Duration
|
||||
clock func() time.Time
|
||||
|
||||
mu sync.Mutex
|
||||
buckets map[string]*rateBucket // key = skillID + "|" + sourceIP
|
||||
}
|
||||
|
||||
type rateBucket struct {
|
||||
// hits is a slice of timestamps within the window. Pruned on
|
||||
// every Admit call so the slice never grows unbounded.
|
||||
hits []time.Time
|
||||
}
|
||||
|
||||
// NewWebhookRateLimiter constructs the limiter.
|
||||
//
|
||||
// limit — max calls per (skill, ip) within window. <=0 means
|
||||
//
|
||||
// "unlimited" (every call admitted; useful for tests).
|
||||
//
|
||||
// window — sliding window length. <=0 falls back to 1 minute.
|
||||
// clock — testable wall-clock; nil → time.Now.
|
||||
func NewWebhookRateLimiter(limit int, window time.Duration, clock func() time.Time) *WebhookRateLimiter {
|
||||
if window <= 0 {
|
||||
window = time.Minute
|
||||
}
|
||||
if clock == nil {
|
||||
clock = time.Now
|
||||
}
|
||||
return &WebhookRateLimiter{
|
||||
limit: limit,
|
||||
window: window,
|
||||
clock: clock,
|
||||
buckets: make(map[string]*rateBucket),
|
||||
}
|
||||
}
|
||||
|
||||
// Admit returns (true, 0) if the call is within the rate cap (records
|
||||
// the hit), or (false, retry-after) if the cap is hit. retry-after is
|
||||
// the time until the OLDEST hit in the window expires — the caller can
|
||||
// surface it via the Retry-After response header.
|
||||
//
|
||||
// Why return retry-after not just bool: HTTP 429 responses commonly
|
||||
// include Retry-After to avoid synchronizing client retries; computing
|
||||
// it from the sliding window is essentially free.
|
||||
func (l *WebhookRateLimiter) Admit(skillID, sourceIP string) (bool, time.Duration) {
|
||||
if l.limit <= 0 {
|
||||
return true, 0
|
||||
}
|
||||
now := l.clock()
|
||||
cutoff := now.Add(-l.window)
|
||||
key := skillID + "|" + sourceIP
|
||||
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
b, ok := l.buckets[key]
|
||||
if !ok {
|
||||
b = &rateBucket{}
|
||||
l.buckets[key] = b
|
||||
}
|
||||
// Prune in place. The slice is append-only at the tail; the head
|
||||
// shrinks as old hits fall out of the window.
|
||||
first := 0
|
||||
for first < len(b.hits) && b.hits[first].Before(cutoff) {
|
||||
first++
|
||||
}
|
||||
if first > 0 {
|
||||
// Copy the surviving tail to the head; reuse backing array.
|
||||
n := copy(b.hits, b.hits[first:])
|
||||
b.hits = b.hits[:n]
|
||||
}
|
||||
if len(b.hits) >= l.limit {
|
||||
oldest := b.hits[0]
|
||||
retryAfter := oldest.Add(l.window).Sub(now)
|
||||
if retryAfter < 0 {
|
||||
retryAfter = 0
|
||||
}
|
||||
return false, retryAfter
|
||||
}
|
||||
b.hits = append(b.hits, now)
|
||||
return true, 0
|
||||
}
|
||||
|
||||
// Sweep purges buckets whose hit-list is empty after pruning. Called
|
||||
// periodically (e.g. once per minute) to bound the buckets map's
|
||||
// growth.
|
||||
//
|
||||
// Why a separate Sweep vs auto-prune in Admit: a hostile source that
|
||||
// rotates IP addresses across many addresses each hitting once
|
||||
// would leave millions of single-hit buckets in the map. A periodic
|
||||
// sweep keeps the worst case bounded.
|
||||
func (l *WebhookRateLimiter) Sweep() {
|
||||
now := l.clock()
|
||||
cutoff := now.Add(-l.window)
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
for k, b := range l.buckets {
|
||||
// Prune in place.
|
||||
first := 0
|
||||
for first < len(b.hits) && b.hits[first].Before(cutoff) {
|
||||
first++
|
||||
}
|
||||
if first > 0 {
|
||||
n := copy(b.hits, b.hits[first:])
|
||||
b.hits = b.hits[:n]
|
||||
}
|
||||
if len(b.hits) == 0 {
|
||||
delete(l.buckets, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CountKeys returns the bucket count. Test helper.
|
||||
func (l *WebhookRateLimiter) CountKeys() int {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
return len(l.buckets)
|
||||
}
|
||||
Reference in New Issue
Block a user