From ee542a7373dd3f5523879c137a5f9e0f54ff26db Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 30 Dec 2024 16:57:17 -0800 Subject: [PATCH] update client sdk tests --- tests/client-sdk/agents/test_agents.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 68ff3089b..1630ef34b 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -10,16 +10,16 @@ from uuid import uuid4 import pytest from llama_stack_client.lib.agents.agent import Agent -from llama_stack_client.lib.agents.custom_tool import CustomTool +from llama_stack_client.lib.agents.client_tool import ClientTool from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.types import ToolResponseMessage from llama_stack_client.types.agent_create_params import AgentConfig -from llama_stack_client.types.custom_tool_def import Parameter from llama_stack_client.types.memory_insert_params import Document from llama_stack_client.types.shared.completion_message import CompletionMessage +from llama_stack_client.types.tool_def_param import UserDefinedToolDefParameter -class TestCustomTool(CustomTool): +class TestClientTool(ClientTool): """Tool to give boiling point of a liquid Returns the correct value for polyjuice in Celcius and Fahrenheit and returns -1 for other liquids @@ -52,15 +52,15 @@ class TestCustomTool(CustomTool): def get_description(self) -> str: return "Get the boiling point of imaginary liquids (eg. polyjuice)" - def get_params_definition(self) -> Dict[str, Parameter]: + def get_params_definition(self) -> Dict[str, UserDefinedToolDefParameter]: return { - "liquid_name": Parameter( + "liquid_name": UserDefinedToolDefParameter( name="liquid_name", parameter_type="string", description="The name of the liquid", required=True, ), - "celcius": Parameter( + "celcius": UserDefinedToolDefParameter( name="celcius", parameter_type="boolean", description="Whether to return the boiling point in Celcius", @@ -205,16 +205,16 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config): def test_custom_tool(llama_stack_client, agent_config): - custom_tool = TestCustomTool() + client_tool = TestClientTool() agent_config = { **agent_config, "model": "meta-llama/Llama-3.2-3B-Instruct", "tool_names": ["brave_search"], - "client_tools": [custom_tool.get_tool_definition()], + "client_tools": [client_tool.get_tool_definition()], "tool_prompt_format": "python_list", } - agent = Agent(llama_stack_client, agent_config, client_tools=(custom_tool,)) + agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,)) session_id = agent.create_session(f"test-session-{uuid4()}") response = agent.create_turn(