mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
Add a simple agents test case
This commit is contained in:
parent
2d94ca71a9
commit
ef4b74c935
6 changed files with 128 additions and 22 deletions
|
@ -21,6 +21,7 @@ async def get_provider_impl(
|
||||||
deps[Api.inference],
|
deps[Api.inference],
|
||||||
deps[Api.memory],
|
deps[Api.memory],
|
||||||
deps[Api.safety],
|
deps[Api.safety],
|
||||||
|
deps[Api.memory_banks],
|
||||||
)
|
)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -137,3 +137,25 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
agent = await self.get_agent(request.agent_id)
|
agent = await self.get_agent(request.agent_id)
|
||||||
async for event in agent.create_and_execute_turn(request):
|
async for event in agent.create_and_execute_turn(request):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
|
async def get_agents_turn(self, agent_id: str, turn_id: str) -> Turn:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
async def get_agents_step(
|
||||||
|
self, agent_id: str, turn_id: str, step_id: str
|
||||||
|
) -> AgentStepResponse:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
async def get_agents_session(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
session_id: str,
|
||||||
|
turn_ids: Optional[List[str]] = None,
|
||||||
|
) -> Session:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
async def delete_agents(self, agent_id: str) -> None:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
|
@ -0,0 +1,34 @@
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- provider_id: together
|
||||||
|
provider_type: remote::together
|
||||||
|
config: {}
|
||||||
|
- provider_id: tgi
|
||||||
|
provider_type: remote::tgi
|
||||||
|
config:
|
||||||
|
url: http://127.0.0.1:7001
|
||||||
|
# - provider_id: meta-reference
|
||||||
|
# provider_type: meta-reference
|
||||||
|
# config:
|
||||||
|
# model: Llama-Guard-3-1B
|
||||||
|
# - provider_id: remote
|
||||||
|
# provider_type: remote
|
||||||
|
# config:
|
||||||
|
# host: localhost
|
||||||
|
# port: 7010
|
||||||
|
safety:
|
||||||
|
- provider_id: together
|
||||||
|
provider_type: remote::together
|
||||||
|
config: {}
|
||||||
|
memory:
|
||||||
|
- provider_id: faiss
|
||||||
|
provider_type: meta-reference
|
||||||
|
config: {}
|
||||||
|
agents:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: meta-reference
|
||||||
|
config:
|
||||||
|
persistence_store:
|
||||||
|
namespace: null
|
||||||
|
type: sqlite
|
||||||
|
db_path: /Users/ashwin/.llama/runtime/kvstore.db
|
|
@ -7,17 +7,13 @@
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
|
||||||
from llama_stack.apis.agents import * # noqa: F403
|
from llama_stack.apis.agents import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
|
||||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
||||||
|
|
||||||
# How to run this test:
|
# How to run this test:
|
||||||
#
|
#
|
||||||
# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky
|
# 1. Ensure you have a conda environment with the right dependencies installed.
|
||||||
# since it depends on the provider you are testing. On top of that you need
|
# This includes `pytest` and `pytest-asyncio`.
|
||||||
# `pytest` and `pytest-asyncio` installed.
|
|
||||||
#
|
#
|
||||||
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
|
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
|
||||||
#
|
#
|
||||||
|
@ -38,22 +34,76 @@ async def agents_settings():
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"impl": impls[Api.safety],
|
"impl": impls[Api.agents],
|
||||||
"memory_impl": impls[Api.memory],
|
"memory_impl": impls[Api.memory],
|
||||||
"inference_impl": impls[Api.inference],
|
"common_params": {
|
||||||
"safety_impl": impls[Api.safety],
|
"model": "Llama3.1-8B-Instruct",
|
||||||
|
"instructions": "You are a helpful assistant.",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_tool_definition():
|
def sample_messages():
|
||||||
return ToolDefinition(
|
return [
|
||||||
tool_name="get_weather",
|
UserMessage(content="What's the weather like today?"),
|
||||||
description="Get the current weather",
|
]
|
||||||
parameters={
|
|
||||||
"location": ToolParamDefinition(
|
|
||||||
param_type="string",
|
@pytest.mark.asyncio
|
||||||
description="The city and state, e.g. San Francisco, CA",
|
async def test_create_agent_turn(agents_settings, sample_messages):
|
||||||
),
|
agents_impl = agents_settings["impl"]
|
||||||
},
|
|
||||||
|
# First, create an agent
|
||||||
|
agent_config = AgentConfig(
|
||||||
|
model=agents_settings["common_params"]["model"],
|
||||||
|
instructions=agents_settings["common_params"]["instructions"],
|
||||||
|
enable_session_persistence=True,
|
||||||
|
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
||||||
|
input_shields=[],
|
||||||
|
output_shields=[],
|
||||||
|
tools=[],
|
||||||
|
max_infer_iters=5,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
create_response = await agents_impl.create_agent(agent_config)
|
||||||
|
agent_id = create_response.agent_id
|
||||||
|
|
||||||
|
# Create a session
|
||||||
|
session_create_response = await agents_impl.create_agent_session(
|
||||||
|
agent_id, "Test Session"
|
||||||
|
)
|
||||||
|
session_id = session_create_response.session_id
|
||||||
|
|
||||||
|
# Create and execute a turn
|
||||||
|
turn_request = dict(
|
||||||
|
agent_id=agent_id,
|
||||||
|
session_id=session_id,
|
||||||
|
messages=sample_messages,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
turn_response = [
|
||||||
|
chunk async for chunk in agents_impl.create_agent_turn(**turn_request)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(turn_response) > 0
|
||||||
|
assert all(
|
||||||
|
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for expected event types
|
||||||
|
event_types = [chunk.event.payload.event_type for chunk in turn_response]
|
||||||
|
assert AgentTurnResponseEventType.turn_start.value in event_types
|
||||||
|
assert AgentTurnResponseEventType.step_start.value in event_types
|
||||||
|
assert AgentTurnResponseEventType.step_complete.value in event_types
|
||||||
|
assert AgentTurnResponseEventType.turn_complete.value in event_types
|
||||||
|
|
||||||
|
# Check the final turn complete event
|
||||||
|
final_event = turn_response[-1].event.payload
|
||||||
|
assert isinstance(final_event, AgentTurnResponseTurnCompletePayload)
|
||||||
|
assert isinstance(final_event.turn, Turn)
|
||||||
|
assert final_event.turn.session_id == session_id
|
||||||
|
assert final_event.turn.input_messages == sample_messages
|
||||||
|
assert isinstance(final_event.turn.output_message, CompletionMessage)
|
||||||
|
assert len(final_event.turn.output_message.content) > 0
|
||||||
|
|
|
@ -21,5 +21,4 @@ providers:
|
||||||
# this is a place to provide such data.
|
# this is a place to provide such data.
|
||||||
provider_data:
|
provider_data:
|
||||||
"test-together":
|
"test-together":
|
||||||
together_api_key:
|
together_api_key: 0xdeadbeefputrealapikeyhere
|
||||||
0xdeadbeefputrealapikeyhere
|
|
||||||
|
|
|
@ -56,7 +56,7 @@ def get_expected_stop_reason(model: str):
|
||||||
scope="session",
|
scope="session",
|
||||||
params=[
|
params=[
|
||||||
{"model": Llama_8B},
|
{"model": Llama_8B},
|
||||||
# {"model": Llama_3B},
|
{"model": Llama_3B},
|
||||||
],
|
],
|
||||||
ids=lambda d: d["model"],
|
ids=lambda d: d["model"],
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue