Kill "remote" providers and fix testing with a remote stack properly (#435)

# What does this PR do?

This PR kills the notion of "pure passthrough" remote providers. You
cannot specify a single provider you must specify a whole distribution
(stack) as remote.

This PR also significantly fixes / upgrades testing infrastructure so
you can now test against a remotely hosted stack server by just doing

```bash
pytest -s -v -m remote  test_agents.py \
  --inference-model=Llama3.1-8B-Instruct --safety-shield=Llama-Guard-3-1B \
  --env REMOTE_STACK_URL=http://localhost:5001
```

Also fixed `test_agents_persistence.py` (which was broken) and killed
some deprecated testing functions.

## Test Plan

All the tests.
This commit is contained in:
Ashwin Bharambe 2024-11-12 21:51:29 -08:00 committed by GitHub
parent 59a65e34d3
commit 12947ac19e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
28 changed files with 406 additions and 519 deletions

View file

@ -20,21 +20,17 @@ from llama_stack.providers.datatypes import RemoteProviderConfig
_CLIENT_CLASSES = {} _CLIENT_CLASSES = {}
async def get_client_impl( async def get_client_impl(protocol, config: RemoteProviderConfig, _deps: Any):
protocol, additional_protocol, config: RemoteProviderConfig, _deps: Any client_class = create_api_client_class(protocol)
):
client_class = create_api_client_class(protocol, additional_protocol)
impl = client_class(config.url) impl = client_class(config.url)
await impl.initialize() await impl.initialize()
return impl return impl
def create_api_client_class(protocol, additional_protocol) -> Type: def create_api_client_class(protocol) -> Type:
if protocol in _CLIENT_CLASSES: if protocol in _CLIENT_CLASSES:
return _CLIENT_CLASSES[protocol] return _CLIENT_CLASSES[protocol]
protocols = [protocol, additional_protocol] if additional_protocol else [protocol]
class APIClient: class APIClient:
def __init__(self, base_url: str): def __init__(self, base_url: str):
print(f"({protocol.__name__}) Connecting to {base_url}") print(f"({protocol.__name__}) Connecting to {base_url}")
@ -42,8 +38,7 @@ def create_api_client_class(protocol, additional_protocol) -> Type:
self.routes = {} self.routes = {}
# Store routes for this protocol # Store routes for this protocol
for p in protocols: for name, method in inspect.getmembers(protocol):
for name, method in inspect.getmembers(p):
if hasattr(method, "__webmethod__"): if hasattr(method, "__webmethod__"):
sig = inspect.signature(method) sig = inspect.signature(method)
self.routes[name] = (method.__webmethod__, sig) self.routes[name] = (method.__webmethod__, sig)
@ -160,8 +155,7 @@ def create_api_client_class(protocol, additional_protocol) -> Type:
return ret return ret
# Add protocol methods to the wrapper # Add protocol methods to the wrapper
for p in protocols: for name, method in inspect.getmembers(protocol):
for name, method in inspect.getmembers(p):
if hasattr(method, "__webmethod__"): if hasattr(method, "__webmethod__"):
async def method_impl(self, *args, method_name=name, **kwargs): async def method_impl(self, *args, method_name=name, **kwargs):

View file

@ -9,7 +9,7 @@ from typing import Dict, List
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.providers.datatypes import Api, ProviderSpec, remote_provider_spec from llama_stack.providers.datatypes import Api, ProviderSpec
def stack_apis() -> List[Api]: def stack_apis() -> List[Api]:
@ -62,9 +62,6 @@ def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
for api in providable_apis(): for api in providable_apis():
name = api.name.lower() name = api.name.lower()
module = importlib.import_module(f"llama_stack.providers.registry.{name}") module = importlib.import_module(f"llama_stack.providers.registry.{name}")
ret[api] = { ret[api] = {a.provider_type: a for a in module.available_providers()}
"remote": remote_provider_spec(api),
**{a.provider_type: a for a in module.available_providers()},
}
return ret return ret

View file

@ -28,6 +28,7 @@ from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry from llama_stack.apis.telemetry import Telemetry
from llama_stack.distribution.client import get_client_impl
from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
@ -59,12 +60,16 @@ def api_protocol_map() -> Dict[Api, Any]:
def additional_protocols_map() -> Dict[Api, Any]: def additional_protocols_map() -> Dict[Api, Any]:
return { return {
Api.inference: (ModelsProtocolPrivate, Models), Api.inference: (ModelsProtocolPrivate, Models, Api.models),
Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks), Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks, Api.memory_banks),
Api.safety: (ShieldsProtocolPrivate, Shields), Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
Api.datasetio: (DatasetsProtocolPrivate, Datasets), Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
Api.scoring: (ScoringFunctionsProtocolPrivate, ScoringFunctions), Api.scoring: (
Api.eval_tasks: (EvalTasksProtocolPrivate, EvalTasks), ScoringFunctionsProtocolPrivate,
ScoringFunctions,
Api.scoring_functions,
),
Api.eval: (EvalTasksProtocolPrivate, EvalTasks, Api.eval_tasks),
} }
@ -73,10 +78,13 @@ class ProviderWithSpec(Provider):
spec: ProviderSpec spec: ProviderSpec
ProviderRegistry = Dict[Api, Dict[str, ProviderSpec]]
# TODO: this code is not very straightforward to follow and needs one more round of refactoring # TODO: this code is not very straightforward to follow and needs one more round of refactoring
async def resolve_impls( async def resolve_impls(
run_config: StackRunConfig, run_config: StackRunConfig,
provider_registry: Dict[Api, Dict[str, ProviderSpec]], provider_registry: ProviderRegistry,
dist_registry: DistributionRegistry, dist_registry: DistributionRegistry,
) -> Dict[Api, Any]: ) -> Dict[Api, Any]:
""" """
@ -273,17 +281,8 @@ async def instantiate_provider(
config_type = instantiate_class_type(provider_spec.config_class) config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider.config) config = config_type(**provider.config)
if provider_spec.adapter:
method = "get_adapter_impl" method = "get_adapter_impl"
args = [config, deps] args = [config, deps]
else:
method = "get_client_impl"
protocol = protocols[provider_spec.api]
if provider_spec.api in additional_protocols:
_, additional_protocol = additional_protocols[provider_spec.api]
else:
additional_protocol = None
args = [protocol, additional_protocol, config, deps]
elif isinstance(provider_spec, AutoRoutedProviderSpec): elif isinstance(provider_spec, AutoRoutedProviderSpec):
method = "get_auto_router_impl" method = "get_auto_router_impl"
@ -313,7 +312,7 @@ async def instantiate_provider(
not isinstance(provider_spec, AutoRoutedProviderSpec) not isinstance(provider_spec, AutoRoutedProviderSpec)
and provider_spec.api in additional_protocols and provider_spec.api in additional_protocols
): ):
additional_api, _ = additional_protocols[provider_spec.api] additional_api, _, _ = additional_protocols[provider_spec.api]
check_protocol_compliance(impl, additional_api) check_protocol_compliance(impl, additional_api)
return impl return impl
@ -359,3 +358,29 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
raise ValueError( raise ValueError(
f"Provider `{obj.__provider_id__} ({obj.__provider_spec__.api})` does not implement the following methods:\n{missing_methods}" f"Provider `{obj.__provider_id__} ({obj.__provider_spec__.api})` does not implement the following methods:\n{missing_methods}"
) )
async def resolve_remote_stack_impls(
config: RemoteProviderConfig,
apis: List[str],
) -> Dict[Api, Any]:
protocols = api_protocol_map()
additional_protocols = additional_protocols_map()
impls = {}
for api_str in apis:
api = Api(api_str)
impls[api] = await get_client_impl(
protocols[api],
config,
{},
)
if api in additional_protocols:
_, additional_protocol, additional_api = additional_protocols[api]
impls[additional_api] = await get_client_impl(
additional_protocol,
config,
{},
)
return impls

