mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 10:13:05 +00:00
minor fixes
This commit is contained in:
parent
6632d7e410
commit
c3865faf37
6 changed files with 19 additions and 9 deletions
|
@ -4012,7 +4012,7 @@
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"tools": {
|
"toolgroups": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
"$ref": "#/components/schemas/AgentTool"
|
"$ref": "#/components/schemas/AgentTool"
|
||||||
|
|
|
@ -624,7 +624,7 @@ components:
|
||||||
type: string
|
type: string
|
||||||
stream:
|
stream:
|
||||||
type: boolean
|
type: boolean
|
||||||
tools:
|
toolgroups:
|
||||||
items:
|
items:
|
||||||
$ref: '#/components/schemas/AgentTool'
|
$ref: '#/components/schemas/AgentTool'
|
||||||
type: array
|
type: array
|
||||||
|
|
|
@ -317,7 +317,7 @@ class Agents(Protocol):
|
||||||
],
|
],
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
documents: Optional[List[Document]] = None,
|
documents: Optional[List[Document]] = None,
|
||||||
tools: Optional[List[AgentToolGroup]] = None,
|
toolgroups: Optional[List[AgentToolGroup]] = None,
|
||||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
||||||
|
|
||||||
@webmethod(route="/agents/turn/get")
|
@webmethod(route="/agents/turn/get")
|
||||||
|
|
|
@ -147,7 +147,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
tools: Optional[List[AgentToolGroup]] = None,
|
toolgroups: Optional[List[AgentToolGroup]] = None,
|
||||||
documents: Optional[List[Document]] = None,
|
documents: Optional[List[Document]] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
@ -156,7 +156,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
stream=True,
|
stream=True,
|
||||||
tools=tools,
|
toolgroups=toolgroups,
|
||||||
documents=documents,
|
documents=documents,
|
||||||
)
|
)
|
||||||
if stream:
|
if stream:
|
||||||
|
|
|
@ -8,6 +8,7 @@ import os
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||||
|
|
||||||
from llama_stack.apis.agents import (
|
from llama_stack.apis.agents import (
|
||||||
AgentConfig,
|
AgentConfig,
|
||||||
|
@ -335,7 +336,7 @@ class TestAgents:
|
||||||
assert isinstance(tool_execution, ToolExecutionStep)
|
assert isinstance(tool_execution, ToolExecutionStep)
|
||||||
assert len(tool_execution.tool_calls) > 0
|
assert len(tool_execution.tool_calls) > 0
|
||||||
actual_tool_name = tool_execution.tool_calls[0].tool_name
|
actual_tool_name = tool_execution.tool_calls[0].tool_name
|
||||||
assert actual_tool_name == "web_search"
|
assert actual_tool_name == BuiltinTool.brave_search
|
||||||
assert len(tool_execution.tool_responses) > 0
|
assert len(tool_execution.tool_responses) > 0
|
||||||
|
|
||||||
check_turn_complete_event(turn_response, session_id, search_query_messages)
|
check_turn_complete_event(turn_response, session_id, search_query_messages)
|
||||||
|
|
|
@ -9,7 +9,7 @@ from typing import Dict, List
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from llama_stack_client.lib.agents.agent import Agent, maybe_register_memory_tool
|
from llama_stack_client.lib.agents.agent import Agent
|
||||||
from llama_stack_client.lib.agents.client_tool import ClientTool
|
from llama_stack_client.lib.agents.client_tool import ClientTool
|
||||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||||
from llama_stack_client.types import ToolResponseMessage
|
from llama_stack_client.types import ToolResponseMessage
|
||||||
|
@ -293,9 +293,18 @@ def test_rag_agent(llama_stack_client, agent_config):
|
||||||
for i, url in enumerate(urls)
|
for i, url in enumerate(urls)
|
||||||
]
|
]
|
||||||
|
|
||||||
tool_name, memory_bank_id = maybe_register_memory_tool(llama_stack_client)
|
agent_config["tools"].append("builtin::memory")
|
||||||
agent_config["tools"].append(tool_name)
|
|
||||||
agent = Agent(llama_stack_client, agent_config)
|
agent = Agent(llama_stack_client, agent_config)
|
||||||
|
memory_bank_id = "test-memory-bank"
|
||||||
|
llama_stack_client.memory_banks.register(
|
||||||
|
memory_bank_id=memory_bank_id,
|
||||||
|
params={
|
||||||
|
"memory_bank_type": "vector",
|
||||||
|
"embedding_model": "all-MiniLM-L6-v2",
|
||||||
|
"chunk_size_in_tokens": 512,
|
||||||
|
"overlap_size_in_tokens": 64,
|
||||||
|
},
|
||||||
|
)
|
||||||
llama_stack_client.memory.insert(
|
llama_stack_client.memory.insert(
|
||||||
bank_id=memory_bank_id,
|
bank_id=memory_bank_id,
|
||||||
documents=documents,
|
documents=documents,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue