mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 00:19:40 +00:00
rename UserDefinedToolDef to ToolDef
This commit is contained in:
parent
db0b2a60c1
commit
e3775eb6f6
8 changed files with 180 additions and 322 deletions
|
|
@ -17,7 +17,7 @@ from llama_stack_client.types.agent_create_params import AgentConfig
|
|||
from llama_stack_client.types.agents.turn_create_params import Document as AgentDocument
|
||||
from llama_stack_client.types.memory_insert_params import Document
|
||||
from llama_stack_client.types.shared.completion_message import CompletionMessage
|
||||
from llama_stack_client.types.tool_def_param import UserDefinedToolDefParameter
|
||||
from llama_stack_client.types.tool_def_param import Parameter
|
||||
|
||||
|
||||
class TestClientTool(ClientTool):
|
||||
|
|
@ -53,15 +53,15 @@ class TestClientTool(ClientTool):
|
|||
def get_description(self) -> str:
|
||||
return "Get the boiling point of imaginary liquids (eg. polyjuice)"
|
||||
|
||||
def get_params_definition(self) -> Dict[str, UserDefinedToolDefParameter]:
|
||||
def get_params_definition(self) -> Dict[str, Parameter]:
|
||||
return {
|
||||
"liquid_name": UserDefinedToolDefParameter(
|
||||
"liquid_name": Parameter(
|
||||
name="liquid_name",
|
||||
parameter_type="string",
|
||||
description="The name of the liquid",
|
||||
required=True,
|
||||
),
|
||||
"celcius": UserDefinedToolDefParameter(
|
||||
"celcius": Parameter(
|
||||
name="celcius",
|
||||
parameter_type="boolean",
|
||||
description="Whether to return the boiling point in Celcius",
|
||||
|
|
@ -149,11 +149,11 @@ def test_agent_simple(llama_stack_client, agent_config):
|
|||
assert "I can't" in logs_str
|
||||
|
||||
|
||||
def test_builtin_tool_brave_search(llama_stack_client, agent_config):
|
||||
def test_builtin_tool_web_search(llama_stack_client, agent_config):
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"tools": [
|
||||
"brave_search",
|
||||
"builtin::web_search",
|
||||
],
|
||||
}
|
||||
agent = Agent(llama_stack_client, agent_config)
|
||||
|
|
@ -182,7 +182,7 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config):
|
|||
agent_config = {
|
||||
**agent_config,
|
||||
"tools": [
|
||||
"code_interpreter",
|
||||
"builtin::code_interpreter",
|
||||
],
|
||||
}
|
||||
agent = Agent(llama_stack_client, agent_config)
|
||||
|
|
@ -209,9 +209,9 @@ def test_code_execution(llama_stack_client):
|
|||
model="meta-llama/Llama-3.1-70B-Instruct",
|
||||
instructions="You are a helpful assistant",
|
||||
tools=[
|
||||
"code_interpreter",
|
||||
"builtin::code_interpreter",
|
||||
],
|
||||
tool_choice="auto",
|
||||
tool_choice="required",
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
enable_session_persistence=False,
|
||||
|
|
@ -242,7 +242,7 @@ def test_code_execution(llama_stack_client):
|
|||
)
|
||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||
logs_str = "".join(logs)
|
||||
print(logs_str)
|
||||
assert "Tool:code_interpreter" in logs_str
|
||||
|
||||
|
||||
def test_custom_tool(llama_stack_client, agent_config):
|
||||
|
|
@ -250,7 +250,7 @@ def test_custom_tool(llama_stack_client, agent_config):
|
|||
agent_config = {
|
||||
**agent_config,
|
||||
"model": "meta-llama/Llama-3.2-3B-Instruct",
|
||||
"tools": ["brave_search"],
|
||||
"tools": ["builtin::web_search"],
|
||||
"client_tools": [client_tool.get_tool_definition()],
|
||||
"tool_prompt_format": "python_list",
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue