minor fixes

This commit is contained in:
Dinesh Yeduguru 2025-01-07 16:37:39 -08:00
parent 6632d7e410
commit c3865faf37
6 changed files with 19 additions and 9 deletions

View file

@ -4012,7 +4012,7 @@
]
}
},
"tools": {
"toolgroups": {
"type": "array",
"items": {
"$ref": "#/components/schemas/AgentTool"

View file

@ -624,7 +624,7 @@ components:
type: string
stream:
type: boolean
tools:
toolgroups:
items:
$ref: '#/components/schemas/AgentTool'
type: array

View file

@ -317,7 +317,7 @@ class Agents(Protocol):
],
stream: Optional[bool] = False,
documents: Optional[List[Document]] = None,
tools: Optional[List[AgentToolGroup]] = None,
toolgroups: Optional[List[AgentToolGroup]] = None,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
@webmethod(route="/agents/turn/get")

View file

@ -147,7 +147,7 @@ class MetaReferenceAgentsImpl(Agents):
ToolResponseMessage,
]
],
tools: Optional[List[AgentToolGroup]] = None,
toolgroups: Optional[List[AgentToolGroup]] = None,
documents: Optional[List[Document]] = None,
stream: Optional[bool] = False,
) -> AsyncGenerator:
@ -156,7 +156,7 @@ class MetaReferenceAgentsImpl(Agents):
session_id=session_id,
messages=messages,
stream=True,
tools=tools,
toolgroups=toolgroups,
documents=documents,
)
if stream:

View file

@ -8,6 +8,7 @@ import os
from typing import Dict, List
import pytest
from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.agents import (
AgentConfig,
@ -335,7 +336,7 @@ class TestAgents:
assert isinstance(tool_execution, ToolExecutionStep)
assert len(tool_execution.tool_calls) > 0
actual_tool_name = tool_execution.tool_calls[0].tool_name
assert actual_tool_name == "web_search"
assert actual_tool_name == BuiltinTool.brave_search
assert len(tool_execution.tool_responses) > 0
check_turn_complete_event(turn_response, session_id, search_query_messages)

View file

@ -9,7 +9,7 @@ from typing import Dict, List
from uuid import uuid4
import pytest
from llama_stack_client.lib.agents.agent import Agent, maybe_register_memory_tool
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.client_tool import ClientTool
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types import ToolResponseMessage
@ -293,9 +293,18 @@ def test_rag_agent(llama_stack_client, agent_config):
for i, url in enumerate(urls)
]
tool_name, memory_bank_id = maybe_register_memory_tool(llama_stack_client)
agent_config["tools"].append(tool_name)
agent_config["tools"].append("builtin::memory")
agent = Agent(llama_stack_client, agent_config)
memory_bank_id = "test-memory-bank"
llama_stack_client.memory_banks.register(
memory_bank_id=memory_bank_id,
params={
"memory_bank_type": "vector",
"embedding_model": "all-MiniLM-L6-v2",
"chunk_size_in_tokens": 512,
"overlap_size_in_tokens": 64,
},
)
llama_stack_client.memory.insert(
bank_id=memory_bank_id,
documents=documents,