mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 10:13:05 +00:00
add unit tests for chat agent
This commit is contained in:
parent
db2ec110a1
commit
854fef7478
4 changed files with 262 additions and 207 deletions
|
@ -13,7 +13,7 @@ import secrets
|
||||||
import string
|
import string
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -76,7 +76,6 @@ def make_random_string(length: int = 8):
|
||||||
|
|
||||||
|
|
||||||
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
||||||
MEMORY_TOOL_GROUP_ID = "builtin::memory"
|
|
||||||
MEMORY_QUERY_TOOL = "query_memory"
|
MEMORY_QUERY_TOOL = "query_memory"
|
||||||
WEB_SEARCH_TOOL = "web_search"
|
WEB_SEARCH_TOOL = "web_search"
|
||||||
|
|
||||||
|
@ -382,6 +381,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
session_id, documents, input_messages, tool_defs
|
session_id, documents, input_messages, tool_defs
|
||||||
)
|
)
|
||||||
if MEMORY_QUERY_TOOL in tool_defs and len(input_messages) > 0:
|
if MEMORY_QUERY_TOOL in tool_defs and len(input_messages) > 0:
|
||||||
|
memory_tool_group = tool_to_group.get(MEMORY_QUERY_TOOL, None)
|
||||||
|
if memory_tool_group is None:
|
||||||
|
raise ValueError(f"Memory tool group not found for {MEMORY_QUERY_TOOL}")
|
||||||
with tracing.span(MEMORY_QUERY_TOOL) as span:
|
with tracing.span(MEMORY_QUERY_TOOL) as span:
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
|
@ -394,7 +396,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
query_args = {
|
query_args = {
|
||||||
"messages": [msg.content for msg in input_messages],
|
"messages": [msg.content for msg in input_messages],
|
||||||
**toolgroup_args.get(MEMORY_TOOL_GROUP_ID, {}),
|
**toolgroup_args.get(memory_tool_group, {}),
|
||||||
}
|
}
|
||||||
|
|
||||||
session_info = await self.storage.get_session_info(session_id)
|
session_info = await self.storage.get_session_info(session_id)
|
||||||
|
@ -484,14 +486,20 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
|
|
||||||
with tracing.span("inference") as span:
|
with tracing.span("inference") as span:
|
||||||
|
|
||||||
|
def is_memory_group(tool):
|
||||||
|
memory_tool_group = tool_to_group.get(MEMORY_QUERY_TOOL, None)
|
||||||
|
has_memory_tool = MEMORY_QUERY_TOOL in tool_defs
|
||||||
|
return (
|
||||||
|
has_memory_tool
|
||||||
|
and tool_to_group.get(tool.tool_name, None) != memory_tool_group
|
||||||
|
)
|
||||||
|
|
||||||
async for chunk in await self.inference_api.chat_completion(
|
async for chunk in await self.inference_api.chat_completion(
|
||||||
self.agent_config.model,
|
self.agent_config.model,
|
||||||
input_messages,
|
input_messages,
|
||||||
tools=[
|
tools=[
|
||||||
tool
|
tool for tool in tool_defs.values() if not is_memory_group(tool)
|
||||||
for tool in tool_defs.values()
|
|
||||||
if tool_to_group.get(tool.tool_name, None)
|
|
||||||
!= MEMORY_TOOL_GROUP_ID
|
|
||||||
],
|
],
|
||||||
tool_prompt_format=self.agent_config.tool_prompt_format,
|
tool_prompt_format=self.agent_config.tool_prompt_format,
|
||||||
stream=True,
|
stream=True,
|
||||||
|
@ -698,8 +706,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
n_iter += 1
|
n_iter += 1
|
||||||
|
|
||||||
async def _get_tool_defs(
|
async def _get_tool_defs(
|
||||||
self, toolgroups_for_turn: Optional[List[AgentToolGroup]]
|
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
|
||||||
) -> Dict[str, ToolDefinition]:
|
) -> Tuple[Dict[str, ToolDefinition], Dict[str, str]]:
|
||||||
# Determine which tools to include
|
# Determine which tools to include
|
||||||
agent_config_toolgroups = set(
|
agent_config_toolgroups = set(
|
||||||
(
|
(
|
||||||
|
|
|
@ -4,21 +4,25 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import tempfile
|
||||||
from typing import AsyncIterator, List, Optional, Union
|
from typing import AsyncIterator, List, Optional, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||||
|
|
||||||
from llama_stack.apis.agents import (
|
from llama_stack.apis.agents import (
|
||||||
AgentConfig,
|
AgentConfig,
|
||||||
AgentTurnCreateRequest,
|
AgentTurnCreateRequest,
|
||||||
AgentTurnResponseTurnCompletePayload,
|
AgentTurnResponseTurnCompletePayload,
|
||||||
|
StepType,
|
||||||
)
|
)
|
||||||
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ChatCompletionResponseEvent,
|
ChatCompletionResponseEvent,
|
||||||
ChatCompletionResponseStreamChunk,
|
ChatCompletionResponseStreamChunk,
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
@ -27,13 +31,24 @@ from llama_stack.apis.inference import (
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.memory import MemoryBank
|
from llama_stack.apis.memory import MemoryBank
|
||||||
|
from llama_stack.apis.memory_banks import BankParams, VectorMemoryBank
|
||||||
from llama_stack.apis.safety import RunShieldResponse
|
from llama_stack.apis.safety import RunShieldResponse
|
||||||
|
from llama_stack.apis.tools import (
|
||||||
from ..agents import (
|
Tool,
|
||||||
AGENT_INSTANCES_BY_ID,
|
ToolDef,
|
||||||
MetaReferenceAgentsImpl,
|
ToolGroup,
|
||||||
MetaReferenceInferenceConfig,
|
ToolHost,
|
||||||
|
ToolInvocationResult,
|
||||||
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
||||||
|
MEMORY_QUERY_TOOL,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.inline.agents.meta_reference.agents import (
|
||||||
|
MetaReferenceAgentsImpl,
|
||||||
|
MetaReferenceAgentsImplConfig,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
|
|
||||||
|
|
||||||
class MockInferenceAPI:
|
class MockInferenceAPI:
|
||||||
|
@ -48,10 +63,10 @@ class MockInferenceAPI:
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncIterator[
|
) -> Union[
|
||||||
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
|
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||||
]:
|
]:
|
||||||
if stream:
|
async def stream_response():
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type="start",
|
event_type="start",
|
||||||
|
@ -65,19 +80,7 @@ class MockInferenceAPI:
|
||||||
delta="AI is a fascinating field...",
|
delta="AI is a fascinating field...",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# yield ChatCompletionResponseStreamChunk(
|
|
||||||
# event=ChatCompletionResponseEvent(
|
|
||||||
# event_type="progress",
|
|
||||||
# delta=ToolCallDelta(
|
|
||||||
# content=ToolCall(
|
|
||||||
# call_id="123",
|
|
||||||
# tool_name=BuiltinTool.brave_search.value,
|
|
||||||
# arguments={"query": "AI history"},
|
|
||||||
# ),
|
|
||||||
# parse_status="success",
|
|
||||||
# ),
|
|
||||||
# )
|
|
||||||
# )
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type="complete",
|
event_type="complete",
|
||||||
|
@ -85,12 +88,17 @@ class MockInferenceAPI:
|
||||||
stop_reason="end_of_turn",
|
stop_reason="end_of_turn",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return stream_response()
|
||||||
else:
|
else:
|
||||||
yield ChatCompletionResponse(
|
return ChatCompletionResponse(
|
||||||
completion_message=CompletionMessage(
|
completion_message=CompletionMessage(
|
||||||
role="assistant", content="Mock response", stop_reason="end_of_turn"
|
role="assistant",
|
||||||
|
content="Mock response",
|
||||||
|
stop_reason="end_of_turn",
|
||||||
),
|
),
|
||||||
logprobs=[0.1, 0.2, 0.3] if logprobs else None,
|
logprobs={"token_logprobs": [0.1, 0.2, 0.3]} if logprobs else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -165,6 +173,99 @@ class MockMemoryAPI:
|
||||||
self.documents[bank_id].pop(doc_id, None)
|
self.documents[bank_id].pop(doc_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
class MockToolGroupsAPI:
|
||||||
|
async def register_tool_group(
|
||||||
|
self, toolgroup_id: str, provider_id: str, mcp_endpoint=None, args=None
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
|
||||||
|
return ToolGroup(
|
||||||
|
identifier=toolgroup_id,
|
||||||
|
provider_resource_id=toolgroup_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def list_tool_groups(self) -> List[ToolGroup]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
|
||||||
|
if tool_group_id == MEMORY_TOOLGROUP:
|
||||||
|
return [
|
||||||
|
Tool(
|
||||||
|
identifier=MEMORY_QUERY_TOOL,
|
||||||
|
provider_resource_id=MEMORY_QUERY_TOOL,
|
||||||
|
toolgroup_id=MEMORY_TOOLGROUP,
|
||||||
|
tool_host=ToolHost.client,
|
||||||
|
description="Mock tool",
|
||||||
|
provider_id="mock_provider",
|
||||||
|
parameters=[],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
if tool_group_id == CODE_INTERPRETER_TOOLGROUP:
|
||||||
|
return [
|
||||||
|
Tool(
|
||||||
|
identifier="code_interpreter",
|
||||||
|
provider_resource_id="code_interpreter",
|
||||||
|
toolgroup_id=CODE_INTERPRETER_TOOLGROUP,
|
||||||
|
built_in_type=BuiltinTool.code_interpreter,
|
||||||
|
tool_host=ToolHost.client,
|
||||||
|
description="Mock tool",
|
||||||
|
provider_id="mock_provider",
|
||||||
|
parameters=[],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_tool(self, tool_name: str) -> Tool:
|
||||||
|
return Tool(
|
||||||
|
identifier=tool_name,
|
||||||
|
provider_resource_id=tool_name,
|
||||||
|
toolgroup_id="mock_group",
|
||||||
|
tool_host=ToolHost.client,
|
||||||
|
description="Mock tool",
|
||||||
|
provider_id="mock_provider",
|
||||||
|
parameters=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def unregister_tool_group(self, tool_group_id: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MockToolRuntimeAPI:
|
||||||
|
async def list_runtime_tools(
|
||||||
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
|
) -> List[ToolDef]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def invoke_tool(self, tool_name: str, args: dict) -> ToolInvocationResult:
|
||||||
|
return ToolInvocationResult(content={"result": "Mock tool result"})
|
||||||
|
|
||||||
|
|
||||||
|
class MockMemoryBanksAPI:
|
||||||
|
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def register_memory_bank(
|
||||||
|
self,
|
||||||
|
memory_bank_id: str,
|
||||||
|
params: BankParams,
|
||||||
|
provider_id: Optional[str] = None,
|
||||||
|
provider_memory_bank_id: Optional[str] = None,
|
||||||
|
) -> MemoryBank:
|
||||||
|
return VectorMemoryBank(
|
||||||
|
identifier=memory_bank_id,
|
||||||
|
provider_resource_id=provider_memory_bank_id or memory_bank_id,
|
||||||
|
embedding_model="mock_model",
|
||||||
|
chunk_size_in_tokens=512,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_inference_api():
|
def mock_inference_api():
|
||||||
return MockInferenceAPI()
|
return MockInferenceAPI()
|
||||||
|
@ -181,64 +282,107 @@ def mock_memory_api():
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def chat_agent(mock_inference_api, mock_safety_api, mock_memory_api):
|
def mock_tool_groups_api():
|
||||||
|
return MockToolGroupsAPI()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_tool_runtime_api():
|
||||||
|
return MockToolRuntimeAPI()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_memory_banks_api():
|
||||||
|
return MockMemoryBanksAPI()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def get_agents_impl(
|
||||||
|
mock_inference_api,
|
||||||
|
mock_safety_api,
|
||||||
|
mock_memory_api,
|
||||||
|
mock_memory_banks_api,
|
||||||
|
mock_tool_runtime_api,
|
||||||
|
mock_tool_groups_api,
|
||||||
|
):
|
||||||
|
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||||
impl = MetaReferenceAgentsImpl(
|
impl = MetaReferenceAgentsImpl(
|
||||||
config=MetaReferenceInferenceConfig(),
|
config=MetaReferenceAgentsImplConfig(
|
||||||
|
persistence_store=SqliteKVStoreConfig(
|
||||||
|
db_name=sqlite_file.name,
|
||||||
|
),
|
||||||
|
),
|
||||||
inference_api=mock_inference_api,
|
inference_api=mock_inference_api,
|
||||||
safety_api=mock_safety_api,
|
safety_api=mock_safety_api,
|
||||||
memory_api=mock_memory_api,
|
memory_api=mock_memory_api,
|
||||||
|
memory_banks_api=mock_memory_banks_api,
|
||||||
|
tool_runtime_api=mock_tool_runtime_api,
|
||||||
|
tool_groups_api=mock_tool_groups_api,
|
||||||
)
|
)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def get_chat_agent(get_agents_impl):
|
||||||
|
impl = await get_agents_impl
|
||||||
agent_config = AgentConfig(
|
agent_config = AgentConfig(
|
||||||
model="test_model",
|
model="test_model",
|
||||||
instructions="You are a helpful assistant.",
|
instructions="You are a helpful assistant.",
|
||||||
sampling_params=SamplingParams(),
|
toolgroups=[],
|
||||||
tools=[
|
|
||||||
# SearchToolDefinition(
|
|
||||||
# name="brave_search",
|
|
||||||
# api_key="test_key",
|
|
||||||
# ),
|
|
||||||
],
|
|
||||||
tool_choice=ToolChoice.auto,
|
tool_choice=ToolChoice.auto,
|
||||||
enable_session_persistence=False,
|
enable_session_persistence=False,
|
||||||
input_shields=[],
|
input_shields=["test_shield"],
|
||||||
output_shields=[],
|
|
||||||
)
|
)
|
||||||
response = await impl.create_agent(agent_config)
|
response = await impl.create_agent(agent_config)
|
||||||
agent = AGENT_INSTANCES_BY_ID[response.agent_id]
|
return await impl.get_agent(response.agent_id)
|
||||||
return agent
|
|
||||||
|
|
||||||
|
MEMORY_TOOLGROUP = "builtin::memory"
|
||||||
|
CODE_INTERPRETER_TOOLGROUP = "builtin::code_interpreter"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def get_chat_agent_with_tools(get_agents_impl, request):
|
||||||
|
impl = await get_agents_impl
|
||||||
|
toolgroups = request.param
|
||||||
|
agent_config = AgentConfig(
|
||||||
|
model="test_model",
|
||||||
|
instructions="You are a helpful assistant.",
|
||||||
|
toolgroups=toolgroups,
|
||||||
|
tool_choice=ToolChoice.auto,
|
||||||
|
enable_session_persistence=False,
|
||||||
|
input_shields=["test_shield"],
|
||||||
|
)
|
||||||
|
response = await impl.create_agent(agent_config)
|
||||||
|
return await impl.get_agent(response.agent_id)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_chat_agent_create_session(chat_agent):
|
async def test_chat_agent_create_and_execute_turn(get_chat_agent):
|
||||||
session = chat_agent.create_session("Test Session")
|
chat_agent = await get_chat_agent
|
||||||
assert session.session_name == "Test Session"
|
session_id = await chat_agent.create_session("Test Session")
|
||||||
assert session.turns == []
|
|
||||||
assert session.session_id in chat_agent.sessions
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_chat_agent_create_and_execute_turn(chat_agent):
|
|
||||||
session = chat_agent.create_session("Test Session")
|
|
||||||
request = AgentTurnCreateRequest(
|
request = AgentTurnCreateRequest(
|
||||||
agent_id="random",
|
agent_id=chat_agent.agent_id,
|
||||||
session_id=session.session_id,
|
session_id=session_id,
|
||||||
messages=[UserMessage(content="Hello")],
|
messages=[UserMessage(content="Hello")],
|
||||||
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
responses = []
|
responses = []
|
||||||
async for response in chat_agent.create_and_execute_turn(request):
|
async for response in chat_agent.create_and_execute_turn(request):
|
||||||
responses.append(response)
|
responses.append(response)
|
||||||
|
|
||||||
print(responses)
|
|
||||||
assert len(responses) > 0
|
assert len(responses) > 0
|
||||||
assert len(responses) == 4 # TurnStart, StepStart, StepComplete, TurnComplete
|
assert (
|
||||||
|
len(responses) == 7
|
||||||
|
) # TurnStart, ShieldCallStart, ShieldCallComplete, StepStart, StepProgress, StepComplete, TurnComplete
|
||||||
assert responses[0].event.payload.turn_id is not None
|
assert responses[0].event.payload.turn_id is not None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_multiple_shields_wrapper(chat_agent):
|
async def test_run_multiple_shields_wrapper(get_chat_agent):
|
||||||
|
chat_agent = await get_chat_agent
|
||||||
messages = [UserMessage(content="Test message")]
|
messages = [UserMessage(content="Test message")]
|
||||||
shields = ["test_shield"]
|
shields = ["test_shield"]
|
||||||
|
|
||||||
|
@ -254,69 +398,83 @@ async def test_run_multiple_shields_wrapper(chat_agent):
|
||||||
|
|
||||||
assert len(responses) == 2 # StepStart, StepComplete
|
assert len(responses) == 2 # StepStart, StepComplete
|
||||||
assert responses[0].event.payload.step_type.value == "shield_call"
|
assert responses[0].event.payload.step_type.value == "shield_call"
|
||||||
assert not responses[1].event.payload.step_details.response.is_violation
|
assert not responses[1].event.payload.step_details.violation
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.skip(reason="Not yet implemented; need to mock out tool execution easily")
|
async def test_chat_agent_complex_turn(get_chat_agent):
|
||||||
async def test_chat_agent_complex_turn(chat_agent):
|
chat_agent = await get_chat_agent
|
||||||
# Setup
|
session_id = await chat_agent.create_session("Test Session")
|
||||||
session = chat_agent.create_session("Test Session")
|
|
||||||
request = AgentTurnCreateRequest(
|
request = AgentTurnCreateRequest(
|
||||||
agent_id="random",
|
agent_id=chat_agent.agent_id,
|
||||||
session_id=session.session_id,
|
session_id=session_id,
|
||||||
messages=[UserMessage(content="Tell me about AI and then use a tool.")],
|
messages=[UserMessage(content="Tell me about AI and then use a tool.")],
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute the turn
|
|
||||||
responses = []
|
responses = []
|
||||||
async for response in chat_agent.create_and_execute_turn(request):
|
async for response in chat_agent.create_and_execute_turn(request):
|
||||||
responses.append(response)
|
responses.append(response)
|
||||||
|
|
||||||
# Assertions
|
|
||||||
assert len(responses) > 0
|
assert len(responses) > 0
|
||||||
|
|
||||||
# Check for the presence of different step types
|
|
||||||
step_types = [
|
step_types = [
|
||||||
response.event.payload.step_type
|
response.event.payload.step_type
|
||||||
for response in responses
|
for response in responses
|
||||||
if hasattr(response.event.payload, "step_type")
|
if hasattr(response.event.payload, "step_type")
|
||||||
]
|
]
|
||||||
|
|
||||||
assert "shield_call" in step_types, "Shield call step is missing"
|
assert StepType.shield_call in step_types, "Shield call step is missing"
|
||||||
assert "inference" in step_types, "Inference step is missing"
|
assert StepType.inference in step_types, "Inference step is missing"
|
||||||
assert "tool_execution" in step_types, "Tool execution step is missing"
|
|
||||||
|
|
||||||
# Check for the presence of start and complete events
|
|
||||||
event_types = [
|
event_types = [
|
||||||
response.event.payload.event_type
|
response.event.payload.event_type
|
||||||
for response in responses
|
for response in responses
|
||||||
if hasattr(response.event.payload, "event_type")
|
if hasattr(response.event.payload, "event_type")
|
||||||
]
|
]
|
||||||
assert "start" in event_types, "Start event is missing"
|
assert "turn_start" in event_types, "Start event is missing"
|
||||||
assert "complete" in event_types, "Complete event is missing"
|
assert "turn_complete" in event_types, "Complete event is missing"
|
||||||
|
|
||||||
# Check for the presence of tool call
|
|
||||||
tool_calls = [
|
|
||||||
response.event.payload.tool_call
|
|
||||||
for response in responses
|
|
||||||
if hasattr(response.event.payload, "tool_call")
|
|
||||||
]
|
|
||||||
assert any(
|
|
||||||
tool_call
|
|
||||||
for tool_call in tool_calls
|
|
||||||
if tool_call and tool_call.content.get("name") == "memory"
|
|
||||||
), "Memory tool call is missing"
|
|
||||||
|
|
||||||
# Check for the final turn complete event
|
|
||||||
assert any(
|
assert any(
|
||||||
isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload)
|
isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload)
|
||||||
for response in responses
|
for response in responses
|
||||||
), "Turn complete event is missing"
|
), "Turn complete event is missing"
|
||||||
|
turn_complete_payload = next(
|
||||||
|
response.event.payload
|
||||||
|
for response in responses
|
||||||
|
if isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload)
|
||||||
|
)
|
||||||
|
turn = turn_complete_payload.turn
|
||||||
|
assert turn.input_messages == request.messages, "Input messages do not match"
|
||||||
|
|
||||||
# Verify the turn was added to the session
|
|
||||||
assert len(session.turns) == 1, "Turn was not added to the session"
|
@pytest.mark.asyncio
|
||||||
assert (
|
@pytest.mark.parametrize(
|
||||||
session.turns[0].input_messages == request.messages
|
"toolgroups, expected_memory, expected_code_interpreter",
|
||||||
), "Input messages do not match"
|
[
|
||||||
|
([], False, False), # no tools
|
||||||
|
([MEMORY_TOOLGROUP], True, False), # memory only
|
||||||
|
([CODE_INTERPRETER_TOOLGROUP], False, True), # code interpreter only
|
||||||
|
([MEMORY_TOOLGROUP, CODE_INTERPRETER_TOOLGROUP], True, True), # all tools
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_chat_agent_tools(
|
||||||
|
get_agents_impl, toolgroups, expected_memory, expected_code_interpreter
|
||||||
|
):
|
||||||
|
impl = await get_agents_impl
|
||||||
|
agent_config = AgentConfig(
|
||||||
|
model="test_model",
|
||||||
|
instructions="You are a helpful assistant.",
|
||||||
|
toolgroups=toolgroups,
|
||||||
|
tool_choice=ToolChoice.auto,
|
||||||
|
enable_session_persistence=False,
|
||||||
|
input_shields=["test_shield"],
|
||||||
|
)
|
||||||
|
response = await impl.create_agent(agent_config)
|
||||||
|
chat_agent = await impl.get_agent(response.agent_id)
|
||||||
|
|
||||||
|
tool_defs, _ = await chat_agent._get_tool_defs()
|
||||||
|
if expected_memory:
|
||||||
|
assert MEMORY_QUERY_TOOL in tool_defs
|
||||||
|
if expected_code_interpreter:
|
||||||
|
assert BuiltinTool.code_interpreter in tool_defs
|
||||||
|
|
|
@ -1,42 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
|
||||||
from llama_stack.apis.safety import Safety
|
|
||||||
|
|
||||||
from ..safety import ShieldRunnerMixin
|
|
||||||
from .builtin import BaseTool
|
|
||||||
|
|
||||||
|
|
||||||
class SafeTool(BaseTool, ShieldRunnerMixin):
|
|
||||||
"""A tool that makes other tools safety enabled"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
tool: BaseTool,
|
|
||||||
safety_api: Safety,
|
|
||||||
input_shields: List[str] = None,
|
|
||||||
output_shields: List[str] = None,
|
|
||||||
):
|
|
||||||
self._tool = tool
|
|
||||||
ShieldRunnerMixin.__init__(
|
|
||||||
self, safety_api, input_shields=input_shields, output_shields=output_shields
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return self._tool.get_name()
|
|
||||||
|
|
||||||
async def run(self, messages: List[Message]) -> List[Message]:
|
|
||||||
if self.input_shields:
|
|
||||||
await self.run_multiple_shields(messages, self.input_shields)
|
|
||||||
# run the underlying tool
|
|
||||||
res = await self._tool.run(messages)
|
|
||||||
if self.output_shields:
|
|
||||||
await self.run_multiple_shields(messages, self.output_shields)
|
|
||||||
|
|
||||||
return res
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||||
|
@ -83,74 +82,6 @@ def query_attachment_messages():
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
async def create_agent_turn_with_toolgroup(
|
|
||||||
agents_stack: Dict[str, object],
|
|
||||||
search_query_messages: List[object],
|
|
||||||
common_params: Dict[str, str],
|
|
||||||
toolgroup_name: str,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Create an agent turn with a toolgroup.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agents_stack (Dict[str, object]): The agents stack.
|
|
||||||
search_query_messages (List[object]): The search query messages.
|
|
||||||
common_params (Dict[str, str]): The common parameters.
|
|
||||||
toolgroup_name (str): The name of the toolgroup.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Create an agent with the toolgroup
|
|
||||||
agent_config = AgentConfig(
|
|
||||||
**{
|
|
||||||
**common_params,
|
|
||||||
"toolgroups": [toolgroup_name],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
agent_id, session_id = await create_agent_session(
|
|
||||||
agents_stack.impls[Api.agents], agent_config
|
|
||||||
)
|
|
||||||
turn_request = dict(
|
|
||||||
agent_id=agent_id,
|
|
||||||
session_id=session_id,
|
|
||||||
messages=search_query_messages,
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
turn_response = [
|
|
||||||
chunk
|
|
||||||
async for chunk in await agents_stack.impls[Api.agents].create_agent_turn(
|
|
||||||
**turn_request
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
assert len(turn_response) > 0
|
|
||||||
assert all(
|
|
||||||
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
|
|
||||||
)
|
|
||||||
|
|
||||||
check_event_types(turn_response)
|
|
||||||
|
|
||||||
# Check for tool execution events
|
|
||||||
tool_execution_events = [
|
|
||||||
chunk
|
|
||||||
for chunk in turn_response
|
|
||||||
if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload)
|
|
||||||
and chunk.event.payload.step_details.step_type == StepType.tool_execution.value
|
|
||||||
]
|
|
||||||
assert len(tool_execution_events) > 0, "No tool execution events found"
|
|
||||||
|
|
||||||
# Check the tool execution details
|
|
||||||
tool_execution = tool_execution_events[0].event.payload.step_details
|
|
||||||
assert isinstance(tool_execution, ToolExecutionStep)
|
|
||||||
assert len(tool_execution.tool_calls) > 0
|
|
||||||
actual_tool_name = tool_execution.tool_calls[0].tool_name
|
|
||||||
assert actual_tool_name.value == tool_name
|
|
||||||
assert len(tool_execution.tool_responses) > 0
|
|
||||||
|
|
||||||
check_turn_complete_event(turn_response, session_id, search_query_messages)
|
|
||||||
|
|
||||||
|
|
||||||
class TestAgents:
|
class TestAgents:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_agent_turns_with_safety(
|
async def test_agent_turns_with_safety(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue