From 6abaf4574d84114d330d77e24b441211370ffdea Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Mon, 16 Dec 2024 12:04:05 -0800 Subject: [PATCH] test custom tool --- .../agents/get_boiling_point_tool.py | 69 ------------------ tests/client-sdk/agents/test_agents.py | 72 +++++++++++++++++-- 2 files changed, 66 insertions(+), 75 deletions(-) delete mode 100644 tests/client-sdk/agents/get_boiling_point_tool.py diff --git a/tests/client-sdk/agents/get_boiling_point_tool.py b/tests/client-sdk/agents/get_boiling_point_tool.py deleted file mode 100644 index 212e90bfe..000000000 --- a/tests/client-sdk/agents/get_boiling_point_tool.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. -import json -from typing import Dict, List - -from llama_stack_client.lib.agents.custom_tool import CustomTool -from llama_stack_client.types import CompletionMessage, ToolResponseMessage -from llama_stack_client.types.tool_param_definition_param import ( - ToolParamDefinitionParam, -) - - -class GetBoilingPointTool(CustomTool): - """Tool to give boiling point of a liquid - Returns the correct value for water in Celcius and Fahrenheit - and returns -1 for other liquids - - """ - - def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]: - assert len(messages) == 1, "Expected single message" - - message = messages[0] - - tool_call = message.tool_calls[0] - - try: - response = self.run_impl(**tool_call.arguments) - response_str = json.dumps(response, ensure_ascii=False) - except Exception as e: - response_str = f"Error when running tool: {e}" - - message = ToolResponseMessage( - call_id=tool_call.call_id, - tool_name=tool_call.tool_name, - content=response_str, - role="ipython", - ) - return [message] - - def get_name(self) -> str: - return "get_boiling_point" - - def get_description(self) -> str: - return "Get the boiling point of a imaginary liquids (eg. polyjuice)" - - def get_params_definition(self) -> Dict[str, ToolParamDefinitionParam]: - return { - "liquid_name": ToolParamDefinitionParam( - param_type="string", description="The name of the liquid", required=True - ), - "celcius": ToolParamDefinitionParam( - param_type="boolean", - description="Whether to return the boiling point in Celcius", - required=False, - ), - } - - def run_impl(self, liquid_name: str, celcius: bool = True) -> int: - if liquid_name.lower() == "polyjuice": - if celcius: - return -100 - else: - return -212 - else: - return -1 diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 02953fb0c..a0e8c973f 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -4,15 +4,77 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json +from typing import Dict, List from uuid import uuid4 from llama_stack.providers.tests.env import get_env_or_fail from llama_stack_client.lib.agents.agent import Agent -from llama_stack_client.lib.agents.event_logger import EventLogger -from llama_stack_client.types.agent_create_params import AgentConfig -from .get_boiling_point_tool import GetBoilingPointTool +from llama_stack_client.lib.agents.custom_tool import CustomTool +from llama_stack_client.lib.agents.event_logger import EventLogger +from llama_stack_client.types import CompletionMessage, ToolResponseMessage +from llama_stack_client.types.agent_create_params import AgentConfig +from llama_stack_client.types.tool_param_definition_param import ( + ToolParamDefinitionParam, +) + + +class TestCustomTool(CustomTool): + """Tool to give boiling point of a liquid + Returns the correct value for water in Celcius and Fahrenheit + and returns -1 for other liquids + + """ + + def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]: + assert len(messages) == 1, "Expected single message" + + message = messages[0] + + tool_call = message.tool_calls[0] + + try: + response = self.run_impl(**tool_call.arguments) + response_str = json.dumps(response, ensure_ascii=False) + except Exception as e: + response_str = f"Error when running tool: {e}" + + message = ToolResponseMessage( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=response_str, + role="ipython", + ) + return [message] + + def get_name(self) -> str: + return "get_boiling_point" + + def get_description(self) -> str: + return "Get the boiling point of a imaginary liquids (eg. polyjuice)" + + def get_params_definition(self) -> Dict[str, ToolParamDefinitionParam]: + return { + "liquid_name": ToolParamDefinitionParam( + param_type="string", description="The name of the liquid", required=True + ), + "celcius": ToolParamDefinitionParam( + param_type="boolean", + description="Whether to return the boiling point in Celcius", + required=False, + ), + } + + def run_impl(self, liquid_name: str, celcius: bool = True) -> int: + if liquid_name.lower() == "polyjuice": + if celcius: + return -100 + else: + return -212 + else: + return -1 def get_agent_config_with_available_models_shields(llama_stack_client): @@ -167,9 +229,7 @@ def test_custom_tool(llama_stack_client): ] agent_config["tool_prompt_format"] = "python_list" - agent = Agent( - llama_stack_client, agent_config, custom_tools=(GetBoilingPointTool(),) - ) + agent = Agent(llama_stack_client, agent_config, custom_tools=(TestCustomTool(),)) session_id = agent.create_session(f"test-session-{uuid4()}") response = agent.create_turn(