package llm import ( "context" "encoding/json" "testing" ) type calcParams struct { A float64 `json:"a" description:"First number"` B float64 `json:"b" description:"Second number"` Op string `json:"op" description:"Operation" enum:"add,sub,mul,div"` } func TestDefine(t *testing.T) { tool := Define[calcParams]("calc", "Calculator", func(ctx context.Context, p calcParams) (string, error) { var result float64 switch p.Op { case "add": result = p.A + p.B case "sub": result = p.A - p.B case "mul": result = p.A * p.B case "div": result = p.A / p.B } b, err := json.Marshal(result) return string(b), err }, ) if tool.Name != "calc" { t.Errorf("expected name 'calc', got %q", tool.Name) } if tool.Description != "Calculator" { t.Errorf("expected description 'Calculator', got %q", tool.Description) } if tool.Schema["type"] != "object" { t.Errorf("expected schema type=object, got %v", tool.Schema["type"]) } // Test execution result, err := tool.Execute(context.Background(), `{"a": 10, "b": 3, "op": "add"}`) if err != nil { t.Fatalf("execute failed: %v", err) } if result != "13" { t.Errorf("expected '13', got %q", result) } } func TestDefineSimple(t *testing.T) { tool := DefineSimple("hello", "Say hello", func(ctx context.Context) (string, error) { return "Hello, world!", nil }, ) result, err := tool.Execute(context.Background(), "") if err != nil { t.Fatalf("execute failed: %v", err) } if result != "Hello, world!" { t.Errorf("expected 'Hello, world!', got %q", result) } } func TestToolBox(t *testing.T) { tool1 := DefineSimple("tool1", "Tool 1", func(ctx context.Context) (string, error) { return "result1", nil }) tool2 := DefineSimple("tool2", "Tool 2", func(ctx context.Context) (string, error) { return "result2", nil }) tb := NewToolBox(tool1, tool2) tools := tb.AllTools() if len(tools) != 2 { t.Errorf("expected 2 tools, got %d", len(tools)) } result, err := tb.Execute(context.Background(), ToolCall{ID: "1", Name: "tool1"}) if err != nil { t.Fatalf("execute failed: %v", err) } if result != "result1" { t.Errorf("expected 'result1', got %q", result) } // Test not found _, err = tb.Execute(context.Background(), ToolCall{ID: "x", Name: "nonexistent"}) if err == nil { t.Error("expected error for nonexistent tool") } } func TestToolBoxExecuteAll(t *testing.T) { tb := NewToolBox( DefineSimple("t1", "T1", func(ctx context.Context) (string, error) { return "r1", nil }), DefineSimple("t2", "T2", func(ctx context.Context) (string, error) { return "r2", nil }), ) calls := []ToolCall{ {ID: "c1", Name: "t1"}, {ID: "c2", Name: "t2"}, } msgs, err := tb.ExecuteAll(context.Background(), calls) if err != nil { t.Fatalf("execute all failed: %v", err) } if len(msgs) != 2 { t.Fatalf("expected 2 messages, got %d", len(msgs)) } if msgs[0].Role != RoleTool { t.Errorf("expected role=tool, got %v", msgs[0].Role) } if msgs[0].ToolCallID != "c1" { t.Errorf("expected toolCallID=c1, got %v", msgs[0].ToolCallID) } if msgs[0].Content.Text != "r1" { t.Errorf("expected content=r1, got %v", msgs[0].Content.Text) } } // jsonMarshal helper for calcParams test func (p calcParams) jsonMarshal(result float64) (string, error) { b, err := json.Marshal(result) return string(b), err }