View file

@ -33,28 +33,20 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
api = get_impl_api(p) api = get_impl_api(p)
if obj.provider_id == "remote": assert obj.provider_id != "remote", "Remote provider should not be registered"
# TODO: this is broken right now because we use the generic
# { identifier, provider_id, provider_resource_id } tuple here
# but the APIs expect things like ModelInput, ShieldInput, etc.
# if this is just a passthrough, we want to let the remote
# end actually do the registration with the correct provider
obj = obj.model_copy(deep=True)
obj.provider_id = ""
if api == Api.inference: if api == Api.inference:
return await p.register_model(obj) return await p.register_model(obj)
elif api == Api.safety: elif api == Api.safety:
await p.register_shield(obj) return await p.register_shield(obj)
elif api == Api.memory: elif api == Api.memory:
await p.register_memory_bank(obj) return await p.register_memory_bank(obj)
elif api == Api.datasetio: elif api == Api.datasetio:
await p.register_dataset(obj) return await p.register_dataset(obj)
elif api == Api.scoring: elif api == Api.scoring:
await p.register_scoring_function(obj) return await p.register_scoring_function(obj)
elif api == Api.eval: elif api == Api.eval:
await p.register_eval_task(obj) return await p.register_eval_task(obj)
else: else:
raise ValueError(f"Unknown API {api} for registering object with provider") raise ValueError(f"Unknown API {api} for registering object with provider")
@ -81,11 +73,6 @@ class CommonRoutingTableImpl(RoutingTable):
for obj in objs: for obj in objs:
if cls is None: if cls is None:
obj.provider_id = provider_id obj.provider_id = provider_id
else:
if provider_id == "remote":
# if this is just a passthrough, we got the *WithProvider object
# so we should just override the provider in-place
obj.provider_id = provider_id
else: else:
# Create a copy of the model data and explicitly set provider_id # Create a copy of the model data and explicitly set provider_id
model_data = obj.model_dump() model_data = obj.model_dump()
@ -100,18 +87,14 @@ class CommonRoutingTableImpl(RoutingTable):
p.model_store = self p.model_store = self
elif api == Api.safety: elif api == Api.safety:
p.shield_store = self p.shield_store = self
elif api == Api.memory: elif api == Api.memory:
p.memory_bank_store = self p.memory_bank_store = self
elif api == Api.datasetio: elif api == Api.datasetio:
p.dataset_store = self p.dataset_store = self
elif api == Api.scoring: elif api == Api.scoring:
p.scoring_function_store = self p.scoring_function_store = self
scoring_functions = await p.list_scoring_functions() scoring_functions = await p.list_scoring_functions()
await add_objects(scoring_functions, pid, ScoringFn) await add_objects(scoring_functions, pid, ScoringFn)
elif api == Api.eval: elif api == Api.eval:
p.eval_task_store = self p.eval_task_store = self

View file

