diff --git a/tests/unit/providers/agent/test_meta_reference_agent.py b/tests/unit/providers/agent/test_meta_reference_agent.py index c06d9ab0e..4d9c22c21 100644 --- a/tests/unit/providers/agent/test_meta_reference_agent.py +++ b/tests/unit/providers/agent/test_meta_reference_agent.py @@ -16,9 +16,11 @@ from llama_stack.apis.agents import ( ) from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.apis.inference import Inference +from llama_stack.apis.resource import ResourceType from llama_stack.apis.safety import Safety -from llama_stack.apis.tools import ToolGroups, ToolRuntime +from llama_stack.apis.tools import ToolGroups, ToolRuntime, ListToolsResponse, Tool, ToolParameter from llama_stack.apis.vector_io import VectorIO +from llama_stack.providers.inline.agents.meta_reference.agent_instance import ChatAgent from llama_stack.providers.inline.agents.meta_reference.agents import MetaReferenceAgentsImpl from llama_stack.providers.inline.agents.meta_reference.config import MetaReferenceAgentsImplConfig from llama_stack.providers.inline.agents.meta_reference.persistence import AgentInfo @@ -75,11 +77,11 @@ def sample_agent_config(): }, input_shields=["string"], output_shields=["string"], - toolgroups=["string"], + toolgroups=["mcp::my_mcp_server"], client_tools=[ { - "name": "string", - "description": "string", + "name": "client_tool", + "description": "Client Tool", "parameters": [ { "name": "string", @@ -226,3 +228,83 @@ async def test_delete_agent(agents_impl, sample_agent_config): # Verify the agent was deleted with pytest.raises(ValueError): await agents_impl.get_agent(agent_id) + + +async def test__initialize_tools(agents_impl, sample_agent_config): + # Mock tool_groups_api.list_tools() + agents_impl.tool_groups_api.list_tools.return_value = ListToolsResponse( + data=[ + Tool( + identifier="story_maker", + provider_id="model-context-protocol", + type=ResourceType.tool, + toolgroup_id="mcp::my_mcp_server", + description="Make a story", + parameters=[ + ToolParameter( + name="story_title", + parameter_type="string", + description="Title of the story", + required=True, + title="Story Title", + ), + ToolParameter( + name="input_words", + parameter_type="array", + description="Input words", + required=False, + items={"type": "string"}, + title="Input Words", + default=[], + ), + ], + ) + ] + ) + + create_response = await agents_impl.create_agent(sample_agent_config) + agent_id = create_response.agent_id + + # Get an instance of ChatAgent + chat_agent = await agents_impl._get_agent_impl(agent_id) + assert chat_agent is not None + assert isinstance(chat_agent, ChatAgent) + + # Initialize tool definitions + await chat_agent._initialize_tools() + assert len(chat_agent.tool_defs) == 2 + + # Verify the first tool, which is a client tool + first_tool = chat_agent.tool_defs[0] + assert first_tool.tool_name == "client_tool" + assert first_tool.description == "Client Tool" + + # Verify the second tool, which is an MCP tool that has an array-type property + second_tool = chat_agent.tool_defs[1] + assert second_tool.tool_name == "story_maker" + assert second_tool.description == "Make a story" + + parameters = second_tool.parameters + assert len(parameters) == 2 + + # Verify a string property + story_title = parameters.get("story_title") + assert story_title is not None + assert story_title.param_type == "string" + assert story_title.description == "Title of the story" + assert story_title.required == True + assert story_title.items is None + assert story_title.title == "Story Title" + assert story_title.default is None + + # Verify an array property + input_words = parameters.get("input_words") + assert input_words is not None + assert input_words.param_type == "array" + assert input_words.description == "Input words" + assert input_words.required == False + assert input_words.items is not None + assert len(input_words.items) == 1 + assert input_words.items.get("type") == "string" + assert input_words.title == "Input Words" + assert input_words.default == []