mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 22:19:49 +00:00
working end to end client sdk tests with custom tools
This commit is contained in:
parent
1a66ddc1b5
commit
4dd2f4c363
5 changed files with 304 additions and 149 deletions
|
|
@ -9,16 +9,13 @@ from typing import Dict, List
|
|||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.providers.tests.env import get_env_or_fail
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.lib.agents.custom_tool import CustomTool
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client.types import CompletionMessage, ToolResponseMessage
|
||||
from llama_stack_client.types import ToolResponseMessage
|
||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||
from llama_stack_client.types.tool_param_definition_param import (
|
||||
ToolParamDefinitionParam,
|
||||
)
|
||||
from llama_stack_client.types.custom_tool_def import Parameter
|
||||
from llama_stack_client.types.shared.completion_message import CompletionMessage
|
||||
|
||||
|
||||
class TestCustomTool(CustomTool):
|
||||
|
|
@ -54,13 +51,17 @@ class TestCustomTool(CustomTool):
|
|||
def get_description(self) -> str:
|
||||
return "Get the boiling point of a imaginary liquids (eg. polyjuice)"
|
||||
|
||||
def get_params_definition(self) -> Dict[str, ToolParamDefinitionParam]:
|
||||
def get_params_definition(self) -> Dict[str, Parameter]:
|
||||
return {
|
||||
"liquid_name": ToolParamDefinitionParam(
|
||||
param_type="string", description="The name of the liquid", required=True
|
||||
"liquid_name": Parameter(
|
||||
name="liquid_name",
|
||||
parameter_type="string",
|
||||
description="The name of the liquid",
|
||||
required=True,
|
||||
),
|
||||
"celcius": ToolParamDefinitionParam(
|
||||
param_type="boolean",
|
||||
"celcius": Parameter(
|
||||
name="celcius",
|
||||
parameter_type="boolean",
|
||||
description="Whether to return the boiling point in Celcius",
|
||||
required=False,
|
||||
),
|
||||
|
|
@ -203,37 +204,16 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config):
|
|||
|
||||
|
||||
def test_custom_tool(llama_stack_client, agent_config):
|
||||
custom_tool = TestCustomTool()
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"model": "meta-llama/Llama-3.2-3B-Instruct",
|
||||
"tools": [
|
||||
{
|
||||
"type": "brave_search",
|
||||
"engine": "brave",
|
||||
"api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"),
|
||||
},
|
||||
{
|
||||
"function_name": "get_boiling_point",
|
||||
"description": "Get the boiling point of a imaginary liquids (eg. polyjuice)",
|
||||
"parameters": {
|
||||
"liquid_name": {
|
||||
"param_type": "str",
|
||||
"description": "The name of the liquid",
|
||||
"required": True,
|
||||
},
|
||||
"celcius": {
|
||||
"param_type": "boolean",
|
||||
"description": "Whether to return the boiling point in Celcius",
|
||||
"required": False,
|
||||
},
|
||||
},
|
||||
"type": "function_call",
|
||||
},
|
||||
],
|
||||
"available_tools": ["brave_search"],
|
||||
"custom_tools": [custom_tool.get_tool_definition()],
|
||||
"tool_prompt_format": "python_list",
|
||||
}
|
||||
|
||||
agent = Agent(llama_stack_client, agent_config, custom_tools=(TestCustomTool(),))
|
||||
agent = Agent(llama_stack_client, agent_config, custom_tools=(custom_tool,))
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
response = agent.create_turn(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue