diff --git a/v2/errors.go b/v2/errors.go index 5082608..52db82c 100644 --- a/v2/errors.go +++ b/v2/errors.go @@ -14,4 +14,7 @@ var ( // ErrStreamClosed is returned when trying to read from a closed stream. ErrStreamClosed = errors.New("stream closed") + + // ErrNoStructuredOutput is returned when the model did not return a structured output tool call. + ErrNoStructuredOutput = errors.New("model did not return structured output") ) diff --git a/v2/generate.go b/v2/generate.go new file mode 100644 index 0000000..c6af254 --- /dev/null +++ b/v2/generate.go @@ -0,0 +1,54 @@ +package llm + +import ( + "context" + "encoding/json" + "fmt" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/internal/schema" +) + +const structuredOutputToolName = "structured_output" + +// Generate sends a single user prompt to the model and parses the response into T. +// T must be a struct. The model is forced to return structured output matching T's schema +// by using a hidden tool call internally. +func Generate[T any](ctx context.Context, model *Model, prompt string, opts ...RequestOption) (T, error) { + return GenerateWith[T](ctx, model, []Message{UserMessage(prompt)}, opts...) +} + +// GenerateWith sends the given messages to the model and parses the response into T. +// T must be a struct. The model is forced to return structured output matching T's schema +// by using a hidden tool call internally. +func GenerateWith[T any](ctx context.Context, model *Model, messages []Message, opts ...RequestOption) (T, error) { + var zero T + + s := schema.FromStruct(zero) + + tool := Tool{ + Name: structuredOutputToolName, + Description: "Return your response as structured data using this function. You MUST call this function with your response.", + Schema: s, + } + + // Append WithTools as the last option so it overrides any user-provided tools. + opts = append(opts, WithTools(NewToolBox(tool))) + + resp, err := model.Complete(ctx, messages, opts...) + if err != nil { + return zero, err + } + + // Find the structured_output tool call in the response. + for _, tc := range resp.ToolCalls { + if tc.Name == structuredOutputToolName { + var result T + if err := json.Unmarshal([]byte(tc.Arguments), &result); err != nil { + return zero, fmt.Errorf("failed to parse structured output: %w", err) + } + return result, nil + } + } + + return zero, ErrNoStructuredOutput +} diff --git a/v2/generate_test.go b/v2/generate_test.go new file mode 100644 index 0000000..7c4de1d --- /dev/null +++ b/v2/generate_test.go @@ -0,0 +1,241 @@ +package llm + +import ( + "context" + "errors" + "testing" + + "gitea.stevedudenhoeffer.com/steve/go-llm/v2/provider" +) + +type testPerson struct { + Name string `json:"name" description:"The person's name"` + Age int `json:"age" description:"The person's age"` +} + +func TestGenerate(t *testing.T) { + mp := newMockProvider(provider.Response{ + ToolCalls: []provider.ToolCall{ + { + ID: "call_1", + Name: "structured_output", + Arguments: `{"name":"Alice","age":30}`, + }, + }, + }) + model := newMockModel(mp) + + result, err := Generate[testPerson](context.Background(), model, "Tell me about Alice") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Name != "Alice" { + t.Errorf("expected name 'Alice', got %q", result.Name) + } + if result.Age != 30 { + t.Errorf("expected age 30, got %d", result.Age) + } + + // Verify the tool was sent in the request + req := mp.lastRequest() + if len(req.Tools) != 1 { + t.Fatalf("expected 1 tool, got %d", len(req.Tools)) + } + if req.Tools[0].Name != "structured_output" { + t.Errorf("expected tool name 'structured_output', got %q", req.Tools[0].Name) + } +} + +func TestGenerateWith(t *testing.T) { + mp := newMockProvider(provider.Response{ + ToolCalls: []provider.ToolCall{ + { + ID: "call_1", + Name: "structured_output", + Arguments: `{"name":"Bob","age":25}`, + }, + }, + }) + model := newMockModel(mp) + + messages := []Message{ + SystemMessage("You are helpful."), + UserMessage("Tell me about Bob"), + } + + result, err := GenerateWith[testPerson](context.Background(), model, messages) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Name != "Bob" { + t.Errorf("expected name 'Bob', got %q", result.Name) + } + if result.Age != 25 { + t.Errorf("expected age 25, got %d", result.Age) + } + + // Verify messages were passed through + req := mp.lastRequest() + if len(req.Messages) != 2 { + t.Fatalf("expected 2 messages, got %d", len(req.Messages)) + } + if req.Messages[0].Role != "system" { + t.Errorf("expected first message role 'system', got %q", req.Messages[0].Role) + } +} + +func TestGenerate_NoToolCall(t *testing.T) { + mp := newMockProvider(provider.Response{ + Text: "I can't use tools right now.", + }) + model := newMockModel(mp) + + _, err := Generate[testPerson](context.Background(), model, "Tell me about someone") + if err == nil { + t.Fatal("expected error, got nil") + } + if !errors.Is(err, ErrNoStructuredOutput) { + t.Errorf("expected ErrNoStructuredOutput, got %v", err) + } +} + +func TestGenerate_InvalidJSON(t *testing.T) { + mp := newMockProvider(provider.Response{ + ToolCalls: []provider.ToolCall{ + { + ID: "call_1", + Name: "structured_output", + Arguments: `{not valid json}`, + }, + }, + }) + model := newMockModel(mp) + + _, err := Generate[testPerson](context.Background(), model, "Tell me about someone") + if err == nil { + t.Fatal("expected error, got nil") + } + if errors.Is(err, ErrNoStructuredOutput) { + t.Error("expected parse error, not ErrNoStructuredOutput") + } +} + +type testAddress struct { + Street string `json:"street" description:"Street address"` + City string `json:"city" description:"City name"` +} + +type testPersonWithAddress struct { + Name string `json:"name" description:"The person's name"` + Age int `json:"age" description:"The person's age"` + Address testAddress `json:"address" description:"The person's address"` +} + +func TestGenerate_NestedStruct(t *testing.T) { + mp := newMockProvider(provider.Response{ + ToolCalls: []provider.ToolCall{ + { + ID: "call_1", + Name: "structured_output", + Arguments: `{"name":"Carol","age":40,"address":{"street":"123 Main St","city":"Springfield"}}`, + }, + }, + }) + model := newMockModel(mp) + + result, err := Generate[testPersonWithAddress](context.Background(), model, "Tell me about Carol") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Name != "Carol" { + t.Errorf("expected name 'Carol', got %q", result.Name) + } + if result.Address.Street != "123 Main St" { + t.Errorf("expected street '123 Main St', got %q", result.Address.Street) + } + if result.Address.City != "Springfield" { + t.Errorf("expected city 'Springfield', got %q", result.Address.City) + } +} + +func TestGenerate_WithOptions(t *testing.T) { + mp := newMockProvider(provider.Response{ + ToolCalls: []provider.ToolCall{ + { + ID: "call_1", + Name: "structured_output", + Arguments: `{"name":"Dave","age":35}`, + }, + }, + }) + model := newMockModel(mp) + + _, err := Generate[testPerson](context.Background(), model, "Tell me about Dave", + WithTemperature(0.5), + WithMaxTokens(200), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + req := mp.lastRequest() + if req.Temperature == nil || *req.Temperature != 0.5 { + t.Errorf("expected temperature 0.5, got %v", req.Temperature) + } + if req.MaxTokens == nil || *req.MaxTokens != 200 { + t.Errorf("expected maxTokens 200, got %v", req.MaxTokens) + } +} + +func TestGenerate_WithMiddleware(t *testing.T) { + var middlewareCalled bool + mw := func(next CompletionFunc) CompletionFunc { + return func(ctx context.Context, model string, messages []Message, cfg *requestConfig) (Response, error) { + middlewareCalled = true + return next(ctx, model, messages, cfg) + } + } + + mp := newMockProvider(provider.Response{ + ToolCalls: []provider.ToolCall{ + { + ID: "call_1", + Name: "structured_output", + Arguments: `{"name":"Eve","age":28}`, + }, + }, + }) + model := newMockModel(mp).WithMiddleware(mw) + + result, err := Generate[testPerson](context.Background(), model, "Tell me about Eve") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !middlewareCalled { + t.Error("middleware was not called") + } + if result.Name != "Eve" { + t.Errorf("expected name 'Eve', got %q", result.Name) + } +} + +func TestGenerate_WrongToolName(t *testing.T) { + mp := newMockProvider(provider.Response{ + ToolCalls: []provider.ToolCall{ + { + ID: "call_1", + Name: "some_other_tool", + Arguments: `{"name":"Frank","age":50}`, + }, + }, + }) + model := newMockModel(mp) + + _, err := Generate[testPerson](context.Background(), model, "Tell me about Frank") + if err == nil { + t.Fatal("expected error, got nil") + } + if !errors.Is(err, ErrNoStructuredOutput) { + t.Errorf("expected ErrNoStructuredOutput, got %v", err) + } +}