minor fixes

This commit is contained in:
Dinesh Yeduguru 2025-01-07 16:37:39 -08:00
parent 6632d7e410
commit c3865faf37
6 changed files with 19 additions and 9 deletions

View file

@ -4012,7 +4012,7 @@
] ]
} }
}, },
"tools": { "toolgroups": {
"type": "array", "type": "array",
"items": { "items": {
"$ref": "#/components/schemas/AgentTool" "$ref": "#/components/schemas/AgentTool"

View file

@ -624,7 +624,7 @@ components:
type: string type: string
stream: stream:
type: boolean type: boolean
tools: toolgroups:
items: items:
$ref: '#/components/schemas/AgentTool' $ref: '#/components/schemas/AgentTool'
type: array type: array

View file

@ -317,7 +317,7 @@ class Agents(Protocol):
], ],
stream: Optional[bool] = False, stream: Optional[bool] = False,
documents: Optional[List[Document]] = None, documents: Optional[List[Document]] = None,
tools: Optional[List[AgentToolGroup]] = None, toolgroups: Optional[List[AgentToolGroup]] = None,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ... ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
@webmethod(route="/agents/turn/get") @webmethod(route="/agents/turn/get")

View file

@ -147,7 +147,7 @@ class MetaReferenceAgentsImpl(Agents):
ToolResponseMessage, ToolResponseMessage,
] ]
], ],
tools: Optional[List[AgentToolGroup]] = None, toolgroups: Optional[List[AgentToolGroup]] = None,
documents: Optional[List[Document]] = None, documents: Optional[List[Document]] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
) -> AsyncGenerator: ) -> AsyncGenerator:
@ -156,7 +156,7 @@ class MetaReferenceAgentsImpl(Agents):
session_id=session_id, session_id=session_id,
messages=messages, messages=messages,
stream=True, stream=True,
tools=tools, toolgroups=toolgroups,
documents=documents, documents=documents,
) )
if stream: if stream:

View file

@ -8,6 +8,7 @@ import os
from typing import Dict, List from typing import Dict, List
import pytest import pytest
from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.agents import ( from llama_stack.apis.agents import (
AgentConfig, AgentConfig,
@ -335,7 +336,7 @@ class TestAgents:
assert isinstance(tool_execution, ToolExecutionStep) assert isinstance(tool_execution, ToolExecutionStep)
assert len(tool_execution.tool_calls) > 0 assert len(tool_execution.tool_calls) > 0
actual_tool_name = tool_execution.tool_calls[0].tool_name actual_tool_name = tool_execution.tool_calls[0].tool_name
assert actual_tool_name == "web_search" assert actual_tool_name == BuiltinTool.brave_search
assert len(tool_execution.tool_responses) > 0 assert len(tool_execution.tool_responses) > 0
check_turn_complete_event(turn_response, session_id, search_query_messages) check_turn_complete_event(turn_response, session_id, search_query_messages)

View file

@ -9,7 +9,7 @@ from typing import Dict, List
from uuid import uuid4 from uuid import uuid4
import pytest import pytest
from llama_stack_client.lib.agents.agent import Agent, maybe_register_memory_tool from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.client_tool import ClientTool from llama_stack_client.lib.agents.client_tool import ClientTool
from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types import ToolResponseMessage from llama_stack_client.types import ToolResponseMessage
@ -293,9 +293,18 @@ def test_rag_agent(llama_stack_client, agent_config):
for i, url in enumerate(urls) for i, url in enumerate(urls)
] ]
tool_name, memory_bank_id = maybe_register_memory_tool(llama_stack_client) agent_config["tools"].append("builtin::memory")
agent_config["tools"].append(tool_name)
agent = Agent(llama_stack_client, agent_config) agent = Agent(llama_stack_client, agent_config)
memory_bank_id = "test-memory-bank"
llama_stack_client.memory_banks.register(
memory_bank_id=memory_bank_id,
params={
"memory_bank_type": "vector",
"embedding_model": "all-MiniLM-L6-v2",
"chunk_size_in_tokens": 512,
"overlap_size_in_tokens": 64,
},
)
llama_stack_client.memory.insert( llama_stack_client.memory.insert(
bank_id=memory_bank_id, bank_id=memory_bank_id,
documents=documents, documents=documents,