@ -182,15 +182,6 @@ async def lifespan(app: FastAPI):
await impl.shutdown() await impl.shutdown()
def create_dynamic_passthrough(
downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None
):
async def endpoint(request: Request):
return await passthrough(request, downstream_url, downstream_headers)
return endpoint
def is_streaming_request(func_name: str, request: Request, **kwargs): def is_streaming_request(func_name: str, request: Request, **kwargs):
# TODO: pass the api method and punt it to the Protocol definition directly # TODO: pass the api method and punt it to the Protocol definition directly
return kwargs.get("stream", False) return kwargs.get("stream", False)
@ -305,19 +296,10 @@ def main(
endpoints = all_endpoints[api] endpoints = all_endpoints[api]
impl = impls[api] impl = impls[api]
if is_passthrough(impl.__provider_spec__):
for endpoint in endpoints:
url = impl.__provider_config__.url.rstrip("/") + endpoint.route
getattr(app, endpoint.method)(endpoint.route)(
create_dynamic_passthrough(url)
)
else:
for endpoint in endpoints: for endpoint in endpoints:
if not hasattr(impl, endpoint.name): if not hasattr(impl, endpoint.name):
# ideally this should be a typing violation already # ideally this should be a typing violation already
raise ValueError( raise ValueError(f"Could not find method {endpoint.name} on {impl}!!")
f"Could not find method {endpoint.name} on {impl}!!"
)
impl_method = getattr(impl, endpoint.name) impl_method = getattr(impl, endpoint.name)

View file

@ -30,7 +30,7 @@ from llama_stack.apis.eval_tasks import * # noqa: F403
from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import resolve_impls from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
from llama_stack.distribution.store.registry import create_dist_registry from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
@ -58,16 +58,7 @@ class LlamaStack(
pass pass
# Produces a stack of providers for the given run config. Not all APIs may be RESOURCES = [
# asked for in the run config.
async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
dist_registry, _ = await create_dist_registry(
run_config.metadata_store, run_config.image_name
)
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
resources = [
("models", Api.models, "register_model", "list_models"), ("models", Api.models, "register_model", "list_models"),
("shields", Api.shields, "register_shield", "list_shields"), ("shields", Api.shields, "register_shield", "list_shields"),
("memory_banks", Api.memory_banks, "register_memory_bank", "list_memory_banks"), ("memory_banks", Api.memory_banks, "register_memory_bank", "list_memory_banks"),
@ -79,8 +70,11 @@ async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
"list_scoring_functions", "list_scoring_functions",
), ),
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"), ("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
] ]
for rsrc, api, register_method, list_method in resources:
async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
for rsrc, api, register_method, list_method in RESOURCES:
objects = getattr(run_config, rsrc) objects = getattr(run_config, rsrc)
if api not in impls: if api not in impls:
continue continue
@ -96,4 +90,18 @@ async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
) )
print("") print("")
# Produces a stack of providers for the given run config. Not all APIs may be
# asked for in the run config.
async def construct_stack(
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
) -> Dict[Api, Any]:
dist_registry, _ = await create_dist_registry(
run_config.metadata_store, run_config.image_name
)
impls = await resolve_impls(
run_config, provider_registry or get_provider_registry(), dist_registry
)
await register_resources(run_config, impls)
return impls return impls

View file

@ -99,6 +99,7 @@ class RoutingTable(Protocol):
def get_provider_impl(self, routing_key: str) -> Any: ... def get_provider_impl(self, routing_key: str) -> Any: ...
# TODO: this can now be inlined into RemoteProviderSpec
@json_schema_type @json_schema_type
class AdapterSpec(BaseModel): class AdapterSpec(BaseModel):
adapter_type: str = Field( adapter_type: str = Field(
@ -171,12 +172,10 @@ class RemoteProviderConfig(BaseModel):
@json_schema_type @json_schema_type
class RemoteProviderSpec(ProviderSpec): class RemoteProviderSpec(ProviderSpec):
adapter: Optional[AdapterSpec] = Field( adapter: AdapterSpec = Field(
default=None,
description=""" description="""
If some code is needed to convert the remote responses into Llama Stack compatible If some code is needed to convert the remote responses into Llama Stack compatible
API responses, specify the adapter here. If not specified, it indicates the remote API responses, specify the adapter here.
as being "Llama Stack compatible"
""", """,
) )
@ -186,38 +185,21 @@ as being "Llama Stack compatible"
@property @property
def module(self) -> str: def module(self) -> str:
if self.adapter:
return self.adapter.module return self.adapter.module
return "llama_stack.distribution.client"
@property @property
def pip_packages(self) -> List[str]: def pip_packages(self) -> List[str]:
if self.adapter:
return self.adapter.pip_packages return self.adapter.pip_packages
return []
@property @property
def provider_data_validator(self) -> Optional[str]: def provider_data_validator(self) -> Optional[str]:
if self.adapter:
return self.adapter.provider_data_validator return self.adapter.provider_data_validator
return None
def is_passthrough(spec: ProviderSpec) -> bool: def remote_provider_spec(api: Api, adapter: AdapterSpec) -> RemoteProviderSpec:
return isinstance(spec, RemoteProviderSpec) and spec.adapter is None
# Can avoid this by using Pydantic computed_field
def remote_provider_spec(
api: Api, adapter: Optional[AdapterSpec] = None
) -> RemoteProviderSpec:
config_class = (
adapter.config_class
if adapter and adapter.config_class
else "llama_stack.distribution.datatypes.RemoteProviderConfig"
)
provider_type = f"remote::{adapter.adapter_type}" if adapter else "remote"
return RemoteProviderSpec( return RemoteProviderSpec(
api=api, provider_type=provider_type, config_class=config_class, adapter=adapter api=api,
provider_type=f"remote::{adapter.adapter_type}",
config_class=adapter.config_class,
adapter=adapter,
) )

View file

@ -234,7 +234,7 @@ class LlamaGuardShield:
# TODO: llama-stack inference protocol has issues with non-streaming inference code # TODO: llama-stack inference protocol has issues with non-streaming inference code
content = "" content = ""
async for chunk in await self.inference_api.chat_completion( async for chunk in await self.inference_api.chat_completion(
model=self.model, model_id=self.model,
messages=[shield_input_message], messages=[shield_input_message],
stream=True, stream=True,
): ):

View file

@ -53,6 +53,7 @@ def available_providers() -> List[ProviderSpec]:
adapter_type="chromadb", adapter_type="chromadb",
pip_packages=EMBEDDING_DEPS + ["chromadb-client"], pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
module="llama_stack.providers.remote.memory.chroma", module="llama_stack.providers.remote.memory.chroma",
config_class="llama_stack.distribution.datatypes.RemoteProviderConfig",
), ),
), ),
remote_provider_spec( remote_provider_spec(

View file

@ -164,7 +164,6 @@ class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPriva
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
print(f"model={model}")
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model.provider_resource_id, model=model.provider_resource_id,
messages=messages, messages=messages,

View file

@ -10,7 +10,7 @@ from ..conftest import get_provider_fixture_overrides
from ..inference.fixtures import INFERENCE_FIXTURES from ..inference.fixtures import INFERENCE_FIXTURES
from ..memory.fixtures import MEMORY_FIXTURES from ..memory.fixtures import MEMORY_FIXTURES
from ..safety.fixtures import SAFETY_FIXTURES from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
from .fixtures import AGENTS_FIXTURES from .fixtures import AGENTS_FIXTURES
@ -46,6 +46,16 @@ DEFAULT_PROVIDER_COMBINATIONS = [
id="together", id="together",
marks=pytest.mark.together, marks=pytest.mark.together,
), ),
pytest.param(
{
"inference": "fireworks",
"safety": "llama_guard",
"memory": "faiss",
"agents": "meta_reference",
},
id="fireworks",
marks=pytest.mark.fireworks,
),
pytest.param( pytest.param(
{ {
"inference": "remote", "inference": "remote",
@ -60,7 +70,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
def pytest_configure(config): def pytest_configure(config):
for mark in ["meta_reference", "ollama", "together", "remote"]: for mark in ["meta_reference", "ollama", "together", "fireworks", "remote"]:
config.addinivalue_line( config.addinivalue_line(
"markers", "markers",
f"{mark}: marks tests as {mark} specific", f"{mark}: marks tests as {mark} specific",
@ -75,28 +85,30 @@ def pytest_addoption(parser):
help="Specify the inference model to use for testing", help="Specify the inference model to use for testing",
) )
parser.addoption( parser.addoption(
"--safety-model", "--safety-shield",
action="store", action="store",
default="Llama-Guard-3-8B", default="Llama-Guard-3-8B",
help="Specify the safety model to use for testing", help="Specify the safety shield to use for testing",
) )
def pytest_generate_tests(metafunc): def pytest_generate_tests(metafunc):
safety_model = metafunc.config.getoption("--safety-model") shield_id = metafunc.config.getoption("--safety-shield")
if "safety_model" in metafunc.fixturenames: if "safety_shield" in metafunc.fixturenames:
metafunc.parametrize( metafunc.parametrize(
"safety_model", "safety_shield",
[pytest.param(safety_model, id="")], [pytest.param(shield_id, id="")],
indirect=True, indirect=True,
) )
if "inference_model" in metafunc.fixturenames: if "inference_model" in metafunc.fixturenames:
inference_model = metafunc.config.getoption("--inference-model") inference_model = metafunc.config.getoption("--inference-model")
models = list(set({inference_model, safety_model})) models = set({inference_model})
if safety_model := safety_model_from_shield(shield_id):
models.add(safety_model)
metafunc.parametrize( metafunc.parametrize(
"inference_model", "inference_model",
[pytest.param(models, id="")], [pytest.param(list(models), id="")],
indirect=True, indirect=True,
) )
if "agents_stack" in metafunc.fixturenames: if "agents_stack" in metafunc.fixturenames:

View file

@ -16,10 +16,9 @@ from llama_stack.providers.inline.agents.meta_reference import (
MetaReferenceAgentsImplConfig, MetaReferenceAgentsImplConfig,
) )
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 from llama_stack.providers.tests.resolver import construct_stack_for_test
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from ..conftest import ProviderFixture, remote_stack_fixture from ..conftest import ProviderFixture, remote_stack_fixture
from ..safety.fixtures import get_shield_to_register
def pick_inference_model(inference_model): def pick_inference_model(inference_model):
@ -60,7 +59,7 @@ AGENTS_FIXTURES = ["meta_reference", "remote"]
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")
async def agents_stack(request, inference_model, safety_model): async def agents_stack(request, inference_model, safety_shield):
fixture_dict = request.param fixture_dict = request.param
providers = {} providers = {}
@ -71,13 +70,10 @@ async def agents_stack(request, inference_model, safety_model):
if fixture.provider_data: if fixture.provider_data:
provider_data.update(fixture.provider_data) provider_data.update(fixture.provider_data)
shield_input = get_shield_to_register(
providers["safety"][0].provider_type, safety_model
)
inference_models = ( inference_models = (
inference_model if isinstance(inference_model, list) else [inference_model] inference_model if isinstance(inference_model, list) else [inference_model]
) )
impls = await resolve_impls_for_test_v2( test_stack = await construct_stack_for_test(
[Api.agents, Api.inference, Api.safety, Api.memory], [Api.agents, Api.inference, Api.safety, Api.memory],
providers, providers,
provider_data, provider_data,
@ -87,6 +83,6 @@ async def agents_stack(request, inference_model, safety_model):
) )
for model in inference_models for model in inference_models
], ],
shields=[shield_input], shields=[safety_shield],
) )
return impls[Api.agents], impls[Api.memory] return test_stack

View file

@ -1,148 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
import pytest_asyncio
from llama_stack.apis.agents import * # noqa: F403
from llama_stack.providers.tests.resolver import resolve_impls_for_test
from llama_stack.providers.datatypes import * # noqa: F403
from dotenv import load_dotenv
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
# How to run this test:
#
# 1. Ensure you have a conda environment with the right dependencies installed.
# This includes `pytest` and `pytest-asyncio`.
#
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
#
# 3. Run:
#
# ```bash
# PROVIDER_ID=<your_provider> \
# PROVIDER_CONFIG=provider_config.yaml \
# pytest -s llama_stack/providers/tests/agents/test_agent_persistence.py \
# --tb=short --disable-warnings
# ```
load_dotenv()
@pytest_asyncio.fixture(scope="session")
async def agents_settings():
impls = await resolve_impls_for_test(
Api.agents, deps=[Api.inference, Api.memory, Api.safety]
)
return {
"impl": impls[Api.agents],
"memory_impl": impls[Api.memory],
"common_params": {
"model": "Llama3.1-8B-Instruct",
"instructions": "You are a helpful assistant.",
},
}
@pytest.fixture
def sample_messages():
return [
UserMessage(content="What's the weather like today?"),
]
@pytest.mark.asyncio
async def test_delete_agents_and_sessions(agents_settings, sample_messages):
agents_impl = agents_settings["impl"]
# First, create an agent
agent_config = AgentConfig(
model=agents_settings["common_params"]["model"],
instructions=agents_settings["common_params"]["instructions"],
enable_session_persistence=True,
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
input_shields=[],
output_shields=[],
tools=[],
max_infer_iters=5,
)
create_response = await agents_impl.create_agent(agent_config)
agent_id = create_response.agent_id
# Create a session
session_create_response = await agents_impl.create_agent_session(
agent_id, "Test Session"
)
session_id = session_create_response.session_id
persistence_store = await kvstore_impl(agents_settings["persistence"])
await agents_impl.delete_agents_session(agent_id, session_id)
session_response = await persistence_store.get(f"session:{agent_id}:{session_id}")
await agents_impl.delete_agents(agent_id)
agent_response = await persistence_store.get(f"agent:{agent_id}")
assert session_response is None
assert agent_response is None
async def test_get_agent_turns_and_steps(agents_settings, sample_messages):
agents_impl = agents_settings["impl"]
# First, create an agent
agent_config = AgentConfig(
model=agents_settings["common_params"]["model"],
instructions=agents_settings["common_params"]["instructions"],
enable_session_persistence=True,
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
input_shields=[],
output_shields=[],
tools=[],
max_infer_iters=5,
)
create_response = await agents_impl.create_agent(agent_config)
agent_id = create_response.agent_id
# Create a session
session_create_response = await agents_impl.create_agent_session(
agent_id, "Test Session"
)
session_id = session_create_response.session_id
# Create and execute a turn
turn_request = dict(
agent_id=agent_id,
session_id=session_id,
messages=sample_messages,
stream=True,
)
turn_response = [
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
]
final_event = turn_response[-1].event.payload
turn_id = final_event.turn.turn_id
persistence_store = await kvstore_impl(SqliteKVStoreConfig())
turn = await persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
response = await agents_impl.get_agents_turn(agent_id, session_id, turn_id)
assert isinstance(response, Turn)
assert response == final_event.turn
assert turn == final_event.turn
steps = final_event.turn.steps
step_id = steps[0].step_id
step_response = await agents_impl.get_agents_step(
agent_id, session_id, turn_id, step_id
)
assert isinstance(step_response.step, Step)
assert step_response.step == steps[0]

