test custom tool

This commit is contained in:
Xi Yan 2024-12-16 12:04:05 -08:00
parent b1f311982f
commit 6abaf4574d
2 changed files with 66 additions and 75 deletions

View file

@ -1,69 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Dict, List
from llama_stack_client.lib.agents.custom_tool import CustomTool
from llama_stack_client.types import CompletionMessage, ToolResponseMessage
from llama_stack_client.types.tool_param_definition_param import (
ToolParamDefinitionParam,
)
class GetBoilingPointTool(CustomTool):
"""Tool to give boiling point of a liquid
Returns the correct value for water in Celcius and Fahrenheit
and returns -1 for other liquids
"""
def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:
assert len(messages) == 1, "Expected single message"
message = messages[0]
tool_call = message.tool_calls[0]
try:
response = self.run_impl(**tool_call.arguments)
response_str = json.dumps(response, ensure_ascii=False)
except Exception as e:
response_str = f"Error when running tool: {e}"
message = ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=response_str,
role="ipython",
)
return [message]
def get_name(self) -> str:
return "get_boiling_point"
def get_description(self) -> str:
return "Get the boiling point of a imaginary liquids (eg. polyjuice)"
def get_params_definition(self) -> Dict[str, ToolParamDefinitionParam]:
return {
"liquid_name": ToolParamDefinitionParam(
param_type="string", description="The name of the liquid", required=True
),
"celcius": ToolParamDefinitionParam(
param_type="boolean",
description="Whether to return the boiling point in Celcius",
required=False,
),
}
def run_impl(self, liquid_name: str, celcius: bool = True) -> int:
if liquid_name.lower() == "polyjuice":
if celcius:
return -100
else:
return -212
else:
return -1

View file

@ -4,15 +4,77 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Dict, List
from uuid import uuid4
from llama_stack.providers.tests.env import get_env_or_fail
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.agent_create_params import AgentConfig
from .get_boiling_point_tool import GetBoilingPointTool
from llama_stack_client.lib.agents.custom_tool import CustomTool
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types import CompletionMessage, ToolResponseMessage
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types.tool_param_definition_param import (
ToolParamDefinitionParam,
)
class TestCustomTool(CustomTool):
"""Tool to give boiling point of a liquid
Returns the correct value for water in Celcius and Fahrenheit
and returns -1 for other liquids
"""
def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:
assert len(messages) == 1, "Expected single message"
message = messages[0]
tool_call = message.tool_calls[0]
try:
response = self.run_impl(**tool_call.arguments)
response_str = json.dumps(response, ensure_ascii=False)
except Exception as e:
response_str = f"Error when running tool: {e}"
message = ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=response_str,
role="ipython",
)
return [message]
def get_name(self) -> str:
return "get_boiling_point"
def get_description(self) -> str:
return "Get the boiling point of a imaginary liquids (eg. polyjuice)"
def get_params_definition(self) -> Dict[str, ToolParamDefinitionParam]:
return {
"liquid_name": ToolParamDefinitionParam(
param_type="string", description="The name of the liquid", required=True
),
"celcius": ToolParamDefinitionParam(
param_type="boolean",
description="Whether to return the boiling point in Celcius",
required=False,
),
}
def run_impl(self, liquid_name: str, celcius: bool = True) -> int:
if liquid_name.lower() == "polyjuice":
if celcius:
return -100
else:
return -212
else:
return -1
def get_agent_config_with_available_models_shields(llama_stack_client):
@ -167,9 +229,7 @@ def test_custom_tool(llama_stack_client):
]
agent_config["tool_prompt_format"] = "python_list"
agent = Agent(
llama_stack_client, agent_config, custom_tools=(GetBoilingPointTool(),)
)
agent = Agent(llama_stack_client, agent_config, custom_tools=(TestCustomTool(),))
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(