forked from phoenix-oss/llama-stack-mirror
Add a test runner and 2 very simple tests for agents
This commit is contained in:
parent
543222ac39
commit
abb43936ab
9 changed files with 243 additions and 5 deletions
2
.flake8
2
.flake8
|
@ -14,6 +14,8 @@ ignore =
|
|||
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
|
||||
# to line this up with executable bit
|
||||
EXE001,
|
||||
# random naming hints don't need
|
||||
N802,
|
||||
# these ignores are from flake8-bugbear; please fix!
|
||||
B007,B008,B950
|
||||
optional-ascii-coding = True
|
||||
|
|
|
@ -3,5 +3,3 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .control_plane import * # noqa: F401 F403
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
|
@ -619,6 +619,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
else:
|
||||
return True
|
||||
|
||||
print(f"{enabled_tools=}")
|
||||
return AgentTool.memory.value in enabled_tools
|
||||
|
||||
def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]:
|
||||
|
|
|
@ -43,12 +43,10 @@ class ShieldRunnerMixin:
|
|||
if len(messages) > 0 and messages[0].role != Role.user.value:
|
||||
messages[0] = UserMessage(content=messages[0].content)
|
||||
|
||||
res = await self.safety_api.run_shields(
|
||||
results = await self.safety_api.run_shields(
|
||||
messages=messages,
|
||||
shields=shields,
|
||||
)
|
||||
|
||||
results = res.responses
|
||||
for shield, r in zip(shields, results):
|
||||
if r.is_violation:
|
||||
if shield.on_violation_action == OnViolationAction.RAISE:
|
||||
|
|
|
@ -0,0 +1,216 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncIterator, List, Optional, Union
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
|
||||
from ..agent_instance import ChatAgent
|
||||
|
||||
|
||||
class MockInferenceAPI:
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = None,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncIterator[
|
||||
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
|
||||
]:
|
||||
if stream:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type="start",
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type="progress",
|
||||
delta="Mock response",
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type="complete",
|
||||
delta="",
|
||||
stop_reason="end_of_turn",
|
||||
)
|
||||
)
|
||||
else:
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
role="assistant", content="Mock response", stop_reason="end_of_turn"
|
||||
),
|
||||
logprobs=[0.1, 0.2, 0.3] if logprobs else None,
|
||||
)
|
||||
|
||||
|
||||
class MockSafetyAPI:
|
||||
async def run_shields(
|
||||
self, messages: List[Message], shields: List[MagicMock]
|
||||
) -> List[ShieldResponse]:
|
||||
return [ShieldResponse(shield_type="mock_shield", is_violation=False)]
|
||||
|
||||
|
||||
class MockMemoryAPI:
|
||||
def __init__(self):
|
||||
self.memory_banks = {}
|
||||
self.documents = {}
|
||||
|
||||
async def create_memory_bank(self, name, config, url=None):
|
||||
bank_id = f"bank_{len(self.memory_banks)}"
|
||||
bank = MemoryBank(bank_id, name, config, url)
|
||||
self.memory_banks[bank_id] = bank
|
||||
self.documents[bank_id] = {}
|
||||
return bank
|
||||
|
||||
async def list_memory_banks(self):
|
||||
return list(self.memory_banks.values())
|
||||
|
||||
async def get_memory_bank(self, bank_id):
|
||||
return self.memory_banks.get(bank_id)
|
||||
|
||||
async def drop_memory_bank(self, bank_id):
|
||||
if bank_id in self.memory_banks:
|
||||
del self.memory_banks[bank_id]
|
||||
del self.documents[bank_id]
|
||||
return bank_id
|
||||
|
||||
async def insert_documents(self, bank_id, documents, ttl_seconds=None):
|
||||
if bank_id not in self.documents:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
for doc in documents:
|
||||
self.documents[bank_id][doc.document_id] = doc
|
||||
|
||||
async def update_documents(self, bank_id, documents):
|
||||
if bank_id not in self.documents:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
for doc in documents:
|
||||
if doc.document_id in self.documents[bank_id]:
|
||||
self.documents[bank_id][doc.document_id] = doc
|
||||
|
||||
async def query_documents(self, bank_id, query, params=None):
|
||||
if bank_id not in self.documents:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
# Simple mock implementation: return all documents
|
||||
chunks = [
|
||||
{"content": doc.content, "token_count": 10, "document_id": doc.document_id}
|
||||
for doc in self.documents[bank_id].values()
|
||||
]
|
||||
scores = [1.0] * len(chunks)
|
||||
return {"chunks": chunks, "scores": scores}
|
||||
|
||||
async def get_documents(self, bank_id, document_ids):
|
||||
if bank_id not in self.documents:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
return [
|
||||
self.documents[bank_id][doc_id]
|
||||
for doc_id in document_ids
|
||||
if doc_id in self.documents[bank_id]
|
||||
]
|
||||
|
||||
async def delete_documents(self, bank_id, document_ids):
|
||||
if bank_id not in self.documents:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
for doc_id in document_ids:
|
||||
self.documents[bank_id].pop(doc_id, None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_inference_api():
|
||||
return MockInferenceAPI()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_safety_api():
|
||||
return MockSafetyAPI()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_memory_api():
|
||||
return MockMemoryAPI()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chat_agent(mock_inference_api, mock_safety_api, mock_memory_api):
|
||||
# You'll need to adjust this based on the actual ChatAgent constructor
|
||||
agent_config = AgentConfig(
|
||||
model="test_model",
|
||||
instructions="You are a helpful assistant.",
|
||||
sampling_params=SamplingParams(),
|
||||
tools=[],
|
||||
tool_choice=ToolChoice.auto,
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
)
|
||||
return ChatAgent(
|
||||
agent_config=agent_config,
|
||||
inference_api=mock_inference_api,
|
||||
memory_api=mock_memory_api,
|
||||
safety_api=mock_safety_api,
|
||||
builtin_tools=[],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_agent_create_session(chat_agent):
|
||||
session = chat_agent.create_session("Test Session")
|
||||
assert session.session_name == "Test Session"
|
||||
assert session.turns == []
|
||||
assert session.session_id in chat_agent.sessions
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_agent_create_and_execute_turn(chat_agent):
|
||||
session = chat_agent.create_session("Test Session")
|
||||
request = AgentTurnCreateRequest(
|
||||
agent_id="random",
|
||||
session_id=session.session_id,
|
||||
messages=[UserMessage(content="Hello")],
|
||||
)
|
||||
|
||||
responses = []
|
||||
async for response in chat_agent.create_and_execute_turn(request):
|
||||
responses.append(response)
|
||||
|
||||
print(responses)
|
||||
assert len(responses) > 0
|
||||
assert len(responses) == 4 # TurnStart, StepStart, StepComplete, TurnComplete
|
||||
assert responses[0].event.payload.turn_id is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_shields_wrapper(chat_agent):
|
||||
messages = [UserMessage(content="Test message")]
|
||||
shields = [ShieldDefinition(shield_type="test_shield")]
|
||||
|
||||
responses = [
|
||||
chunk
|
||||
async for chunk in chat_agent.run_shields_wrapper(
|
||||
turn_id="test_turn_id",
|
||||
messages=messages,
|
||||
shields=shields,
|
||||
touchpoint="user-input",
|
||||
)
|
||||
]
|
||||
|
||||
assert len(responses) == 2 # StepStart, StepComplete
|
||||
assert responses[0].event.payload.step_type.value == "shield_call"
|
||||
assert not responses[1].event.payload.step_details.response.is_violation
|
5
llama_stack/scripts/__init__.py
Normal file
5
llama_stack/scripts/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
16
llama_stack/scripts/run_tests.sh
Normal file
16
llama_stack/scripts/run_tests.sh
Normal file
|
@ -0,0 +1,16 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
THIS_DIR="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)"
|
||||
|
||||
set -euo pipefail
|
||||
set -x
|
||||
|
||||
stack_dir=$(dirname $THIS_DIR)
|
||||
models_dir=$(dirname $(dirname $stack_dir))/llama-models
|
||||
PYTHONPATH=$models_dir:$stack_dir pytest -p no:warnings --asyncio-mode auto --tb=short
|
Loading…
Add table
Add a link
Reference in a new issue