Add a simple agents test case

This commit is contained in:
Ashwin Bharambe 2024-10-09 21:52:49 -07:00
parent 2d94ca71a9
commit ef4b74c935
6 changed files with 128 additions and 22 deletions

View file

@ -21,6 +21,7 @@ async def get_provider_impl(
deps[Api.inference], deps[Api.inference],
deps[Api.memory], deps[Api.memory],
deps[Api.safety], deps[Api.safety],
deps[Api.memory_banks],
) )
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -137,3 +137,25 @@ class MetaReferenceAgentsImpl(Agents):
agent = await self.get_agent(request.agent_id) agent = await self.get_agent(request.agent_id)
async for event in agent.create_and_execute_turn(request): async for event in agent.create_and_execute_turn(request):
yield event yield event
async def get_agents_turn(self, agent_id: str, turn_id: str) -> Turn:
raise NotImplementedError()
async def get_agents_step(
self, agent_id: str, turn_id: str, step_id: str
) -> AgentStepResponse:
raise NotImplementedError()
async def get_agents_session(
self,
agent_id: str,
session_id: str,
turn_ids: Optional[List[str]] = None,
) -> Session:
raise NotImplementedError()
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
raise NotImplementedError()
async def delete_agents(self, agent_id: str) -> None:
raise NotImplementedError()

View file

@ -0,0 +1,34 @@
providers:
inference:
- provider_id: together
provider_type: remote::together
config: {}
- provider_id: tgi
provider_type: remote::tgi
config:
url: http://127.0.0.1:7001
# - provider_id: meta-reference
# provider_type: meta-reference
# config:
# model: Llama-Guard-3-1B
# - provider_id: remote
# provider_type: remote
# config:
# host: localhost
# port: 7010
safety:
- provider_id: together
provider_type: remote::together
config: {}
memory:
- provider_id: faiss
provider_type: meta-reference
config: {}
agents:
- provider_id: meta-reference
provider_type: meta-reference
config:
persistence_store:
namespace: null
type: sqlite
db_path: /Users/ashwin/.llama/runtime/kvstore.db

View file

@ -7,17 +7,13 @@
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.agents import * # noqa: F403 from llama_stack.apis.agents import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.providers.tests.resolver import resolve_impls_for_test from llama_stack.providers.tests.resolver import resolve_impls_for_test
# How to run this test: # How to run this test:
# #
# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky # 1. Ensure you have a conda environment with the right dependencies installed.
# since it depends on the provider you are testing. On top of that you need # This includes `pytest` and `pytest-asyncio`.
# `pytest` and `pytest-asyncio` installed.
# #
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing. # 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
# #
@ -38,22 +34,76 @@ async def agents_settings():
) )
return { return {
"impl": impls[Api.safety], "impl": impls[Api.agents],
"memory_impl": impls[Api.memory], "memory_impl": impls[Api.memory],
"inference_impl": impls[Api.inference], "common_params": {
"safety_impl": impls[Api.safety], "model": "Llama3.1-8B-Instruct",
"instructions": "You are a helpful assistant.",
},
} }
@pytest.fixture @pytest.fixture
def sample_tool_definition(): def sample_messages():
return ToolDefinition( return [
tool_name="get_weather", UserMessage(content="What's the weather like today?"),
description="Get the current weather", ]
parameters={
"location": ToolParamDefinition(
param_type="string", @pytest.mark.asyncio
description="The city and state, e.g. San Francisco, CA", async def test_create_agent_turn(agents_settings, sample_messages):
), agents_impl = agents_settings["impl"]
},
# First, create an agent
agent_config = AgentConfig(
model=agents_settings["common_params"]["model"],
instructions=agents_settings["common_params"]["instructions"],
enable_session_persistence=True,
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
input_shields=[],
output_shields=[],
tools=[],
max_infer_iters=5,
) )
create_response = await agents_impl.create_agent(agent_config)
agent_id = create_response.agent_id
# Create a session
session_create_response = await agents_impl.create_agent_session(
agent_id, "Test Session"
)
session_id = session_create_response.session_id
# Create and execute a turn
turn_request = dict(
agent_id=agent_id,
session_id=session_id,
messages=sample_messages,
stream=True,
)
turn_response = [
chunk async for chunk in agents_impl.create_agent_turn(**turn_request)
]
assert len(turn_response) > 0
assert all(
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
)
# Check for expected event types
event_types = [chunk.event.payload.event_type for chunk in turn_response]
assert AgentTurnResponseEventType.turn_start.value in event_types
assert AgentTurnResponseEventType.step_start.value in event_types
assert AgentTurnResponseEventType.step_complete.value in event_types
assert AgentTurnResponseEventType.turn_complete.value in event_types
# Check the final turn complete event
final_event = turn_response[-1].event.payload
assert isinstance(final_event, AgentTurnResponseTurnCompletePayload)
assert isinstance(final_event.turn, Turn)
assert final_event.turn.session_id == session_id
assert final_event.turn.input_messages == sample_messages
assert isinstance(final_event.turn.output_message, CompletionMessage)
assert len(final_event.turn.output_message.content) > 0

View file

@ -21,5 +21,4 @@ providers:
# this is a place to provide such data. # this is a place to provide such data.
provider_data: provider_data:
"test-together": "test-together":
together_api_key: together_api_key: 0xdeadbeefputrealapikeyhere
0xdeadbeefputrealapikeyhere

View file

@ -56,7 +56,7 @@ def get_expected_stop_reason(model: str):
scope="session", scope="session",
params=[ params=[
{"model": Llama_8B}, {"model": Llama_8B},
# {"model": Llama_3B}, {"model": Llama_3B},
], ],
ids=lambda d: d["model"], ids=lambda d: d["model"],
) )