Add structured output support with Generate[T] and GenerateWith[T]
Generic functions that use the "hidden tool" technique to force models to return structured JSON matching a Go struct's schema, replacing the verbose "tool as structured output" pattern. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -14,4 +14,7 @@ var (
|
|||||||
|
|
||||||
// ErrStreamClosed is returned when trying to read from a closed stream.
|
// ErrStreamClosed is returned when trying to read from a closed stream.
|
||||||
ErrStreamClosed = errors.New("stream closed")
|
ErrStreamClosed = errors.New("stream closed")
|
||||||
|
|
||||||
|
// ErrNoStructuredOutput is returned when the model did not return a structured output tool call.
|
||||||
|
ErrNoStructuredOutput = errors.New("model did not return structured output")
|
||||||
)
|
)
|
||||||
|
|||||||
54
v2/generate.go
Normal file
54
v2/generate.go
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/internal/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
const structuredOutputToolName = "structured_output"
|
||||||
|
|
||||||
|
// Generate sends a single user prompt to the model and parses the response into T.
|
||||||
|
// T must be a struct. The model is forced to return structured output matching T's schema
|
||||||
|
// by using a hidden tool call internally.
|
||||||
|
func Generate[T any](ctx context.Context, model *Model, prompt string, opts ...RequestOption) (T, error) {
|
||||||
|
return GenerateWith[T](ctx, model, []Message{UserMessage(prompt)}, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateWith sends the given messages to the model and parses the response into T.
|
||||||
|
// T must be a struct. The model is forced to return structured output matching T's schema
|
||||||
|
// by using a hidden tool call internally.
|
||||||
|
func GenerateWith[T any](ctx context.Context, model *Model, messages []Message, opts ...RequestOption) (T, error) {
|
||||||
|
var zero T
|
||||||
|
|
||||||
|
s := schema.FromStruct(zero)
|
||||||
|
|
||||||
|
tool := Tool{
|
||||||
|
Name: structuredOutputToolName,
|
||||||
|
Description: "Return your response as structured data using this function. You MUST call this function with your response.",
|
||||||
|
Schema: s,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append WithTools as the last option so it overrides any user-provided tools.
|
||||||
|
opts = append(opts, WithTools(NewToolBox(tool)))
|
||||||
|
|
||||||
|
resp, err := model.Complete(ctx, messages, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return zero, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the structured_output tool call in the response.
|
||||||
|
for _, tc := range resp.ToolCalls {
|
||||||
|
if tc.Name == structuredOutputToolName {
|
||||||
|
var result T
|
||||||
|
if err := json.Unmarshal([]byte(tc.Arguments), &result); err != nil {
|
||||||
|
return zero, fmt.Errorf("failed to parse structured output: %w", err)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return zero, ErrNoStructuredOutput
|
||||||
|
}
|
||||||
241
v2/generate_test.go
Normal file
241
v2/generate_test.go
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testPerson struct {
|
||||||
|
Name string `json:"name" description:"The person's name"`
|
||||||
|
Age int `json:"age" description:"The person's age"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerate(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{
|
||||||
|
ToolCalls: []provider.ToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_1",
|
||||||
|
Name: "structured_output",
|
||||||
|
Arguments: `{"name":"Alice","age":30}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
|
||||||
|
result, err := Generate[testPerson](context.Background(), model, "Tell me about Alice")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if result.Name != "Alice" {
|
||||||
|
t.Errorf("expected name 'Alice', got %q", result.Name)
|
||||||
|
}
|
||||||
|
if result.Age != 30 {
|
||||||
|
t.Errorf("expected age 30, got %d", result.Age)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the tool was sent in the request
|
||||||
|
req := mp.lastRequest()
|
||||||
|
if len(req.Tools) != 1 {
|
||||||
|
t.Fatalf("expected 1 tool, got %d", len(req.Tools))
|
||||||
|
}
|
||||||
|
if req.Tools[0].Name != "structured_output" {
|
||||||
|
t.Errorf("expected tool name 'structured_output', got %q", req.Tools[0].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateWith(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{
|
||||||
|
ToolCalls: []provider.ToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_1",
|
||||||
|
Name: "structured_output",
|
||||||
|
Arguments: `{"name":"Bob","age":25}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
|
||||||
|
messages := []Message{
|
||||||
|
SystemMessage("You are helpful."),
|
||||||
|
UserMessage("Tell me about Bob"),
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := GenerateWith[testPerson](context.Background(), model, messages)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if result.Name != "Bob" {
|
||||||
|
t.Errorf("expected name 'Bob', got %q", result.Name)
|
||||||
|
}
|
||||||
|
if result.Age != 25 {
|
||||||
|
t.Errorf("expected age 25, got %d", result.Age)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify messages were passed through
|
||||||
|
req := mp.lastRequest()
|
||||||
|
if len(req.Messages) != 2 {
|
||||||
|
t.Fatalf("expected 2 messages, got %d", len(req.Messages))
|
||||||
|
}
|
||||||
|
if req.Messages[0].Role != "system" {
|
||||||
|
t.Errorf("expected first message role 'system', got %q", req.Messages[0].Role)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerate_NoToolCall(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{
|
||||||
|
Text: "I can't use tools right now.",
|
||||||
|
})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
|
||||||
|
_, err := Generate[testPerson](context.Background(), model, "Tell me about someone")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, ErrNoStructuredOutput) {
|
||||||
|
t.Errorf("expected ErrNoStructuredOutput, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerate_InvalidJSON(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{
|
||||||
|
ToolCalls: []provider.ToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_1",
|
||||||
|
Name: "structured_output",
|
||||||
|
Arguments: `{not valid json}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
|
||||||
|
_, err := Generate[testPerson](context.Background(), model, "Tell me about someone")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if errors.Is(err, ErrNoStructuredOutput) {
|
||||||
|
t.Error("expected parse error, not ErrNoStructuredOutput")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type testAddress struct {
|
||||||
|
Street string `json:"street" description:"Street address"`
|
||||||
|
City string `json:"city" description:"City name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type testPersonWithAddress struct {
|
||||||
|
Name string `json:"name" description:"The person's name"`
|
||||||
|
Age int `json:"age" description:"The person's age"`
|
||||||
|
Address testAddress `json:"address" description:"The person's address"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerate_NestedStruct(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{
|
||||||
|
ToolCalls: []provider.ToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_1",
|
||||||
|
Name: "structured_output",
|
||||||
|
Arguments: `{"name":"Carol","age":40,"address":{"street":"123 Main St","city":"Springfield"}}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
|
||||||
|
result, err := Generate[testPersonWithAddress](context.Background(), model, "Tell me about Carol")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if result.Name != "Carol" {
|
||||||
|
t.Errorf("expected name 'Carol', got %q", result.Name)
|
||||||
|
}
|
||||||
|
if result.Address.Street != "123 Main St" {
|
||||||
|
t.Errorf("expected street '123 Main St', got %q", result.Address.Street)
|
||||||
|
}
|
||||||
|
if result.Address.City != "Springfield" {
|
||||||
|
t.Errorf("expected city 'Springfield', got %q", result.Address.City)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerate_WithOptions(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{
|
||||||
|
ToolCalls: []provider.ToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_1",
|
||||||
|
Name: "structured_output",
|
||||||
|
Arguments: `{"name":"Dave","age":35}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
|
||||||
|
_, err := Generate[testPerson](context.Background(), model, "Tell me about Dave",
|
||||||
|
WithTemperature(0.5),
|
||||||
|
WithMaxTokens(200),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := mp.lastRequest()
|
||||||
|
if req.Temperature == nil || *req.Temperature != 0.5 {
|
||||||
|
t.Errorf("expected temperature 0.5, got %v", req.Temperature)
|
||||||
|
}
|
||||||
|
if req.MaxTokens == nil || *req.MaxTokens != 200 {
|
||||||
|
t.Errorf("expected maxTokens 200, got %v", req.MaxTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerate_WithMiddleware(t *testing.T) {
|
||||||
|
var middlewareCalled bool
|
||||||
|
mw := func(next CompletionFunc) CompletionFunc {
|
||||||
|
return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) {
|
||||||
|
middlewareCalled = true
|
||||||
|
return next(ctx, model, messages, cfg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mp := newMockProvider(provider.Response{
|
||||||
|
ToolCalls: []provider.ToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_1",
|
||||||
|
Name: "structured_output",
|
||||||
|
Arguments: `{"name":"Eve","age":28}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
model := newMockModel(mp).WithMiddleware(mw)
|
||||||
|
|
||||||
|
result, err := Generate[testPerson](context.Background(), model, "Tell me about Eve")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if !middlewareCalled {
|
||||||
|
t.Error("middleware was not called")
|
||||||
|
}
|
||||||
|
if result.Name != "Eve" {
|
||||||
|
t.Errorf("expected name 'Eve', got %q", result.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerate_WrongToolName(t *testing.T) {
|
||||||
|
mp := newMockProvider(provider.Response{
|
||||||
|
ToolCalls: []provider.ToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_1",
|
||||||
|
Name: "some_other_tool",
|
||||||
|
Arguments: `{"name":"Frank","age":50}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
model := newMockModel(mp)
|
||||||
|
|
||||||
|
_, err := Generate[testPerson](context.Background(), model, "Tell me about Frank")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, ErrNoStructuredOutput) {
|
||||||
|
t.Errorf("expected ErrNoStructuredOutput, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user