forked from phoenix-oss/llama-stack-mirror
Kill "remote" providers and fix testing with a remote stack properly (#435)
# What does this PR do? This PR kills the notion of "pure passthrough" remote providers. You cannot specify a single provider you must specify a whole distribution (stack) as remote. This PR also significantly fixes / upgrades testing infrastructure so you can now test against a remotely hosted stack server by just doing ```bash pytest -s -v -m remote test_agents.py \ --inference-model=Llama3.1-8B-Instruct --safety-shield=Llama-Guard-3-1B \ --env REMOTE_STACK_URL=http://localhost:5001 ``` Also fixed `test_agents_persistence.py` (which was broken) and killed some deprecated testing functions. ## Test Plan All the tests.
This commit is contained in:
parent
59a65e34d3
commit
12947ac19e
28 changed files with 406 additions and 519 deletions
|
@ -10,7 +10,7 @@ from ..conftest import get_provider_fixture_overrides
|
|||
|
||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||
from ..memory.fixtures import MEMORY_FIXTURES
|
||||
from ..safety.fixtures import SAFETY_FIXTURES
|
||||
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
|
||||
from .fixtures import AGENTS_FIXTURES
|
||||
|
||||
|
||||
|
@ -46,6 +46,16 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
id="together",
|
||||
marks=pytest.mark.together,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "fireworks",
|
||||
"safety": "llama_guard",
|
||||
"memory": "faiss",
|
||||
"agents": "meta_reference",
|
||||
},
|
||||
id="fireworks",
|
||||
marks=pytest.mark.fireworks,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "remote",
|
||||
|
@ -60,7 +70,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
|
||||
|
||||
def pytest_configure(config):
|
||||
for mark in ["meta_reference", "ollama", "together", "remote"]:
|
||||
for mark in ["meta_reference", "ollama", "together", "fireworks", "remote"]:
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
f"{mark}: marks tests as {mark} specific",
|
||||
|
@ -75,28 +85,30 @@ def pytest_addoption(parser):
|
|||
help="Specify the inference model to use for testing",
|
||||
)
|
||||
parser.addoption(
|
||||
"--safety-model",
|
||||
"--safety-shield",
|
||||
action="store",
|
||||
default="Llama-Guard-3-8B",
|
||||
help="Specify the safety model to use for testing",
|
||||
help="Specify the safety shield to use for testing",
|
||||
)
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
safety_model = metafunc.config.getoption("--safety-model")
|
||||
if "safety_model" in metafunc.fixturenames:
|
||||
shield_id = metafunc.config.getoption("--safety-shield")
|
||||
if "safety_shield" in metafunc.fixturenames:
|
||||
metafunc.parametrize(
|
||||
"safety_model",
|
||||
[pytest.param(safety_model, id="")],
|
||||
"safety_shield",
|
||||
[pytest.param(shield_id, id="")],
|
||||
indirect=True,
|
||||
)
|
||||
if "inference_model" in metafunc.fixturenames:
|
||||
inference_model = metafunc.config.getoption("--inference-model")
|
||||
models = list(set({inference_model, safety_model}))
|
||||
models = set({inference_model})
|
||||
if safety_model := safety_model_from_shield(shield_id):
|
||||
models.add(safety_model)
|
||||
|
||||
metafunc.parametrize(
|
||||
"inference_model",
|
||||
[pytest.param(models, id="")],
|
||||
[pytest.param(list(models), id="")],
|
||||
indirect=True,
|
||||
)
|
||||
if "agents_stack" in metafunc.fixturenames:
|
||||
|
|
|
@ -16,10 +16,9 @@ from llama_stack.providers.inline.agents.meta_reference import (
|
|||
MetaReferenceAgentsImplConfig,
|
||||
)
|
||||
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
from ..safety.fixtures import get_shield_to_register
|
||||
|
||||
|
||||
def pick_inference_model(inference_model):
|
||||
|
@ -60,7 +59,7 @@ AGENTS_FIXTURES = ["meta_reference", "remote"]
|
|||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def agents_stack(request, inference_model, safety_model):
|
||||
async def agents_stack(request, inference_model, safety_shield):
|
||||
fixture_dict = request.param
|
||||
|
||||
providers = {}
|
||||
|
@ -71,13 +70,10 @@ async def agents_stack(request, inference_model, safety_model):
|
|||
if fixture.provider_data:
|
||||
provider_data.update(fixture.provider_data)
|
||||
|
||||
shield_input = get_shield_to_register(
|
||||
providers["safety"][0].provider_type, safety_model
|
||||
)
|
||||
inference_models = (
|
||||
inference_model if isinstance(inference_model, list) else [inference_model]
|
||||
)
|
||||
impls = await resolve_impls_for_test_v2(
|
||||
test_stack = await construct_stack_for_test(
|
||||
[Api.agents, Api.inference, Api.safety, Api.memory],
|
||||
providers,
|
||||
provider_data,
|
||||
|
@ -87,6 +83,6 @@ async def agents_stack(request, inference_model, safety_model):
|
|||
)
|
||||
for model in inference_models
|
||||
],
|
||||
shields=[shield_input],
|
||||
shields=[safety_shield],
|
||||
)
|
||||
return impls[Api.agents], impls[Api.memory]
|
||||
return test_stack
|
||||
|
|
|
@ -1,148 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
||||
from llama_stack.providers.datatypes import * # noqa: F403
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# 1. Ensure you have a conda environment with the right dependencies installed.
|
||||
# This includes `pytest` and `pytest-asyncio`.
|
||||
#
|
||||
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
|
||||
#
|
||||
# 3. Run:
|
||||
#
|
||||
# ```bash
|
||||
# PROVIDER_ID=<your_provider> \
|
||||
# PROVIDER_CONFIG=provider_config.yaml \
|
||||
# pytest -s llama_stack/providers/tests/agents/test_agent_persistence.py \
|
||||
# --tb=short --disable-warnings
|
||||
# ```
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def agents_settings():
|
||||
impls = await resolve_impls_for_test(
|
||||
Api.agents, deps=[Api.inference, Api.memory, Api.safety]
|
||||
)
|
||||
|
||||
return {
|
||||
"impl": impls[Api.agents],
|
||||
"memory_impl": impls[Api.memory],
|
||||
"common_params": {
|
||||
"model": "Llama3.1-8B-Instruct",
|
||||
"instructions": "You are a helpful assistant.",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_messages():
|
||||
return [
|
||||
UserMessage(content="What's the weather like today?"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_agents_and_sessions(agents_settings, sample_messages):
|
||||
agents_impl = agents_settings["impl"]
|
||||
# First, create an agent
|
||||
agent_config = AgentConfig(
|
||||
model=agents_settings["common_params"]["model"],
|
||||
instructions=agents_settings["common_params"]["instructions"],
|
||||
enable_session_persistence=True,
|
||||
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
tools=[],
|
||||
max_infer_iters=5,
|
||||
)
|
||||
|
||||
create_response = await agents_impl.create_agent(agent_config)
|
||||
agent_id = create_response.agent_id
|
||||
|
||||
# Create a session
|
||||
session_create_response = await agents_impl.create_agent_session(
|
||||
agent_id, "Test Session"
|
||||
)
|
||||
session_id = session_create_response.session_id
|
||||
persistence_store = await kvstore_impl(agents_settings["persistence"])
|
||||
|
||||
await agents_impl.delete_agents_session(agent_id, session_id)
|
||||
session_response = await persistence_store.get(f"session:{agent_id}:{session_id}")
|
||||
|
||||
await agents_impl.delete_agents(agent_id)
|
||||
agent_response = await persistence_store.get(f"agent:{agent_id}")
|
||||
|
||||
assert session_response is None
|
||||
assert agent_response is None
|
||||
|
||||
|
||||
async def test_get_agent_turns_and_steps(agents_settings, sample_messages):
|
||||
agents_impl = agents_settings["impl"]
|
||||
|
||||
# First, create an agent
|
||||
agent_config = AgentConfig(
|
||||
model=agents_settings["common_params"]["model"],
|
||||
instructions=agents_settings["common_params"]["instructions"],
|
||||
enable_session_persistence=True,
|
||||
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
tools=[],
|
||||
max_infer_iters=5,
|
||||
)
|
||||
|
||||
create_response = await agents_impl.create_agent(agent_config)
|
||||
agent_id = create_response.agent_id
|
||||
|
||||
# Create a session
|
||||
session_create_response = await agents_impl.create_agent_session(
|
||||
agent_id, "Test Session"
|
||||
)
|
||||
session_id = session_create_response.session_id
|
||||
|
||||
# Create and execute a turn
|
||||
turn_request = dict(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=sample_messages,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
turn_response = [
|
||||
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||
]
|
||||
|
||||
final_event = turn_response[-1].event.payload
|
||||
turn_id = final_event.turn.turn_id
|
||||
persistence_store = await kvstore_impl(SqliteKVStoreConfig())
|
||||
turn = await persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
||||
response = await agents_impl.get_agents_turn(agent_id, session_id, turn_id)
|
||||
|
||||
assert isinstance(response, Turn)
|
||||
assert response == final_event.turn
|
||||
assert turn == final_event.turn
|
||||
|
||||
steps = final_event.turn.steps
|
||||
step_id = steps[0].step_id
|
||||
step_response = await agents_impl.get_agents_step(
|
||||
agent_id, session_id, turn_id, step_id
|
||||
)
|
||||
|
||||
assert isinstance(step_response.step, Step)
|
||||
assert step_response.step == steps[0]
|
|
@ -17,6 +17,7 @@ from llama_stack.providers.datatypes import * # noqa: F403
|
|||
# -m "meta_reference"
|
||||
|
||||
from .fixtures import pick_inference_model
|
||||
from .utils import create_agent_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -67,31 +68,19 @@ def query_attachment_messages():
|
|||
]
|
||||
|
||||
|
||||
async def create_agent_session(agents_impl, agent_config):
|
||||
create_response = await agents_impl.create_agent(agent_config)
|
||||
agent_id = create_response.agent_id
|
||||
|
||||
# Create a session
|
||||
session_create_response = await agents_impl.create_agent_session(
|
||||
agent_id, "Test Session"
|
||||
)
|
||||
session_id = session_create_response.session_id
|
||||
return agent_id, session_id
|
||||
|
||||
|
||||
class TestAgents:
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_turns_with_safety(
|
||||
self, safety_model, agents_stack, common_params
|
||||
self, safety_shield, agents_stack, common_params
|
||||
):
|
||||
agents_impl, _ = agents_stack
|
||||
agents_impl = agents_stack.impls[Api.agents]
|
||||
agent_id, session_id = await create_agent_session(
|
||||
agents_impl,
|
||||
AgentConfig(
|
||||
**{
|
||||
**common_params,
|
||||
"input_shields": [safety_model],
|
||||
"output_shields": [safety_model],
|
||||
"input_shields": [safety_shield.shield_id],
|
||||
"output_shields": [safety_shield.shield_id],
|
||||
}
|
||||
),
|
||||
)
|
||||
|
@ -127,7 +116,7 @@ class TestAgents:
|
|||
async def test_create_agent_turn(
|
||||
self, agents_stack, sample_messages, common_params
|
||||
):
|
||||
agents_impl, _ = agents_stack
|
||||
agents_impl = agents_stack.impls[Api.agents]
|
||||
|
||||
agent_id, session_id = await create_agent_session(
|
||||
agents_impl, AgentConfig(**common_params)
|
||||
|
@ -158,7 +147,7 @@ class TestAgents:
|
|||
query_attachment_messages,
|
||||
common_params,
|
||||
):
|
||||
agents_impl, _ = agents_stack
|
||||
agents_impl = agents_stack.impls[Api.agents]
|
||||
urls = [
|
||||
"memory_optimizations.rst",
|
||||
"chat.rst",
|
||||
|
@ -226,7 +215,7 @@ class TestAgents:
|
|||
async def test_create_agent_turn_with_brave_search(
|
||||
self, agents_stack, search_query_messages, common_params
|
||||
):
|
||||
agents_impl, _ = agents_stack
|
||||
agents_impl = agents_stack.impls[Api.agents]
|
||||
|
||||
if "BRAVE_SEARCH_API_KEY" not in os.environ:
|
||||
pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test")
|
||||
|
|
122
llama_stack/providers/tests/agents/test_persistence.py
Normal file
122
llama_stack/providers/tests/agents/test_persistence.py
Normal file
|
@ -0,0 +1,122 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
|
||||
from .fixtures import pick_inference_model
|
||||
|
||||
from .utils import create_agent_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_messages():
|
||||
return [
|
||||
UserMessage(content="What's the weather like today?"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def common_params(inference_model):
|
||||
inference_model = pick_inference_model(inference_model)
|
||||
|
||||
return dict(
|
||||
model=inference_model,
|
||||
instructions="You are a helpful assistant.",
|
||||
enable_session_persistence=True,
|
||||
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
tools=[],
|
||||
max_infer_iters=5,
|
||||
)
|
||||
|
||||
|
||||
class TestAgentPersistence:
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_agents_and_sessions(self, agents_stack, common_params):
|
||||
agents_impl = agents_stack.impls[Api.agents]
|
||||
agent_id, session_id = await create_agent_session(
|
||||
agents_impl,
|
||||
AgentConfig(
|
||||
**{
|
||||
**common_params,
|
||||
"input_shields": [],
|
||||
"output_shields": [],
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
run_config = agents_stack.run_config
|
||||
provider_config = run_config.providers["agents"][0].config
|
||||
persistence_store = await kvstore_impl(
|
||||
SqliteKVStoreConfig(**provider_config["persistence_store"])
|
||||
)
|
||||
|
||||
await agents_impl.delete_agents_session(agent_id, session_id)
|
||||
session_response = await persistence_store.get(
|
||||
f"session:{agent_id}:{session_id}"
|
||||
)
|
||||
|
||||
await agents_impl.delete_agents(agent_id)
|
||||
agent_response = await persistence_store.get(f"agent:{agent_id}")
|
||||
|
||||
assert session_response is None
|
||||
assert agent_response is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_agent_turns_and_steps(
|
||||
self, agents_stack, sample_messages, common_params
|
||||
):
|
||||
agents_impl = agents_stack.impls[Api.agents]
|
||||
|
||||
agent_id, session_id = await create_agent_session(
|
||||
agents_impl,
|
||||
AgentConfig(
|
||||
**{
|
||||
**common_params,
|
||||
"input_shields": [],
|
||||
"output_shields": [],
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
# Create and execute a turn
|
||||
turn_request = dict(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=sample_messages,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
turn_response = [
|
||||
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||
]
|
||||
|
||||
final_event = turn_response[-1].event.payload
|
||||
turn_id = final_event.turn.turn_id
|
||||
|
||||
provider_config = agents_stack.run_config.providers["agents"][0].config
|
||||
persistence_store = await kvstore_impl(
|
||||
SqliteKVStoreConfig(**provider_config["persistence_store"])
|
||||
)
|
||||
turn = await persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
||||
response = await agents_impl.get_agents_turn(agent_id, session_id, turn_id)
|
||||
|
||||
assert isinstance(response, Turn)
|
||||
assert response == final_event.turn
|
||||
assert turn == final_event.turn.model_dump_json()
|
||||
|
||||
steps = final_event.turn.steps
|
||||
step_id = steps[0].step_id
|
||||
step_response = await agents_impl.get_agents_step(
|
||||
agent_id, session_id, turn_id, step_id
|
||||
)
|
||||
|
||||
assert step_response.step == steps[0]
|
17
llama_stack/providers/tests/agents/utils.py
Normal file
17
llama_stack/providers/tests/agents/utils.py
Normal file
|
@ -0,0 +1,17 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
async def create_agent_session(agents_impl, agent_config):
|
||||
create_response = await agents_impl.create_agent(agent_config)
|
||||
agent_id = create_response.agent_id
|
||||
|
||||
# Create a session
|
||||
session_create_response = await agents_impl.create_agent_session(
|
||||
agent_id, "Test Session"
|
||||
)
|
||||
session_id = session_create_response.session_id
|
||||
return agent_id, session_id
|
Loading…
Add table
Add a link
Reference in a new issue