mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-16 10:02:38 +00:00
fixed agent persistence test, more cleanup
This commit is contained in:
parent
4f3b009980
commit
22aedd0277
14 changed files with 202 additions and 310 deletions
|
|
@ -78,10 +78,13 @@ class ProviderWithSpec(Provider):
|
||||||
spec: ProviderSpec
|
spec: ProviderSpec
|
||||||
|
|
||||||
|
|
||||||
|
ProviderRegistry = Dict[Api, Dict[str, ProviderSpec]]
|
||||||
|
|
||||||
|
|
||||||
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
|
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
|
||||||
async def resolve_impls(
|
async def resolve_impls(
|
||||||
run_config: StackRunConfig,
|
run_config: StackRunConfig,
|
||||||
provider_registry: Dict[Api, Dict[str, ProviderSpec]],
|
provider_registry: ProviderRegistry,
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
) -> Dict[Api, Any]:
|
) -> Dict[Api, Any]:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ from llama_stack.apis.eval_tasks import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import StackRunConfig
|
from llama_stack.distribution.datatypes import StackRunConfig
|
||||||
from llama_stack.distribution.distribution import get_provider_registry
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
from llama_stack.distribution.resolver import resolve_impls
|
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
||||||
from llama_stack.distribution.store.registry import create_dist_registry
|
from llama_stack.distribution.store.registry import create_dist_registry
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
|
|
@ -94,10 +94,14 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
|
||||||
|
|
||||||
# Produces a stack of providers for the given run config. Not all APIs may be
|
# Produces a stack of providers for the given run config. Not all APIs may be
|
||||||
# asked for in the run config.
|
# asked for in the run config.
|
||||||
async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
|
async def construct_stack(
|
||||||
|
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
|
||||||
|
) -> Dict[Api, Any]:
|
||||||
dist_registry, _ = await create_dist_registry(
|
dist_registry, _ = await create_dist_registry(
|
||||||
run_config.metadata_store, run_config.image_name
|
run_config.metadata_store, run_config.image_name
|
||||||
)
|
)
|
||||||
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
|
impls = await resolve_impls(
|
||||||
|
run_config, provider_registry or get_provider_registry(), dist_registry
|
||||||
|
)
|
||||||
await register_resources(run_config, impls)
|
await register_resources(run_config, impls)
|
||||||
return impls
|
return impls
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ from llama_stack.providers.inline.agents.meta_reference import (
|
||||||
MetaReferenceAgentsImplConfig,
|
MetaReferenceAgentsImplConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||||
|
|
||||||
|
|
@ -73,7 +73,7 @@ async def agents_stack(request, inference_model, safety_shield):
|
||||||
inference_models = (
|
inference_models = (
|
||||||
inference_model if isinstance(inference_model, list) else [inference_model]
|
inference_model if isinstance(inference_model, list) else [inference_model]
|
||||||
)
|
)
|
||||||
impls = await resolve_impls_for_test_v2(
|
test_stack = await construct_stack_for_test(
|
||||||
[Api.agents, Api.inference, Api.safety, Api.memory],
|
[Api.agents, Api.inference, Api.safety, Api.memory],
|
||||||
providers,
|
providers,
|
||||||
provider_data,
|
provider_data,
|
||||||
|
|
@ -85,5 +85,4 @@ async def agents_stack(request, inference_model, safety_shield):
|
||||||
],
|
],
|
||||||
shields=[safety_shield],
|
shields=[safety_shield],
|
||||||
)
|
)
|
||||||
|
return test_stack
|
||||||
return impls[Api.agents], impls[Api.memory]
|
|
||||||
|
|
|
||||||
|
|
@ -1,148 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import pytest_asyncio
|
|
||||||
|
|
||||||
from llama_stack.apis.agents import * # noqa: F403
|
|
||||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
|
||||||
from llama_stack.providers.datatypes import * # noqa: F403
|
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
|
|
||||||
|
|
||||||
# How to run this test:
|
|
||||||
#
|
|
||||||
# 1. Ensure you have a conda environment with the right dependencies installed.
|
|
||||||
# This includes `pytest` and `pytest-asyncio`.
|
|
||||||
#
|
|
||||||
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
|
|
||||||
#
|
|
||||||
# 3. Run:
|
|
||||||
#
|
|
||||||
# ```bash
|
|
||||||
# PROVIDER_ID=<your_provider> \
|
|
||||||
# PROVIDER_CONFIG=provider_config.yaml \
|
|
||||||
# pytest -s llama_stack/providers/tests/agents/test_agent_persistence.py \
|
|
||||||
# --tb=short --disable-warnings
|
|
||||||
# ```
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
|
||||||
async def agents_settings():
|
|
||||||
impls = await resolve_impls_for_test(
|
|
||||||
Api.agents, deps=[Api.inference, Api.memory, Api.safety]
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"impl": impls[Api.agents],
|
|
||||||
"memory_impl": impls[Api.memory],
|
|
||||||
"common_params": {
|
|
||||||
"model": "Llama3.1-8B-Instruct",
|
|
||||||
"instructions": "You are a helpful assistant.",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_messages():
|
|
||||||
return [
|
|
||||||
UserMessage(content="What's the weather like today?"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_agents_and_sessions(agents_settings, sample_messages):
|
|
||||||
agents_impl = agents_settings["impl"]
|
|
||||||
# First, create an agent
|
|
||||||
agent_config = AgentConfig(
|
|
||||||
model=agents_settings["common_params"]["model"],
|
|
||||||
instructions=agents_settings["common_params"]["instructions"],
|
|
||||||
enable_session_persistence=True,
|
|
||||||
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
|
||||||
input_shields=[],
|
|
||||||
output_shields=[],
|
|
||||||
tools=[],
|
|
||||||
max_infer_iters=5,
|
|
||||||
)
|
|
||||||
|
|
||||||
create_response = await agents_impl.create_agent(agent_config)
|
|
||||||
agent_id = create_response.agent_id
|
|
||||||
|
|
||||||
# Create a session
|
|
||||||
session_create_response = await agents_impl.create_agent_session(
|
|
||||||
agent_id, "Test Session"
|
|
||||||
)
|
|
||||||
session_id = session_create_response.session_id
|
|
||||||
persistence_store = await kvstore_impl(agents_settings["persistence"])
|
|
||||||
|
|
||||||
await agents_impl.delete_agents_session(agent_id, session_id)
|
|
||||||
session_response = await persistence_store.get(f"session:{agent_id}:{session_id}")
|
|
||||||
|
|
||||||
await agents_impl.delete_agents(agent_id)
|
|
||||||
agent_response = await persistence_store.get(f"agent:{agent_id}")
|
|
||||||
|
|
||||||
assert session_response is None
|
|
||||||
assert agent_response is None
|
|
||||||
|
|
||||||
|
|
||||||
async def test_get_agent_turns_and_steps(agents_settings, sample_messages):
|
|
||||||
agents_impl = agents_settings["impl"]
|
|
||||||
|
|
||||||
# First, create an agent
|
|
||||||
agent_config = AgentConfig(
|
|
||||||
model=agents_settings["common_params"]["model"],
|
|
||||||
instructions=agents_settings["common_params"]["instructions"],
|
|
||||||
enable_session_persistence=True,
|
|
||||||
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
|
||||||
input_shields=[],
|
|
||||||
output_shields=[],
|
|
||||||
tools=[],
|
|
||||||
max_infer_iters=5,
|
|
||||||
)
|
|
||||||
|
|
||||||
create_response = await agents_impl.create_agent(agent_config)
|
|
||||||
agent_id = create_response.agent_id
|
|
||||||
|
|
||||||
# Create a session
|
|
||||||
session_create_response = await agents_impl.create_agent_session(
|
|
||||||
agent_id, "Test Session"
|
|
||||||
)
|
|
||||||
session_id = session_create_response.session_id
|
|
||||||
|
|
||||||
# Create and execute a turn
|
|
||||||
turn_request = dict(
|
|
||||||
agent_id=agent_id,
|
|
||||||
session_id=session_id,
|
|
||||||
messages=sample_messages,
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
turn_response = [
|
|
||||||
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
|
||||||
]
|
|
||||||
|
|
||||||
final_event = turn_response[-1].event.payload
|
|
||||||
turn_id = final_event.turn.turn_id
|
|
||||||
persistence_store = await kvstore_impl(SqliteKVStoreConfig())
|
|
||||||
turn = await persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
|
||||||
response = await agents_impl.get_agents_turn(agent_id, session_id, turn_id)
|
|
||||||
|
|
||||||
assert isinstance(response, Turn)
|
|
||||||
assert response == final_event.turn
|
|
||||||
assert turn == final_event.turn
|
|
||||||
|
|
||||||
steps = final_event.turn.steps
|
|
||||||
step_id = steps[0].step_id
|
|
||||||
step_response = await agents_impl.get_agents_step(
|
|
||||||
agent_id, session_id, turn_id, step_id
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(step_response.step, Step)
|
|
||||||
assert step_response.step == steps[0]
|
|
||||||
|
|
@ -17,6 +17,7 @@ from llama_stack.providers.datatypes import * # noqa: F403
|
||||||
# -m "meta_reference"
|
# -m "meta_reference"
|
||||||
|
|
||||||
from .fixtures import pick_inference_model
|
from .fixtures import pick_inference_model
|
||||||
|
from .utils import create_agent_session
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -67,24 +68,12 @@ def query_attachment_messages():
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
async def create_agent_session(agents_impl, agent_config):
|
|
||||||
create_response = await agents_impl.create_agent(agent_config)
|
|
||||||
agent_id = create_response.agent_id
|
|
||||||
|
|
||||||
# Create a session
|
|
||||||
session_create_response = await agents_impl.create_agent_session(
|
|
||||||
agent_id, "Test Session"
|
|
||||||
)
|
|
||||||
session_id = session_create_response.session_id
|
|
||||||
return agent_id, session_id
|
|
||||||
|
|
||||||
|
|
||||||
class TestAgents:
|
class TestAgents:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_agent_turns_with_safety(
|
async def test_agent_turns_with_safety(
|
||||||
self, safety_shield, agents_stack, common_params
|
self, safety_shield, agents_stack, common_params
|
||||||
):
|
):
|
||||||
agents_impl, _ = agents_stack
|
agents_impl = agents_stack.impls[Api.agents]
|
||||||
agent_id, session_id = await create_agent_session(
|
agent_id, session_id = await create_agent_session(
|
||||||
agents_impl,
|
agents_impl,
|
||||||
AgentConfig(
|
AgentConfig(
|
||||||
|
|
@ -127,7 +116,7 @@ class TestAgents:
|
||||||
async def test_create_agent_turn(
|
async def test_create_agent_turn(
|
||||||
self, agents_stack, sample_messages, common_params
|
self, agents_stack, sample_messages, common_params
|
||||||
):
|
):
|
||||||
agents_impl, _ = agents_stack
|
agents_impl = agents_stack.impls[Api.agents]
|
||||||
|
|
||||||
agent_id, session_id = await create_agent_session(
|
agent_id, session_id = await create_agent_session(
|
||||||
agents_impl, AgentConfig(**common_params)
|
agents_impl, AgentConfig(**common_params)
|
||||||
|
|
@ -158,7 +147,7 @@ class TestAgents:
|
||||||
query_attachment_messages,
|
query_attachment_messages,
|
||||||
common_params,
|
common_params,
|
||||||
):
|
):
|
||||||
agents_impl, _ = agents_stack
|
agents_impl = agents_stack.impls[Api.agents]
|
||||||
urls = [
|
urls = [
|
||||||
"memory_optimizations.rst",
|
"memory_optimizations.rst",
|
||||||
"chat.rst",
|
"chat.rst",
|
||||||
|
|
@ -226,7 +215,7 @@ class TestAgents:
|
||||||
async def test_create_agent_turn_with_brave_search(
|
async def test_create_agent_turn_with_brave_search(
|
||||||
self, agents_stack, search_query_messages, common_params
|
self, agents_stack, search_query_messages, common_params
|
||||||
):
|
):
|
||||||
agents_impl, _ = agents_stack
|
agents_impl = agents_stack.impls[Api.agents]
|
||||||
|
|
||||||
if "BRAVE_SEARCH_API_KEY" not in os.environ:
|
if "BRAVE_SEARCH_API_KEY" not in os.environ:
|
||||||
pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test")
|
pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test")
|
||||||
|
|
|
||||||
122
llama_stack/providers/tests/agents/test_persistence.py
Normal file
122
llama_stack/providers/tests/agents/test_persistence.py
Normal file
|
|
@ -0,0 +1,122 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.agents import * # noqa: F403
|
||||||
|
from llama_stack.providers.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
|
||||||
|
from .fixtures import pick_inference_model
|
||||||
|
|
||||||
|
from .utils import create_agent_session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_messages():
|
||||||
|
return [
|
||||||
|
UserMessage(content="What's the weather like today?"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def common_params(inference_model):
|
||||||
|
inference_model = pick_inference_model(inference_model)
|
||||||
|
|
||||||
|
return dict(
|
||||||
|
model=inference_model,
|
||||||
|
instructions="You are a helpful assistant.",
|
||||||
|
enable_session_persistence=True,
|
||||||
|
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
||||||
|
input_shields=[],
|
||||||
|
output_shields=[],
|
||||||
|
tools=[],
|
||||||
|
max_infer_iters=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentPersistence:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_agents_and_sessions(self, agents_stack, common_params):
|
||||||
|
agents_impl = agents_stack.impls[Api.agents]
|
||||||
|
agent_id, session_id = await create_agent_session(
|
||||||
|
agents_impl,
|
||||||
|
AgentConfig(
|
||||||
|
**{
|
||||||
|
**common_params,
|
||||||
|
"input_shields": [],
|
||||||
|
"output_shields": [],
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
run_config = agents_stack.run_config
|
||||||
|
provider_config = run_config.providers["agents"][0].config
|
||||||
|
persistence_store = await kvstore_impl(
|
||||||
|
SqliteKVStoreConfig(**provider_config["persistence_store"])
|
||||||
|
)
|
||||||
|
|
||||||
|
await agents_impl.delete_agents_session(agent_id, session_id)
|
||||||
|
session_response = await persistence_store.get(
|
||||||
|
f"session:{agent_id}:{session_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
await agents_impl.delete_agents(agent_id)
|
||||||
|
agent_response = await persistence_store.get(f"agent:{agent_id}")
|
||||||
|
|
||||||
|
assert session_response is None
|
||||||
|
assert agent_response is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_agent_turns_and_steps(
|
||||||
|
self, agents_stack, sample_messages, common_params
|
||||||
|
):
|
||||||
|
agents_impl = agents_stack.impls[Api.agents]
|
||||||
|
|
||||||
|
agent_id, session_id = await create_agent_session(
|
||||||
|
agents_impl,
|
||||||
|
AgentConfig(
|
||||||
|
**{
|
||||||
|
**common_params,
|
||||||
|
"input_shields": [],
|
||||||
|
"output_shields": [],
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create and execute a turn
|
||||||
|
turn_request = dict(
|
||||||
|
agent_id=agent_id,
|
||||||
|
session_id=session_id,
|
||||||
|
messages=sample_messages,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
turn_response = [
|
||||||
|
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||||
|
]
|
||||||
|
|
||||||
|
final_event = turn_response[-1].event.payload
|
||||||
|
turn_id = final_event.turn.turn_id
|
||||||
|
|
||||||
|
provider_config = agents_stack.run_config.providers["agents"][0].config
|
||||||
|
persistence_store = await kvstore_impl(
|
||||||
|
SqliteKVStoreConfig(**provider_config["persistence_store"])
|
||||||
|
)
|
||||||
|
turn = await persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
||||||
|
response = await agents_impl.get_agents_turn(agent_id, session_id, turn_id)
|
||||||
|
|
||||||
|
assert isinstance(response, Turn)
|
||||||
|
assert response == final_event.turn
|
||||||
|
assert turn == final_event.turn.model_dump_json()
|
||||||
|
|
||||||
|
steps = final_event.turn.steps
|
||||||
|
step_id = steps[0].step_id
|
||||||
|
step_response = await agents_impl.get_agents_step(
|
||||||
|
agent_id, session_id, turn_id, step_id
|
||||||
|
)
|
||||||
|
|
||||||
|
assert step_response.step == steps[0]
|
||||||
17
llama_stack/providers/tests/agents/utils.py
Normal file
17
llama_stack/providers/tests/agents/utils.py
Normal file
|
|
@ -0,0 +1,17 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
async def create_agent_session(agents_impl, agent_config):
|
||||||
|
create_response = await agents_impl.create_agent(agent_config)
|
||||||
|
agent_id = create_response.agent_id
|
||||||
|
|
||||||
|
# Create a session
|
||||||
|
session_create_response = await agents_impl.create_agent_session(
|
||||||
|
agent_id, "Test Session"
|
||||||
|
)
|
||||||
|
session_id = session_create_response.session_id
|
||||||
|
return agent_id, session_id
|
||||||
|
|
@ -9,7 +9,7 @@ import pytest_asyncio
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
|
|
||||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -52,10 +52,10 @@ async def datasetio_stack(request):
|
||||||
fixture_name = request.param
|
fixture_name = request.param
|
||||||
fixture = request.getfixturevalue(f"datasetio_{fixture_name}")
|
fixture = request.getfixturevalue(f"datasetio_{fixture_name}")
|
||||||
|
|
||||||
impls = await resolve_impls_for_test_v2(
|
test_stack = await construct_stack_for_test(
|
||||||
[Api.datasetio],
|
[Api.datasetio],
|
||||||
{"datasetio": fixture.providers},
|
{"datasetio": fixture.providers},
|
||||||
fixture.provider_data,
|
fixture.provider_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
return impls[Api.datasetio], impls[Api.datasets]
|
return test_stack.impls[Api.datasetio], test_stack.impls[Api.datasets]
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ import pytest_asyncio
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
|
|
||||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -46,10 +46,10 @@ async def eval_stack(request):
|
||||||
if fixture.provider_data:
|
if fixture.provider_data:
|
||||||
provider_data.update(fixture.provider_data)
|
provider_data.update(fixture.provider_data)
|
||||||
|
|
||||||
impls = await resolve_impls_for_test_v2(
|
test_stack = await construct_stack_for_test(
|
||||||
[Api.eval, Api.datasetio, Api.inference, Api.scoring],
|
[Api.eval, Api.datasetio, Api.inference, Api.scoring],
|
||||||
providers,
|
providers,
|
||||||
provider_data,
|
provider_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
return impls
|
return test_stack.impls
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
|
||||||
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||||
from llama_stack.providers.remote.inference.together import TogetherImplConfig
|
from llama_stack.providers.remote.inference.together import TogetherImplConfig
|
||||||
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
|
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
|
||||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||||
|
|
||||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||||
from ..env import get_env_or_fail
|
from ..env import get_env_or_fail
|
||||||
|
|
@ -182,11 +182,11 @@ INFERENCE_FIXTURES = [
|
||||||
async def inference_stack(request, inference_model):
|
async def inference_stack(request, inference_model):
|
||||||
fixture_name = request.param
|
fixture_name = request.param
|
||||||
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
|
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
|
||||||
impls = await resolve_impls_for_test_v2(
|
test_stack = await construct_stack_for_test(
|
||||||
[Api.inference],
|
[Api.inference],
|
||||||
{"inference": inference_fixture.providers},
|
{"inference": inference_fixture.providers},
|
||||||
inference_fixture.provider_data,
|
inference_fixture.provider_data,
|
||||||
models=[ModelInput(model_id=inference_model)],
|
models=[ModelInput(model_id=inference_model)],
|
||||||
)
|
)
|
||||||
|
|
||||||
return (impls[Api.inference], impls[Api.models])
|
return test_stack.impls[Api.inference], test_stack.impls[Api.models]
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConf
|
||||||
from llama_stack.providers.inline.memory.faiss import FaissImplConfig
|
from llama_stack.providers.inline.memory.faiss import FaissImplConfig
|
||||||
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
|
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
|
||||||
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
|
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
|
||||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||||
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
|
||||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||||
from ..env import get_env_or_fail
|
from ..env import get_env_or_fail
|
||||||
|
|
@ -101,10 +101,10 @@ async def memory_stack(request):
|
||||||
fixture_name = request.param
|
fixture_name = request.param
|
||||||
fixture = request.getfixturevalue(f"memory_{fixture_name}")
|
fixture = request.getfixturevalue(f"memory_{fixture_name}")
|
||||||
|
|
||||||
impls = await resolve_impls_for_test_v2(
|
test_stack = await construct_stack_for_test(
|
||||||
[Api.memory],
|
[Api.memory],
|
||||||
{"memory": fixture.providers},
|
{"memory": fixture.providers},
|
||||||
fixture.provider_data,
|
fixture.provider_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
return impls[Api.memory], impls[Api.memory_banks]
|
return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks]
|
||||||
|
|
|
||||||
|
|
@ -5,38 +5,26 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
import tempfile
|
import tempfile
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
from llama_stack.distribution.build import print_pip_install_help
|
from llama_stack.distribution.build import print_pip_install_help
|
||||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||||
from llama_stack.distribution.distribution import get_provider_registry
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||||
from llama_stack.distribution.resolver import resolve_impls, resolve_remote_stack_impls
|
from llama_stack.distribution.resolver import resolve_remote_stack_impls
|
||||||
from llama_stack.distribution.stack import construct_stack
|
from llama_stack.distribution.stack import construct_stack
|
||||||
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
|
||||||
|
|
||||||
|
|
||||||
async def construct_stack_for_test(run_config: StackRunConfig):
|
class TestStack(BaseModel):
|
||||||
remote_config = remote_provider_config(run_config)
|
impls: Dict[Api, Any]
|
||||||
if not remote_config:
|
run_config: StackRunConfig
|
||||||
return await construct_stack(run_config)
|
|
||||||
|
|
||||||
impls = await resolve_remote_stack_impls(remote_config, run_config.apis)
|
|
||||||
|
|
||||||
# we don't register resources for a remote stack as part of the fixture setup
|
|
||||||
# because the stack is already "up". if a test needs to register resources, it
|
|
||||||
# can do so manually always.
|
|
||||||
|
|
||||||
return impls
|
|
||||||
|
|
||||||
|
|
||||||
async def resolve_impls_for_test_v2(
|
async def construct_stack_for_test(
|
||||||
apis: List[Api],
|
apis: List[Api],
|
||||||
providers: Dict[str, List[Provider]],
|
providers: Dict[str, List[Provider]],
|
||||||
provider_data: Optional[Dict[str, Any]] = None,
|
provider_data: Optional[Dict[str, Any]] = None,
|
||||||
|
|
@ -46,7 +34,7 @@ async def resolve_impls_for_test_v2(
|
||||||
datasets: Optional[List[DatasetInput]] = None,
|
datasets: Optional[List[DatasetInput]] = None,
|
||||||
scoring_fns: Optional[List[ScoringFnInput]] = None,
|
scoring_fns: Optional[List[ScoringFnInput]] = None,
|
||||||
eval_tasks: Optional[List[EvalTaskInput]] = None,
|
eval_tasks: Optional[List[EvalTaskInput]] = None,
|
||||||
):
|
) -> TestStack:
|
||||||
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||||
run_config = dict(
|
run_config = dict(
|
||||||
built_at=datetime.now(),
|
built_at=datetime.now(),
|
||||||
|
|
@ -63,7 +51,18 @@ async def resolve_impls_for_test_v2(
|
||||||
)
|
)
|
||||||
run_config = parse_and_maybe_upgrade_config(run_config)
|
run_config = parse_and_maybe_upgrade_config(run_config)
|
||||||
try:
|
try:
|
||||||
impls = await construct_stack_for_test(run_config)
|
remote_config = remote_provider_config(run_config)
|
||||||
|
if not remote_config:
|
||||||
|
# TODO: add to provider registry by creating interesting mocks or fakes
|
||||||
|
impls = await construct_stack(run_config, get_provider_registry())
|
||||||
|
else:
|
||||||
|
# we don't register resources for a remote stack as part of the fixture setup
|
||||||
|
# because the stack is already "up". if a test needs to register resources, it
|
||||||
|
# can do so manually always.
|
||||||
|
|
||||||
|
impls = await resolve_remote_stack_impls(remote_config, run_config.apis)
|
||||||
|
|
||||||
|
test_stack = TestStack(impls=impls, run_config=run_config)
|
||||||
except ModuleNotFoundError as e:
|
except ModuleNotFoundError as e:
|
||||||
print_pip_install_help(providers)
|
print_pip_install_help(providers)
|
||||||
raise e
|
raise e
|
||||||
|
|
@ -73,7 +72,7 @@ async def resolve_impls_for_test_v2(
|
||||||
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
|
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
|
||||||
)
|
)
|
||||||
|
|
||||||
return impls
|
return test_stack
|
||||||
|
|
||||||
|
|
||||||
def remote_provider_config(
|
def remote_provider_config(
|
||||||
|
|
@ -92,90 +91,3 @@ def remote_provider_config(
|
||||||
assert not has_non_remote, "Remote stack cannot have non-remote providers"
|
assert not has_non_remote, "Remote stack cannot have non-remote providers"
|
||||||
|
|
||||||
return remote_config
|
return remote_config
|
||||||
|
|
||||||
|
|
||||||
async def resolve_impls_for_test(api: Api, deps: List[Api] = None):
|
|
||||||
if "PROVIDER_CONFIG" not in os.environ:
|
|
||||||
raise ValueError(
|
|
||||||
"You must set PROVIDER_CONFIG to a YAML file containing provider config"
|
|
||||||
)
|
|
||||||
|
|
||||||
with open(os.environ["PROVIDER_CONFIG"], "r") as f:
|
|
||||||
config_dict = yaml.safe_load(f)
|
|
||||||
|
|
||||||
providers = read_providers(api, config_dict)
|
|
||||||
|
|
||||||
chosen = choose_providers(providers, api, deps)
|
|
||||||
run_config = dict(
|
|
||||||
built_at=datetime.now(),
|
|
||||||
image_name="test-fixture",
|
|
||||||
apis=[api] + (deps or []),
|
|
||||||
providers=chosen,
|
|
||||||
)
|
|
||||||
run_config = parse_and_maybe_upgrade_config(run_config)
|
|
||||||
try:
|
|
||||||
impls = await resolve_impls(run_config, get_provider_registry())
|
|
||||||
except ModuleNotFoundError as e:
|
|
||||||
print_pip_install_help(providers)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
if "provider_data" in config_dict:
|
|
||||||
provider_id = chosen[api.value][0].provider_id
|
|
||||||
provider_data = config_dict["provider_data"].get(provider_id, {})
|
|
||||||
if provider_data:
|
|
||||||
set_request_provider_data(
|
|
||||||
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
|
|
||||||
)
|
|
||||||
|
|
||||||
return impls
|
|
||||||
|
|
||||||
|
|
||||||
def read_providers(api: Api, config_dict: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
if "providers" not in config_dict:
|
|
||||||
raise ValueError("Config file should contain a `providers` key")
|
|
||||||
|
|
||||||
providers = config_dict["providers"]
|
|
||||||
if isinstance(providers, dict):
|
|
||||||
return providers
|
|
||||||
elif isinstance(providers, list):
|
|
||||||
return {
|
|
||||||
api.value: providers,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Config file should contain a list of providers or dict(api to providers)"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def choose_providers(
|
|
||||||
providers: Dict[str, Any], api: Api, deps: List[Api] = None
|
|
||||||
) -> Dict[str, Provider]:
|
|
||||||
chosen = {}
|
|
||||||
if api.value not in providers:
|
|
||||||
raise ValueError(f"No providers found for `{api}`?")
|
|
||||||
chosen[api.value] = [pick_provider(api, providers[api.value], "PROVIDER_ID")]
|
|
||||||
|
|
||||||
for dep in deps or []:
|
|
||||||
if dep.value not in providers:
|
|
||||||
raise ValueError(f"No providers specified for `{dep}` in config?")
|
|
||||||
chosen[dep.value] = [Provider(**x) for x in providers[dep.value]]
|
|
||||||
|
|
||||||
return chosen
|
|
||||||
|
|
||||||
|
|
||||||
def pick_provider(api: Api, providers: List[Any], key: str) -> Provider:
|
|
||||||
providers_by_id = {x["provider_id"]: x for x in providers}
|
|
||||||
if len(providers_by_id) == 0:
|
|
||||||
raise ValueError(f"No providers found for `{api}` in config file")
|
|
||||||
|
|
||||||
if key in os.environ:
|
|
||||||
provider_id = os.environ[key]
|
|
||||||
if provider_id not in providers_by_id:
|
|
||||||
raise ValueError(f"Provider ID {provider_id} not found in config file")
|
|
||||||
provider = providers_by_id[provider_id]
|
|
||||||
else:
|
|
||||||
provider = list(providers_by_id.values())[0]
|
|
||||||
provider_id = provider["provider_id"]
|
|
||||||
print(f"No provider ID specified, picking first `{provider_id}`")
|
|
||||||
|
|
||||||
return Provider(**provider)
|
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
|
||||||
from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig
|
from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig
|
||||||
from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
|
from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
|
||||||
|
|
||||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||||
|
|
||||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||||
from ..env import get_env_or_fail
|
from ..env import get_env_or_fail
|
||||||
|
|
@ -102,22 +102,16 @@ SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote"]
|
||||||
async def safety_stack(inference_model, safety_shield, request):
|
async def safety_stack(inference_model, safety_shield, request):
|
||||||
# We need an inference + safety fixture to test safety
|
# We need an inference + safety fixture to test safety
|
||||||
fixture_dict = request.param
|
fixture_dict = request.param
|
||||||
inference_fixture = request.getfixturevalue(
|
|
||||||
f"inference_{fixture_dict['inference']}"
|
|
||||||
)
|
|
||||||
safety_fixture = request.getfixturevalue(f"safety_{fixture_dict['safety']}")
|
|
||||||
|
|
||||||
providers = {
|
providers = {}
|
||||||
"inference": inference_fixture.providers,
|
|
||||||
"safety": safety_fixture.providers,
|
|
||||||
}
|
|
||||||
provider_data = {}
|
provider_data = {}
|
||||||
if inference_fixture.provider_data:
|
for key in ["inference", "safety"]:
|
||||||
provider_data.update(inference_fixture.provider_data)
|
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||||
if safety_fixture.provider_data:
|
providers[key] = fixture.providers
|
||||||
provider_data.update(safety_fixture.provider_data)
|
if fixture.provider_data:
|
||||||
|
provider_data.update(fixture.provider_data)
|
||||||
|
|
||||||
impls = await resolve_impls_for_test_v2(
|
test_stack = await construct_stack_for_test(
|
||||||
[Api.safety, Api.shields, Api.inference],
|
[Api.safety, Api.shields, Api.inference],
|
||||||
providers,
|
providers,
|
||||||
provider_data,
|
provider_data,
|
||||||
|
|
@ -125,5 +119,5 @@ async def safety_stack(inference_model, safety_shield, request):
|
||||||
shields=[safety_shield],
|
shields=[safety_shield],
|
||||||
)
|
)
|
||||||
|
|
||||||
shield = await impls[Api.shields].get_shield(safety_shield.shield_id)
|
shield = await test_stack.impls[Api.shields].get_shield(safety_shield.shield_id)
|
||||||
return impls[Api.safety], impls[Api.shields], shield
|
return test_stack.impls[Api.safety], test_stack.impls[Api.shields], shield
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from llama_stack.apis.models import ModelInput
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
|
|
||||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -74,7 +74,7 @@ async def scoring_stack(request, inference_model):
|
||||||
if fixture.provider_data:
|
if fixture.provider_data:
|
||||||
provider_data.update(fixture.provider_data)
|
provider_data.update(fixture.provider_data)
|
||||||
|
|
||||||
impls = await resolve_impls_for_test_v2(
|
test_stack = await construct_stack_for_test(
|
||||||
[Api.scoring, Api.datasetio, Api.inference],
|
[Api.scoring, Api.datasetio, Api.inference],
|
||||||
providers,
|
providers,
|
||||||
provider_data,
|
provider_data,
|
||||||
|
|
@ -88,4 +88,4 @@ async def scoring_stack(request, inference_model):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
return impls
|
return test_stack.impls
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue