mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 18:22:41 +00:00
update client sdk tests
This commit is contained in:
parent
a945ab53d0
commit
ee542a7373
1 changed files with 9 additions and 9 deletions
|
@ -10,16 +10,16 @@ from uuid import uuid4
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from llama_stack_client.lib.agents.agent import Agent
|
from llama_stack_client.lib.agents.agent import Agent
|
||||||
from llama_stack_client.lib.agents.custom_tool import CustomTool
|
from llama_stack_client.lib.agents.client_tool import ClientTool
|
||||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||||
from llama_stack_client.types import ToolResponseMessage
|
from llama_stack_client.types import ToolResponseMessage
|
||||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||||
from llama_stack_client.types.custom_tool_def import Parameter
|
|
||||||
from llama_stack_client.types.memory_insert_params import Document
|
from llama_stack_client.types.memory_insert_params import Document
|
||||||
from llama_stack_client.types.shared.completion_message import CompletionMessage
|
from llama_stack_client.types.shared.completion_message import CompletionMessage
|
||||||
|
from llama_stack_client.types.tool_def_param import UserDefinedToolDefParameter
|
||||||
|
|
||||||
|
|
||||||
class TestCustomTool(CustomTool):
|
class TestClientTool(ClientTool):
|
||||||
"""Tool to give boiling point of a liquid
|
"""Tool to give boiling point of a liquid
|
||||||
Returns the correct value for polyjuice in Celcius and Fahrenheit
|
Returns the correct value for polyjuice in Celcius and Fahrenheit
|
||||||
and returns -1 for other liquids
|
and returns -1 for other liquids
|
||||||
|
@ -52,15 +52,15 @@ class TestCustomTool(CustomTool):
|
||||||
def get_description(self) -> str:
|
def get_description(self) -> str:
|
||||||
return "Get the boiling point of imaginary liquids (eg. polyjuice)"
|
return "Get the boiling point of imaginary liquids (eg. polyjuice)"
|
||||||
|
|
||||||
def get_params_definition(self) -> Dict[str, Parameter]:
|
def get_params_definition(self) -> Dict[str, UserDefinedToolDefParameter]:
|
||||||
return {
|
return {
|
||||||
"liquid_name": Parameter(
|
"liquid_name": UserDefinedToolDefParameter(
|
||||||
name="liquid_name",
|
name="liquid_name",
|
||||||
parameter_type="string",
|
parameter_type="string",
|
||||||
description="The name of the liquid",
|
description="The name of the liquid",
|
||||||
required=True,
|
required=True,
|
||||||
),
|
),
|
||||||
"celcius": Parameter(
|
"celcius": UserDefinedToolDefParameter(
|
||||||
name="celcius",
|
name="celcius",
|
||||||
parameter_type="boolean",
|
parameter_type="boolean",
|
||||||
description="Whether to return the boiling point in Celcius",
|
description="Whether to return the boiling point in Celcius",
|
||||||
|
@ -205,16 +205,16 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config):
|
||||||
|
|
||||||
|
|
||||||
def test_custom_tool(llama_stack_client, agent_config):
|
def test_custom_tool(llama_stack_client, agent_config):
|
||||||
custom_tool = TestCustomTool()
|
client_tool = TestClientTool()
|
||||||
agent_config = {
|
agent_config = {
|
||||||
**agent_config,
|
**agent_config,
|
||||||
"model": "meta-llama/Llama-3.2-3B-Instruct",
|
"model": "meta-llama/Llama-3.2-3B-Instruct",
|
||||||
"tool_names": ["brave_search"],
|
"tool_names": ["brave_search"],
|
||||||
"client_tools": [custom_tool.get_tool_definition()],
|
"client_tools": [client_tool.get_tool_definition()],
|
||||||
"tool_prompt_format": "python_list",
|
"tool_prompt_format": "python_list",
|
||||||
}
|
}
|
||||||
|
|
||||||
agent = Agent(llama_stack_client, agent_config, client_tools=(custom_tool,))
|
agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,))
|
||||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||||
|
|
||||||
response = agent.create_turn(
|
response = agent.create_turn(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue