mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 16:54:42 +00:00
test custom tool
This commit is contained in:
parent
b1f311982f
commit
6abaf4574d
2 changed files with 66 additions and 75 deletions
|
@ -1,69 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
import json
|
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
from llama_stack_client.lib.agents.custom_tool import CustomTool
|
|
||||||
from llama_stack_client.types import CompletionMessage, ToolResponseMessage
|
|
||||||
from llama_stack_client.types.tool_param_definition_param import (
|
|
||||||
ToolParamDefinitionParam,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GetBoilingPointTool(CustomTool):
|
|
||||||
"""Tool to give boiling point of a liquid
|
|
||||||
Returns the correct value for water in Celcius and Fahrenheit
|
|
||||||
and returns -1 for other liquids
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:
|
|
||||||
assert len(messages) == 1, "Expected single message"
|
|
||||||
|
|
||||||
message = messages[0]
|
|
||||||
|
|
||||||
tool_call = message.tool_calls[0]
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = self.run_impl(**tool_call.arguments)
|
|
||||||
response_str = json.dumps(response, ensure_ascii=False)
|
|
||||||
except Exception as e:
|
|
||||||
response_str = f"Error when running tool: {e}"
|
|
||||||
|
|
||||||
message = ToolResponseMessage(
|
|
||||||
call_id=tool_call.call_id,
|
|
||||||
tool_name=tool_call.tool_name,
|
|
||||||
content=response_str,
|
|
||||||
role="ipython",
|
|
||||||
)
|
|
||||||
return [message]
|
|
||||||
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "get_boiling_point"
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "Get the boiling point of a imaginary liquids (eg. polyjuice)"
|
|
||||||
|
|
||||||
def get_params_definition(self) -> Dict[str, ToolParamDefinitionParam]:
|
|
||||||
return {
|
|
||||||
"liquid_name": ToolParamDefinitionParam(
|
|
||||||
param_type="string", description="The name of the liquid", required=True
|
|
||||||
),
|
|
||||||
"celcius": ToolParamDefinitionParam(
|
|
||||||
param_type="boolean",
|
|
||||||
description="Whether to return the boiling point in Celcius",
|
|
||||||
required=False,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
def run_impl(self, liquid_name: str, celcius: bool = True) -> int:
|
|
||||||
if liquid_name.lower() == "polyjuice":
|
|
||||||
if celcius:
|
|
||||||
return -100
|
|
||||||
else:
|
|
||||||
return -212
|
|
||||||
else:
|
|
||||||
return -1
|
|
|
@ -4,15 +4,77 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Dict, List
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from llama_stack.providers.tests.env import get_env_or_fail
|
from llama_stack.providers.tests.env import get_env_or_fail
|
||||||
|
|
||||||
from llama_stack_client.lib.agents.agent import Agent
|
from llama_stack_client.lib.agents.agent import Agent
|
||||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
|
||||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
|
||||||
|
|
||||||
from .get_boiling_point_tool import GetBoilingPointTool
|
from llama_stack_client.lib.agents.custom_tool import CustomTool
|
||||||
|
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||||
|
from llama_stack_client.types import CompletionMessage, ToolResponseMessage
|
||||||
|
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||||
|
from llama_stack_client.types.tool_param_definition_param import (
|
||||||
|
ToolParamDefinitionParam,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCustomTool(CustomTool):
|
||||||
|
"""Tool to give boiling point of a liquid
|
||||||
|
Returns the correct value for water in Celcius and Fahrenheit
|
||||||
|
and returns -1 for other liquids
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:
|
||||||
|
assert len(messages) == 1, "Expected single message"
|
||||||
|
|
||||||
|
message = messages[0]
|
||||||
|
|
||||||
|
tool_call = message.tool_calls[0]
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.run_impl(**tool_call.arguments)
|
||||||
|
response_str = json.dumps(response, ensure_ascii=False)
|
||||||
|
except Exception as e:
|
||||||
|
response_str = f"Error when running tool: {e}"
|
||||||
|
|
||||||
|
message = ToolResponseMessage(
|
||||||
|
call_id=tool_call.call_id,
|
||||||
|
tool_name=tool_call.tool_name,
|
||||||
|
content=response_str,
|
||||||
|
role="ipython",
|
||||||
|
)
|
||||||
|
return [message]
|
||||||
|
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "get_boiling_point"
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return "Get the boiling point of a imaginary liquids (eg. polyjuice)"
|
||||||
|
|
||||||
|
def get_params_definition(self) -> Dict[str, ToolParamDefinitionParam]:
|
||||||
|
return {
|
||||||
|
"liquid_name": ToolParamDefinitionParam(
|
||||||
|
param_type="string", description="The name of the liquid", required=True
|
||||||
|
),
|
||||||
|
"celcius": ToolParamDefinitionParam(
|
||||||
|
param_type="boolean",
|
||||||
|
description="Whether to return the boiling point in Celcius",
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
def run_impl(self, liquid_name: str, celcius: bool = True) -> int:
|
||||||
|
if liquid_name.lower() == "polyjuice":
|
||||||
|
if celcius:
|
||||||
|
return -100
|
||||||
|
else:
|
||||||
|
return -212
|
||||||
|
else:
|
||||||
|
return -1
|
||||||
|
|
||||||
|
|
||||||
def get_agent_config_with_available_models_shields(llama_stack_client):
|
def get_agent_config_with_available_models_shields(llama_stack_client):
|
||||||
|
@ -167,9 +229,7 @@ def test_custom_tool(llama_stack_client):
|
||||||
]
|
]
|
||||||
agent_config["tool_prompt_format"] = "python_list"
|
agent_config["tool_prompt_format"] = "python_list"
|
||||||
|
|
||||||
agent = Agent(
|
agent = Agent(llama_stack_client, agent_config, custom_tools=(TestCustomTool(),))
|
||||||
llama_stack_client, agent_config, custom_tools=(GetBoilingPointTool(),)
|
|
||||||
)
|
|
||||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||||
|
|
||||||
response = agent.create_turn(
|
response = agent.create_turn(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue