mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 09:21:45 +00:00
update client sdk tests
This commit is contained in:
parent
a945ab53d0
commit
ee542a7373
1 changed files with 9 additions and 9 deletions
|
@ -10,16 +10,16 @@ from uuid import uuid4
|
|||
|
||||
import pytest
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.lib.agents.custom_tool import CustomTool
|
||||
from llama_stack_client.lib.agents.client_tool import ClientTool
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client.types import ToolResponseMessage
|
||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||
from llama_stack_client.types.custom_tool_def import Parameter
|
||||
from llama_stack_client.types.memory_insert_params import Document
|
||||
from llama_stack_client.types.shared.completion_message import CompletionMessage
|
||||
from llama_stack_client.types.tool_def_param import UserDefinedToolDefParameter
|
||||
|
||||
|
||||
class TestCustomTool(CustomTool):
|
||||
class TestClientTool(ClientTool):
|
||||
"""Tool to give boiling point of a liquid
|
||||
Returns the correct value for polyjuice in Celcius and Fahrenheit
|
||||
and returns -1 for other liquids
|
||||
|
@ -52,15 +52,15 @@ class TestCustomTool(CustomTool):
|
|||
def get_description(self) -> str:
|
||||
return "Get the boiling point of imaginary liquids (eg. polyjuice)"
|
||||
|
||||
def get_params_definition(self) -> Dict[str, Parameter]:
|
||||
def get_params_definition(self) -> Dict[str, UserDefinedToolDefParameter]:
|
||||
return {
|
||||
"liquid_name": Parameter(
|
||||
"liquid_name": UserDefinedToolDefParameter(
|
||||
name="liquid_name",
|
||||
parameter_type="string",
|
||||
description="The name of the liquid",
|
||||
required=True,
|
||||
),
|
||||
"celcius": Parameter(
|
||||
"celcius": UserDefinedToolDefParameter(
|
||||
name="celcius",
|
||||
parameter_type="boolean",
|
||||
description="Whether to return the boiling point in Celcius",
|
||||
|
@ -205,16 +205,16 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config):
|
|||
|
||||
|
||||
def test_custom_tool(llama_stack_client, agent_config):
|
||||
custom_tool = TestCustomTool()
|
||||
client_tool = TestClientTool()
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"model": "meta-llama/Llama-3.2-3B-Instruct",
|
||||
"tool_names": ["brave_search"],
|
||||
"client_tools": [custom_tool.get_tool_definition()],
|
||||
"client_tools": [client_tool.get_tool_definition()],
|
||||
"tool_prompt_format": "python_list",
|
||||
}
|
||||
|
||||
agent = Agent(llama_stack_client, agent_config, client_tools=(custom_tool,))
|
||||
agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,))
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
response = agent.create_turn(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue