Unit test

This commit is contained in:
Tami Takamiya 2025-08-01 08:08:59 -04:00
parent dbaf38acc5
commit 309ae82a8a
No known key found for this signature in database
GPG key ID: B23C22B8CB1E70C7

View file

@ -16,9 +16,11 @@ from llama_stack.apis.agents import (
)
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.inference import Inference
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.tools import ToolGroups, ToolRuntime, ListToolsResponse, Tool, ToolParameter
from llama_stack.apis.vector_io import VectorIO
from llama_stack.providers.inline.agents.meta_reference.agent_instance import ChatAgent
from llama_stack.providers.inline.agents.meta_reference.agents import MetaReferenceAgentsImpl
from llama_stack.providers.inline.agents.meta_reference.config import MetaReferenceAgentsImplConfig
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentInfo
@ -75,11 +77,11 @@ def sample_agent_config():
},
input_shields=["string"],
output_shields=["string"],
toolgroups=["string"],
toolgroups=["mcp::my_mcp_server"],
client_tools=[
{
"name": "string",
"description": "string",
"name": "client_tool",
"description": "Client Tool",
"parameters": [
{
"name": "string",
@ -226,3 +228,83 @@ async def test_delete_agent(agents_impl, sample_agent_config):
# Verify the agent was deleted
with pytest.raises(ValueError):
await agents_impl.get_agent(agent_id)
async def test__initialize_tools(agents_impl, sample_agent_config):
# Mock tool_groups_api.list_tools()
agents_impl.tool_groups_api.list_tools.return_value = ListToolsResponse(
data=[
Tool(
identifier="story_maker",
provider_id="model-context-protocol",
type=ResourceType.tool,
toolgroup_id="mcp::my_mcp_server",
description="Make a story",
parameters=[
ToolParameter(
name="story_title",
parameter_type="string",
description="Title of the story",
required=True,
title="Story Title",
),
ToolParameter(
name="input_words",
parameter_type="array",
description="Input words",
required=False,
items={"type": "string"},
title="Input Words",
default=[],
),
],
)
]
)
create_response = await agents_impl.create_agent(sample_agent_config)
agent_id = create_response.agent_id
# Get an instance of ChatAgent
chat_agent = await agents_impl._get_agent_impl(agent_id)
assert chat_agent is not None
assert isinstance(chat_agent, ChatAgent)
# Initialize tool definitions
await chat_agent._initialize_tools()
assert len(chat_agent.tool_defs) == 2
# Verify the first tool, which is a client tool
first_tool = chat_agent.tool_defs[0]
assert first_tool.tool_name == "client_tool"
assert first_tool.description == "Client Tool"
# Verify the second tool, which is an MCP tool that has an array-type property
second_tool = chat_agent.tool_defs[1]
assert second_tool.tool_name == "story_maker"
assert second_tool.description == "Make a story"
parameters = second_tool.parameters
assert len(parameters) == 2
# Verify a string property
story_title = parameters.get("story_title")
assert story_title is not None
assert story_title.param_type == "string"
assert story_title.description == "Title of the story"
assert story_title.required == True
assert story_title.items is None
assert story_title.title == "Story Title"
assert story_title.default is None
# Verify an array property
input_words = parameters.get("input_words")
assert input_words is not None
assert input_words.param_type == "array"
assert input_words.description == "Input words"
assert input_words.required == False
assert input_words.items is not None
assert len(input_words.items) == 1
assert input_words.items.get("type") == "string"
assert input_words.title == "Input Words"
assert input_words.default == []