forked from phoenix-oss/llama-stack-mirror
Kill "remote" providers and fix testing with a remote stack properly (#435)
# What does this PR do? This PR kills the notion of "pure passthrough" remote providers. You cannot specify a single provider you must specify a whole distribution (stack) as remote. This PR also significantly fixes / upgrades testing infrastructure so you can now test against a remotely hosted stack server by just doing ```bash pytest -s -v -m remote test_agents.py \ --inference-model=Llama3.1-8B-Instruct --safety-shield=Llama-Guard-3-1B \ --env REMOTE_STACK_URL=http://localhost:5001 ``` Also fixed `test_agents_persistence.py` (which was broken) and killed some deprecated testing functions. ## Test Plan All the tests.
This commit is contained in:
parent
59a65e34d3
commit
12947ac19e
28 changed files with 406 additions and 519 deletions
|
@ -17,6 +17,7 @@ from llama_stack.providers.datatypes import * # noqa: F403
|
|||
# -m "meta_reference"
|
||||
|
||||
from .fixtures import pick_inference_model
|
||||
from .utils import create_agent_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -67,31 +68,19 @@ def query_attachment_messages():
|
|||
]
|
||||
|
||||
|
||||
async def create_agent_session(agents_impl, agent_config):
|
||||
create_response = await agents_impl.create_agent(agent_config)
|
||||
agent_id = create_response.agent_id
|
||||
|
||||
# Create a session
|
||||
session_create_response = await agents_impl.create_agent_session(
|
||||
agent_id, "Test Session"
|
||||
)
|
||||
session_id = session_create_response.session_id
|
||||
return agent_id, session_id
|
||||
|
||||
|
||||
class TestAgents:
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_turns_with_safety(
|
||||
self, safety_model, agents_stack, common_params
|
||||
self, safety_shield, agents_stack, common_params
|
||||
):
|
||||
agents_impl, _ = agents_stack
|
||||
agents_impl = agents_stack.impls[Api.agents]
|
||||
agent_id, session_id = await create_agent_session(
|
||||
agents_impl,
|
||||
AgentConfig(
|
||||
**{
|
||||
**common_params,
|
||||
"input_shields": [safety_model],
|
||||
"output_shields": [safety_model],
|
||||
"input_shields": [safety_shield.shield_id],
|
||||
"output_shields": [safety_shield.shield_id],
|
||||
}
|
||||
),
|
||||
)
|
||||
|
@ -127,7 +116,7 @@ class TestAgents:
|
|||
async def test_create_agent_turn(
|
||||
self, agents_stack, sample_messages, common_params
|
||||
):
|
||||
agents_impl, _ = agents_stack
|
||||
agents_impl = agents_stack.impls[Api.agents]
|
||||
|
||||
agent_id, session_id = await create_agent_session(
|
||||
agents_impl, AgentConfig(**common_params)
|
||||
|
@ -158,7 +147,7 @@ class TestAgents:
|
|||
query_attachment_messages,
|
||||
common_params,
|
||||
):
|
||||
agents_impl, _ = agents_stack
|
||||
agents_impl = agents_stack.impls[Api.agents]
|
||||
urls = [
|
||||
"memory_optimizations.rst",
|
||||
"chat.rst",
|
||||
|
@ -226,7 +215,7 @@ class TestAgents:
|
|||
async def test_create_agent_turn_with_brave_search(
|
||||
self, agents_stack, search_query_messages, common_params
|
||||
):
|
||||
agents_impl, _ = agents_stack
|
||||
agents_impl = agents_stack.impls[Api.agents]
|
||||
|
||||
if "BRAVE_SEARCH_API_KEY" not in os.environ:
|
||||
pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue