update client sdk tests

This commit is contained in:
Dinesh Yeduguru 2024-12-30 16:57:17 -08:00
parent a945ab53d0
commit ee542a7373

View file

@ -10,16 +10,16 @@ from uuid import uuid4
import pytest
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.custom_tool import CustomTool
from llama_stack_client.lib.agents.client_tool import ClientTool
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types import ToolResponseMessage
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types.custom_tool_def import Parameter
from llama_stack_client.types.memory_insert_params import Document
from llama_stack_client.types.shared.completion_message import CompletionMessage
from llama_stack_client.types.tool_def_param import UserDefinedToolDefParameter
class TestCustomTool(CustomTool):
class TestClientTool(ClientTool):
"""Tool to give boiling point of a liquid
Returns the correct value for polyjuice in Celcius and Fahrenheit
and returns -1 for other liquids
@ -52,15 +52,15 @@ class TestCustomTool(CustomTool):
def get_description(self) -> str:
return "Get the boiling point of imaginary liquids (eg. polyjuice)"
def get_params_definition(self) -> Dict[str, Parameter]:
def get_params_definition(self) -> Dict[str, UserDefinedToolDefParameter]:
return {
"liquid_name": Parameter(
"liquid_name": UserDefinedToolDefParameter(
name="liquid_name",
parameter_type="string",
description="The name of the liquid",
required=True,
),
"celcius": Parameter(
"celcius": UserDefinedToolDefParameter(
name="celcius",
parameter_type="boolean",
description="Whether to return the boiling point in Celcius",
@ -205,16 +205,16 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config):
def test_custom_tool(llama_stack_client, agent_config):
custom_tool = TestCustomTool()
client_tool = TestClientTool()
agent_config = {
**agent_config,
"model": "meta-llama/Llama-3.2-3B-Instruct",
"tool_names": ["brave_search"],
"client_tools": [custom_tool.get_tool_definition()],
"client_tools": [client_tool.get_tool_definition()],
"tool_prompt_format": "python_list",
}
agent = Agent(llama_stack_client, agent_config, client_tools=(custom_tool,))
agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,))
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(