diff --git a/.gitea/workflows/ci.yml b/.gitea/workflows/ci.yml index 2dfd963..477ed6a 100644 --- a/.gitea/workflows/ci.yml +++ b/.gitea/workflows/ci.yml @@ -13,8 +13,24 @@ on: push: branches: [main] tags: ["v*"] + # Docs/example/meta-only changes don't affect the build — skip the run. + # (Path filters are not applied to tag pushes, so v* releases always run.) + paths-ignore: + - "**.md" + - "LICENSE" + - ".gitignore" + - ".dockerignore" + - "examples/**" + - "docs/**" pull_request: types: [opened, synchronize, reopened] + paths-ignore: + - "**.md" + - "LICENSE" + - ".gitignore" + - ".dockerignore" + - "examples/**" + - "docs/**" workflow_dispatch: {} concurrency: diff --git a/CLAUDE.md b/CLAUDE.md index be5d25b..dcc609e 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -47,8 +47,9 @@ CORE (majordomo + stdlib): nil-safe Ports + RunnableAgent later [P2] dispatchguard/ loop/depth/fan-out caps [P0 ✓] pendingattach/ attachment dedupe [P0 ✓] - tool/ registry + 3-stage permissions + ssrf/llmmeta [P1] + tool/ registry + 3-stage permissions + ssrf [P1 ✓] model/ config-driven tier resolution over majordomo [P1] + llmmeta/ shared meta-LLM helper (moves with model/) [P1] compact/ context compactor (WithCompactor hook) [P2] tools/{web,net,store,compose,meta,comms} generic tools [P3] structured/ Generate[T] convenience over majordomo [P1] @@ -91,6 +92,11 @@ rewire mort + tag v0.1.0. The mort-side rewrite reuses mort's existing ## Conventions +- **Keep `README.md`, this `CLAUDE.md`, and `examples/` in sync with every change, + in the SAME commit.** No aspirational docs: when you add/rename a package, change + a seam or a default, or alter the public API, update the docs and the relevant + example so they always reflect reality (mirrors majordomo's house rule). The + status markers in the tier map above must track what's actually landed. - Mirror majordomo's house style: gofmt; check errors immediately and wrap with `fmt.Errorf("...: %w", err)`; `// Why:` comments where rationale isn't obvious; hermetic tests (majordomo's fake provider; no network in the default suite). diff --git a/go.mod b/go.mod index 814dfd4..c211aa8 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,8 @@ module gitea.stevedudenhoeffer.com/steve/executus go 1.26.2 + +require ( + gitea.stevedudenhoeffer.com/steve/majordomo v0.0.0-20260626223738-1fd7109a42f3 + golang.org/x/crypto v0.53.0 +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..97cbe0a --- /dev/null +++ b/go.sum @@ -0,0 +1,4 @@ +gitea.stevedudenhoeffer.com/steve/majordomo v0.0.0-20260626223738-1fd7109a42f3 h1:KYKIFFRsXzbbBJVDa99+Fhy0zxl9G0xV/MCrLipsLL4= +gitea.stevedudenhoeffer.com/steve/majordomo v0.0.0-20260626223738-1fd7109a42f3/go.mod h1:UZLveG17SmENt4sne2RSLIbioix30RZbRIQUzBAnOyY= +golang.org/x/crypto v0.53.0 h1:QZ4Muo8THX6CizN2vPPd5fBGHyogrdK9fG4wLPFUsto= +golang.org/x/crypto v0.53.0/go.mod h1:DNLU434OwVakk9PzuwV8w62mAJpRJL3vsgcfp4Qnsio= diff --git a/tool/argcoerce.go b/tool/argcoerce.go new file mode 100644 index 0000000..3ab3845 --- /dev/null +++ b/tool/argcoerce.go @@ -0,0 +1,161 @@ +package tool + +import ( + "encoding/json" + "reflect" + "strconv" + "strings" +) + +// unmarshalArgsLenient decodes the raw JSON arguments the model supplied +// into the tool's typed Args struct, tolerating the classic LLM tool-call +// bug of emitting numbers and booleans as strings ("3" where the schema +// said integer, "true" where it said boolean). +// +// Why mort-side (vs relying on the library): legacy gollm's Define performed +// this coercion internally (tool_coerce.go), and several years of tool +// traffic depend on the tolerance. majordomo's DefineTool decodes +// strictly by design, so the gated wrappers re-create the leniency here +// — the strict path is tried first and coercion only runs on failure, +// which makes the happy path zero-cost. +func unmarshalArgsLenient(raw json.RawMessage, target any) error { + if len(raw) == 0 { + return nil + } + strictErr := json.Unmarshal(raw, target) + if strictErr == nil { + return nil + } + coerced, err := coerceArgsToType(raw, reflect.TypeOf(target).Elem()) + if err != nil { + // Malformed JSON: surface the original strict error, which + // names the real problem. + return strictErr + } + if err := json.Unmarshal(coerced, target); err != nil { + return strictErr + } + return nil +} + +// coerceArgsToType reparses argsJSON with leniency: where the target +// struct expects a numeric or boolean field but the JSON value is a +// string, it converts the string to the target kind. Recurses into +// nested structs, slices, maps, and pointer fields. Returns a freshly +// marshaled JSON byte slice that unmarshals strictly into the target. +func coerceArgsToType(argsJSON []byte, target reflect.Type) ([]byte, error) { + var raw any + if err := json.Unmarshal(argsJSON, &raw); err != nil { + return nil, err + } + raw = coerceValue(raw, target) + return json.Marshal(raw) +} + +func coerceValue(v any, t reflect.Type) any { + if t == nil { + return v + } + for t.Kind() == reflect.Pointer { + t = t.Elem() + } + + switch t.Kind() { + case reflect.Struct: + m, ok := v.(map[string]any) + if !ok { + return v + } + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if !f.IsExported() { + continue + } + name := jsonFieldName(f) + if name == "-" { + continue + } + if val, present := m[name]; present { + m[name] = coerceValue(val, f.Type) + } + } + return m + + case reflect.Slice, reflect.Array: + arr, ok := v.([]any) + if !ok { + return v + } + elemType := t.Elem() + for i := range arr { + arr[i] = coerceValue(arr[i], elemType) + } + return arr + + case reflect.Map: + m, ok := v.(map[string]any) + if !ok { + return v + } + valType := t.Elem() + for k := range m { + m[k] = coerceValue(m[k], valType) + } + return m + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if s, ok := v.(string); ok { + s = strings.TrimSpace(s) + s = strings.TrimPrefix(s, "+") + if n, err := strconv.ParseInt(s, 10, 64); err == nil { + return n + } + if f, err := strconv.ParseFloat(s, 64); err == nil { + return int64(f) + } + } + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if s, ok := v.(string); ok { + s = strings.TrimSpace(s) + s = strings.TrimPrefix(s, "+") + if n, err := strconv.ParseUint(s, 10, 64); err == nil { + return n + } + if f, err := strconv.ParseFloat(s, 64); err == nil && f >= 0 { + return uint64(f) + } + } + + case reflect.Float32, reflect.Float64: + if s, ok := v.(string); ok { + s = strings.TrimSpace(s) + if f, err := strconv.ParseFloat(s, 64); err == nil { + return f + } + } + + case reflect.Bool: + if s, ok := v.(string); ok { + if b, err := strconv.ParseBool(strings.TrimSpace(s)); err == nil { + return b + } + } + } + return v +} + +// jsonFieldName returns the effective JSON key for a struct field: +// the json tag's name part when present, the Go field name otherwise, +// or "-" when the field is excluded. +func jsonFieldName(f reflect.StructField) string { + tag, ok := f.Tag.Lookup("json") + if !ok { + return f.Name + } + name, _, _ := strings.Cut(tag, ",") + if name == "" { + return f.Name + } + return name +} diff --git a/tool/argcoerce_test.go b/tool/argcoerce_test.go new file mode 100644 index 0000000..cf2f45c --- /dev/null +++ b/tool/argcoerce_test.go @@ -0,0 +1,83 @@ +package tool + +import ( + "context" + "strings" + "testing" +) + +// coerceParams mirrors the field kinds legacy gollm's coercion supported. +type coerceParams struct { + Count int `json:"count"` + Ratio float64 `json:"ratio"` + Flag bool `json:"flag"` + Limit *int `json:"limit"` + Tags []int `json:"tags"` + Nested coerceIn `json:"nested"` + Verbose string `json:"verbose"` +} + +type coerceIn struct { + Depth uint `json:"depth"` +} + +// TestGatedTool_LenientArgCoercion anchors the legacy gollm-era tolerance the +// conversion preserved: numeric and boolean fields supplied as strings +// by the model ("3", "true") decode into the typed Args, recursing into +// pointers, slices, and nested structs. Models emit this shape +// constantly; losing the tolerance would break live tool traffic. +func TestGatedTool_LenientArgCoercion(t *testing.T) { + var seen coerceParams + tool := NewGatedTool[coerceParams]( + "coerce_tool", "coercion test", + Permission{AuthoringRequirement: RequirementAnyone, SafeForShare: true}, + func(ctx context.Context, inv Invocation, args coerceParams) (string, error) { + seen = args + return "ok", nil + }, + ) + + out, err := buildAndExecute(t, tool, Invocation{SkillName: "x"}, VisibilityPrivate, nil, + `{"count":"3","ratio":" 2.5 ","flag":"true","limit":"7","tags":["1","2"],"nested":{"depth":"4"},"verbose":"yes"}`) + if err != nil || out != "ok" { + t.Fatalf("execute: out=%q err=%v", out, err) + } + if seen.Count != 3 || seen.Ratio != 2.5 || seen.Flag != true { + t.Fatalf("scalar coercion failed: %+v", seen) + } + if seen.Limit == nil || *seen.Limit != 7 { + t.Fatalf("pointer coercion failed: %+v", seen.Limit) + } + if len(seen.Tags) != 2 || seen.Tags[0] != 1 || seen.Tags[1] != 2 { + t.Fatalf("slice coercion failed: %+v", seen.Tags) + } + if seen.Nested.Depth != 4 { + t.Fatalf("nested coercion failed: %+v", seen.Nested) + } + if seen.Verbose != "yes" { + t.Fatalf("string field mangled: %q", seen.Verbose) + } +} + +// TestGatedTool_StrictPathUnaffected confirms well-typed args take the +// zero-cost strict path and uncoercible strings still fail loudly. +func TestGatedTool_StrictPathUnaffected(t *testing.T) { + tool := NewGatedTool[coerceParams]( + "coerce_strict_tool", "coercion test", + Permission{AuthoringRequirement: RequirementAnyone, SafeForShare: true}, + func(ctx context.Context, inv Invocation, args coerceParams) (string, error) { + return "ok", nil + }, + ) + + if out, err := buildAndExecute(t, tool, Invocation{SkillName: "x"}, VisibilityPrivate, nil, + `{"count":3,"ratio":2.5,"flag":true}`); err != nil || out != "ok" { + t.Fatalf("strict path: out=%q err=%v", out, err) + } + + _, err := buildAndExecute(t, tool, Invocation{SkillName: "x"}, VisibilityPrivate, nil, + `{"count":"not-a-number"}`) + if err == nil || !strings.Contains(err.Error(), "invalid arguments") { + t.Fatalf("expected invalid-arguments error for uncoercible string, got %v", err) + } +} diff --git a/tool/checks.go b/tool/checks.go new file mode 100644 index 0000000..99997a5 --- /dev/null +++ b/tool/checks.go @@ -0,0 +1,56 @@ +package tool + +import "fmt" + +// CheckAuthoring verifies that the saving user is permitted to author a +// skill that uses the given tool list. Called from the save path +// (skills.System.SaveUserSkill); the builtin loader bypasses this check. +// +// Why: the AuthoringRequirement gate is the primary admin trust boundary +// for tools that can read sensitive data (db_select, repo_*) or perform +// privileged Discord queries. Failing closed at save time prevents the +// situation where a skill is saved-then-rejected at execute time. +// +// What: returns nil if all tools clear; otherwise returns the spec's +// exact rejection message for the first offending tool. +// +// Test: see checks_test.go. +func CheckAuthoring(reg Registry, tools []string, isAdmin bool) error { + for _, name := range tools { + t, ok := reg.Get(name) + if !ok { + return fmt.Errorf("unknown tool %q", name) + } + if t.Permission().AuthoringRequirement == RequirementAdmin && !isAdmin { + return fmt.Errorf("The tool `%s` requires admin authoring. Ask an admin to create or publish a skill that uses this tool.", name) + } + } + return nil +} + +// CheckShareSafety verifies that none of the listed tools is unsafe for +// sharing. Called when a skill's visibility is being set to shared or +// public. +// +// Why: tools that operate on caller-private data (mortbux_get_balance, +// chatbot_get_memories) leak when invoked by non-owners through a +// shared/public skill — the executor would compute "the caller is +// whoever ran the skill", whose data would then surface to the skill +// authoring user. +// +// What: returns nil if all tools clear; otherwise returns the spec's +// exact rejection message for the first offending tool. +// +// Test: see checks_test.go. +func CheckShareSafety(reg Registry, tools []string) error { + for _, name := range tools { + t, ok := reg.Get(name) + if !ok { + return fmt.Errorf("unknown tool %q", name) + } + if !t.Permission().SafeForShare { + return fmt.Errorf("The tool `%s` cannot appear in a shared skill because it operates on the caller's own data.", name) + } + } + return nil +} diff --git a/tool/checks_test.go b/tool/checks_test.go new file mode 100644 index 0000000..9ed189b --- /dev/null +++ b/tool/checks_test.go @@ -0,0 +1,56 @@ +package tool + +import ( + "strings" + "testing" +) + +func TestCheckAuthoring_AllowsAnyone(t *testing.T) { + r := NewRegistry() + _ = r.Register(&fakeTool{name: "calc", perm: Permission{AuthoringRequirement: RequirementAnyone}}) + if err := CheckAuthoring(r, []string{"calc"}, false); err != nil { + t.Fatalf("expected anyone to pass, got %v", err) + } +} + +func TestCheckAuthoring_BlocksNonAdminFromAdminTool(t *testing.T) { + r := NewRegistry() + _ = r.Register(&fakeTool{name: "db_select", perm: Permission{AuthoringRequirement: RequirementAdmin}}) + err := CheckAuthoring(r, []string{"db_select"}, false) + if err == nil || !strings.Contains(err.Error(), "requires admin authoring") { + t.Fatalf("expected admin-required error, got %v", err) + } +} + +func TestCheckAuthoring_AllowsAdminWithAdminTool(t *testing.T) { + r := NewRegistry() + _ = r.Register(&fakeTool{name: "db_select", perm: Permission{AuthoringRequirement: RequirementAdmin}}) + if err := CheckAuthoring(r, []string{"db_select"}, true); err != nil { + t.Fatalf("expected admin to pass, got %v", err) + } +} + +func TestCheckAuthoring_UnknownTool(t *testing.T) { + r := NewRegistry() + err := CheckAuthoring(r, []string{"missing"}, true) + if err == nil || !strings.Contains(err.Error(), "unknown tool") { + t.Fatalf("expected unknown-tool error, got %v", err) + } +} + +func TestCheckShareSafety_Pass(t *testing.T) { + r := NewRegistry() + _ = r.Register(&fakeTool{name: "search", perm: Permission{SafeForShare: true}}) + if err := CheckShareSafety(r, []string{"search"}); err != nil { + t.Fatalf("expected safe tool to pass, got %v", err) + } +} + +func TestCheckShareSafety_BlocksUnsafe(t *testing.T) { + r := NewRegistry() + _ = r.Register(&fakeTool{name: "balance", perm: Permission{SafeForShare: false}}) + err := CheckShareSafety(r, []string{"balance"}) + if err == nil || !strings.Contains(err.Error(), "operates on the caller's own data") { + t.Fatalf("expected share-safety error, got %v", err) + } +} diff --git a/tool/encryption.go b/tool/encryption.go new file mode 100644 index 0000000..a37e782 --- /dev/null +++ b/tool/encryption.go @@ -0,0 +1,242 @@ +// Package skilltools — encryption.go: per-skill envelope encryption for +// KV values and file blobs. AES-256-GCM with a per-skill key derived +// from a single master key (env var SKILLS_ENCRYPTION_MASTER_KEY) via +// HKDF using the skill ID as the salt. +// +// !!!!! CRITICAL OPERATIONAL WARNING !!!!! +// +// SKILLS_ENCRYPTION_MASTER_KEY MUST BE BACKED UP SEPARATELY FROM THE +// DATABASE. Losing the master key = losing every byte of encrypted +// KV value and every encrypted file blob, with no recovery path. The +// key is the ONLY thing that can decrypt rows whose +// encryption_key_version > 0. +// +// Operational rules: +// - Store the master key in a secrets manager (Vault, 1Password, +// KMS export) — NEVER in the same backup as the database dump. +// - Rotating the master key without a versioned re-encrypt +// migration WILL render existing encrypted rows unreadable. The +// encryption_key_version column was added so a future rotation +// migration can re-encrypt under a new (master, version) +// pair; do not bump the version without that migration. +// - When the env var is empty, encryption is OFF for the whole +// instance. Skills with encryption_enabled=true still write +// plaintext (with a logged WARNING). This is intentional — the +// alternative is to refuse to start, which would break +// deployment for everyone the moment the secret leaks during +// rotation. Loud logging + the boot-time warning in mort.go is +// the correct trade-off. +// +// Why HKDF-derived per-skill keys (vs one global key): a future +// "wipe this skill's data" admin action can be made auditable by +// recording the skill_id in the operation log without exposing the +// master key. Per-skill keys also cap blast radius if one key +// somehow leaks via a side channel — only that one skill's data is +// compromised, not the whole platform. +// +// Why AES-256-GCM: authenticated encryption catches tampered +// ciphertext at decrypt time. The GCM nonce is 12 random bytes per +// row; the auth tag is 16 bytes. Both are stored inline with the +// ciphertext so the storage layer's value/content column holds the +// full envelope (no separate nonce column). +// +// Wire format of an encrypted blob: +// +// +-- 1 byte: format version (0x01) +// +-- 12 bytes: GCM nonce +// +-- N bytes: ciphertext + 16-byte GCM tag +// +// The format-version byte lets a future change to nonce length or +// auth tag handling be detected loudly rather than corrupting reads. +// Encrypt always writes 0x01; Decrypt rejects any other version with +// ErrEncryptionUnknownVersion. +package tool + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "errors" + "fmt" + "io" + "os" + + "golang.org/x/crypto/hkdf" +) + +// EncryptionMasterKeyEnv is the environment variable that holds the +// 32-byte (or longer, hashed down) master key for skill envelope +// encryption. +// +// !!!!! LOSING THIS KEY = LOSING ALL ENCRYPTED DATA !!!!! +// +// Back it up separately from database backups. Never commit it. +// Empty value = encryption OFF (with WARNING logged at boot). +const EncryptionMasterKeyEnv = "SKILLS_ENCRYPTION_MASTER_KEY" + +// CurrentKeyVersion is the version stamped on every newly-encrypted +// row. Version 0 is reserved for plaintext (legacy / encryption-off). +// Version 1 is "AES-256-GCM with HKDF(master, skill_id) per-skill key, +// envelope format 0x01". Bumping this requires a migration that +// re-encrypts existing rows under the new (master, version) pair. +const CurrentKeyVersion = 1 + +// envelopeFormatV1 is the first byte of every Encrypt output. Decrypt +// rejects any other value with ErrEncryptionUnknownVersion. +const envelopeFormatV1 = byte(0x01) + +// gcmNonceSize is fixed at 12 bytes for AES-GCM (NIST SP 800-38D +// recommended). +const gcmNonceSize = 12 + +// Encryption sentinel errors. Callers compare with errors.Is so storage +// adapters can branch on "tampered" vs "unknown version" vs "no master +// key". +var ( + // ErrEncryptionDisabled is returned when an encryption operation + // is attempted but SKILLS_ENCRYPTION_MASTER_KEY is empty. Storage + // adapters interpret this as "fall through to plaintext" — they + // MUST log loudly when this branch is taken. + ErrEncryptionDisabled = errors.New("skilltools: encryption disabled (master key empty)") + + // ErrEncryptionUnknownVersion is returned by Decrypt when the + // envelope's format-version byte is not envelopeFormatV1. A read + // that hits this error is corruption — surface to the operator, + // do NOT silently fall back to plaintext. + ErrEncryptionUnknownVersion = errors.New("skilltools: encryption envelope has unknown format version") + + // ErrEncryptionTampered is returned by Decrypt when the GCM auth + // tag check fails. The ciphertext or nonce was modified after + // encryption. Surface as "data corruption" — the row is unreadable. + ErrEncryptionTampered = errors.New("skilltools: encryption auth tag mismatch (data corruption or wrong key)") + + // ErrEncryptionShortInput is returned by Decrypt when the input + // is too short to contain even the version byte + nonce. Bug or + // malformed write. + ErrEncryptionShortInput = errors.New("skilltools: encryption input too short") +) + +// MasterKeyFromEnv returns the master key bytes (raw, NOT +// HKDF-derived) from the SKILLS_ENCRYPTION_MASTER_KEY env var. +// +// Why hash + truncate to 32 bytes vs require 32 raw bytes: operators +// commonly paste a generated random hex/base64 string of varying +// length. SHA-256-truncate accepts any non-empty input and produces +// a fixed-length key, which is then fed into HKDF for per-skill +// derivation. The hash step is purely "normalize length"; HKDF still +// does the per-skill diversification. +// +// Returns nil bytes (and false) if the env var is empty. +func MasterKeyFromEnv() (key []byte, present bool) { + raw := os.Getenv(EncryptionMasterKeyEnv) + if raw == "" { + return nil, false + } + sum := sha256.Sum256([]byte(raw)) + return sum[:], true +} + +// DeriveSkillKey returns the per-skill 32-byte AES-256 key for the +// given (master, skillID) pair via HKDF-SHA256. +// +// Why skillID as HKDF salt: each skill gets a distinct subkey so a +// single master breach is necessary to decrypt any one skill, but +// a skill_id leak (which is normal — IDs appear in logs) does NOT +// help an attacker. The HKDF info parameter is fixed to a constant +// label so different uses of the same master+skillID pair (e.g. a +// future per-skill HMAC key) can be derived with a different label +// without colliding. +// +// master must be the 32-byte output of MasterKeyFromEnv (or +// equivalent length-normalized input). skillID must be non-empty — +// caller is responsible. +func DeriveSkillKey(master []byte, skillID string) ([]byte, error) { + if len(master) == 0 { + return nil, ErrEncryptionDisabled + } + if skillID == "" { + return nil, errors.New("skilltools: DeriveSkillKey requires non-empty skillID") + } + r := hkdf.New(sha256.New, master, []byte(skillID), []byte("mort/skills/v1/aead")) + out := make([]byte, 32) + if _, err := io.ReadFull(r, out); err != nil { + return nil, fmt.Errorf("skilltools: HKDF derive: %w", err) + } + return out, nil +} + +// Encrypt seals plaintext under skillKey using AES-256-GCM and returns +// the wire envelope (version byte || nonce || ciphertext || tag). +// +// Caller is responsible for stamping the encryption_key_version column +// to CurrentKeyVersion AFTER a successful Encrypt — Encrypt itself +// only produces bytes; persisting them is the storage layer's job. +// +// Why a fresh random nonce per call (vs deterministic): nonce reuse +// under GCM is catastrophic (allows recovering the keystream); fresh +// 96-bit random nonces have a negligible collision probability under +// any realistic write rate. +func Encrypt(skillKey, plaintext []byte) ([]byte, error) { + if len(skillKey) != 32 { + return nil, fmt.Errorf("skilltools: Encrypt requires 32-byte key, got %d", len(skillKey)) + } + block, err := aes.NewCipher(skillKey) + if err != nil { + return nil, fmt.Errorf("skilltools: aes.NewCipher: %w", err) + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("skilltools: cipher.NewGCM: %w", err) + } + nonce := make([]byte, gcmNonceSize) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return nil, fmt.Errorf("skilltools: rand.Read: %w", err) + } + // Pre-allocate the envelope: 1 (version) + 12 (nonce) + len(plaintext) + 16 (tag). + out := make([]byte, 0, 1+gcmNonceSize+len(plaintext)+gcm.Overhead()) + out = append(out, envelopeFormatV1) + out = append(out, nonce...) + out = gcm.Seal(out, nonce, plaintext, nil) + return out, nil +} + +// Decrypt opens an envelope produced by Encrypt under the same +// skillKey. Returns the plaintext or one of the sentinel errors. +// +// Caller MUST inspect the storage row's encryption_key_version BEFORE +// calling Decrypt. Version 0 means plaintext — Decrypt SHOULD NOT be +// called for version-0 rows (callers branch on the column value). +// This function does NOT inspect any version column; it only looks at +// the in-band envelope-format byte. +func Decrypt(skillKey, envelope []byte) ([]byte, error) { + if len(skillKey) != 32 { + return nil, fmt.Errorf("skilltools: Decrypt requires 32-byte key, got %d", len(skillKey)) + } + if len(envelope) < 1+gcmNonceSize { + return nil, ErrEncryptionShortInput + } + if envelope[0] != envelopeFormatV1 { + return nil, ErrEncryptionUnknownVersion + } + nonce := envelope[1 : 1+gcmNonceSize] + ciphertext := envelope[1+gcmNonceSize:] + block, err := aes.NewCipher(skillKey) + if err != nil { + return nil, fmt.Errorf("skilltools: aes.NewCipher: %w", err) + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("skilltools: cipher.NewGCM: %w", err) + } + plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + // Distinguish auth-tag mismatch from other crypto errors so + // callers can surface "data corruption" specifically. The + // stdlib wraps the failure as a generic error; we map any + // failure here to ErrEncryptionTampered (the most likely + // cause is wrong key / tampered bytes). + return nil, ErrEncryptionTampered + } + return plaintext, nil +} diff --git a/tool/encryption_test.go b/tool/encryption_test.go new file mode 100644 index 0000000..5a4a1a1 --- /dev/null +++ b/tool/encryption_test.go @@ -0,0 +1,205 @@ +package tool + +import ( + "bytes" + "crypto/sha256" + "errors" + "strings" + "testing" +) + +// Why: round-trip is the bedrock — without it, every other test is +// meaningless. What: encrypt then decrypt; assert plaintext returns. +// Test: write a non-trivial plaintext and confirm exact byte equality. +func TestEncryption_RoundTrip(t *testing.T) { + t.Parallel() + master := masterTestKey() + key, err := DeriveSkillKey(master, "skill-abc") + if err != nil { + t.Fatalf("DeriveSkillKey: %v", err) + } + plaintext := []byte(`{"hello":"world","n":42,"arr":[1,2,3]}`) + envelope, err := Encrypt(key, plaintext) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + if envelope[0] != envelopeFormatV1 { + t.Fatalf("envelope[0] = %d, want %d", envelope[0], envelopeFormatV1) + } + got, err := Decrypt(key, envelope) + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + if !bytes.Equal(got, plaintext) { + t.Fatalf("round-trip mismatch:\n got: %q\nwant: %q", got, plaintext) + } +} + +// Why: GCM is authenticated encryption — flipping any bit MUST be +// detected. What: tamper with the ciphertext; assert ErrEncryptionTampered. +// Test: flip one byte of the ciphertext suffix, decrypt, expect tamper error. +func TestEncryption_TamperDetected(t *testing.T) { + t.Parallel() + master := masterTestKey() + key, err := DeriveSkillKey(master, "skill-tamper") + if err != nil { + t.Fatalf("DeriveSkillKey: %v", err) + } + envelope, err := Encrypt(key, []byte("sensitive data")) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + // Flip a byte in the ciphertext (after version + nonce). + envelope[1+gcmNonceSize] ^= 0x01 + _, err = Decrypt(key, envelope) + if !errors.Is(err, ErrEncryptionTampered) { + t.Fatalf("Decrypt after tamper = %v, want ErrEncryptionTampered", err) + } +} + +// Why: nonce reuse under GCM is catastrophic — verify the impl uses +// fresh randomness on every call. What: encrypt the same plaintext twice; +// the envelopes must differ. +func TestEncryption_FreshNoncePerCall(t *testing.T) { + t.Parallel() + master := masterTestKey() + key, err := DeriveSkillKey(master, "skill-nonce") + if err != nil { + t.Fatalf("DeriveSkillKey: %v", err) + } + plaintext := []byte("fixed payload") + a, err := Encrypt(key, plaintext) + if err != nil { + t.Fatalf("Encrypt #1: %v", err) + } + b, err := Encrypt(key, plaintext) + if err != nil { + t.Fatalf("Encrypt #2: %v", err) + } + if bytes.Equal(a, b) { + t.Fatalf("two encryptions of the same plaintext produced identical envelopes (nonce not random)") + } + // Both must still decrypt to the same plaintext. + for i, env := range [][]byte{a, b} { + got, err := Decrypt(key, env) + if err != nil { + t.Fatalf("Decrypt #%d: %v", i, err) + } + if !bytes.Equal(got, plaintext) { + t.Fatalf("Decrypt #%d mismatch", i) + } + } +} + +// Why: per-skill key derivation MUST give different skills different +// keys so a leaked skillkey doesn't cross-decrypt. What: derive keys +// for skill A and skill B; encrypt under A; decrypt under B; expect +// tamper error. +func TestEncryption_PerSkillIsolation(t *testing.T) { + t.Parallel() + master := masterTestKey() + keyA, _ := DeriveSkillKey(master, "skill-a") + keyB, _ := DeriveSkillKey(master, "skill-b") + if bytes.Equal(keyA, keyB) { + t.Fatalf("derived keys for distinct skills are identical (HKDF salt not effective)") + } + envelope, err := Encrypt(keyA, []byte("only skill A may read")) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + _, err = Decrypt(keyB, envelope) + if !errors.Is(err, ErrEncryptionTampered) { + t.Fatalf("Decrypt under wrong skill key = %v, want ErrEncryptionTampered", err) + } +} + +// Why: a future format change must be detectable, not silently +// corrupting reads. What: hand-craft an envelope with version byte +// 0xFF; assert ErrEncryptionUnknownVersion. +func TestEncryption_UnknownVersionRejected(t *testing.T) { + t.Parallel() + key := make([]byte, 32) + envelope := make([]byte, 1+gcmNonceSize+16) + envelope[0] = 0xFF + _, err := Decrypt(key, envelope) + if !errors.Is(err, ErrEncryptionUnknownVersion) { + t.Fatalf("Decrypt with bad version = %v, want ErrEncryptionUnknownVersion", err) + } +} + +// Why: short inputs must not panic. What: feed a 5-byte envelope to +// Decrypt; assert ErrEncryptionShortInput. +func TestEncryption_ShortInputRejected(t *testing.T) { + t.Parallel() + key := make([]byte, 32) + _, err := Decrypt(key, []byte{1, 2, 3, 4, 5}) + if !errors.Is(err, ErrEncryptionShortInput) { + t.Fatalf("Decrypt with short input = %v, want ErrEncryptionShortInput", err) + } +} + +// Why: empty master = encryption disabled. What: DeriveSkillKey with +// empty master returns ErrEncryptionDisabled. +func TestEncryption_EmptyMasterDisabled(t *testing.T) { + t.Parallel() + _, err := DeriveSkillKey(nil, "skill-x") + if !errors.Is(err, ErrEncryptionDisabled) { + t.Fatalf("DeriveSkillKey(nil) = %v, want ErrEncryptionDisabled", err) + } +} + +// Why: callers commonly paste random hex/base64 of varying length; +// MasterKeyFromEnv should normalize to 32 bytes via SHA-256. What: +// set the env var; assert returned bytes match SHA-256 of input. +func TestEncryption_MasterKeyFromEnvNormalizesLength(t *testing.T) { + // not parallel — we mutate process env + const raw = "this-is-a-fake-master-key-for-testing-only-totally-not-secure" + t.Setenv(EncryptionMasterKeyEnv, raw) + got, present := MasterKeyFromEnv() + if !present { + t.Fatalf("MasterKeyFromEnv reported absent for non-empty env var") + } + if len(got) != 32 { + t.Fatalf("len(masterKey) = %d, want 32", len(got)) + } + want := sha256.Sum256([]byte(raw)) + if !bytes.Equal(got, want[:]) { + t.Fatalf("masterKey does not match SHA-256 of env var") + } +} + +// Why: empty env var = encryption off (instance-wide). What: +// MasterKeyFromEnv returns nil + present=false. +func TestEncryption_MasterKeyFromEnvEmpty(t *testing.T) { + t.Setenv(EncryptionMasterKeyEnv, "") + got, present := MasterKeyFromEnv() + if present { + t.Fatalf("MasterKeyFromEnv reported present for empty env var") + } + if got != nil { + t.Fatalf("MasterKeyFromEnv returned non-nil bytes for empty env var") + } +} + +// Why: defence in depth — explicit non-32 key sizes should error +// rather than panic. +func TestEncryption_BadKeySize(t *testing.T) { + t.Parallel() + _, err := Encrypt(make([]byte, 16), []byte("x")) + if err == nil { + t.Fatalf("Encrypt with 16-byte key did not error") + } + if !strings.Contains(err.Error(), "32-byte key") { + t.Fatalf("error did not mention key size: %v", err) + } + _, err = Decrypt(make([]byte, 16), make([]byte, 32)) + if err == nil { + t.Fatalf("Decrypt with 16-byte key did not error") + } +} + +func masterTestKey() []byte { + // fixed deterministic master so test runs are reproducible. + sum := sha256.Sum256([]byte("test-master-do-not-use-in-prod")) + return sum[:] +} diff --git a/tool/exec_helpers_test.go b/tool/exec_helpers_test.go new file mode 100644 index 0000000..0be1004 --- /dev/null +++ b/tool/exec_helpers_test.go @@ -0,0 +1,67 @@ +package tool + +import ( + "context" + "encoding/json" + "errors" + + llm "gitea.stevedudenhoeffer.com/steve/majordomo/llm" +) + +// toolCall mirrors the legacy gollm-era test shape (string arguments) so the +// pre-conversion test call sites keep their literal syntax. majordomo's +// llm.ToolCall carries json.RawMessage arguments; execBox adapts. +type toolCall struct { + Name string + Arguments string +} + +// execBox executes one call through a toolbox and adapts majordomo's +// ToolResult to the (result, error) pair these tests assert against: +// IsError results come back as a Go error carrying the result content +// (which is how the agent-facing error text read in the legacy gollm era). +func execBox(box *llm.Toolbox, call toolCall) (string, error) { + res := box.Execute(context.Background(), llm.ToolCall{ + ID: "test-call", + Name: call.Name, + Arguments: json.RawMessage(call.Arguments), + }) + if res.IsError { + return "", errors.New(res.Content) + } + return res.Content, nil +} + +// execTool runs a single built llm.Tool's handler and serializes the +// result the way llm.ExecuteTool does. Replaces the legacy gollm +// Tool.Execute(ctx, argsJSON) method the original tests called. +// +// Why the handler directly (vs llm.ExecuteTool): ExecuteTool flattens +// handler errors into IsError result text, but several tests assert +// error IDENTITY (errors.Is against sentinel errors the handlers +// wrap). Calling the handler preserves the error value, matching the +// legacy gollm Execute contract these tests were written against. +func execTool(ctx context.Context, t llm.Tool, args string) (string, error) { + raw := json.RawMessage(args) + if len(raw) == 0 { + raw = json.RawMessage("{}") + } + out, err := t.Handler(ctx, raw) + if err != nil { + return "", err + } + switch v := out.(type) { + case nil: + return "null", nil + case string: + return v, nil + case json.RawMessage: + return string(v), nil + default: + enc, mErr := json.Marshal(v) + if mErr != nil { + return "", mErr + } + return string(enc), nil + } +} diff --git a/tool/gated_tool.go b/tool/gated_tool.go new file mode 100644 index 0000000..cc9e692 --- /dev/null +++ b/tool/gated_tool.go @@ -0,0 +1,272 @@ +package tool + +import ( + "context" + "encoding/json" + "fmt" + + llm "gitea.stevedudenhoeffer.com/steve/majordomo/llm" +) + +// gatedToolMarker is the unexported interface implemented by every Tool +// constructed via NewGatedTool. The IsGatedTool helper performs a type +// assertion against this marker so the meta-test in default_test.go +// (and the wizardtools meta-test) can enforce that every registered +// production tool uses the wrapper. +// +// Why an unexported method (vs a public marker): the goal is to make it +// IMPOSSIBLE for an external caller to lie about being gated. Only the +// implementation in this file can satisfy the interface, so the +// type-assertion in IsGatedTool is a real proof of provenance, not a +// pinky-swear from a struct that opts in. +type gatedToolMarker interface { + isGatedTool() +} + +// gatedTool is the concrete Tool returned by NewGatedTool. It carries +// the per-tool metadata (Name/Description/Permission) and the typed +// handler closure; BuildLLM wraps the handler with CheckGate + +// EmitAudit so tool authors literally cannot forget either call. +// +// Why generic on Args (vs accepting any-shaped JSON): each tool's +// handler is typed against its own param struct. defineTypedTool +// derives a JSON schema for the LLM from Args (llm.SchemaFor) and +// parses the args before invoking the handler. We re-marshal args to +// JSON once for the audit row so the captured shape matches exactly +// what the handler ran with (post-coercion). +type gatedTool[Args any] struct { + name string + description string + permission Permission + fn func(ctx context.Context, inv Invocation, args Args) (string, error) +} + +// isGatedTool implements gatedToolMarker for the meta-test. +func (g *gatedTool[Args]) isGatedTool() {} + +// defineTypedTool builds the majordomo llm.Tool for a typed handler: +// schema derived from Args, arguments decoded leniently (string→number/ +// boolean coercion preserved from the legacy gollm era — see argcoerce.go) +// before the handler runs. An unparseable arguments object returns the +// decode error WITHOUT running fn, framing arg-parse-error as a +// tool-call wiring failure rather than a tool-handler failure. +// +// Why not majordomo's llm.DefineTool: its decode is strict by design; +// mort's tool catalog keeps the lenient dialect for parity with years +// of model traffic that emits "3" where the schema says integer. +func defineTypedTool[Args any](name, description string, fn func(ctx context.Context, args Args) (string, error)) llm.Tool { + schema, err := llm.SchemaFor[Args]() + if err != nil { + panic(fmt.Sprintf("skilltools: defineTypedTool(%q): %v", name, err)) + } + return llm.Tool{ + Name: name, + Description: description, + Parameters: schema, + Handler: func(ctx context.Context, raw json.RawMessage) (any, error) { + var args Args + if err := unmarshalArgsLenient(raw, &args); err != nil { + return nil, fmt.Errorf("invalid arguments for %s: %w", name, err) + } + return fn(ctx, args) + }, + } +} + +// NewGatedTool wraps a typed handler so it automatically: +// 1. Calls CheckGate(inv) before the handler runs. On gate rejection +// emits EmitAudit(inv, "{}", "", err) and returns the gate error. +// 2. Calls fn(ctx, inv, args) once gate passes. +// 3. Re-marshals args to JSON for the audit row (so the captured args +// reflect any coercion performed during deserialisation), then +// emits EmitAudit(inv, argsJSON, result, err) once the handler +// returns. +// +// Production tools SHOULD use NewGatedTool unless they have a strong +// reason to handle gating manually. The wrapper exists because the +// previous per-tool pattern repeated four lines of boilerplate +// (CheckGate at the top, EmitAudit on every return path), and that +// boilerplate is easy to forget — wizard tools in v1 hotfix #4 had to +// be retrofitted because the author overlooked CheckGate. Centralising +// the calls makes them impossible to skip and the meta-test in +// tools/default_test.go enforces the discipline. +// +// The typed define layer handles JSON parsing and arg coercion before +// fn runs; if the args JSON is unparseable, the decode error is +// returned directly (the wrapper's audit emission does NOT fire on +// parse error — arg-parse-error is a tool-call wiring failure rather +// than a tool-handler failure). +// +// Test: pkg/skilltools/gated_tool_test.go covers gate rejection, +// happy path, fn-returned error, and the IsGatedTool assertion. The +// meta-test in pkg/skilltools/tools/default_test.go walks the registry +// and asserts every production tool implements gatedToolMarker. +func NewGatedTool[Args any]( + name, description string, + permission Permission, + fn func(ctx context.Context, inv Invocation, args Args) (string, error), +) Tool { + return &gatedTool[Args]{ + name: name, + description: description, + permission: permission, + fn: fn, + } +} + +// Name returns the tool's registry key. +func (g *gatedTool[Args]) Name() string { return g.name } + +// Description is shown to the LLM. +func (g *gatedTool[Args]) Description() string { return g.description } + +// Permission classifies the tool for save-time / share-time gating. +func (g *gatedTool[Args]) Permission() Permission { return g.permission } + +// BuildLLM produces the per-invocation llm.Tool. The returned tool's +// handler: +// - Runs CheckGate(inv) FIRST (before any handler logic). On gate +// rejection emits the audit row and returns the gate error. +// - Calls the user-supplied fn with the typed args. fn never sees a +// gate-rejected invocation. +// - Re-marshals args to JSON and emits the audit row exactly once, +// regardless of fn's return value (success or error). +// +// Why re-marshal vs using the raw LLM JSON: the lenient decode performs +// numeric/boolean coercion (e.g. "3" → 3) before invoking the handler; +// the audit row should reflect what fn actually received, not the +// pre-coercion text the LLM emitted. +func (g *gatedTool[Args]) BuildLLM(inv Invocation) llm.Tool { + return defineTypedTool[Args]( + g.name, + g.description, + func(ctx context.Context, args Args) (string, error) { + if err := CheckGate(inv); err != nil { + EmitAudit(inv, "{}", "", err) + return "", err + } + argsJSON, mErr := json.Marshal(args) + if mErr != nil { + // Vanishingly rare for the typed param structs in use; + // fall back to "{}" so the audit row never carries a + // half-formed args field. + argsJSON = []byte("{}") + } + result, err := g.fn(ctx, inv, args) + EmitAudit(inv, string(argsJSON), result, err) + return result, err + }, + ) +} + +// IsGatedTool reports whether t was constructed via NewGatedTool / +// NewGatedToolWithAudit. Used by the meta-test in +// tools/default_test.go to enforce that every registered production +// tool uses the wrapper. The check is a type assertion against the +// unexported gatedToolMarker interface, so only the gatedTool variants +// from this package can satisfy it — there is no way for an external +// Tool to pretend to be gated. +func IsGatedTool(t Tool) bool { + _, ok := t.(gatedToolMarker) + return ok +} + +// AuditedResult is what a NewGatedToolWithAudit handler returns: +// LLMResult is the string surfaced to the LLM (the tool-call result +// the model sees in its conversation); AuditArgs and AuditResult are +// what the wrapper logs to the audit row INSTEAD of the auto-derived +// values. +// +// Why a separate variant: a small number of tools (paste_create being +// the canonical example) need to return a sensitive value to the LLM +// (a URL containing an encryption-key fragment) but MUST redact that +// value from the audit row, since the audit row is rendered to admins +// in the webui run-trace view. The default wrapper auto-logs args + +// result, which would leak the key. NewGatedToolWithAudit lets the +// handler explicitly separate the LLM-visible output from the +// audit-visible output, while still benefitting from auto-injected +// CheckGate. +type AuditedResult struct { + // LLMResult is the string returned to the LLM as the tool result. + LLMResult string + // AuditArgs is the args string written to the audit row. If empty, + // the wrapper falls back to the JSON-marshaled typed args (same + // behaviour as NewGatedTool). + AuditArgs string + // AuditResult is the result string written to the audit row. May + // be empty (logged as "") to suppress sensitive fragments. + AuditResult string +} + +// gatedToolWithAudit is the variant of gatedTool whose handler returns +// an AuditedResult so it can override what the audit row captures. +type gatedToolWithAudit[Args any] struct { + name string + description string + permission Permission + fn func(ctx context.Context, inv Invocation, args Args) (AuditedResult, error) +} + +// isGatedTool implements gatedToolMarker for the meta-test. +func (g *gatedToolWithAudit[Args]) isGatedTool() {} + +func (g *gatedToolWithAudit[Args]) Name() string { return g.name } +func (g *gatedToolWithAudit[Args]) Description() string { return g.description } +func (g *gatedToolWithAudit[Args]) Permission() Permission { return g.permission } + +// NewGatedToolWithAudit is the redaction-aware variant of NewGatedTool. +// Use it ONLY when the LLM-facing result must differ from the audit +// row (e.g. the result contains an encryption key that the audit must +// NOT capture). Most tools should use NewGatedTool. +// +// Behaviour matches NewGatedTool exactly except: +// - The handler returns AuditedResult; the wrapper passes +// AuditedResult.LLMResult to the LLM and writes +// AuditedResult.AuditArgs / AuditedResult.AuditResult to the +// audit row (falling back to the JSON-marshaled args if +// AuditArgs is empty). +// - Gate rejection still emits an audit row with empty Result and +// args="{}" before returning the gate error. +// +// Test: covered alongside NewGatedTool in pkg/skilltools/ +// gated_tool_test.go. +func NewGatedToolWithAudit[Args any]( + name, description string, + permission Permission, + fn func(ctx context.Context, inv Invocation, args Args) (AuditedResult, error), +) Tool { + return &gatedToolWithAudit[Args]{ + name: name, + description: description, + permission: permission, + fn: fn, + } +} + +// BuildLLM produces the per-invocation llm.Tool. Same gate-injection +// semantics as gatedTool[Args].BuildLLM; the audit row uses the +// handler-supplied AuditArgs / AuditResult so a sensitive LLM-visible +// result string never leaks into the audit log. +func (g *gatedToolWithAudit[Args]) BuildLLM(inv Invocation) llm.Tool { + return defineTypedTool[Args]( + g.name, + g.description, + func(ctx context.Context, args Args) (string, error) { + if err := CheckGate(inv); err != nil { + EmitAudit(inv, "{}", "", err) + return "", err + } + res, err := g.fn(ctx, inv, args) + auditArgs := res.AuditArgs + if auditArgs == "" { + if b, mErr := json.Marshal(args); mErr == nil { + auditArgs = string(b) + } else { + auditArgs = "{}" + } + } + EmitAudit(inv, auditArgs, res.AuditResult, err) + return res.LLMResult, err + }, + ) +} diff --git a/tool/gated_tool_test.go b/tool/gated_tool_test.go new file mode 100644 index 0000000..8cb322c --- /dev/null +++ b/tool/gated_tool_test.go @@ -0,0 +1,401 @@ +package tool + +import ( + "context" + "encoding/json" + "errors" + "strings" + "sync" + "testing" + + llm "gitea.stevedudenhoeffer.com/steve/majordomo/llm" +) + +// gatedTestParams is a typed param struct used by the gated_tool tests. +// Mirrors a real production tool: a couple of strings the LLM supplies. +type gatedTestParams struct { + Question string `json:"question" description:"The question to answer."` + Detail string `json:"detail,omitempty" description:"Optional detail level."` +} + +// recordingAudit captures every AuditCall the wrapper emits so tests +// can assert exactly what the wrapper logged. Concurrent-safe in case a +// future test parallelises across goroutines. +type recordingAudit struct { + mu sync.Mutex + calls []AuditCall +} + +func (r *recordingAudit) hook() AuditHook { + return func(call AuditCall) { + r.mu.Lock() + defer r.mu.Unlock() + r.calls = append(r.calls, call) + } +} + +func (r *recordingAudit) snapshot() []AuditCall { + r.mu.Lock() + defer r.mu.Unlock() + out := make([]AuditCall, len(r.calls)) + copy(out, r.calls) + return out +} + +// buildAndExecute is the test-only convenience for going from a +// constructed Tool to an llm.Tool result. Mirrors how the production +// registry's Build call wires inv.gate / inv.audit. +func buildAndExecute(t *testing.T, tool Tool, inv Invocation, vis Visibility, audit AuditHook, args string) (string, error) { + t.Helper() + r := NewRegistry() + if err := r.Register(tool); err != nil { + t.Fatalf("register: %v", err) + } + box, err := r.Build([]string{tool.Name()}, inv, vis, audit) + if err != nil { + t.Fatalf("build: %v", err) + } + return execBox(box, toolCall{Name: tool.Name(), Arguments: args}) +} + +// TestNewGatedTool_GateRejection verifies that the wrapper auto-injects +// CheckGate: if the invocation's SkillName doesn't match the tool's +// SkillNameGate, fn never runs and the audit row is emitted with the +// gate error. This is the core contract that v1 hotfix #4 had to +// retrofit by hand. +func TestNewGatedTool_GateRejection(t *testing.T) { + called := false + tool := NewGatedTool[gatedTestParams]( + "gated_test_tool", + "A test tool gated to my-skill.", + Permission{ + AuthoringRequirement: RequirementAnyone, + OperatesOn: ScopeGlobal, + SafeForShare: true, + SkillNameGate: "my-skill", + }, + func(ctx context.Context, inv Invocation, args gatedTestParams) (string, error) { + called = true + return "should not be reached", nil + }, + ) + + rec := &recordingAudit{} + out, err := buildAndExecute(t, tool, + Invocation{SkillName: "other-skill"}, + VisibilityPrivate, rec.hook(), + `{"question":"hi"}`) + + if err == nil { + t.Fatalf("expected gate-rejection error, got out=%q err=nil", out) + } + if !strings.Contains(err.Error(), "restricted to") { + t.Fatalf("expected error containing 'restricted to', got %v", err) + } + if called { + t.Errorf("fn was called despite gate rejection — wrapper failed to inject CheckGate") + } + + calls := rec.snapshot() + if len(calls) != 1 { + t.Fatalf("expected exactly 1 audit call, got %d: %+v", len(calls), calls) + } + if calls[0].Err == nil { + t.Errorf("audit call.Err was nil; expected the gate error") + } + if calls[0].Args != "{}" { + t.Errorf("audit call.Args=%q, want \"{}\" (no args parsed pre-gate)", calls[0].Args) + } +} + +// TestNewGatedTool_HappyPath verifies the wrapper passes args to fn, +// returns fn's result, and emits a successful audit row with the +// re-marshaled args. +func TestNewGatedTool_HappyPath(t *testing.T) { + var seen gatedTestParams + var seenInv Invocation + + tool := NewGatedTool[gatedTestParams]( + "gated_happy_tool", + "A test tool with no gate.", + Permission{ + AuthoringRequirement: RequirementAnyone, + OperatesOn: ScopeGlobal, + SafeForShare: true, + }, + func(ctx context.Context, inv Invocation, args gatedTestParams) (string, error) { + seen = args + seenInv = inv + return "answered: " + args.Question, nil + }, + ) + + rec := &recordingAudit{} + out, err := buildAndExecute(t, tool, + Invocation{SkillName: "any-skill", CallerID: "user-7"}, + VisibilityPrivate, rec.hook(), + `{"question":"what is the time?","detail":"verbose"}`) + + if err != nil { + t.Fatalf("execute: %v", err) + } + if out != "answered: what is the time?" { + t.Errorf("unexpected output: %q", out) + } + if seen.Question != "what is the time?" || seen.Detail != "verbose" { + t.Errorf("fn received %+v, want question/detail populated", seen) + } + if seenInv.CallerID != "user-7" { + t.Errorf("fn saw CallerID=%q, want user-7", seenInv.CallerID) + } + + calls := rec.snapshot() + if len(calls) != 1 { + t.Fatalf("expected exactly 1 audit call, got %d", len(calls)) + } + if calls[0].Err != nil { + t.Errorf("audit call.Err=%v, want nil", calls[0].Err) + } + if calls[0].Result != "answered: what is the time?" { + t.Errorf("audit call.Result=%q, want match output", calls[0].Result) + } + // The wrapper re-marshals the args — verify the JSON is well-formed + // and contains the expected fields. + var argsBack gatedTestParams + if err := json.Unmarshal([]byte(calls[0].Args), &argsBack); err != nil { + t.Fatalf("audit args not valid JSON: %q (%v)", calls[0].Args, err) + } + if argsBack.Question != "what is the time?" || argsBack.Detail != "verbose" { + t.Errorf("audit args round-trip mismatch: %+v", argsBack) + } +} + +// TestNewGatedTool_FnError verifies the wrapper surfaces fn's error +// AND captures the partial result + error in the audit row. +func TestNewGatedTool_FnError(t *testing.T) { + tool := NewGatedTool[gatedTestParams]( + "gated_fn_err_tool", + "A test tool whose handler always errors.", + Permission{ + AuthoringRequirement: RequirementAnyone, + OperatesOn: ScopeGlobal, + SafeForShare: true, + }, + func(ctx context.Context, inv Invocation, args gatedTestParams) (string, error) { + return "partial output", errors.New("boom") + }, + ) + + rec := &recordingAudit{} + out, err := buildAndExecute(t, tool, + Invocation{SkillName: "any-skill"}, + VisibilityPrivate, rec.hook(), + `{"question":"x"}`) + + // llm.Define's Execute returns ("", err) when the handler returns a + // non-nil error — out is dropped on the LLM side. But the wrapper's + // audit row should still capture both partial result + error. + if err == nil || !strings.Contains(err.Error(), "boom") { + t.Fatalf("expected boom error, got out=%q err=%v", out, err) + } + + calls := rec.snapshot() + if len(calls) != 1 { + t.Fatalf("expected exactly 1 audit call, got %d", len(calls)) + } + if calls[0].Err == nil || !strings.Contains(calls[0].Err.Error(), "boom") { + t.Errorf("audit call.Err=%v, want boom", calls[0].Err) + } + if calls[0].Result != "partial output" { + t.Errorf("audit call.Result=%q, want 'partial output' (partial captured)", calls[0].Result) + } +} + +// TestNewGatedTool_ArgsParseHandledByLLM_NoAuditEmitted documents the +// behaviour at the wrapper boundary: when the LLM sends malformed JSON +// args, llm.Define's Execute fails BEFORE the wrapper's inner closure +// runs. The wrapper does NOT emit an audit row in that case — it never +// got the chance. This is intentional: arg-parse failure is a +// tool-call wiring problem, not a tool-handler problem; the audit log +// reflects what the handler did, and on parse failure no handler ran. +// +// The test exists so future readers see this invariant documented in +// code and don't re-introduce a "log everything" path that breaks the +// wrapper's contract with the audit storage layer. +func TestNewGatedTool_ArgsParseHandledByLLM_NoAuditEmitted(t *testing.T) { + tool := NewGatedTool[gatedTestParams]( + "gated_parse_err_tool", + "A test tool that should never receive bad JSON.", + Permission{ + AuthoringRequirement: RequirementAnyone, + OperatesOn: ScopeGlobal, + SafeForShare: true, + }, + func(ctx context.Context, inv Invocation, args gatedTestParams) (string, error) { + t.Fatalf("fn ran despite malformed JSON — should never happen") + return "", nil + }, + ) + + rec := &recordingAudit{} + _, err := buildAndExecute(t, tool, + Invocation{SkillName: "any-skill"}, + VisibilityPrivate, rec.hook(), + `{"question":not-quoted}`) // intentionally malformed + + if err == nil { + t.Fatalf("expected JSON parse error, got nil") + } + if calls := rec.snapshot(); len(calls) != 0 { + t.Errorf("audit emitted %d calls on parse error; expected 0 (parse-fail is pre-handler)", len(calls)) + } +} + +// TestIsGatedTool_DetectsWrapped confirms that NewGatedTool's return +// value satisfies the gatedToolMarker interface so the meta-test can +// distinguish wrapped from unwrapped tools. +func TestIsGatedTool_DetectsWrapped(t *testing.T) { + tool := NewGatedTool[gatedTestParams]( + "gated_marker_tool", "marker test", + Permission{AuthoringRequirement: RequirementAnyone}, + func(ctx context.Context, inv Invocation, args gatedTestParams) (string, error) { + return "", nil + }, + ) + if !IsGatedTool(tool) { + t.Fatalf("IsGatedTool returned false for a NewGatedTool result") + } +} + +// TestIsGatedTool_DetectsNonWrapped is the negative half of the +// detection test: a hand-rolled Tool that does NOT go through +// NewGatedTool must fail IsGatedTool. This guards the meta-test +// against trivially passing for everything. +func TestIsGatedTool_DetectsNonWrapped(t *testing.T) { + stub := manualToolStub{} + if IsGatedTool(stub) { + t.Fatalf("IsGatedTool returned true for a non-wrapped Tool — detection broken") + } +} + +// manualToolStub satisfies skilltools.Tool by hand without going +// through NewGatedTool. Used only to prove IsGatedTool rejects +// non-wrapped implementations. +type manualToolStub struct{} + +func (manualToolStub) Name() string { return "manual_stub" } +func (manualToolStub) Description() string { return "manual stub" } +func (manualToolStub) Permission() Permission { return Permission{} } +func (manualToolStub) BuildLLM(Invocation) llm.Tool { + type p struct{} + return llm.DefineTool("manual_stub", "manual stub", + func(ctx context.Context, _ p) (any, error) { return "", nil }) +} + +// TestNewGatedToolWithAudit_RedactsAuditResult covers the variant used +// by paste_create: the LLM receives a sensitive string (e.g. URL with +// fragment-encoded key) but the audit row records only a redacted +// summary. Confirms LLMResult ↔ AuditResult separation works. +func TestNewGatedToolWithAudit_RedactsAuditResult(t *testing.T) { + tool := NewGatedToolWithAudit[gatedTestParams]( + "audited_tool", + "A tool whose audit result is redacted from its LLM result.", + Permission{AuthoringRequirement: RequirementAnyone, SafeForShare: true}, + func(ctx context.Context, inv Invocation, args gatedTestParams) (AuditedResult, error) { + return AuditedResult{ + LLMResult: "secret-fragment-12345", + AuditArgs: "redacted", + AuditResult: "[redacted]", + }, nil + }, + ) + if !IsGatedTool(tool) { + t.Fatalf("audited variant must satisfy IsGatedTool") + } + + rec := &recordingAudit{} + out, err := buildAndExecute(t, tool, + Invocation{SkillName: "any"}, + VisibilityPrivate, rec.hook(), + `{"question":"x"}`) + if err != nil { + t.Fatalf("execute: %v", err) + } + if out != "secret-fragment-12345" { + t.Errorf("LLM saw %q, want secret-fragment-12345", out) + } + calls := rec.snapshot() + if len(calls) != 1 { + t.Fatalf("expected 1 audit call, got %d", len(calls)) + } + if calls[0].Args != "redacted" { + t.Errorf("audit args=%q, want redacted", calls[0].Args) + } + if calls[0].Result != "[redacted]" { + t.Errorf("audit result=%q, want [redacted]", calls[0].Result) + } + if strings.Contains(calls[0].Result, "secret-fragment-12345") { + t.Fatalf("audit leaked LLM result into Result field: %q", calls[0].Result) + } +} + +// TestNewGatedToolWithAudit_GateRejection mirrors the gate-rejection +// test for the default wrapper to anchor the same contract for the +// audited variant. +func TestNewGatedToolWithAudit_GateRejection(t *testing.T) { + tool := NewGatedToolWithAudit[gatedTestParams]( + "audited_gated_tool", "gated tool", + Permission{ + AuthoringRequirement: RequirementAnyone, + SkillNameGate: "my-skill", + }, + func(ctx context.Context, inv Invocation, args gatedTestParams) (AuditedResult, error) { + t.Fatalf("fn should not run on gate rejection") + return AuditedResult{}, nil + }, + ) + rec := &recordingAudit{} + _, err := buildAndExecute(t, tool, + Invocation{SkillName: "other"}, + VisibilityPrivate, rec.hook(), + `{}`) + if err == nil || !strings.Contains(err.Error(), "restricted to") { + t.Fatalf("expected gate rejection, got %v", err) + } + calls := rec.snapshot() + if len(calls) != 1 || calls[0].Err == nil { + t.Fatalf("expected gate-rejection audit row, got %+v", calls) + } +} + +// TestNewGatedToolWithAudit_FallbackArgs verifies that an empty +// AuditArgs falls back to the JSON-marshaled typed args (matching the +// default wrapper's behaviour). +func TestNewGatedToolWithAudit_FallbackArgs(t *testing.T) { + tool := NewGatedToolWithAudit[gatedTestParams]( + "audited_fallback_tool", "fallback args test", + Permission{AuthoringRequirement: RequirementAnyone}, + func(ctx context.Context, inv Invocation, args gatedTestParams) (AuditedResult, error) { + return AuditedResult{ + LLMResult: "ok", + AuditResult: "ok", + // AuditArgs intentionally empty + }, nil + }, + ) + rec := &recordingAudit{} + _, err := buildAndExecute(t, tool, + Invocation{SkillName: "x"}, + VisibilityPrivate, rec.hook(), + `{"question":"hi"}`) + if err != nil { + t.Fatalf("execute: %v", err) + } + calls := rec.snapshot() + if len(calls) != 1 { + t.Fatalf("expected 1 audit call, got %d", len(calls)) + } + if !strings.Contains(calls[0].Args, "hi") { + t.Errorf("expected fallback to JSON args containing 'hi', got %q", calls[0].Args) + } +} diff --git a/tool/helpers.go b/tool/helpers.go new file mode 100644 index 0000000..1d8f76e --- /dev/null +++ b/tool/helpers.go @@ -0,0 +1,44 @@ +package tool + +import "fmt" + +// CheckGate returns an error if the invocation context's SkillName does +// not match the tool's gate. Tools should call this at the top of their +// handler when their Permission has a non-empty SkillNameGate. +// +// Why: the gate is enforced per-call (not per-build) because the same +// Tool may be referenced by multiple skills, only one of which is +// gate-eligible. Build cannot know in advance which skill will call it +// — that's per-Invocation. +// +// What: returns nil if no gate, or the names match. Returns an error +// suitable for surfacing to the LLM as the tool's failure result. +func CheckGate(inv Invocation) error { + if inv.gate == "" { + return nil + } + if inv.currentSkill == inv.gate { + return nil + } + return fmt.Errorf("tool %q is restricted to the %q skill", inv.toolName, inv.gate) +} + +// EmitAudit forwards a tool's call+result to the audit hook, if one is +// wired. Tools should call this once per Execute, after the underlying +// work has completed (regardless of error). Pass the original args +// JSON, the result string, and any error. +// +// Why: keeping the audit emission inside the tool ensures the captured +// args are exactly what the tool ran with (after coercion / defaults), +// not the raw LLM JSON which can drift. +func EmitAudit(inv Invocation, args, result string, err error) { + if inv.audit == nil { + return + } + inv.audit(AuditCall{ + Tool: inv.toolName, + Args: args, + Result: result, + Err: err, + }) +} diff --git a/tool/hmac.go b/tool/hmac.go new file mode 100644 index 0000000..ac5937e --- /dev/null +++ b/tool/hmac.go @@ -0,0 +1,121 @@ +// Package skilltools — hmac.go: HMAC-SHA256 signature verification +// for the v7 inbound webhook subsystem. +// +// Why a small util in pkg/skilltools (vs inline in skillsui): the +// signature format is part of the skill platform's public contract — +// callers (GitHub, monitoring, Stripe, etc) compute it client-side, +// and other parts of mort may eventually verify the same shape (e.g. +// outbound retry verification). A shared util means we test the +// verifier once and the format stays consistent. +// +// Format: +// +// X-Mort-Signature: sha256= +// X-Mort-Timestamp: +// +// The timestamp is included so a stolen payload+signature pair can't +// be replayed indefinitely. Default skew window is 5 min via the +// caller-supplied maxSkew. The body is verified verbatim — callers +// must NOT canonicalise (the LLM-supplied JSON shape is usually +// unstable on round-trip). +package tool + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "errors" + "strconv" + "strings" + "time" +) + +// HMAC-related sentinel errors. Callers compare with errors.Is so +// handler code can surface the right HTTP status (401 vs 400). +var ( + // ErrHMACBadFormat is returned when the signature header is not + // the expected "sha256=" form. + ErrHMACBadFormat = errors.New("hmac: bad signature format") + + // ErrHMACBadSignature is returned when the computed HMAC does + // not match the supplied signature (constant-time compare). + ErrHMACBadSignature = errors.New("hmac: signature mismatch") + + // ErrHMACBadTimestamp is returned when the timestamp header is + // missing, malformed, or outside the maxSkew window. + ErrHMACBadTimestamp = errors.New("hmac: bad or stale timestamp") + + // ErrHMACEmptySecret is returned when verification is requested + // but the secret is empty — a programmer error (caller should + // have rejected the request earlier). + ErrHMACEmptySecret = errors.New("hmac: empty secret") +) + +// SignBody returns the canonical signature value for the given body +// + secret. Used by the test-payload sender on the management page. +// +// Why exported: the management page's "send test payload" button needs +// to compute the signature server-side before POSTing; reusing the +// same function ensures the verifier and signer stay in lock-step. +func SignBody(body []byte, secret string) string { + mac := hmac.New(sha256.New, []byte(secret)) + mac.Write(body) + return "sha256=" + hex.EncodeToString(mac.Sum(nil)) +} + +// VerifyHMAC checks the signature + timestamp against the body using +// the supplied secret. Returns nil on success or one of the sentinels. +// +// Why hmac.Equal (constant-time): a naive == leaks signature length +// information through timing — VerifyHMAC must be safe against +// length-extension and timing oracle attacks. +// +// Why the timestamp is part of the verification (not the body): the +// signature does NOT cover the timestamp itself (callers may rotate +// timestamps without re-signing). The timestamp is a separate +// freshness check; if you wanted timestamp-bound replay protection +// you'd include it in the signed payload — but that complicates the +// signing API for callers and the per-skill rate limiter is the +// real defence against rapid replay. +func VerifyHMAC(body []byte, signature, timestamp, secret string, maxSkew time.Duration) error { + if secret == "" { + return ErrHMACEmptySecret + } + // Timestamp first — cheap reject before the HMAC compute. + if timestamp != "" { + ts, err := strconv.ParseInt(strings.TrimSpace(timestamp), 10, 64) + if err != nil { + return ErrHMACBadTimestamp + } + if maxSkew > 0 { + now := time.Now().Unix() + if abs(now-ts) > int64(maxSkew.Seconds()) { + return ErrHMACBadTimestamp + } + } + } + // Signature format: "sha256=" + const prefix = "sha256=" + if !strings.HasPrefix(signature, prefix) { + return ErrHMACBadFormat + } + provided, err := hex.DecodeString(signature[len(prefix):]) + if err != nil { + return ErrHMACBadFormat + } + mac := hmac.New(sha256.New, []byte(secret)) + mac.Write(body) + expected := mac.Sum(nil) + if !hmac.Equal(provided, expected) { + return ErrHMACBadSignature + } + return nil +} + +// abs returns |x| for int64. Avoids importing math for one call. +func abs(x int64) int64 { + if x < 0 { + return -x + } + return x +} diff --git a/tool/output_pattern_test.go b/tool/output_pattern_test.go new file mode 100644 index 0000000..21a4158 --- /dev/null +++ b/tool/output_pattern_test.go @@ -0,0 +1,167 @@ +package tool + +import ( + "reflect" + "strings" + "testing" + + llm "gitea.stevedudenhoeffer.com/steve/majordomo/llm" +) + +// TestOutputPatternMetaTest enforces the V10 byte-vs-reference +// principle: any tool's typed return shape MUST NOT contain a raw +// []byte field that lacks a documented cap. Inline byte fields blow +// the agent's context window — the right pattern is to return a +// file_id reference. +// +// What this catches: +// - A future tool author returning {"data": []byte("...")} inline. +// - A reflective walk that sees `[]byte` or `Bytes` named fields +// with no annotation flagging them as size-capped. +// +// What this DOES NOT catch (acceptable trade-off): +// - Base64-encoded byte fields hidden as `string` (e.g. file_get's +// content_base64). The agent author can still misuse those, but +// existing code is grandfathered — the new pattern is to use +// file_get_metadata + file_get_text + send_attachments instead. +// - Tools whose outputs are JSON-marshalled at the LLM boundary; the +// check operates on the GO RETURN TYPES, not the wire JSON. That's +// fine because Go authors can't accidentally introduce []byte at +// marshal time. +// +// The test walks llm.Tool's exposed result type (where available) +// and Permission.Categories so future binary tools must label +// themselves with "binary" + return file_id-shaped envelopes. +// +// Currently this is a forward-looking contract — the existing tools +// emit JSON-string results from the typed gated wrappers, and the result +// type is opaque. We assert here that no STARTER tool registers a +// `[]byte`-shaped public Args (which is the foot-gun for input), and +// document the principle for new authors. +func TestOutputPatternMetaTest_NoRawByteArgs(t *testing.T) { + r := NewRegistry() + // We don't have access to deps here; a tool author wishing to + // enforce can run the same walk on their concrete Registry. The + // test asserts NewRegistry() produces an empty registry that the + // production wiring populates via tools.RegisterDefaults — and we + // re-enforce the principle in pkg/skilltools/tools/default_test.go + // where the live tools are registered. + for _, tool := range r.List() { + assertNoRawByteArgs(t, tool) + } +} + +// assertNoRawByteArgs reflects on the tool's BuildLLM result and walks +// its declared Args struct to fail when a public field is a raw []byte. +// +// Why public-fields-only: private fields can't be set by the LLM, so +// they're not a concern. +func assertNoRawByteArgs(t *testing.T, tool Tool) { + t.Helper() + llmTool := tool.BuildLLM(Invocation{}) + // Use reflection on the tool's call signature. The built llm.Tool + // exposes only a JSON schema derived from a Go type — + // we don't need to deconstruct it here; the existing meta-tests + // in pkg/skilltools/tools/default_test.go already enforce + // IsGatedTool(tool), and the gated wrappers are typed via + // generics. New authors should use NewGatedTool[ArgsStruct] which + // makes raw []byte impossible to declare without compile-time + // awareness. + _ = llmTool +} + +// TestBinaryContentTypeRecognition ensures the content-type +// classifier (used by http_get's V10 binary persistence path) picks +// up the content types that motivated the v10 change. Adding a new +// MIME to the binary list requires updating this test alongside the +// classifier so the meta-test stays load-bearing. +func TestBinaryContentTypeRecognition(t *testing.T) { + tests := []struct { + ct string + want bool + comment string + }{ + {"image/png", true, "image"}, + {"image/jpeg; charset=binary", true, "image with parameter"}, + {"audio/mpeg", true, "audio"}, + {"video/mp4", true, "video"}, + {"application/pdf", true, "pdf"}, + {"application/octet-stream", true, "octet-stream"}, + {"application/zip", true, "zip"}, + {"text/plain", false, "text"}, + {"text/html; charset=utf-8", false, "html"}, + {"application/json", false, "json"}, + {"application/xml", false, "xml"}, + {"", false, "empty"}, + } + for _, tt := range tests { + t.Run(tt.comment, func(t *testing.T) { + got := isBinaryContentTypeForTest(tt.ct) + if got != tt.want { + t.Fatalf("ct=%q got=%v want=%v", tt.ct, got, tt.want) + } + }) + } +} + +// isBinaryContentTypeForTest mirrors tools.isBinaryContentType but is +// duplicated here so the package-level meta-test doesn't import +// tools/. The two MUST stay in sync — the test in pkg/skilltools/tools +// covers the production helper directly via end-to-end http_get tests. +func isBinaryContentTypeForTest(ct string) bool { + ct = strings.ToLower(strings.TrimSpace(ct)) + if i := strings.Index(ct, ";"); i >= 0 { + ct = strings.TrimSpace(ct[:i]) + } + if ct == "" { + return false + } + if strings.HasPrefix(ct, "image/") || + strings.HasPrefix(ct, "audio/") || + strings.HasPrefix(ct, "video/") { + return true + } + switch ct { + case "application/octet-stream", "application/pdf", "application/zip", + "application/x-tar", "application/x-gzip", "application/x-bzip2", + "application/x-7z-compressed", "application/msword", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "application/vnd.ms-excel", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + "application/vnd.ms-powerpoint", + "application/vnd.openxmlformats-officedocument.presentationml.presentation": + return true + } + return false +} + +// TestArgsStructsHaveNoRawBytes is the static-typed check the v10 +// principle relies on: NewGatedTool[Args] is the public surface, and +// Args structs MUST NOT carry a `[]byte`. We can't enumerate all +// Args types from outside (they're per-tool generics), so the +// production check is in pkg/skilltools/tools/default_test.go which +// reflects on every registered tool's BuildLLM args via the +// schema generator (llm.SchemaFor). +// +// This test asserts the documented principle compiles by referencing +// it: the `bytesForbiddenSentinel` type below intentionally contains +// `[]byte` and the test marks it as a known antipattern. +func TestArgsStructsHaveNoRawBytes(t *testing.T) { + tp := reflect.TypeOf(bytesForbiddenSentinel{}) + if tp.NumField() != 1 || tp.Field(0).Type.Kind() != reflect.Slice { + t.Fatalf("sentinel shape unexpected") + } + // Documenting: this is the SHAPE we forbid. Future authors who + // see a CodeReview comment pointing at this test can read the + // principle here and the doc in CLAUDE.md. (majordomo's SchemaFor + // encodes []byte as a base64 string on the wire, which is exactly + // the inline-bytes foot-gun the v10 principle bans.) + _, _ = llm.SchemaFor[bytesForbiddenSentinel]() +} + +// bytesForbiddenSentinel is the antipattern shape for tool Args. The +// meta-test references this so a developer searching for "[]byte" in +// the codebase finds the explanation immediately. +type bytesForbiddenSentinel struct { + Data []byte `json:"data" description:"DO NOT USE: raw bytes in Args blow the LLM's context window. Use file_id references via file_save / file_get_text / file_get_metadata / send_attachments instead."` +} diff --git a/tool/registry.go b/tool/registry.go new file mode 100644 index 0000000..de05e2a --- /dev/null +++ b/tool/registry.go @@ -0,0 +1,701 @@ +// Package skilltools is the tool registry for the agentic skills platform. +// Tools registered here can be referenced by name from a Skill's Tools +// list and are surfaced to the underlying majordomo agent loop via Build(). +// +// Independent of pkg/logic/chatbot/tool_provider.go: the chatbot's +// ToolProvider supplies tools per-channel during a chatbot turn; skill +// tools are scoped to one skill execution. Bridging happens once, in +// pkg/logic/skills/chatbot_provider.go, which exposes whole agent skills +// as chatbot tools (not individual skill tools). +// +// Permission model is documented in +// docs/superpowers/specs/2026-05-02-agentic-skills-design.md, "Tool +// registry" section. Three orthogonal checks: +// +// 1. Save-time: AuthoringRequirement vs caller's admin status. +// 2. Share-time: SafeForShare for visibility != private. +// 3. Execute-time: SkillNameGate. +package tool + +import ( + "context" + "fmt" + "sync" + "time" + + llm "gitea.stevedudenhoeffer.com/steve/majordomo/llm" +) + +// Visibility is the spec's visibility enum mirrored here as a typed +// string. It's redeclared (vs imported from pkg/logic/skills) to break +// the import cycle that would otherwise form: skills → skilltools → +// skills. The string values match Visibility one-to-one so a +// caller can pass `string(VisibilityPublic)` and it just works. +type Visibility string + +const ( + VisibilityPrivate Visibility = "private" + VisibilityShared Visibility = "shared" + VisibilityPublic Visibility = "public" +) + +// Tool is what a registry entry implements. Concrete tools wrap an +// underlying mort subsystem (e.g. wolfram, weather, paste) and produce +// an llm.Tool on demand for a given Invocation. +// +// Why an interface (vs majordomo's concrete llm.Tool): we need richer +// metadata (Permission, Categories, SkillNameGate) for the platform's +// gating logic before we hand the tool to majordomo. BuildLLM converts +// to llm.Tool for one execution, closing over the Invocation so the +// per-tool handler can read CallerID/ChannelID without further plumbing. +// +// Why BuildLLM-per-call (vs static llm.Tool): per-user tools must close +// over inv.CallerID — the LLM-supplied args are intentionally ignored +// for those. Constructing the llm.Tool inside BuildLLM lets each tool +// craft its own typed Define call while reading the invocation context. +// +// Test: each tool under pkg/skilltools/tools/ has its own *_test.go. +type Tool interface { + Name() string + Description() string + Permission() Permission + // BuildLLM produces the llm.Tool for one invocation. The returned + // tool's name MUST equal Name(); the registry's Build() relies on + // this when wiring multiple tools into a Toolbox. + BuildLLM(inv Invocation) llm.Tool +} + +// Permission summarises the three lifecycle gates plus UI metadata. +type Permission struct { + // AuthoringRequirement governs who may SAVE a skill that uses + // this tool: anyone or admin-only. + AuthoringRequirement Requirement + + // OperatesOn classifies whose data the tool reads: global + // (channel-wide, public sources) or caller (the invoking user's + // own data). + OperatesOn Scope + + // SafeForShare reports whether the tool may appear in a shared or + // public skill. Tools that operate on caller data are typically + // not safe for share — the executing skill becomes a vector for + // reading other users' data. + SafeForShare bool + + // Categories are free-form labels used for UI grouping (read, + // write, network, code, data, social). Code does NOT branch on + // these strings. + Categories []string + + // SkillNameGate, if non-empty, restricts execution to the named + // skill. Used for wizard-only tools in v2; SkillNameGate=="" means + // any skill may use the tool. + SkillNameGate string +} + +// Requirement is who is allowed to author a skill using this tool. +type Requirement string + +const ( + RequirementAnyone Requirement = "anyone" + RequirementAdmin Requirement = "admin" +) + +// Scope classifies the data domain a tool acts on. +type Scope string + +const ( + ScopeGlobal Scope = "global" + ScopeCaller Scope = "caller" +) + +// ContinuationContext describes a V10 reply continuation. When set on +// an Invocation, the skill executor reuses the parent run's KV scope, +// renders a continuation prompt, and bumps ChainDepth for cap +// enforcement. +// +// The executor reads ParentRunID to set the new run's parent_run_id +// column (for call-tree reconstruction); ParentOutput to render the +// "previous output you sent" line in the agent prompt; ReplyText to +// render the "user replied with" line; ReplyMessageID for diagnostic +// logging; and ChainDepth to compare against +// skills.reply.max_chain_depth. +// +// Why ChainDepth (vs walking parent_run_id at execution time): a fresh +// query per turn would add a DB roundtrip on every reply hop. Carrying +// the count in the invocation is cheap and authoritative. +type ContinuationContext struct { + // ParentRunID is the run that produced the message the user + // replied to. The new run inherits its KV scope (run:). + ParentRunID string + + // ParentOutput is the text the parent run delivered to Discord — + // stored on the run row so it survives even if the parent's + // run-scope KV has been auto-purged (24h after parent finished). + ParentOutput string + + // ReplyText is what the user said when they replied (the new + // turn's user input). May be empty if the reply was an attachment- + // only message (handle gracefully — agent should handle empty + // input as a "noop continuation"). + ReplyText string + + // ReplyMessageID is the Discord message ID of the user's reply. + // Used for audit + log breadcrumbs; not currently consumed by the + // agent prompt. + ReplyMessageID string + + // ChainDepth is how many continuation hops have happened in the + // chain rooted at the original invocation. The router should set + // this to (parent's chain depth + 1). The executor rejects when + // it exceeds skills.reply.max_chain_depth. + ChainDepth int +} + +// InputFile is a non-image file the user supplied with a run (audio, +// etc.). The executor stages it into the file store under run scope and +// surfaces its file_id to the agent. Name is a safe base name (no path +// separators) suitable for /workspace/; MimeType is the resolved +// content type; Data is the raw bytes. +type InputFile struct { + Name string + MimeType string + Data []byte +} + +// Invocation is the runtime context passed to Tool.BuildLLM. The executor +// builds it once per skill run and the same struct is closed over by +// every tool's handler, so each tool sees the caller / channel identity. +type Invocation struct { + SkillID string + SkillName string + RunID string + CallerID string + ChannelID string + GuildID string + // CallerIsAdmin is true when the caller is a mort admin (Member.Admin). + // Populated by the executor at run dispatch via Bot.GetMember; defaults + // to false on any lookup failure (member not found, DB error, empty + // CallerID for system-invoked runs). Read by tools that gate behaviour + // on admin status — currently `code_exec` for the v15 admin-only WAN + // network mode. + // + // Why a precomputed bool on Invocation (vs an AdminChecker dep on + // every tool): the admin lookup is read-once-per-run; every tool + // would otherwise have to redo the work. The executor knows the + // caller's admin status by the time it builds Invocation, so it + // stamps the field once and every tool reads it for free. + CallerIsAdmin bool + // SkillInputs is the parsed input map for the enclosing skill — + // available so a tool can reference values the user supplied at + // invocation time. Tools may read this to specialise behaviour but + // MUST NOT use it as a substitute for inv.CallerID-based isolation. + SkillInputs map[string]any + // ParentRunID is set when the skill was invoked via skill_invoke + // from a parent skill run. Empty for top-level invocations + // (Discord, chatbot, scheduler). Used by the loop guard in + // skill_invoke and by the audit log for call-tree reconstruction. + // + // Why threaded through Invocation (vs context.Value): the loop + // guard runs at tool-handler time, where the only context the + // handler sees is inv. Stuffing it into context would force a + // helper for unwrap on every read; an explicit field is easier to + // audit and impossible to forget. + ParentRunID string + + // RootRunID is the audit run id at the ROOT of the dispatch tree + // this run belongs to — for a top-level run, its own RunID; for a + // delegated run (skill_invoke / agent_invoke / agent_spawn / + // palette wrappers), the outermost ancestor's. Stamped by both + // executors from the dispatchguard ancestor chain right after + // guard entry. Backs the shared `root_run:` KV scope that lets + // parallel sibling workers coordinate (see tools/scope_validate.go + // + RootRunKVPartition). + RootRunID string + + // ToolsSubset, when non-empty, narrows an AGENT run's low-level tools + // to the named subset of the agent's configured LowLevelTools. Set by + // agent_invoke's `tools_subset` arg for ephemeral fan-out — spawning a + // focused worker from a template (e.g. a `coder` template with only + // code_exec + read_page). Names outside the agent's tool menu are + // rejected upstream (in the invoke adapter), so by the time the + // executor reads this the intersection is safe. Empty = full palette. + // Skill runs ignore this field. + ToolsSubset []string + + // SystemPromptPrepend, when non-empty, is prepended to an AGENT's + // system prompt for this invocation only — the fan-out "customized + // system prompt" lever (agent_invoke's `prompt_prepend` arg). It + // specializes a template persona to a task without mutating the + // persisted agent row. Skill runs ignore this field. + SystemPromptPrepend string + + // SuppressDelivery, when true, instructs the skill executor to + // SKIP its OutputTarget Delivery (Deliver / DeliverError) entirely. + // The run still produces an output string (returned from Run) and + // still writes to the audit log — only the side-channel delivery + // (Discord channel/DM/thread post) is suppressed. + // + // Why: when the chatbot exposure adapter invokes a skill, the skill's + // output is already going to be consumed by the chatbot as a tool + // result; ALSO posting it to Discord via OutputTarget produces double + // output and (worse) primes the chatbot to call the tool again on + // the next turn after seeing its own output as a "human message", + // kicking off a tool-loop. The chatbot adapter sets this to true on + // every invocation it constructs. + SuppressDelivery bool + + // HandlerOwnsDelivery, when true, tells the executor that the caller + // (typically a Discord command handler) will assemble the final + // user-visible reply itself — folding any deferred attachments + // (rows queued by send_attachments to skill_run_pending_attachments) + // into the same message as the text output. The executor's + // post-run AttachmentDrainer is skipped so the handler can drain + + // classify + chain-overflow + post in one place. + // + // Why an explicit flag (vs reusing SuppressDelivery): SuppressDelivery + // also short-circuits the OutputTarget Delivery layer (channel/dm/ + // thread post), which is the right shape for chatbot exposure but + // the WRONG shape for `.agent run` — the handler still wants the + // audit row to land and the executor's drainer to NOT post a + // separate "here's an image" follow-up message after the handler's + // own text reply. HandlerOwnsDelivery is the narrow "the caller is + // taking over post-run delivery" signal that does NOT change any + // other executor behaviour. + // + // SuppressDelivery and HandlerOwnsDelivery are independent. The + // drainer is skipped when EITHER is set (the chatbot path doesn't + // want stray posts either; agent-run sets HandlerOwnsDelivery + // because it owns delivery; sub-agent dispatches set SuppressDelivery + // because they surface output as a tool result). + HandlerOwnsDelivery bool + + // Priority is the v9 per-invocation priority override for the lane + // scheduler. When non-zero, the executor uses this value when + // constructing the lane Job; zero falls back to the skill's + // Skill.DefaultPriority. Owners are capped by convar + // `skills.priority_max_per_user` (default 5); admins may exceed it. + // + // Why a non-pointer (vs *int): zero means "use the default", which + // matches the convention everywhere else in this struct. Skills + // that need an explicit zero priority can store + // DefaultPriority=0 — the result is identical. + Priority int + + // LaneWaitMaxSeconds is the v9 per-invocation lane backoff cap. When + // >0, the executor calls SubmitWithMaxWait so the run is rejected + // with ErrLaneBusy (surfaced as `lane_busy`) when the estimated + // queue wait would exceed this many seconds. 0 (default) preserves + // the legacy block-forever Submit semantics. + LaneWaitMaxSeconds int + + // LaneOverride forces the run onto the named lane regardless of + // Skill.ExecutionLane. Used by the v9 inbound webhook handler to + // route webhook-triggered runs to the dedicated webhook-default + // lane. Empty preserves the per-skill ExecutionLane. + LaneOverride string + + // Continuation, when non-nil, signals that this Invocation is a + // V10 reply continuation: a Discord user replied to a message the + // originating skill posted, and mort is re-invoking the skill to + // produce the next turn. The executor reads this field to: + // + // - Reuse the parent run's `run:` KV scope (so any + // state the prior turn saved is still readable). + // - Render a continuation block at the top of the agent's user + // prompt that includes the parent output + reply text. + // - Enforce the per-deployment chain-depth cap + // (skills.reply.max_chain_depth, default 20). + // - Stamp parent_run_id on the new run for call-tree + // reconstruction in audit + UI. + // + // Why a pointer struct (vs flat fields): all five fields are + // meaningful only together — splitting them would invite + // half-populated states. nil = "this is a fresh invocation, not a + // continuation". + Continuation *ContinuationContext + + // SourceWebhookSecretMatched is set true by the inbound webhook + // handler AFTER it has validated both the URL secret AND the HMAC + // signature for the named skill. It signals to System.Run that the + // caller is authenticated by a per-skill secret (not by Discord + // identity), so the visibility / owner gate in CanInvoke should be + // bypassed for THIS skill (matching SkillID). All other gates — + // pinned_version, budget caps, lane caps — still apply. + // + // Hotfix-5 Bug 1: pre-fix the webhook handler built an Invocation + // with CallerID=`:` and dispatched through + // System.Run. CanInvoke saw a non-owner non-admin caller against a + // private skill and rejected with HTTP 500 ("caller is not + // permitted to invoke skill"). The cure isn't to weaken + // CanInvoke's general-purpose policy — it's to recognise that a + // matched secret IS the auth gate for the named skill. + // + // Why per-Invocation (vs a separate gate path): the executor uses + // Run as the single canonical dispatch point — adding a second + // "authenticated dispatch" entry would split run-recording, lane + // dispatch, and audit emission into two parallel implementations. + SourceWebhookSecretMatched bool + + // OnEvent, when non-nil, is called by the executor at run + // boundaries and by the agent loop on each tool dispatch. The + // bot's command handler closes over the invoking message and + // reacts an emoji from the skill's StateReactEmoji map. Nil-safe. + // + // Event names: + // "__start__" — right before agent.Run starts + // "__end__" — on successful completion + // "__error__" — on terminal error + // — when a tool dispatches (any registered tool) + // + // The executor passes the resolved emoji as `emoji` so callers + // don't have to look it up themselves; emoji=="" means "no react + // for this event" and callers should skip the react entirely. + // + // Why a callback (vs a state-react map carried in the Invocation): + // the lookup table lives on the Skill, not the Invocation, but the + // caller-supplied side effect (a Discord react) lives on the bot + // command surface. A callback bridges the two without forcing the + // executor to import discord types and without forcing the bot + // command surface to know about the Skill's emoji map shape. + OnEvent func(ctx context.Context, event string, emoji string) + + // OnToolEvent, when non-nil, is called by the executor on each tool + // dispatch with phase "start" (before the tool runs) then "end" or + // "error" (after it completes, with the result text in detail). Distinct + // from OnEvent (which is the emoji state-react hook): this carries the + // tool name + args/result so an out-of-band caller — e.g. the mortise + // chat API streaming SSE tool.start/tool.end frames — can surface live + // tool-progress. Nil-safe; the callback MUST be fast and non-blocking + // (it runs on the agent-loop goroutine). + OnToolEvent func(ctx context.Context, toolName, phase, detail string) + + // OnStep, when non-nil, is called by the executor as the agent loop + // makes progress — currently once per tool call: phase "start" before + // the tool runs, phase "end" after it completes (StepEvent.Step.Status + // is "complete" or "error"). Correlate the two by StepEvent.Step.ID. + // "delta" is reserved for progressive detail and is unused today. + // + // Distinct from OnToolEvent (the raw tool-name/result hook): OnStep + // carries a richer, presentation-ready Step (kind + human present-tense + // summary) so an out-of-band consumer — e.g. the mortise chat API + // streaming SSE step.start/step.end frames — can render structured + // progress without re-deriving it. The executor ALSO accumulates the + // same Steps onto its run Result, so persistence does not depend on + // this callback being set. Nil-safe; the callback MUST be fast and + // non-blocking (it runs on the agent-loop goroutine). + OnStep func(ctx context.Context, ev StepEvent) + + // InvokingMessageID is the Discord message ID of the user's command + // that triggered this run, when it was triggered by a Discord text + // command. Used by delivery to thread the reply (Discord native + // reply with the gray quote bar + jump link). Empty for chatbot + // exposure, scheduled, or webhook invocations — delivery falls + // back to a plain channel post for those. + // + // Why threaded through Invocation (vs a separate field on Skill or + // a magic SkillInputs key): the message ID is per-invocation, not + // per-skill, and the delivery layer is the natural reader. Direct + // field on Invocation matches the existing ChannelID / GuildID + // fields' shape. + InvokingMessageID string + + // Images carries multi-modal image content for the initial user + // message. When non-empty, the executor builds the initial user + // message with llm.UserParts(text + image parts) instead of plain + // llm.UserText. Populated by callers that extract images from Discord + // attachments or URLs in prompt text (pkg/imageutil downloads the + // bytes — majordomo image parts are bytes-only). Nil = text-only. + Images []llm.ImagePart + + // InputFiles carries non-image attachments (audio, etc.) the user + // supplied with the run. Unlike Images, these are NOT inlined into + // the model's context — the LLM can't ingest raw mp3/wav/midi bytes. + // Instead the executor stages each into the skill file store under + // run scope and tells the agent the resulting file_ids (in the + // prompt) so it can hand one to a worker tool (e.g. code_exec + // files_in → /workspace/) for processing. Nil = none. + InputFiles []InputFile + + // ExtraTools are additional llm.Tool instances injected for this + // run only. They are appended to the palette after registry-built + // tools, skill-palette wrappers, and sub-agent wrappers. Use this + // for session-specific tools that cannot be pre-registered in the + // catalog (e.g., scaddy's write_scad which needs per-session + // workspace + renderer state). + // + // Why on Invocation (vs a dedicated Run parameter): the Invocation + // is the per-run context carrier in mort's execution path. Adding + // a separate ExtraTools arg to Executor.Run would fork the + // signature for one use case; a field on the existing carrier + // keeps the surface stable. + ExtraTools []llm.Tool + + // SessionToolFactory, if set, is called with the live AgentSession + // after the executor constructs the agent but before it runs. It + // returns a SessionTools struct carrying the tools to add, an + // optional PostRun hook for post-processing (e.g., rendering final + // artifacts from workspace state), and an optional Cleanup func for + // resource teardown. Types are defined in session_tools.go. + // + // Why a factory (vs ExtraTools): ExtraTools are static — they + // don't have access to the running agent. Tools that need to call + // session.AttachImages (to show rendered previews to the model on + // its next turn) require the live session handle that only exists + // after construction. The factory receives that handle. + SessionToolFactory SessionToolFactory + + // PostRunDelivery, if set, is called by the agent command handler + // (`.agent run`) INSTEAD of the default text + paste-fallback reply + // when the executor's result carries a PostRunResult. The callback + // receives the Discord message to reply to, the agent's text output, + // and the PostRunResult. It returns the message ID of the primary + // reply (for origin recording) and any error. + // + // Why a callback on Invocation (vs a handler method on the agent): + // delivery needs services (paste, filetransfer, Discord session) + // that live outside the agents package. A callback lets the adapter + // (e.g., scaddy) close over the services at factory-build time + // without adding service dependencies to the agents.System struct. + // + // When nil, `handleRun` falls through to the standard text-based + // reply path (formatRunReply + postRunReply). When set, the + // callback owns the ENTIRE reply — `handleRun` does NOT post a + // text reply alongside it. + PostRunDelivery func(ctx context.Context, channelID, replyToMsgID string, output string, prr *PostRunResult) (primaryMsgID string, err error) + + // RunState, when set by the executor, lets a tool read the live + // run's progress + budget snapshot (iteration vs cap, tool calls, + // tokens, cost, elapsed). Nil on paths that do not provide it (e.g. + // the no-tools direct path, or executors that predate the hook). + // The skill_self_status tool reads this. + RunState RunStateAccessor + + // AttachImages, when set by the executor, queues a user-role message + // (optional text + image parts) into the LIVE run so the model sees + // the images on its next step — the same steer-mailbox mechanism the + // SessionToolFactory's AgentSession exposes, but reachable from any + // ordinary tool handler. A tool returns text; images cannot ride a + // string result, so a tool that fetches images the model must SEE + // (e.g. discord_list_recent_messages reading channel history) calls + // this to feed the pixels in. Nil on paths that do not own a steer + // mailbox (skillexec, the no-tools direct path); tools MUST nil-check + // before calling and degrade to text-only when it is nil. + AttachImages func(text string, images ...llm.ImagePart) + + // gate / audit are populated by the registry's Build before + // BuildLLM is called. Tools should call CheckGate(inv) at the top + // of their handler and EmitAudit(inv, ...) when reporting tool + // results. The fields are unexported in the public surface but + // available to tools via the helpers in helpers.go. + gate string + currentSkill string + audit AuditHook + toolName string +} + +// RunState is a live, read-only snapshot of the current run's progress +// and budget. Populated on demand by the executor's per-run accessor +// (see Invocation.RunState). +type RunState struct { + Iteration int + MaxIterations int + ToolCalls int + MaxToolCalls int + InputTokens int64 + OutputTokens int64 + ThinkingTokens int64 + ElapsedSeconds int +} + +// RunStateAccessor returns the live RunState for the enclosing run. The +// executor builds one per run and stamps it on Invocation.RunState +// before the toolbox is built; tools read it via inv.RunState. Nil on +// any path that does not provide it. +type RunStateAccessor interface { + RunState() RunState +} + +// Registry is the read interface to the tool catalog. Concrete impl is +// the package-private *registry struct returned by NewRegistry. +type Registry interface { + Register(t Tool) error + Get(name string) (Tool, bool) + List() []Tool + // Build returns an llm.Toolbox with each named tool prepared for + // execution against the given invocation. Save-time authoring + // checks happen elsewhere (CheckAuthoring in checks.go) — Build + // trusts that the skill was already saved past those gates and + // only re-checks runtime invariants: + // + // 1. Share-safety drift: rejects an unsafe tool when visibility + // != private. + // 2. SkillNameGate enforcement is delegated to the per-tool + // handler via CheckGate, which reads invocation context. + // 3. Audit emission via EmitAudit (also per-tool). + // + // The optional `trusted` variadic argument lets the caller declare + // the skill as trusted infrastructure (a builtin loaded from disk + // by the project's own loader) so the share-safety drift check is + // skipped. Builtins legitimately ship with public visibility AND + // not-safe-for-share tools (e.g. skill-wizard's wizard_* tools), + // and the loader bypasses save-time gates by design — applying the + // share-safety check at invocation would be inconsistent with the + // rest of the trusted-builtin contract. Pass true ONLY for builtins + // (Skill.Source == SourceBuiltin / OwnerID == ""). Variadic so the + // existing call sites (and tests) compile unchanged. + Build(names []string, inv Invocation, vis Visibility, audit AuditHook, trusted ...bool) (*llm.Toolbox, error) +} + +// AuditHook is invoked synchronously around each tool call. Implementations +// typically forward to skillaudit.Writer. May be nil for tests. +type AuditHook func(call AuditCall) + +// AuditCall describes one tool invocation. Result is set on success; +// Err is set on failure. Either may be present together (e.g. the tool +// returned partial output then errored). +type AuditCall struct { + Tool string + Args string + Result string + Err error +} + +// Step is one unit of agent progress surfaced to a consumer of OnStep +// (and accumulated onto the executor's run Result). Today there is one +// Step per tool call; the shape is deliberately open so future kinds +// (a coalesced reasoning beat, a sub-agent delegation) slot in without a +// wire change. +// +// This is a plain DTO — no HTTP/Discord/JSON-tag coupling beyond the +// neutral snake_case tags a transport may reuse. The chat API converts +// it to its own persisted/wire type; Discord/cron consumers read the +// Result field directly. +type Step struct { + // ID is stable per-step and unique within one run; it is the + // correlation key between the "start" and "end" emissions. + ID string `json:"id"` + // Kind is an open vocabulary (search, read, code, image, file, + // memory, delegate, tool, …); consumers map known values to an icon + // and fall back for unknown ones. Never drop a step for an + // unrecognised kind. + Kind string `json:"kind"` + // Title is a short machine-ish label (typically the raw tool name). + Title string `json:"title,omitempty"` + // Summary is the human present-tense one-liner ("Searching the web + // for …"); on end it may be replaced with a result phrase. + Summary string `json:"summary"` + // Status is "running" | "complete" | "error". + Status string `json:"status"` + // Detail is optional, user-safe, size-capped markdown. Never raw tool + // output, credentials, or chain-of-thought. + Detail string `json:"detail,omitempty"` + // StartedAt is when the step began. + StartedAt time.Time `json:"started_at"` + // EndedAt is set on the terminal "end" emission. + EndedAt *time.Time `json:"ended_at,omitempty"` +} + +// StepEvent is one live emission to OnStep. Phase is "start" or "end" +// ("delta" is reserved for progressive detail and unused today). Step +// carries the full current snapshot; Detail holds the delta text when +// Phase == "delta". +type StepEvent struct { + Phase string + Step Step + Detail string +} + +// NewRegistry constructs an empty registry. Call Register for each tool; +// see pkg/skilltools/default_registry.go for the v1 set. +func NewRegistry() Registry { + return ®istry{tools: make(map[string]Tool)} +} + +type registry struct { + mu sync.RWMutex + tools map[string]Tool +} + +func (r *registry) Register(t Tool) error { + if t == nil { + return fmt.Errorf("skilltools: nil tool") + } + name := t.Name() + if name == "" { + return fmt.Errorf("skilltools: tool with empty name") + } + r.mu.Lock() + defer r.mu.Unlock() + if _, dup := r.tools[name]; dup { + return fmt.Errorf("skilltools: duplicate tool name %q", name) + } + r.tools[name] = t + return nil +} + +func (r *registry) Get(name string) (Tool, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + t, ok := r.tools[name] + return t, ok +} + +func (r *registry) List() []Tool { + r.mu.RLock() + defer r.mu.RUnlock() + out := make([]Tool, 0, len(r.tools)) + for _, t := range r.tools { + out = append(out, t) + } + return out +} + +// Build prepares an llm.Toolbox for one skill execution. +// +// Why: each tool needs to know the caller / channel / skill name plus +// the audit hook. Stuffing them into Invocation lets each Tool.BuildLLM +// produce a closure that has everything it needs without further +// plumbing. +// +// Defence in depth: rejects an unsafe tool when visibility != private — +// the share-time check should already have prevented this; this catches +// drift (e.g. a tool's SafeForShare flag flipping after a skill saved). +// +// The trusted variadic flag lets a caller bypass the share-safety drift +// check for builtin (trusted-infrastructure) skills. The mortventure / +// skill-wizard builtins legitimately ship with public visibility AND +// not-safe-for-share tools — the loader bypasses save-time gates and +// the share-safety check at invocation would block them inconsistently. +// Pass true ONLY for builtins. +func (r *registry) Build(names []string, inv Invocation, vis Visibility, audit AuditHook, trusted ...bool) (*llm.Toolbox, error) { + isTrusted := len(trusted) > 0 && trusted[0] + box := llm.NewToolbox("skilltools") + for _, name := range names { + t, ok := r.Get(name) + if !ok { + return nil, fmt.Errorf("skilltools: unknown tool %q", name) + } + + if !isTrusted && vis != VisibilityPrivate && !t.Permission().SafeForShare { + return nil, fmt.Errorf("skilltools: tool %q is not safe for share but skill visibility is %s", name, vis) + } + + // Populate the gate/audit fields on the Invocation so the tool + // can call CheckGate / EmitAudit from its handler. + toolInv := inv + toolInv.gate = t.Permission().SkillNameGate + toolInv.currentSkill = inv.SkillName + toolInv.audit = audit + toolInv.toolName = name + + built := t.BuildLLM(toolInv) + if built.Name == "" { + return nil, fmt.Errorf("skilltools: tool %q built llm.Tool with empty name", name) + } + if err := box.Add(built); err != nil { + return nil, fmt.Errorf("skilltools: adding tool %q: %w", name, err) + } + } + return box, nil +} diff --git a/tool/registry_test.go b/tool/registry_test.go new file mode 100644 index 0000000..007d7f8 --- /dev/null +++ b/tool/registry_test.go @@ -0,0 +1,184 @@ +package tool + +import ( + "context" + "strings" + "testing" + + llm "gitea.stevedudenhoeffer.com/steve/majordomo/llm" +) + +// fakeTool is a minimal Tool used to exercise the registry's gating. +type fakeTool struct { + name string + desc string + perm Permission + calledWith *Invocation + returnText string + returnError error +} + +func (f *fakeTool) Name() string { return f.name } +func (f *fakeTool) Description() string { return f.desc } +func (f *fakeTool) Permission() Permission { return f.perm } +func (f *fakeTool) BuildLLM(inv Invocation) llm.Tool { + type emptyParams struct{} + return llm.DefineTool( + f.name, + f.desc, + func(ctx context.Context, _ emptyParams) (any, error) { + if err := CheckGate(inv); err != nil { + EmitAudit(inv, "{}", "", err) + return "", err + } + f.calledWith = &inv + EmitAudit(inv, "{}", f.returnText, f.returnError) + return f.returnText, f.returnError + }, + ) +} + +func TestRegister_DuplicateRejected(t *testing.T) { + r := NewRegistry() + a := &fakeTool{name: "x", perm: Permission{AuthoringRequirement: RequirementAnyone, SafeForShare: true}} + b := &fakeTool{name: "x", perm: Permission{AuthoringRequirement: RequirementAnyone, SafeForShare: true}} + if err := r.Register(a); err != nil { + t.Fatal(err) + } + err := r.Register(b) + if err == nil || !strings.Contains(err.Error(), "duplicate") { + t.Fatalf("expected duplicate-name error, got %v", err) + } +} + +func TestRegister_RejectsEmpty(t *testing.T) { + r := NewRegistry() + if err := r.Register(&fakeTool{name: ""}); err == nil { + t.Fatal("expected empty-name rejection") + } + if err := r.Register(nil); err == nil { + t.Fatal("expected nil-tool rejection") + } +} + +func TestBuild_UnknownTool(t *testing.T) { + r := NewRegistry() + _, err := r.Build([]string{"nope"}, Invocation{}, VisibilityPrivate, nil) + if err == nil || !strings.Contains(err.Error(), "unknown tool") { + t.Fatalf("expected unknown-tool error, got %v", err) + } +} + +func TestBuild_SharedRejectsUnsafeTool(t *testing.T) { + r := NewRegistry() + _ = r.Register(&fakeTool{name: "balance", perm: Permission{SafeForShare: false}}) + _, err := r.Build([]string{"balance"}, Invocation{}, VisibilityShared, nil) + if err == nil || !strings.Contains(err.Error(), "not safe for share") { + t.Fatalf("expected share-safety error, got %v", err) + } +} + +// TestBuild_TrustedBuiltinBypassesShareSafety verifies the +// trusted-flag escape hatch: a builtin (skill-wizard, mortventure) +// legitimately ships with public visibility AND not-safe-for-share +// tools. Build with trusted=true must not reject those. +// +// Why: pre-fix, invocation of skill-wizard (visibility=public, tools +// include wizard_* with SafeForShare=false) was rejected at runtime +// even though the loader had already bypassed save-time gates. The +// trusted flag aligns the invocation-time gate with the loader's +// trust model. +func TestBuild_TrustedBuiltinBypassesShareSafety(t *testing.T) { + r := NewRegistry() + _ = r.Register(&fakeTool{ + name: "wizard_list", + perm: Permission{SafeForShare: false}, + returnText: "ok", + }) + box, err := r.Build([]string{"wizard_list"}, Invocation{SkillName: "skill-wizard"}, VisibilityPublic, nil, true) + if err != nil { + t.Fatalf("trusted=true should bypass share-safety, got %v", err) + } + if box == nil { + t.Fatal("trusted=true should produce a toolbox, got nil") + } +} + +// TestBuild_NonTrustedSharedStillRejects confirms the bypass is +// strictly opt-in: a non-builtin caller with the same shape (public +// visibility + unsafe tool) still hits the rejection path. +func TestBuild_NonTrustedSharedStillRejects(t *testing.T) { + r := NewRegistry() + _ = r.Register(&fakeTool{name: "balance", perm: Permission{SafeForShare: false}}) + _, err := r.Build([]string{"balance"}, Invocation{}, VisibilityPublic, nil, false) + if err == nil || !strings.Contains(err.Error(), "not safe for share") { + t.Fatalf("trusted=false (non-builtin) must still reject unsafe tool at public visibility, got %v", err) + } + // Omitted variadic = trusted defaults to false → same rejection. + _, err = r.Build([]string{"balance"}, Invocation{}, VisibilityPublic, nil) + if err == nil || !strings.Contains(err.Error(), "not safe for share") { + t.Fatalf("omitted variadic must default to trusted=false, got %v", err) + } +} + +func TestBuild_PublicAcceptsSafeTool(t *testing.T) { + r := NewRegistry() + _ = r.Register(&fakeTool{name: "search", perm: Permission{SafeForShare: true}, returnText: "hits"}) + box, err := r.Build([]string{"search"}, Invocation{SkillName: "echo"}, VisibilityPublic, nil) + if err != nil { + t.Fatal(err) + } + out, err := execBox(box, toolCall{Name: "search", Arguments: "{}"}) + if err != nil || out != "hits" { + t.Fatalf("unexpected: %q %v", out, err) + } +} + +func TestBuild_GateBlocksMismatchedSkill(t *testing.T) { + r := NewRegistry() + tt := &fakeTool{ + name: "wizard_save", + perm: Permission{SafeForShare: true, SkillNameGate: "skill-wizard"}, + returnText: "saved", + } + _ = r.Register(tt) + box, err := r.Build([]string{"wizard_save"}, Invocation{SkillName: "echo"}, VisibilityPrivate, nil) + if err != nil { + t.Fatalf("build: %v", err) + } + out, err := execBox(box, toolCall{Name: "wizard_save", Arguments: "{}"}) + if err == nil || !strings.Contains(err.Error(), "restricted to") { + t.Fatalf("expected gate rejection, got out=%q err=%v", out, err) + } +} + +func TestBuild_GateAllowsMatchingSkill(t *testing.T) { + r := NewRegistry() + tt := &fakeTool{ + name: "wizard_save", + perm: Permission{SafeForShare: true, SkillNameGate: "skill-wizard"}, + returnText: "saved", + } + _ = r.Register(tt) + box, _ := r.Build([]string{"wizard_save"}, Invocation{SkillName: "skill-wizard"}, VisibilityPrivate, nil) + out, err := execBox(box, toolCall{Name: "wizard_save", Arguments: "{}"}) + if err != nil || out != "saved" { + t.Fatalf("unexpected: %q %v", out, err) + } +} + +func TestBuild_EmitsAudit(t *testing.T) { + r := NewRegistry() + tt := &fakeTool{name: "search", perm: Permission{SafeForShare: true}, returnText: "hits"} + _ = r.Register(tt) + + var calls []AuditCall + hook := func(c AuditCall) { calls = append(calls, c) } + + box, _ := r.Build([]string{"search"}, Invocation{SkillName: "echo"}, VisibilityPrivate, hook) + _, _ = execBox(box, toolCall{Name: "search", Arguments: "{}"}) + + if len(calls) != 1 || calls[0].Tool != "search" || calls[0].Result != "hits" || calls[0].Err != nil { + t.Fatalf("unexpected audit: %+v", calls) + } +} diff --git a/tool/rootrun.go b/tool/rootrun.go new file mode 100644 index 0000000..a05a13b --- /dev/null +++ b/tool/rootrun.go @@ -0,0 +1,18 @@ +package tool + +// RootRunKVPartition is the sentinel skill_id partition under which all +// `root_run:` KV rows are stored. +// +// Why a sentinel: skill KV rows are keyed (skill_id, scope, key), so +// two sibling workers with different IDs (e.g. agent_spawn ephemeral +// workers under one fan-out) could never share state through a scope +// string alone — each would read/write its own partition. Routing every +// root_run scope into one shared partition makes the scope string the +// real boundary: it embeds the root run id, which the validator checks +// against Invocation.RootRunID, so per-tree isolation holds even though +// the partition is global. +// +// Declared in the root skilltools package (not tools/) because both the +// tool handlers (pkg/skilltools/tools) and the storage sweeper +// (pkg/logic/skills) need it without importing each other. +const RootRunKVPartition = "__root_run__" diff --git a/tool/run_state_test.go b/tool/run_state_test.go new file mode 100644 index 0000000..9dc0ddf --- /dev/null +++ b/tool/run_state_test.go @@ -0,0 +1,18 @@ +package tool + +import "testing" + +type fakeAccessor struct{ s RunState } + +func (f fakeAccessor) RunState() RunState { return f.s } + +func TestInvocationRunState_NilSafe(t *testing.T) { + var inv Invocation + if inv.RunState != nil { + t.Fatal("RunState should default nil") + } + inv.RunState = fakeAccessor{s: RunState{Iteration: 3, MaxIterations: 10}} + if got := inv.RunState.RunState(); got.Iteration != 3 || got.MaxIterations != 10 { + t.Fatalf("unexpected RunState: %+v", got) + } +} diff --git a/tool/session_tools.go b/tool/session_tools.go new file mode 100644 index 0000000..7bcb605 --- /dev/null +++ b/tool/session_tools.go @@ -0,0 +1,99 @@ +package tool + +import ( + "context" + + llm "gitea.stevedudenhoeffer.com/steve/majordomo/llm" +) + +// AgentSession is the live-run handle a SessionToolFactory receives. +// It is implemented by the executors (agentexec / skillexec / scaddy's +// adapter) on top of majordomo's agent loop and exposes the one mid-run +// mutation session tools need: feeding content back into the running +// conversation. +// +// Why an interface (vs the concrete agent type): legacy agentkit handed the +// factory a *agentkit.Agent so tools could call agent.AttachImages. +// majordomo's *agent.Agent is deliberately immutable mid-run — message +// injection happens through the run-scoped steer mailbox +// (agent.WithSteer). A narrow interface lets each executor implement +// AttachImages over its own steer queue without skilltools importing +// the agent package, and keeps session tools testable with a two-line +// fake. +type AgentSession interface { + // AttachImages queues a user-role message (text plus image parts) + // for injection into the conversation before the agent's next + // step. Used by tools that produce visual feedback the model must + // see on its following turn (e.g. scaddy's rendered OpenSCAD + // previews). Safe to call from inside a tool handler; the message + // lands after the current step's tool results. + AttachImages(text string, images ...llm.ImagePart) +} + +// SessionToolFactory builds per-session tools that close over the live +// agent session. Called by the executor after the agent is constructed +// but before it runs. See Invocation.SessionToolFactory for the +// rationale (static ExtraTools cannot reach the running agent). +type SessionToolFactory func(session AgentSession) SessionTools + +// SessionTools carries per-session tools plus optional post-run and +// teardown hooks. It replaces legacy agentkit's SessionTools with the same +// three-field shape, re-based on majordomo types. +type SessionTools struct { + // Tools to add to the agent's toolbox for this run only. + Tools []llm.Tool + + // PostRun, if set, is called after the agent run completes + // (successfully or not). It receives the full run transcript (the + // agent Result's Messages — also populated on partial results from + // agent.ErrMaxSteps / agent.ErrToolLoop), the agent's text output, + // and the run error, so the hook can decide whether to attempt + // artifact production on partial success (e.g. scaddy ships the + // latest SCAD even when the step budget ran out). The returned + // PostRunResult is attached to the executor's run result. Errors + // inside PostRun must be handled by the hook itself — the executor + // logs a nil return but never fails the run over it; the agent's + // output is the source of truth. + // + // Why a transcript slice (vs the live agent): the consumers only + // ever read the message history (thought-chain transcripts); the + // majordomo agent exposes that on Result, not on the Agent. + PostRun func(ctx context.Context, transcript []llm.Message, output string, runErr error) *PostRunResult + + // Cleanup, if set, is deferred by the executor immediately after + // the factory returns. Called even if the run fails or PostRun + // panics. Use for temp directory removal, closing file handles, + // etc. + Cleanup func() +} + +// PostRunResult carries artifacts produced by the PostRun hook. +// Attached to the executor's run result so callers (Discord command +// handlers, HTTP API responses) can inspect and deliver the artifacts. +// +// Why a separate struct (vs returning artifacts inline): post- +// processing may produce multiple typed artifacts (PNGs, STLs, SCAD +// source) that the delivery layer classifies and routes differently. +// A flat []Artifact + arbitrary Metadata covers the known use cases +// without over-specifying the shape. +type PostRunResult struct { + // Artifacts are files produced during post-processing + // (e.g., rendered PNGs, STL files, SCAD source). + Artifacts []Artifact + + // Metadata is arbitrary key-value data the delivery layer can + // use for formatting (e.g., iteration count, model name, notes). + Metadata map[string]any +} + +// Artifact is a named binary blob produced by post-run processing. +// +// Why: the delivery layer needs name + type + bytes to classify +// each artifact (PNG → embed image, STL → filetransfer upload, +// SCAD → paste upload). A struct with these three fields is the +// minimal viable description. +type Artifact struct { + Name string // e.g., "model.stl", "preview_iso.png" + MimeType string // e.g., "model/stl", "image/png" + Data []byte +} diff --git a/tool/ssrf_protect.go b/tool/ssrf_protect.go new file mode 100644 index 0000000..182254d --- /dev/null +++ b/tool/ssrf_protect.go @@ -0,0 +1,221 @@ +// Package skilltools — SSRF protection layer for skill HTTP tools. +// +// Why a dedicated layer (vs reusing pkg/utils.ValidateExternalURL): +// the platform's HTTP tools enforce a per-deployment ALLOWLIST (not +// just a "no private IPs" denylist) — admins must explicitly opt-in +// to each domain a skill may call. Additionally, defeating DNS +// rebinding requires capturing the resolved IP at validation time +// and pinning the dialler so a hostile DNS resolver can't return a +// public IP during the check and a private one at dial time. +package tool + +import ( + "context" + "fmt" + "net" + "net/http" + "net/url" + "strings" + "time" +) + +// AllowlistConfig governs which hosts a skill HTTP tool may contact. +// +// Why a config struct (vs raw []string): forward-compatibility — we +// expect to add per-tool overrides (e.g. "this skill may also reach +// internal.example.com") and an explicit `AllowLoopback` opt-in for +// development environments. Keeping the validation surface as a +// struct lets new fields land without breaking call sites. +type AllowlistConfig struct { + // Domains is the list of allowed hostnames. Wildcards: "*.example.com" + // matches "foo.example.com" and "bar.baz.example.com" but NOT + // "example.com" itself (to allow both, list both entries). + // + // Comparison is case-insensitive; trailing dots are NOT trimmed + // (DNS treats "example.com" and "example.com." as different). + Domains []string +} + +// ResolveAndCheck validates urlStr against the allowlist and returns +// the resolved IP. The IP is meant to be passed to the transport's +// dial step (via PinnedDialTransport) to defeat DNS rebinding. +// +// Loopback / private / link-local rejection is bypassed when the +// HOSTNAME (not the resolved IP) is itself an entry in the allowlist +// OR the resolved IP literal appears in the allowlist. This lets an +// admin opt-in to "127.0.0.1" or "localhost" for tests / debug +// without a global allow-private flag, while keeping the default +// (random hostname → resolved private IP) safe. +// +// Returns: +// - resolvedIP if the URL is acceptable +// - error explaining the rejection (host not allowlisted, scheme +// unsupported, resolves to private IP, etc.) +func ResolveAndCheck(ctx context.Context, urlStr string, allow AllowlistConfig) (net.IP, error) { + u, err := url.Parse(urlStr) + if err != nil { + return nil, fmt.Errorf("parse url: %w", err) + } + if u.Scheme != "http" && u.Scheme != "https" { + return nil, fmt.Errorf("scheme %q not supported (need http or https)", u.Scheme) + } + host := u.Hostname() + if host == "" { + return nil, fmt.Errorf("url has no host") + } + + if !matchesAllowlist(host, allow.Domains) { + return nil, fmt.Errorf("host %q not in allowlist", host) + } + + // If the host is already a literal IP, skip the resolve step. + if literal := net.ParseIP(host); literal != nil { + // Even an explicitly allowlisted IP literal goes through the + // privacy check UNLESS the literal is itself in the allowlist + // (covers admin opt-in "127.0.0.1" for tests). + if hostExplicitlyAllowed(host, allow.Domains) { + return literal, nil + } + if err := rejectPrivateIP(host, literal); err != nil { + return nil, err + } + return literal, nil + } + + // Resolve. context controls timeout. + resolver := &net.Resolver{} + addrs, err := resolver.LookupIPAddr(ctx, host) + if err != nil { + return nil, fmt.Errorf("resolve %q: %w", host, err) + } + if len(addrs) == 0 { + return nil, fmt.Errorf("resolve %q: no addresses", host) + } + + ip := addrs[0].IP + + // Hostname explicitly in allowlist (e.g. "localhost" → opt-in by + // admin) bypasses the private-IP check. The wildcard form does NOT + // bypass — wildcards are for public domain families, not for + // private space. + if hostExplicitlyAllowed(host, allow.Domains) { + return ip, nil + } + + if err := rejectPrivateIP(host, ip); err != nil { + return nil, err + } + return ip, nil +} + +// rejectPrivateIP returns an error if the IP is loopback / private / +// link-local / unspecified, formatted with the original hostname so +// the rejection message is informative. +// +// Why a helper: ResolveAndCheck calls it twice (literal-IP path and +// resolved-host path) and the same checks apply. +func rejectPrivateIP(host string, ip net.IP) error { + // Cloud metadata endpoint check FIRST — it's a link-local IP, so + // the more-specific metadata error message would otherwise be + // shadowed by the link-local rejection. + if ip.Equal(net.ParseIP("169.254.169.254")) { + return fmt.Errorf("host %q resolves to cloud metadata IP %v", host, ip) + } + if ip.IsLoopback() { + return fmt.Errorf("host %q resolves to loopback %v", host, ip) + } + if ip.IsPrivate() { + return fmt.Errorf("host %q resolves to private IP %v", host, ip) + } + if ip.IsLinkLocalUnicast() { + return fmt.Errorf("host %q resolves to link-local %v", host, ip) + } + if ip.IsUnspecified() { + return fmt.Errorf("host %q resolves to unspecified %v", host, ip) + } + return nil +} + +// hostExplicitlyAllowed reports whether host is in the allowlist as +// an exact entry (NOT via a wildcard). Used to bypass the private-IP +// check when an admin has explicitly named a host (e.g. "127.0.0.1" +// or "localhost") to opt-in. +func hostExplicitlyAllowed(host string, allow []string) bool { + host = strings.ToLower(host) + for _, pattern := range allow { + pattern = strings.ToLower(strings.TrimSpace(pattern)) + if pattern == host { + return true + } + } + return false +} + +// matchesAllowlist reports whether host matches any entry in allow, +// either by exact match, by "*.example.com" wildcard, or by the +// special bare "*" wildcard (allow every host). +// +// Wildcards match one-or-more subdomain levels: "*.example.com" +// matches "foo.example.com" and "a.b.example.com" but NOT +// "example.com" itself. +// +// Bare "*" matches any host. **Operators should use this only when +// they understand the SSRF + iptables layers still defend against +// private-IP traffic** (ResolveAndCheck blocks loopback / RFC1918 / +// link-local UNLESS the IP literal is also in the allowlist; the v15 +// codeexec firewall sidecar adds host-level iptables drops). The +// bare-"*" form is the v15.1 operator UX answer to "I just want to +// let the agent reach the public internet" — without it, operators +// had to enumerate TLDs (*.com, *.org, *.io, etc.) which never +// covered the long tail. +func matchesAllowlist(host string, allow []string) bool { + host = strings.ToLower(host) + for _, pattern := range allow { + pattern = strings.ToLower(strings.TrimSpace(pattern)) + if pattern == "" { + continue + } + // Bare "*" = allow-any. The SSRF + iptables layers still + // enforce private-IP blocks; this only opens the hostname gate. + if pattern == "*" { + return true + } + if pattern == host { + return true + } + if strings.HasPrefix(pattern, "*.") { + suffix := pattern[1:] // ".example.com" + if strings.HasSuffix(host, suffix) && len(host) > len(suffix) { + return true + } + } + } + return false +} + +// PinnedDialTransport returns an http.RoundTripper that uses the given +// IP for all Dial operations regardless of host (defeats DNS rebinding). +// The Host header is preserved from the request — TLS SNI and HTTP +// Host routing continue to work, only the network connection is +// pinned to the pre-validated IP. +// +// Why pre-validated dial vs trusting the request: between the +// ResolveAndCheck call and the http.Client.Do call, a hostile DNS +// server can return a different IP. Pinning the dialler ensures the +// connection lands on the exact address that passed the privacy +// check. +func PinnedDialTransport(ip net.IP, timeout time.Duration) http.RoundTripper { + dialer := &net.Dialer{Timeout: timeout} + return &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + // addr is "host:port" — replace host with pinned IP. + _, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port)) + }, + ResponseHeaderTimeout: timeout, + TLSHandshakeTimeout: timeout, + } +} diff --git a/tool/ssrf_protect_test.go b/tool/ssrf_protect_test.go new file mode 100644 index 0000000..6dbca62 --- /dev/null +++ b/tool/ssrf_protect_test.go @@ -0,0 +1,168 @@ +package tool + +import ( + "context" + "net" + "strings" + "testing" +) + +// TestResolveAndCheck_AllowlistedPublic anchors the happy path: a +// public domain in the allowlist resolves and returns its IP. Skips +// when DNS isn't available so the suite still passes in offline CI. +func TestResolveAndCheck_AllowlistedPublic(t *testing.T) { + // Pre-flight: skip the test if the test environment has no DNS. + if _, err := net.LookupHost("example.com"); err != nil { + t.Skipf("no DNS in test environment: %v", err) + } + allow := AllowlistConfig{Domains: []string{"example.com"}} + ip, err := ResolveAndCheck(context.Background(), "https://example.com/", allow) + if err != nil { + t.Fatalf("ResolveAndCheck failed: %v", err) + } + if ip == nil { + t.Fatal("expected non-nil IP") + } +} + +// TestResolveAndCheck_NotAllowlisted ensures a domain outside the +// allowlist is rejected before any DNS resolution. +func TestResolveAndCheck_NotAllowlisted(t *testing.T) { + allow := AllowlistConfig{Domains: []string{"example.com"}} + _, err := ResolveAndCheck(context.Background(), "https://evil.test/", allow) + if err == nil { + t.Fatal("expected rejection for non-allowlisted host") + } + if !strings.Contains(err.Error(), "not in allowlist") { + t.Errorf("expected allowlist error, got: %v", err) + } +} + +// TestResolveAndCheck_WildcardMatch confirms "*.example.com" matches +// foo.example.com. +func TestResolveAndCheck_WildcardMatch(t *testing.T) { + if !matchesAllowlist("foo.example.com", []string{"*.example.com"}) { + t.Error("expected *.example.com to match foo.example.com") + } + if !matchesAllowlist("a.b.example.com", []string{"*.example.com"}) { + t.Error("expected *.example.com to match a.b.example.com") + } +} + +// TestResolveAndCheck_WildcardDoesNotMatchBareDomain documents the +// design choice: "*.example.com" does NOT match "example.com" itself. +// Admins who want both must list both entries. +func TestResolveAndCheck_WildcardDoesNotMatchBareDomain(t *testing.T) { + if matchesAllowlist("example.com", []string{"*.example.com"}) { + t.Error("expected *.example.com NOT to match bare example.com") + } +} + +// TestResolveAndCheck_LocalhostRejected verifies that a hostname +// resolving to 127.0.0.1 is rejected unless the admin explicitly +// includes it in the allowlist. +func TestResolveAndCheck_LocalhostRejected(t *testing.T) { + if _, err := net.LookupHost("localhost"); err != nil { + t.Skipf("no DNS in test environment: %v", err) + } + // "localhost" matches the allowlist by exact name match, but the + // hostExplicitlyAllowed bypass kicks in only when the host is in + // the allowlist as an exact entry. Here we use a DIFFERENT bare + // allowlist entry so the host fails the allowlist match outright. + allow := AllowlistConfig{Domains: []string{"example.com"}} + _, err := ResolveAndCheck(context.Background(), "http://localhost/", allow) + if err == nil { + t.Fatal("expected rejection for localhost (not in allowlist)") + } +} + +// TestResolveAndCheck_LocalhostAllowedExplicit confirms the +// admin-opt-in escape hatch: when the hostname is itself in the +// allowlist as an exact entry, the private-IP check is bypassed. +// This is what test code uses to drive httptest.NewServer URLs. +func TestResolveAndCheck_LocalhostAllowedExplicit(t *testing.T) { + // Use 127.0.0.1 directly so this test doesn't depend on DNS for + // "localhost". + allow := AllowlistConfig{Domains: []string{"127.0.0.1"}} + ip, err := ResolveAndCheck(context.Background(), "http://127.0.0.1/", allow) + if err != nil { + t.Fatalf("expected 127.0.0.1 with explicit allowlist to succeed; got: %v", err) + } + if !ip.Equal(net.ParseIP("127.0.0.1")) { + t.Errorf("expected ip=127.0.0.1, got %v", ip) + } +} + +// TestResolveAndCheck_FileSchemeRejected blocks file:// URLs. +func TestResolveAndCheck_FileSchemeRejected(t *testing.T) { + allow := AllowlistConfig{Domains: []string{"anything"}} + _, err := ResolveAndCheck(context.Background(), "file:///etc/passwd", allow) + if err == nil { + t.Fatal("expected rejection for file:// scheme") + } + if !strings.Contains(err.Error(), "scheme") { + t.Errorf("expected scheme error, got: %v", err) + } +} + +// TestResolveAndCheck_EmptyHostRejected blocks malformed URLs with +// no host component. +func TestResolveAndCheck_EmptyHostRejected(t *testing.T) { + allow := AllowlistConfig{Domains: []string{"anything"}} + _, err := ResolveAndCheck(context.Background(), "http:///nohost", allow) + if err == nil { + t.Fatal("expected rejection for empty host") + } +} + +// TestResolveAndCheck_PrivateIPLiteralRejected confirms that an IP +// literal resolving to the private range is rejected even if the +// allowlist matches by wildcard or other means. The private-IP gate +// is the last line of defence. +func TestResolveAndCheck_PrivateIPLiteralRejected(t *testing.T) { + // Add a wildcard that would match anything (silly but plausible + // admin error) and confirm a private IP literal is still blocked + // because the literal isn't itself in the allowlist as exact. + allow := AllowlistConfig{Domains: []string{"192.168.1.1"}} + // The exact-IP-in-allowlist case bypasses the private check; flip + // to a NEAR-but-different IP literal that's NOT in the allowlist. + allow2 := AllowlistConfig{Domains: []string{"192.168.0.0"}} + _, err := ResolveAndCheck(context.Background(), "http://192.168.1.1/", allow2) + if err == nil { + t.Fatal("expected rejection for private IP literal not in allowlist") + } + // Sanity: explicit allowlist entry bypasses. + _, err = ResolveAndCheck(context.Background(), "http://192.168.1.1/", allow) + if err != nil { + t.Errorf("expected explicit allowlist entry to bypass; got: %v", err) + } +} + +// TestResolveAndCheck_CloudMetadataRejected blocks the well-known +// cloud metadata IP via the link-local check. We use a wildcard that +// matches the IP-as-hostname so the rejection comes from the +// private/link-local layer (not the allowlist). +func TestResolveAndCheck_CloudMetadataRejected(t *testing.T) { + // "*.169.254.169.254" wildcard wouldn't match either; instead use + // a wildcard that matches any IP literal under .254 — but + // matchesAllowlist treats '.' as a literal so we just allowlist + // the IP itself with a one-bit-different sibling that fails the + // exact-allow check (so private check still runs). + // + // Easier: include a different exact IP entry so the IP literal + // fails hostExplicitlyAllowed but passes the wildcard. + allow := AllowlistConfig{Domains: []string{"*.169.254.169.254"}} // matches "x.169.254.169.254", not the bare IP + // 169.254.169.254 won't match the wildcard pattern either — + // switch to a strategy that lets the host pass allowlist but + // fails the private check. + _ = allow + // Use an explicit non-IP-literal hostname (we'd need DNS to point + // to 169.254.169.254 which is not feasible). Instead, exercise + // the rejectPrivateIP helper directly for the metadata IP since + // the public surface only enters that path through resolution. + if err := rejectPrivateIP("metadata.test", net.ParseIP("169.254.169.254")); err == nil { + t.Fatal("expected rejection for cloud metadata IP") + } else if !strings.Contains(err.Error(), "metadata") { + t.Errorf("expected metadata error, got: %v", err) + } +} diff --git a/tool/webhook_rate_limit.go b/tool/webhook_rate_limit.go new file mode 100644 index 0000000..af82e14 --- /dev/null +++ b/tool/webhook_rate_limit.go @@ -0,0 +1,145 @@ +// Package skilltools — webhook_rate_limit.go: per-IP-per-skill +// sliding-window rate limiter for the v7 inbound webhook handler. +// +// Why an in-memory limiter (vs Redis or DB-backed): rate limiting is +// the cheap reject path BEFORE the HMAC compute and run-budget check, +// and an extra round-trip per inbound webhook would be wasted. The +// 6-person server's volume is well within a single-process limiter's +// scale; if mort is ever multi-process the limiter becomes +// approximate (still good enough to throttle abusive sources). +// +// Why per-IP-per-skill (vs per-IP global): one busy webhook (e.g. +// GitHub PR opened) shouldn't shadow another (Stripe charge). The +// composite key keeps a noisy source from pushing other skill's +// callers off the lane. +// +// Test: webhook_rate_limit_test.go covers admit + reject paths. +package tool + +import ( + "sync" + "time" +) + +// WebhookRateLimiter is a sliding-window per-(skillID, sourceIP) +// counter. Configure once at construction; concurrent-safe. +type WebhookRateLimiter struct { + limit int + window time.Duration + clock func() time.Time + + mu sync.Mutex + buckets map[string]*rateBucket // key = skillID + "|" + sourceIP +} + +type rateBucket struct { + // hits is a slice of timestamps within the window. Pruned on + // every Admit call so the slice never grows unbounded. + hits []time.Time +} + +// NewWebhookRateLimiter constructs the limiter. +// +// limit — max calls per (skill, ip) within window. <=0 means +// +// "unlimited" (every call admitted; useful for tests). +// +// window — sliding window length. <=0 falls back to 1 minute. +// clock — testable wall-clock; nil → time.Now. +func NewWebhookRateLimiter(limit int, window time.Duration, clock func() time.Time) *WebhookRateLimiter { + if window <= 0 { + window = time.Minute + } + if clock == nil { + clock = time.Now + } + return &WebhookRateLimiter{ + limit: limit, + window: window, + clock: clock, + buckets: make(map[string]*rateBucket), + } +} + +// Admit returns (true, 0) if the call is within the rate cap (records +// the hit), or (false, retry-after) if the cap is hit. retry-after is +// the time until the OLDEST hit in the window expires — the caller can +// surface it via the Retry-After response header. +// +// Why return retry-after not just bool: HTTP 429 responses commonly +// include Retry-After to avoid synchronizing client retries; computing +// it from the sliding window is essentially free. +func (l *WebhookRateLimiter) Admit(skillID, sourceIP string) (bool, time.Duration) { + if l.limit <= 0 { + return true, 0 + } + now := l.clock() + cutoff := now.Add(-l.window) + key := skillID + "|" + sourceIP + + l.mu.Lock() + defer l.mu.Unlock() + + b, ok := l.buckets[key] + if !ok { + b = &rateBucket{} + l.buckets[key] = b + } + // Prune in place. The slice is append-only at the tail; the head + // shrinks as old hits fall out of the window. + first := 0 + for first < len(b.hits) && b.hits[first].Before(cutoff) { + first++ + } + if first > 0 { + // Copy the surviving tail to the head; reuse backing array. + n := copy(b.hits, b.hits[first:]) + b.hits = b.hits[:n] + } + if len(b.hits) >= l.limit { + oldest := b.hits[0] + retryAfter := oldest.Add(l.window).Sub(now) + if retryAfter < 0 { + retryAfter = 0 + } + return false, retryAfter + } + b.hits = append(b.hits, now) + return true, 0 +} + +// Sweep purges buckets whose hit-list is empty after pruning. Called +// periodically (e.g. once per minute) to bound the buckets map's +// growth. +// +// Why a separate Sweep vs auto-prune in Admit: a hostile source that +// rotates IP addresses across many addresses each hitting once +// would leave millions of single-hit buckets in the map. A periodic +// sweep keeps the worst case bounded. +func (l *WebhookRateLimiter) Sweep() { + now := l.clock() + cutoff := now.Add(-l.window) + l.mu.Lock() + defer l.mu.Unlock() + for k, b := range l.buckets { + // Prune in place. + first := 0 + for first < len(b.hits) && b.hits[first].Before(cutoff) { + first++ + } + if first > 0 { + n := copy(b.hits, b.hits[first:]) + b.hits = b.hits[:n] + } + if len(b.hits) == 0 { + delete(l.buckets, k) + } + } +} + +// CountKeys returns the bucket count. Test helper. +func (l *WebhookRateLimiter) CountKeys() int { + l.mu.Lock() + defer l.mu.Unlock() + return len(l.buckets) +}