From 517bc9ebea49e405ac6e98c7c9cbfff8e47b61ed Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 23 Dec 2024 12:12:09 -0800 Subject: [PATCH] add back custom tool tests --- tests/client-sdk/agents/test_agents.py | 114 +++++++++++++++++++++++++ 1 file changed, 114 insertions(+) diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 4e335d8d3..7939259d1 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -4,12 +4,76 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json +from typing import Dict, List from uuid import uuid4 import pytest + +from llama_stack.providers.tests.env import get_env_or_fail from llama_stack_client.lib.agents.agent import Agent +from llama_stack_client.lib.agents.custom_tool import CustomTool from llama_stack_client.lib.agents.event_logger import EventLogger +from llama_stack_client.types import CompletionMessage, ToolResponseMessage from llama_stack_client.types.agent_create_params import AgentConfig +from llama_stack_client.types.tool_param_definition_param import ( + ToolParamDefinitionParam, +) + + +class TestCustomTool(CustomTool): + """Tool to give boiling point of a liquid + Returns the correct value for water in Celcius and Fahrenheit + and returns -1 for other liquids + """ + + def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]: + assert len(messages) == 1, "Expected single message" + + message = messages[0] + + tool_call = message.tool_calls[0] + + try: + response = self.run_impl(**tool_call.arguments) + response_str = json.dumps(response, ensure_ascii=False) + except Exception as e: + response_str = f"Error when running tool: {e}" + + message = ToolResponseMessage( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=response_str, + role="ipython", + ) + return [message] + + def get_name(self) -> str: + return "get_boiling_point" + + def get_description(self) -> str: + return "Get the boiling point of a imaginary liquids (eg. polyjuice)" + + def get_params_definition(self) -> Dict[str, ToolParamDefinitionParam]: + return { + "liquid_name": ToolParamDefinitionParam( + param_type="string", description="The name of the liquid", required=True + ), + "celcius": ToolParamDefinitionParam( + param_type="boolean", + description="Whether to return the boiling point in Celcius", + required=False, + ), + } + + def run_impl(self, liquid_name: str, celcius: bool = True) -> int: + if liquid_name.lower() == "polyjuice": + if celcius: + return -100 + else: + return -212 + else: + return -1 @pytest.fixture(scope="session") @@ -136,3 +200,53 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config): assert "541" in logs_str assert "Tool:code_interpreter Response" in logs_str + + +def test_custom_tool(llama_stack_client, agent_config): + agent_config = { + **agent_config, + "model": "meta-llama/Llama-3.2-3B-Instruct", + "tools": [ + { + "type": "brave_search", + "engine": "brave", + "api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"), + }, + { + "function_name": "get_boiling_point", + "description": "Get the boiling point of a imaginary liquids (eg. polyjuice)", + "parameters": { + "liquid_name": { + "param_type": "str", + "description": "The name of the liquid", + "required": True, + }, + "celcius": { + "param_type": "boolean", + "description": "Whether to return the boiling point in Celcius", + "required": False, + }, + }, + "type": "function_call", + }, + ], + "tool_prompt_format": "python_list", + } + + agent = Agent(llama_stack_client, agent_config, custom_tools=(TestCustomTool(),)) + session_id = agent.create_session(f"test-session-{uuid4()}") + + response = agent.create_turn( + messages=[ + { + "role": "user", + "content": "What is the boiling point of polyjuice?", + }, + ], + session_id=session_id, + ) + + logs = [str(log) for log in EventLogger().log(response) if log is not None] + logs_str = "".join(logs) + assert "-100" in logs_str + assert "CustomTool" in logs_str