View file

@ -17,6 +17,7 @@ from llama_stack.providers.datatypes import * # noqa: F403
# -m "meta_reference" # -m "meta_reference"
from .fixtures import pick_inference_model from .fixtures import pick_inference_model
from .utils import create_agent_session
@pytest.fixture @pytest.fixture
@ -67,31 +68,19 @@ def query_attachment_messages():
] ]
async def create_agent_session(agents_impl, agent_config):
create_response = await agents_impl.create_agent(agent_config)
agent_id = create_response.agent_id
# Create a session
session_create_response = await agents_impl.create_agent_session(
agent_id, "Test Session"
)
session_id = session_create_response.session_id
return agent_id, session_id
class TestAgents: class TestAgents:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_agent_turns_with_safety( async def test_agent_turns_with_safety(
self, safety_model, agents_stack, common_params self, safety_shield, agents_stack, common_params
): ):
agents_impl, _ = agents_stack agents_impl = agents_stack.impls[Api.agents]
agent_id, session_id = await create_agent_session( agent_id, session_id = await create_agent_session(
agents_impl, agents_impl,
AgentConfig( AgentConfig(
**{ **{
**common_params, **common_params,
"input_shields": [safety_model], "input_shields": [safety_shield.shield_id],
"output_shields": [safety_model], "output_shields": [safety_shield.shield_id],
} }
), ),
) )
@ -127,7 +116,7 @@ class TestAgents:
async def test_create_agent_turn( async def test_create_agent_turn(
self, agents_stack, sample_messages, common_params self, agents_stack, sample_messages, common_params
): ):
agents_impl, _ = agents_stack agents_impl = agents_stack.impls[Api.agents]
agent_id, session_id = await create_agent_session( agent_id, session_id = await create_agent_session(
agents_impl, AgentConfig(**common_params) agents_impl, AgentConfig(**common_params)
@ -158,7 +147,7 @@ class TestAgents:
query_attachment_messages, query_attachment_messages,
common_params, common_params,
): ):
agents_impl, _ = agents_stack agents_impl = agents_stack.impls[Api.agents]
urls = [ urls = [
"memory_optimizations.rst", "memory_optimizations.rst",
"chat.rst", "chat.rst",
@ -226,7 +215,7 @@ class TestAgents:
async def test_create_agent_turn_with_brave_search( async def test_create_agent_turn_with_brave_search(
self, agents_stack, search_query_messages, common_params self, agents_stack, search_query_messages, common_params
): ):
agents_impl, _ = agents_stack agents_impl = agents_stack.impls[Api.agents]
if "BRAVE_SEARCH_API_KEY" not in os.environ: if "BRAVE_SEARCH_API_KEY" not in os.environ:
pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test") pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test")

View file

@ -0,0 +1,122 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.apis.agents import * # noqa: F403
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
from .fixtures import pick_inference_model
from .utils import create_agent_session
@pytest.fixture
def sample_messages():
return [
UserMessage(content="What's the weather like today?"),
]
@pytest.fixture
def common_params(inference_model):
inference_model = pick_inference_model(inference_model)
return dict(
model=inference_model,
instructions="You are a helpful assistant.",
enable_session_persistence=True,
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
input_shields=[],
output_shields=[],
tools=[],
max_infer_iters=5,
)
class TestAgentPersistence:
@pytest.mark.asyncio
async def test_delete_agents_and_sessions(self, agents_stack, common_params):
agents_impl = agents_stack.impls[Api.agents]
agent_id, session_id = await create_agent_session(
agents_impl,
AgentConfig(
**{
**common_params,
"input_shields": [],
"output_shields": [],
}
),
)
run_config = agents_stack.run_config
provider_config = run_config.providers["agents"][0].config
persistence_store = await kvstore_impl(
SqliteKVStoreConfig(**provider_config["persistence_store"])
)
await agents_impl.delete_agents_session(agent_id, session_id)
session_response = await persistence_store.get(
f"session:{agent_id}:{session_id}"
)
await agents_impl.delete_agents(agent_id)
agent_response = await persistence_store.get(f"agent:{agent_id}")
assert session_response is None
assert agent_response is None
@pytest.mark.asyncio
async def test_get_agent_turns_and_steps(
self, agents_stack, sample_messages, common_params
):
agents_impl = agents_stack.impls[Api.agents]
agent_id, session_id = await create_agent_session(
agents_impl,
AgentConfig(
**{
**common_params,
"input_shields": [],
"output_shields": [],
}
),
)
# Create and execute a turn
turn_request = dict(
agent_id=agent_id,
session_id=session_id,
messages=sample_messages,
stream=True,
)
turn_response = [
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
]
final_event = turn_response[-1].event.payload
turn_id = final_event.turn.turn_id
provider_config = agents_stack.run_config.providers["agents"][0].config
persistence_store = await kvstore_impl(
SqliteKVStoreConfig(**provider_config["persistence_store"])
)
turn = await persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
response = await agents_impl.get_agents_turn(agent_id, session_id, turn_id)
assert isinstance(response, Turn)
assert response == final_event.turn
assert turn == final_event.turn.model_dump_json()
steps = final_event.turn.steps
step_id = steps[0].step_id
step_response = await agents_impl.get_agents_step(
agent_id, session_id, turn_id, step_id
)
assert step_response.step == steps[0]

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
async def create_agent_session(agents_impl, agent_config):
create_response = await agents_impl.create_agent(agent_config)
agent_id = create_response.agent_id
# Create a session
session_create_response = await agents_impl.create_agent_session(
agent_id, "Test Session"
)
session_id = session_create_response.session_id
return agent_id, session_id

View file

@ -35,8 +35,8 @@ def remote_stack_fixture() -> ProviderFixture:
return ProviderFixture( return ProviderFixture(
providers=[ providers=[
Provider( Provider(
provider_id="remote", provider_id="test::remote",
provider_type="remote", provider_type="test::remote",
config=config.model_dump(), config=config.model_dump(),
) )
], ],

View file

@ -9,7 +9,7 @@ import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture from ..conftest import ProviderFixture, remote_stack_fixture
@ -52,10 +52,10 @@ async def datasetio_stack(request):
fixture_name = request.param fixture_name = request.param
fixture = request.getfixturevalue(f"datasetio_{fixture_name}") fixture = request.getfixturevalue(f"datasetio_{fixture_name}")
impls = await resolve_impls_for_test_v2( test_stack = await construct_stack_for_test(
[Api.datasetio], [Api.datasetio],
{"datasetio": fixture.providers}, {"datasetio": fixture.providers},
fixture.provider_data, fixture.provider_data,
) )
return impls[Api.datasetio], impls[Api.datasets] return test_stack.impls[Api.datasetio], test_stack.impls[Api.datasets]

View file

@ -9,7 +9,7 @@ import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture from ..conftest import ProviderFixture, remote_stack_fixture
@ -46,10 +46,10 @@ async def eval_stack(request):
if fixture.provider_data: if fixture.provider_data:
provider_data.update(fixture.provider_data) provider_data.update(fixture.provider_data)
impls = await resolve_impls_for_test_v2( test_stack = await construct_stack_for_test(
[Api.eval, Api.datasetio, Api.inference, Api.scoring], [Api.eval, Api.datasetio, Api.inference, Api.scoring],
providers, providers,
provider_data, provider_data,
) )
return impls return test_stack.impls

View file

@ -21,7 +21,7 @@ from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.providers.remote.inference.together import TogetherImplConfig from llama_stack.providers.remote.inference.together import TogetherImplConfig
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail from ..env import get_env_or_fail
@ -182,15 +182,11 @@ INFERENCE_FIXTURES = [
async def inference_stack(request, inference_model): async def inference_stack(request, inference_model):
fixture_name = request.param fixture_name = request.param
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}") inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
impls = await resolve_impls_for_test_v2( test_stack = await construct_stack_for_test(
[Api.inference], [Api.inference],
{"inference": inference_fixture.providers}, {"inference": inference_fixture.providers},
inference_fixture.provider_data, inference_fixture.provider_data,
models=[ models=[ModelInput(model_id=inference_model)],
ModelInput(
model_id=inference_model,
)
],
) )
return (impls[Api.inference], impls[Api.models]) return test_stack.impls[Api.inference], test_stack.impls[Api.models]

View file

@ -147,9 +147,9 @@ class TestInference:
user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003." user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003."
response = await inference_impl.completion( response = await inference_impl.completion(
model_id=inference_model,
content=user_input, content=user_input,
stream=False, stream=False,
model=inference_model,
sampling_params=SamplingParams( sampling_params=SamplingParams(
max_tokens=50, max_tokens=50,
), ),

View file

@ -55,7 +55,7 @@ class TestVisionModelInference:
) )
response = await inference_impl.chat_completion( response = await inference_impl.chat_completion(
model=inference_model, model_id=inference_model,
messages=[ messages=[
UserMessage(content="You are a helpful assistant."), UserMessage(content="You are a helpful assistant."),
UserMessage(content=[image, "Describe this image in two sentences."]), UserMessage(content=[image, "Describe this image in two sentences."]),
@ -102,7 +102,7 @@ class TestVisionModelInference:
response = [ response = [
r r
async for r in await inference_impl.chat_completion( async for r in await inference_impl.chat_completion(
model=inference_model, model_id=inference_model,
messages=[ messages=[
UserMessage(content="You are a helpful assistant."), UserMessage(content="You are a helpful assistant."),
UserMessage( UserMessage(

View file

@ -14,7 +14,7 @@ from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConf
from llama_stack.providers.inline.memory.faiss import FaissImplConfig from llama_stack.providers.inline.memory.faiss import FaissImplConfig
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 from llama_stack.providers.tests.resolver import construct_stack_for_test
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
from ..conftest import ProviderFixture, remote_stack_fixture from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail from ..env import get_env_or_fail
@ -101,10 +101,10 @@ async def memory_stack(request):
fixture_name = request.param fixture_name = request.param
fixture = request.getfixturevalue(f"memory_{fixture_name}") fixture = request.getfixturevalue(f"memory_{fixture_name}")
impls = await resolve_impls_for_test_v2( test_stack = await construct_stack_for_test(
[Api.memory], [Api.memory],
{"memory": fixture.providers}, {"memory": fixture.providers},
fixture.provider_data, fixture.provider_data,
) )
return impls[Api.memory], impls[Api.memory_banks] return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks]

View file

@ -5,33 +5,36 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
import os
import tempfile import tempfile
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import yaml
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_remote_stack_impls
from llama_stack.distribution.stack import construct_stack from llama_stack.distribution.stack import construct_stack
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
async def resolve_impls_for_test_v2( class TestStack(BaseModel):
impls: Dict[Api, Any]
run_config: StackRunConfig
async def construct_stack_for_test(
apis: List[Api], apis: List[Api],
providers: Dict[str, List[Provider]], providers: Dict[str, List[Provider]],
provider_data: Optional[Dict[str, Any]] = None, provider_data: Optional[Dict[str, Any]] = None,
models: Optional[List[Model]] = None, models: Optional[List[ModelInput]] = None,
shields: Optional[List[Shield]] = None, shields: Optional[List[ShieldInput]] = None,
memory_banks: Optional[List[MemoryBank]] = None, memory_banks: Optional[List[MemoryBankInput]] = None,
datasets: Optional[List[Dataset]] = None, datasets: Optional[List[DatasetInput]] = None,
scoring_fns: Optional[List[ScoringFn]] = None, scoring_fns: Optional[List[ScoringFnInput]] = None,
eval_tasks: Optional[List[EvalTask]] = None, eval_tasks: Optional[List[EvalTaskInput]] = None,
): ) -> TestStack:
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db") sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
run_config = dict( run_config = dict(
built_at=datetime.now(), built_at=datetime.now(),
@ -48,7 +51,18 @@ async def resolve_impls_for_test_v2(
) )
run_config = parse_and_maybe_upgrade_config(run_config) run_config = parse_and_maybe_upgrade_config(run_config)
try: try:
impls = await construct_stack(run_config) remote_config = remote_provider_config(run_config)
if not remote_config:
# TODO: add to provider registry by creating interesting mocks or fakes
impls = await construct_stack(run_config, get_provider_registry())
else:
# we don't register resources for a remote stack as part of the fixture setup
# because the stack is already "up". if a test needs to register resources, it
# can do so manually always.
impls = await resolve_remote_stack_impls(remote_config, run_config.apis)
test_stack = TestStack(impls=impls, run_config=run_config)
except ModuleNotFoundError as e: except ModuleNotFoundError as e:
print_pip_install_help(providers) print_pip_install_help(providers)
raise e raise e
@ -58,91 +72,22 @@ async def resolve_impls_for_test_v2(
{"X-LlamaStack-ProviderData": json.dumps(provider_data)} {"X-LlamaStack-ProviderData": json.dumps(provider_data)}
) )
return impls return test_stack
async def resolve_impls_for_test(api: Api, deps: List[Api] = None): def remote_provider_config(
if "PROVIDER_CONFIG" not in os.environ: run_config: StackRunConfig,
raise ValueError( ) -> Optional[RemoteProviderConfig]:
"You must set PROVIDER_CONFIG to a YAML file containing provider config" remote_config = None
) has_non_remote = False
for api_providers in run_config.providers.values():
with open(os.environ["PROVIDER_CONFIG"], "r") as f: for provider in api_providers:
config_dict = yaml.safe_load(f) if provider.provider_type == "test::remote":
remote_config = RemoteProviderConfig(**provider.config)
providers = read_providers(api, config_dict)
chosen = choose_providers(providers, api, deps)
run_config = dict(
built_at=datetime.now(),
image_name="test-fixture",
apis=[api] + (deps or []),
providers=chosen,
)
run_config = parse_and_maybe_upgrade_config(run_config)
try:
impls = await resolve_impls(run_config, get_provider_registry())
except ModuleNotFoundError as e:
print_pip_install_help(providers)
raise e
if "provider_data" in config_dict:
provider_id = chosen[api.value][0].provider_id
provider_data = config_dict["provider_data"].get(provider_id, {})
if provider_data:
set_request_provider_data(
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
)
return impls
def read_providers(api: Api, config_dict: Dict[str, Any]) -> Dict[str, Any]:
if "providers" not in config_dict:
raise ValueError("Config file should contain a `providers` key")
providers = config_dict["providers"]
if isinstance(providers, dict):
return providers
elif isinstance(providers, list):
return {
api.value: providers,
}
else: else:
raise ValueError( has_non_remote = True
"Config file should contain a list of providers or dict(api to providers)"
)
if remote_config:
assert not has_non_remote, "Remote stack cannot have non-remote providers"
def choose_providers( return remote_config
providers: Dict[str, Any], api: Api, deps: List[Api] = None
) -> Dict[str, Provider]:
chosen = {}
if api.value not in providers:
raise ValueError(f"No providers found for `{api}`?")
chosen[api.value] = [pick_provider(api, providers[api.value], "PROVIDER_ID")]
for dep in deps or []:
if dep.value not in providers:
raise ValueError(f"No providers specified for `{dep}` in config?")
chosen[dep.value] = [Provider(**x) for x in providers[dep.value]]
return chosen
def pick_provider(api: Api, providers: List[Any], key: str) -> Provider:
providers_by_id = {x["provider_id"]: x for x in providers}
if len(providers_by_id) == 0:
raise ValueError(f"No providers found for `{api}` in config file")
if key in os.environ:
provider_id = os.environ[key]
if provider_id not in providers_by_id:
raise ValueError(f"Provider ID {provider_id} not found in config file")
provider = providers_by_id[provider_id]
else:
provider = list(providers_by_id.values())[0]
provider_id = provider["provider_id"]
print(f"No provider ID specified, picking first `{provider_id}`")
return Provider(**provider)

View file

@ -66,14 +66,14 @@ def pytest_configure(config):
def pytest_addoption(parser): def pytest_addoption(parser):
parser.addoption( parser.addoption(
"--safety-model", "--safety-shield",
action="store", action="store",
default=None, default=None,
help="Specify the safety model to use for testing", help="Specify the safety shield to use for testing",
) )
SAFETY_MODEL_PARAMS = [ SAFETY_SHIELD_PARAMS = [
pytest.param("Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"), pytest.param("Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"),
] ]
@ -83,13 +83,13 @@ def pytest_generate_tests(metafunc):
# But a user can also pass in a custom combination via the CLI by doing # But a user can also pass in a custom combination via the CLI by doing
# `--providers inference=together,safety=meta_reference` # `--providers inference=together,safety=meta_reference`
if "safety_model" in metafunc.fixturenames: if "safety_shield" in metafunc.fixturenames:
model = metafunc.config.getoption("--safety-model") shield_id = metafunc.config.getoption("--safety-shield")
if model: if shield_id:
params = [pytest.param(model, id="")] params = [pytest.param(shield_id, id="")]
else: else:
params = SAFETY_MODEL_PARAMS params = SAFETY_SHIELD_PARAMS
for fixture in ["inference_model", "safety_model"]: for fixture in ["inference_model", "safety_shield"]:
metafunc.parametrize( metafunc.parametrize(
fixture, fixture,
params, params,

View file

@ -16,7 +16,7 @@ from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig
from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail from ..env import get_env_or_fail
@ -27,19 +27,38 @@ def safety_remote() -> ProviderFixture:
return remote_stack_fixture() return remote_stack_fixture()
def safety_model_from_shield(shield_id):
if shield_id in ("Bedrock", "CodeScanner", "CodeShield"):
return None
return shield_id
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def safety_model(request): def safety_shield(request):
if hasattr(request, "param"): if hasattr(request, "param"):
return request.param shield_id = request.param
return request.config.getoption("--safety-model", None) else:
shield_id = request.config.getoption("--safety-shield", None)
if shield_id == "bedrock":
shield_id = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER")
params = {"guardrailVersion": get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")}
else:
params = {}
return ShieldInput(
shield_id=shield_id,
params=params,
)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def safety_llama_guard(safety_model) -> ProviderFixture: def safety_llama_guard() -> ProviderFixture:
return ProviderFixture( return ProviderFixture(
providers=[ providers=[
Provider( Provider(
provider_id="inline::llama-guard", provider_id="llama-guard",
provider_type="inline::llama-guard", provider_type="inline::llama-guard",
config=LlamaGuardConfig().model_dump(), config=LlamaGuardConfig().model_dump(),
) )
@ -55,7 +74,7 @@ def safety_prompt_guard() -> ProviderFixture:
return ProviderFixture( return ProviderFixture(
providers=[ providers=[
Provider( Provider(
provider_id="inline::prompt-guard", provider_id="prompt-guard",
provider_type="inline::prompt-guard", provider_type="inline::prompt-guard",
config=PromptGuardConfig().model_dump(), config=PromptGuardConfig().model_dump(),
) )
@ -80,50 +99,25 @@ SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote"]
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")
async def safety_stack(inference_model, safety_model, request): async def safety_stack(inference_model, safety_shield, request):
# We need an inference + safety fixture to test safety # We need an inference + safety fixture to test safety
fixture_dict = request.param fixture_dict = request.param
inference_fixture = request.getfixturevalue(
f"inference_{fixture_dict['inference']}"
)
safety_fixture = request.getfixturevalue(f"safety_{fixture_dict['safety']}")
providers = { providers = {}
"inference": inference_fixture.providers,
"safety": safety_fixture.providers,
}
provider_data = {} provider_data = {}
if inference_fixture.provider_data: for key in ["inference", "safety"]:
provider_data.update(inference_fixture.provider_data) fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
if safety_fixture.provider_data: providers[key] = fixture.providers
provider_data.update(safety_fixture.provider_data) if fixture.provider_data:
provider_data.update(fixture.provider_data)
shield_provider_type = safety_fixture.providers[0].provider_type test_stack = await construct_stack_for_test(
shield_input = get_shield_to_register(shield_provider_type, safety_model)
print(f"inference_model: {inference_model}")
print(f"shield_input = {shield_input}")
impls = await resolve_impls_for_test_v2(
[Api.safety, Api.shields, Api.inference], [Api.safety, Api.shields, Api.inference],
providers, providers,
provider_data, provider_data,
models=[ModelInput(model_id=inference_model)], models=[ModelInput(model_id=inference_model)],
shields=[shield_input], shields=[safety_shield],
) )
shield = await impls[Api.shields].get_shield(shield_input.shield_id) shield = await test_stack.impls[Api.shields].get_shield(safety_shield.shield_id)
return impls[Api.safety], impls[Api.shields], shield return test_stack.impls[Api.safety], test_stack.impls[Api.shields], shield
def get_shield_to_register(provider_type: str, safety_model: str) -> ShieldInput:
if provider_type == "remote::bedrock":
identifier = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER")
params = {"guardrailVersion": get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")}
else:
params = {}
identifier = safety_model
return ShieldInput(
shield_id=identifier,
params=params,
)

View file

@ -18,13 +18,6 @@ from llama_stack.distribution.datatypes import * # noqa: F403
class TestSafety: class TestSafety:
@pytest.mark.asyncio
async def test_new_shield(self, safety_stack):
_, shields_impl, shield = safety_stack
assert shield is not None
assert shield.provider_resource_id == shield.identifier
assert shield.provider_id is not None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_shield_list(self, safety_stack): async def test_shield_list(self, safety_stack):
_, shields_impl, _ = safety_stack _, shields_impl, _ = safety_stack

View file

@ -11,7 +11,7 @@ from llama_stack.apis.models import ModelInput
from llama_stack.distribution.datatypes import Api, Provider from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture from ..conftest import ProviderFixture, remote_stack_fixture
@ -74,7 +74,7 @@ async def scoring_stack(request, inference_model):
if fixture.provider_data: if fixture.provider_data:
provider_data.update(fixture.provider_data) provider_data.update(fixture.provider_data)
impls = await resolve_impls_for_test_v2( test_stack = await construct_stack_for_test(
[Api.scoring, Api.datasetio, Api.inference], [Api.scoring, Api.datasetio, Api.inference],
providers, providers,
provider_data, provider_data,
@ -88,4 +88,4 @@ async def scoring_stack(request, inference_model):
], ],
) )
return impls return test_stack.impls