working end to end client sdk tests with custom tools

This commit is contained in:
Dinesh Yeduguru 2024-12-23 18:27:55 -08:00
parent 1a66ddc1b5
commit 4dd2f4c363
5 changed files with 304 additions and 149 deletions

View file

@ -9,16 +9,13 @@ from typing import Dict, List
from uuid import uuid4
import pytest
from llama_stack.providers.tests.env import get_env_or_fail
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.custom_tool import CustomTool
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types import CompletionMessage, ToolResponseMessage
from llama_stack_client.types import ToolResponseMessage
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types.tool_param_definition_param import (
ToolParamDefinitionParam,
)
from llama_stack_client.types.custom_tool_def import Parameter
from llama_stack_client.types.shared.completion_message import CompletionMessage
class TestCustomTool(CustomTool):
@ -54,13 +51,17 @@ class TestCustomTool(CustomTool):
def get_description(self) -> str:
return "Get the boiling point of a imaginary liquids (eg. polyjuice)"
def get_params_definition(self) -> Dict[str, ToolParamDefinitionParam]:
def get_params_definition(self) -> Dict[str, Parameter]:
return {
"liquid_name": ToolParamDefinitionParam(
param_type="string", description="The name of the liquid", required=True
"liquid_name": Parameter(
name="liquid_name",
parameter_type="string",
description="The name of the liquid",
required=True,
),
"celcius": ToolParamDefinitionParam(
param_type="boolean",
"celcius": Parameter(
name="celcius",
parameter_type="boolean",
description="Whether to return the boiling point in Celcius",
required=False,
),
@ -203,37 +204,16 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config):
def test_custom_tool(llama_stack_client, agent_config):
custom_tool = TestCustomTool()
agent_config = {
**agent_config,
"model": "meta-llama/Llama-3.2-3B-Instruct",
"tools": [
{
"type": "brave_search",
"engine": "brave",
"api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"),
},
{
"function_name": "get_boiling_point",
"description": "Get the boiling point of a imaginary liquids (eg. polyjuice)",
"parameters": {
"liquid_name": {
"param_type": "str",
"description": "The name of the liquid",
"required": True,
},
"celcius": {
"param_type": "boolean",
"description": "Whether to return the boiling point in Celcius",
"required": False,
},
},
"type": "function_call",
},
],
"available_tools": ["brave_search"],
"custom_tools": [custom_tool.get_tool_definition()],
"tool_prompt_format": "python_list",
}
agent = Agent(llama_stack_client, agent_config, custom_tools=(TestCustomTool(),))
agent = Agent(llama_stack_client, agent_config, custom_tools=(custom_tool,))
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(