diff --git a/llama_stack/distribution/client.py b/llama_stack/distribution/client.py index ce788a713..b36ef94e4 100644 --- a/llama_stack/distribution/client.py +++ b/llama_stack/distribution/client.py @@ -20,21 +20,17 @@ from llama_stack.providers.datatypes import RemoteProviderConfig _CLIENT_CLASSES = {} -async def get_client_impl( - protocol, additional_protocol, config: RemoteProviderConfig, _deps: Any -): - client_class = create_api_client_class(protocol, additional_protocol) +async def get_client_impl(protocol, config: RemoteProviderConfig, _deps: Any): + client_class = create_api_client_class(protocol) impl = client_class(config.url) await impl.initialize() return impl -def create_api_client_class(protocol, additional_protocol) -> Type: +def create_api_client_class(protocol) -> Type: if protocol in _CLIENT_CLASSES: return _CLIENT_CLASSES[protocol] - protocols = [protocol, additional_protocol] if additional_protocol else [protocol] - class APIClient: def __init__(self, base_url: str): print(f"({protocol.__name__}) Connecting to {base_url}") @@ -42,11 +38,10 @@ def create_api_client_class(protocol, additional_protocol) -> Type: self.routes = {} # Store routes for this protocol - for p in protocols: - for name, method in inspect.getmembers(p): - if hasattr(method, "__webmethod__"): - sig = inspect.signature(method) - self.routes[name] = (method.__webmethod__, sig) + for name, method in inspect.getmembers(protocol): + if hasattr(method, "__webmethod__"): + sig = inspect.signature(method) + self.routes[name] = (method.__webmethod__, sig) async def initialize(self): pass @@ -160,17 +155,16 @@ def create_api_client_class(protocol, additional_protocol) -> Type: return ret # Add protocol methods to the wrapper - for p in protocols: - for name, method in inspect.getmembers(p): - if hasattr(method, "__webmethod__"): + for name, method in inspect.getmembers(protocol): + if hasattr(method, "__webmethod__"): - async def method_impl(self, *args, method_name=name, **kwargs): - return await self.__acall__(method_name, *args, **kwargs) + async def method_impl(self, *args, method_name=name, **kwargs): + return await self.__acall__(method_name, *args, **kwargs) - method_impl.__name__ = name - method_impl.__qualname__ = f"APIClient.{name}" - method_impl.__signature__ = inspect.signature(method) - setattr(APIClient, name, method_impl) + method_impl.__name__ = name + method_impl.__qualname__ = f"APIClient.{name}" + method_impl.__signature__ = inspect.signature(method) + setattr(APIClient, name, method_impl) # Name the class after the protocol APIClient.__name__ = f"{protocol.__name__}Client" diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index 3fc3b2d5d..6fc4545c7 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -9,7 +9,7 @@ from typing import Dict, List from pydantic import BaseModel -from llama_stack.providers.datatypes import Api, ProviderSpec, remote_provider_spec +from llama_stack.providers.datatypes import Api, ProviderSpec def stack_apis() -> List[Api]: @@ -62,9 +62,6 @@ def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]: for api in providable_apis(): name = api.name.lower() module = importlib.import_module(f"llama_stack.providers.registry.{name}") - ret[api] = { - "remote": remote_provider_spec(api), - **{a.provider_type: a for a in module.available_providers()}, - } + ret[api] = {a.provider_type: a for a in module.available_providers()} return ret diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 4e7fa0102..4c74b0d1f 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -28,6 +28,7 @@ from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring_functions import ScoringFunctions from llama_stack.apis.shields import Shields from llama_stack.apis.telemetry import Telemetry +from llama_stack.distribution.client import get_client_impl from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.utils.dynamic import instantiate_class_type @@ -59,12 +60,16 @@ def api_protocol_map() -> Dict[Api, Any]: def additional_protocols_map() -> Dict[Api, Any]: return { - Api.inference: (ModelsProtocolPrivate, Models), - Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks), - Api.safety: (ShieldsProtocolPrivate, Shields), - Api.datasetio: (DatasetsProtocolPrivate, Datasets), - Api.scoring: (ScoringFunctionsProtocolPrivate, ScoringFunctions), - Api.eval_tasks: (EvalTasksProtocolPrivate, EvalTasks), + Api.inference: (ModelsProtocolPrivate, Models, Api.models), + Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks, Api.memory_banks), + Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields), + Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets), + Api.scoring: ( + ScoringFunctionsProtocolPrivate, + ScoringFunctions, + Api.scoring_functions, + ), + Api.eval: (EvalTasksProtocolPrivate, EvalTasks, Api.eval_tasks), } @@ -73,10 +78,13 @@ class ProviderWithSpec(Provider): spec: ProviderSpec +ProviderRegistry = Dict[Api, Dict[str, ProviderSpec]] + + # TODO: this code is not very straightforward to follow and needs one more round of refactoring async def resolve_impls( run_config: StackRunConfig, - provider_registry: Dict[Api, Dict[str, ProviderSpec]], + provider_registry: ProviderRegistry, dist_registry: DistributionRegistry, ) -> Dict[Api, Any]: """ @@ -273,17 +281,8 @@ async def instantiate_provider( config_type = instantiate_class_type(provider_spec.config_class) config = config_type(**provider.config) - if provider_spec.adapter: - method = "get_adapter_impl" - args = [config, deps] - else: - method = "get_client_impl" - protocol = protocols[provider_spec.api] - if provider_spec.api in additional_protocols: - _, additional_protocol = additional_protocols[provider_spec.api] - else: - additional_protocol = None - args = [protocol, additional_protocol, config, deps] + method = "get_adapter_impl" + args = [config, deps] elif isinstance(provider_spec, AutoRoutedProviderSpec): method = "get_auto_router_impl" @@ -313,7 +312,7 @@ async def instantiate_provider( not isinstance(provider_spec, AutoRoutedProviderSpec) and provider_spec.api in additional_protocols ): - additional_api, _ = additional_protocols[provider_spec.api] + additional_api, _, _ = additional_protocols[provider_spec.api] check_protocol_compliance(impl, additional_api) return impl @@ -359,3 +358,29 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None: raise ValueError( f"Provider `{obj.__provider_id__} ({obj.__provider_spec__.api})` does not implement the following methods:\n{missing_methods}" ) + + +async def resolve_remote_stack_impls( + config: RemoteProviderConfig, + apis: List[str], +) -> Dict[Api, Any]: + protocols = api_protocol_map() + additional_protocols = additional_protocols_map() + + impls = {} + for api_str in apis: + api = Api(api_str) + impls[api] = await get_client_impl( + protocols[api], + config, + {}, + ) + if api in additional_protocols: + _, additional_protocol, additional_api = additional_protocols[api] + impls[additional_api] = await get_client_impl( + additional_protocol, + config, + {}, + ) + + return impls diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 249d3a144..5342728b1 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -33,28 +33,20 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable api = get_impl_api(p) - if obj.provider_id == "remote": - # TODO: this is broken right now because we use the generic - # { identifier, provider_id, provider_resource_id } tuple here - # but the APIs expect things like ModelInput, ShieldInput, etc. - - # if this is just a passthrough, we want to let the remote - # end actually do the registration with the correct provider - obj = obj.model_copy(deep=True) - obj.provider_id = "" + assert obj.provider_id != "remote", "Remote provider should not be registered" if api == Api.inference: return await p.register_model(obj) elif api == Api.safety: - await p.register_shield(obj) + return await p.register_shield(obj) elif api == Api.memory: - await p.register_memory_bank(obj) + return await p.register_memory_bank(obj) elif api == Api.datasetio: - await p.register_dataset(obj) + return await p.register_dataset(obj) elif api == Api.scoring: - await p.register_scoring_function(obj) + return await p.register_scoring_function(obj) elif api == Api.eval: - await p.register_eval_task(obj) + return await p.register_eval_task(obj) else: raise ValueError(f"Unknown API {api} for registering object with provider") @@ -82,15 +74,10 @@ class CommonRoutingTableImpl(RoutingTable): if cls is None: obj.provider_id = provider_id else: - if provider_id == "remote": - # if this is just a passthrough, we got the *WithProvider object - # so we should just override the provider in-place - obj.provider_id = provider_id - else: - # Create a copy of the model data and explicitly set provider_id - model_data = obj.model_dump() - model_data["provider_id"] = provider_id - obj = cls(**model_data) + # Create a copy of the model data and explicitly set provider_id + model_data = obj.model_dump() + model_data["provider_id"] = provider_id + obj = cls(**model_data) await self.dist_registry.register(obj) # Register all objects from providers @@ -100,18 +87,14 @@ class CommonRoutingTableImpl(RoutingTable): p.model_store = self elif api == Api.safety: p.shield_store = self - elif api == Api.memory: p.memory_bank_store = self - elif api == Api.datasetio: p.dataset_store = self - elif api == Api.scoring: p.scoring_function_store = self scoring_functions = await p.list_scoring_functions() await add_objects(scoring_functions, pid, ScoringFn) - elif api == Api.eval: p.eval_task_store = self diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index bb57e2cc8..05927eef5 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -182,15 +182,6 @@ async def lifespan(app: FastAPI): await impl.shutdown() -def create_dynamic_passthrough( - downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None -): - async def endpoint(request: Request): - return await passthrough(request, downstream_url, downstream_headers) - - return endpoint - - def is_streaming_request(func_name: str, request: Request, **kwargs): # TODO: pass the api method and punt it to the Protocol definition directly return kwargs.get("stream", False) @@ -305,28 +296,19 @@ def main( endpoints = all_endpoints[api] impl = impls[api] - if is_passthrough(impl.__provider_spec__): - for endpoint in endpoints: - url = impl.__provider_config__.url.rstrip("/") + endpoint.route - getattr(app, endpoint.method)(endpoint.route)( - create_dynamic_passthrough(url) - ) - else: - for endpoint in endpoints: - if not hasattr(impl, endpoint.name): - # ideally this should be a typing violation already - raise ValueError( - f"Could not find method {endpoint.name} on {impl}!!" - ) + for endpoint in endpoints: + if not hasattr(impl, endpoint.name): + # ideally this should be a typing violation already + raise ValueError(f"Could not find method {endpoint.name} on {impl}!!") - impl_method = getattr(impl, endpoint.name) + impl_method = getattr(impl, endpoint.name) - getattr(app, endpoint.method)(endpoint.route, response_model=None)( - create_dynamic_typed_route( - impl_method, - endpoint.method, - ) + getattr(app, endpoint.method)(endpoint.route, response_model=None)( + create_dynamic_typed_route( + impl_method, + endpoint.method, ) + ) cprint(f"Serving API {api_str}", "white", attrs=["bold"]) for endpoint in endpoints: diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index 1c7325eee..1cffd7749 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -30,7 +30,7 @@ from llama_stack.apis.eval_tasks import * # noqa: F403 from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.distribution import get_provider_registry -from llama_stack.distribution.resolver import resolve_impls +from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls from llama_stack.distribution.store.registry import create_dist_registry from llama_stack.providers.datatypes import Api @@ -58,29 +58,23 @@ class LlamaStack( pass -# Produces a stack of providers for the given run config. Not all APIs may be -# asked for in the run config. -async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]: - dist_registry, _ = await create_dist_registry( - run_config.metadata_store, run_config.image_name - ) +RESOURCES = [ + ("models", Api.models, "register_model", "list_models"), + ("shields", Api.shields, "register_shield", "list_shields"), + ("memory_banks", Api.memory_banks, "register_memory_bank", "list_memory_banks"), + ("datasets", Api.datasets, "register_dataset", "list_datasets"), + ( + "scoring_fns", + Api.scoring_functions, + "register_scoring_function", + "list_scoring_functions", + ), + ("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"), +] - impls = await resolve_impls(run_config, get_provider_registry(), dist_registry) - resources = [ - ("models", Api.models, "register_model", "list_models"), - ("shields", Api.shields, "register_shield", "list_shields"), - ("memory_banks", Api.memory_banks, "register_memory_bank", "list_memory_banks"), - ("datasets", Api.datasets, "register_dataset", "list_datasets"), - ( - "scoring_fns", - Api.scoring_functions, - "register_scoring_function", - "list_scoring_functions", - ), - ("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"), - ] - for rsrc, api, register_method, list_method in resources: +async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]): + for rsrc, api, register_method, list_method in RESOURCES: objects = getattr(run_config, rsrc) if api not in impls: continue @@ -96,4 +90,18 @@ async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]: ) print("") + + +# Produces a stack of providers for the given run config. Not all APIs may be +# asked for in the run config. +async def construct_stack( + run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None +) -> Dict[Api, Any]: + dist_registry, _ = await create_dist_registry( + run_config.metadata_store, run_config.image_name + ) + impls = await resolve_impls( + run_config, provider_registry or get_provider_registry(), dist_registry + ) + await register_resources(run_config, impls) return impls diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 5a259ae2d..51ff163ab 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -99,6 +99,7 @@ class RoutingTable(Protocol): def get_provider_impl(self, routing_key: str) -> Any: ... +# TODO: this can now be inlined into RemoteProviderSpec @json_schema_type class AdapterSpec(BaseModel): adapter_type: str = Field( @@ -171,12 +172,10 @@ class RemoteProviderConfig(BaseModel): @json_schema_type class RemoteProviderSpec(ProviderSpec): - adapter: Optional[AdapterSpec] = Field( - default=None, + adapter: AdapterSpec = Field( description=""" If some code is needed to convert the remote responses into Llama Stack compatible -API responses, specify the adapter here. If not specified, it indicates the remote -as being "Llama Stack compatible" +API responses, specify the adapter here. """, ) @@ -186,38 +185,21 @@ as being "Llama Stack compatible" @property def module(self) -> str: - if self.adapter: - return self.adapter.module - return "llama_stack.distribution.client" + return self.adapter.module @property def pip_packages(self) -> List[str]: - if self.adapter: - return self.adapter.pip_packages - return [] + return self.adapter.pip_packages @property def provider_data_validator(self) -> Optional[str]: - if self.adapter: - return self.adapter.provider_data_validator - return None + return self.adapter.provider_data_validator -def is_passthrough(spec: ProviderSpec) -> bool: - return isinstance(spec, RemoteProviderSpec) and spec.adapter is None - - -# Can avoid this by using Pydantic computed_field -def remote_provider_spec( - api: Api, adapter: Optional[AdapterSpec] = None -) -> RemoteProviderSpec: - config_class = ( - adapter.config_class - if adapter and adapter.config_class - else "llama_stack.distribution.datatypes.RemoteProviderConfig" - ) - provider_type = f"remote::{adapter.adapter_type}" if adapter else "remote" - +def remote_provider_spec(api: Api, adapter: AdapterSpec) -> RemoteProviderSpec: return RemoteProviderSpec( - api=api, provider_type=provider_type, config_class=config_class, adapter=adapter + api=api, + provider_type=f"remote::{adapter.adapter_type}", + config_class=adapter.config_class, + adapter=adapter, ) diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index 494c1b43e..9950064a4 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -234,7 +234,7 @@ class LlamaGuardShield: # TODO: llama-stack inference protocol has issues with non-streaming inference code content = "" async for chunk in await self.inference_api.chat_completion( - model=self.model, + model_id=self.model, messages=[shield_input_message], stream=True, ): diff --git a/llama_stack/providers/registry/memory.py b/llama_stack/providers/registry/memory.py index 0b98f3368..ff0926108 100644 --- a/llama_stack/providers/registry/memory.py +++ b/llama_stack/providers/registry/memory.py @@ -53,6 +53,7 @@ def available_providers() -> List[ProviderSpec]: adapter_type="chromadb", pip_packages=EMBEDDING_DEPS + ["chromadb-client"], module="llama_stack.providers.remote.memory.chroma", + config_class="llama_stack.distribution.datatypes.RemoteProviderConfig", ), ), remote_provider_spec( diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 99f74572e..3a32125b2 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -164,7 +164,6 @@ class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPriva logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: model = await self.model_store.get_model(model_id) - print(f"model={model}") request = ChatCompletionRequest( model=model.provider_resource_id, messages=messages, diff --git a/llama_stack/providers/tests/agents/conftest.py b/llama_stack/providers/tests/agents/conftest.py index aa3910b39..6ce7913d7 100644 --- a/llama_stack/providers/tests/agents/conftest.py +++ b/llama_stack/providers/tests/agents/conftest.py @@ -10,7 +10,7 @@ from ..conftest import get_provider_fixture_overrides from ..inference.fixtures import INFERENCE_FIXTURES from ..memory.fixtures import MEMORY_FIXTURES -from ..safety.fixtures import SAFETY_FIXTURES +from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield from .fixtures import AGENTS_FIXTURES @@ -46,6 +46,16 @@ DEFAULT_PROVIDER_COMBINATIONS = [ id="together", marks=pytest.mark.together, ), + pytest.param( + { + "inference": "fireworks", + "safety": "llama_guard", + "memory": "faiss", + "agents": "meta_reference", + }, + id="fireworks", + marks=pytest.mark.fireworks, + ), pytest.param( { "inference": "remote", @@ -60,7 +70,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ def pytest_configure(config): - for mark in ["meta_reference", "ollama", "together", "remote"]: + for mark in ["meta_reference", "ollama", "together", "fireworks", "remote"]: config.addinivalue_line( "markers", f"{mark}: marks tests as {mark} specific", @@ -75,28 +85,30 @@ def pytest_addoption(parser): help="Specify the inference model to use for testing", ) parser.addoption( - "--safety-model", + "--safety-shield", action="store", default="Llama-Guard-3-8B", - help="Specify the safety model to use for testing", + help="Specify the safety shield to use for testing", ) def pytest_generate_tests(metafunc): - safety_model = metafunc.config.getoption("--safety-model") - if "safety_model" in metafunc.fixturenames: + shield_id = metafunc.config.getoption("--safety-shield") + if "safety_shield" in metafunc.fixturenames: metafunc.parametrize( - "safety_model", - [pytest.param(safety_model, id="")], + "safety_shield", + [pytest.param(shield_id, id="")], indirect=True, ) if "inference_model" in metafunc.fixturenames: inference_model = metafunc.config.getoption("--inference-model") - models = list(set({inference_model, safety_model})) + models = set({inference_model}) + if safety_model := safety_model_from_shield(shield_id): + models.add(safety_model) metafunc.parametrize( "inference_model", - [pytest.param(models, id="")], + [pytest.param(list(models), id="")], indirect=True, ) if "agents_stack" in metafunc.fixturenames: diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index db157174f..1f89b909a 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -16,10 +16,9 @@ from llama_stack.providers.inline.agents.meta_reference import ( MetaReferenceAgentsImplConfig, ) -from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 +from llama_stack.providers.tests.resolver import construct_stack_for_test from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig from ..conftest import ProviderFixture, remote_stack_fixture -from ..safety.fixtures import get_shield_to_register def pick_inference_model(inference_model): @@ -60,7 +59,7 @@ AGENTS_FIXTURES = ["meta_reference", "remote"] @pytest_asyncio.fixture(scope="session") -async def agents_stack(request, inference_model, safety_model): +async def agents_stack(request, inference_model, safety_shield): fixture_dict = request.param providers = {} @@ -71,13 +70,10 @@ async def agents_stack(request, inference_model, safety_model): if fixture.provider_data: provider_data.update(fixture.provider_data) - shield_input = get_shield_to_register( - providers["safety"][0].provider_type, safety_model - ) inference_models = ( inference_model if isinstance(inference_model, list) else [inference_model] ) - impls = await resolve_impls_for_test_v2( + test_stack = await construct_stack_for_test( [Api.agents, Api.inference, Api.safety, Api.memory], providers, provider_data, @@ -87,6 +83,6 @@ async def agents_stack(request, inference_model, safety_model): ) for model in inference_models ], - shields=[shield_input], + shields=[safety_shield], ) - return impls[Api.agents], impls[Api.memory] + return test_stack diff --git a/llama_stack/providers/tests/agents/test_agent_persistence.py b/llama_stack/providers/tests/agents/test_agent_persistence.py deleted file mode 100644 index a15887b33..000000000 --- a/llama_stack/providers/tests/agents/test_agent_persistence.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import pytest -import pytest_asyncio - -from llama_stack.apis.agents import * # noqa: F403 -from llama_stack.providers.tests.resolver import resolve_impls_for_test -from llama_stack.providers.datatypes import * # noqa: F403 - -from dotenv import load_dotenv - -from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig - -# How to run this test: -# -# 1. Ensure you have a conda environment with the right dependencies installed. -# This includes `pytest` and `pytest-asyncio`. -# -# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing. -# -# 3. Run: -# -# ```bash -# PROVIDER_ID= \ -# PROVIDER_CONFIG=provider_config.yaml \ -# pytest -s llama_stack/providers/tests/agents/test_agent_persistence.py \ -# --tb=short --disable-warnings -# ``` - -load_dotenv() - - -@pytest_asyncio.fixture(scope="session") -async def agents_settings(): - impls = await resolve_impls_for_test( - Api.agents, deps=[Api.inference, Api.memory, Api.safety] - ) - - return { - "impl": impls[Api.agents], - "memory_impl": impls[Api.memory], - "common_params": { - "model": "Llama3.1-8B-Instruct", - "instructions": "You are a helpful assistant.", - }, - } - - -@pytest.fixture -def sample_messages(): - return [ - UserMessage(content="What's the weather like today?"), - ] - - -@pytest.mark.asyncio -async def test_delete_agents_and_sessions(agents_settings, sample_messages): - agents_impl = agents_settings["impl"] - # First, create an agent - agent_config = AgentConfig( - model=agents_settings["common_params"]["model"], - instructions=agents_settings["common_params"]["instructions"], - enable_session_persistence=True, - sampling_params=SamplingParams(temperature=0.7, top_p=0.95), - input_shields=[], - output_shields=[], - tools=[], - max_infer_iters=5, - ) - - create_response = await agents_impl.create_agent(agent_config) - agent_id = create_response.agent_id - - # Create a session - session_create_response = await agents_impl.create_agent_session( - agent_id, "Test Session" - ) - session_id = session_create_response.session_id - persistence_store = await kvstore_impl(agents_settings["persistence"]) - - await agents_impl.delete_agents_session(agent_id, session_id) - session_response = await persistence_store.get(f"session:{agent_id}:{session_id}") - - await agents_impl.delete_agents(agent_id) - agent_response = await persistence_store.get(f"agent:{agent_id}") - - assert session_response is None - assert agent_response is None - - -async def test_get_agent_turns_and_steps(agents_settings, sample_messages): - agents_impl = agents_settings["impl"] - - # First, create an agent - agent_config = AgentConfig( - model=agents_settings["common_params"]["model"], - instructions=agents_settings["common_params"]["instructions"], - enable_session_persistence=True, - sampling_params=SamplingParams(temperature=0.7, top_p=0.95), - input_shields=[], - output_shields=[], - tools=[], - max_infer_iters=5, - ) - - create_response = await agents_impl.create_agent(agent_config) - agent_id = create_response.agent_id - - # Create a session - session_create_response = await agents_impl.create_agent_session( - agent_id, "Test Session" - ) - session_id = session_create_response.session_id - - # Create and execute a turn - turn_request = dict( - agent_id=agent_id, - session_id=session_id, - messages=sample_messages, - stream=True, - ) - - turn_response = [ - chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) - ] - - final_event = turn_response[-1].event.payload - turn_id = final_event.turn.turn_id - persistence_store = await kvstore_impl(SqliteKVStoreConfig()) - turn = await persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}") - response = await agents_impl.get_agents_turn(agent_id, session_id, turn_id) - - assert isinstance(response, Turn) - assert response == final_event.turn - assert turn == final_event.turn - - steps = final_event.turn.steps - step_id = steps[0].step_id - step_response = await agents_impl.get_agents_step( - agent_id, session_id, turn_id, step_id - ) - - assert isinstance(step_response.step, Step) - assert step_response.step == steps[0] diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 47e5a751f..60c047058 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -17,6 +17,7 @@ from llama_stack.providers.datatypes import * # noqa: F403 # -m "meta_reference" from .fixtures import pick_inference_model +from .utils import create_agent_session @pytest.fixture @@ -67,31 +68,19 @@ def query_attachment_messages(): ] -async def create_agent_session(agents_impl, agent_config): - create_response = await agents_impl.create_agent(agent_config) - agent_id = create_response.agent_id - - # Create a session - session_create_response = await agents_impl.create_agent_session( - agent_id, "Test Session" - ) - session_id = session_create_response.session_id - return agent_id, session_id - - class TestAgents: @pytest.mark.asyncio async def test_agent_turns_with_safety( - self, safety_model, agents_stack, common_params + self, safety_shield, agents_stack, common_params ): - agents_impl, _ = agents_stack + agents_impl = agents_stack.impls[Api.agents] agent_id, session_id = await create_agent_session( agents_impl, AgentConfig( **{ **common_params, - "input_shields": [safety_model], - "output_shields": [safety_model], + "input_shields": [safety_shield.shield_id], + "output_shields": [safety_shield.shield_id], } ), ) @@ -127,7 +116,7 @@ class TestAgents: async def test_create_agent_turn( self, agents_stack, sample_messages, common_params ): - agents_impl, _ = agents_stack + agents_impl = agents_stack.impls[Api.agents] agent_id, session_id = await create_agent_session( agents_impl, AgentConfig(**common_params) @@ -158,7 +147,7 @@ class TestAgents: query_attachment_messages, common_params, ): - agents_impl, _ = agents_stack + agents_impl = agents_stack.impls[Api.agents] urls = [ "memory_optimizations.rst", "chat.rst", @@ -226,7 +215,7 @@ class TestAgents: async def test_create_agent_turn_with_brave_search( self, agents_stack, search_query_messages, common_params ): - agents_impl, _ = agents_stack + agents_impl = agents_stack.impls[Api.agents] if "BRAVE_SEARCH_API_KEY" not in os.environ: pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test") diff --git a/llama_stack/providers/tests/agents/test_persistence.py b/llama_stack/providers/tests/agents/test_persistence.py new file mode 100644 index 000000000..97094cd7a --- /dev/null +++ b/llama_stack/providers/tests/agents/test_persistence.py @@ -0,0 +1,122 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +from llama_stack.apis.agents import * # noqa: F403 +from llama_stack.providers.datatypes import * # noqa: F403 + +from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig +from .fixtures import pick_inference_model + +from .utils import create_agent_session + + +@pytest.fixture +def sample_messages(): + return [ + UserMessage(content="What's the weather like today?"), + ] + + +@pytest.fixture +def common_params(inference_model): + inference_model = pick_inference_model(inference_model) + + return dict( + model=inference_model, + instructions="You are a helpful assistant.", + enable_session_persistence=True, + sampling_params=SamplingParams(temperature=0.7, top_p=0.95), + input_shields=[], + output_shields=[], + tools=[], + max_infer_iters=5, + ) + + +class TestAgentPersistence: + @pytest.mark.asyncio + async def test_delete_agents_and_sessions(self, agents_stack, common_params): + agents_impl = agents_stack.impls[Api.agents] + agent_id, session_id = await create_agent_session( + agents_impl, + AgentConfig( + **{ + **common_params, + "input_shields": [], + "output_shields": [], + } + ), + ) + + run_config = agents_stack.run_config + provider_config = run_config.providers["agents"][0].config + persistence_store = await kvstore_impl( + SqliteKVStoreConfig(**provider_config["persistence_store"]) + ) + + await agents_impl.delete_agents_session(agent_id, session_id) + session_response = await persistence_store.get( + f"session:{agent_id}:{session_id}" + ) + + await agents_impl.delete_agents(agent_id) + agent_response = await persistence_store.get(f"agent:{agent_id}") + + assert session_response is None + assert agent_response is None + + @pytest.mark.asyncio + async def test_get_agent_turns_and_steps( + self, agents_stack, sample_messages, common_params + ): + agents_impl = agents_stack.impls[Api.agents] + + agent_id, session_id = await create_agent_session( + agents_impl, + AgentConfig( + **{ + **common_params, + "input_shields": [], + "output_shields": [], + } + ), + ) + + # Create and execute a turn + turn_request = dict( + agent_id=agent_id, + session_id=session_id, + messages=sample_messages, + stream=True, + ) + + turn_response = [ + chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) + ] + + final_event = turn_response[-1].event.payload + turn_id = final_event.turn.turn_id + + provider_config = agents_stack.run_config.providers["agents"][0].config + persistence_store = await kvstore_impl( + SqliteKVStoreConfig(**provider_config["persistence_store"]) + ) + turn = await persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}") + response = await agents_impl.get_agents_turn(agent_id, session_id, turn_id) + + assert isinstance(response, Turn) + assert response == final_event.turn + assert turn == final_event.turn.model_dump_json() + + steps = final_event.turn.steps + step_id = steps[0].step_id + step_response = await agents_impl.get_agents_step( + agent_id, session_id, turn_id, step_id + ) + + assert step_response.step == steps[0] diff --git a/llama_stack/providers/tests/agents/utils.py b/llama_stack/providers/tests/agents/utils.py new file mode 100644 index 000000000..048877991 --- /dev/null +++ b/llama_stack/providers/tests/agents/utils.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +async def create_agent_session(agents_impl, agent_config): + create_response = await agents_impl.create_agent(agent_config) + agent_id = create_response.agent_id + + # Create a session + session_create_response = await agents_impl.create_agent_session( + agent_id, "Test Session" + ) + session_id = session_create_response.session_id + return agent_id, session_id diff --git a/llama_stack/providers/tests/conftest.py b/llama_stack/providers/tests/conftest.py index 3bec2d11d..8b73500d0 100644 --- a/llama_stack/providers/tests/conftest.py +++ b/llama_stack/providers/tests/conftest.py @@ -35,8 +35,8 @@ def remote_stack_fixture() -> ProviderFixture: return ProviderFixture( providers=[ Provider( - provider_id="remote", - provider_type="remote", + provider_id="test::remote", + provider_type="test::remote", config=config.model_dump(), ) ], diff --git a/llama_stack/providers/tests/datasetio/fixtures.py b/llama_stack/providers/tests/datasetio/fixtures.py index 6f20bf96a..60f89de46 100644 --- a/llama_stack/providers/tests/datasetio/fixtures.py +++ b/llama_stack/providers/tests/datasetio/fixtures.py @@ -9,7 +9,7 @@ import pytest_asyncio from llama_stack.distribution.datatypes import Api, Provider -from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 +from llama_stack.providers.tests.resolver import construct_stack_for_test from ..conftest import ProviderFixture, remote_stack_fixture @@ -52,10 +52,10 @@ async def datasetio_stack(request): fixture_name = request.param fixture = request.getfixturevalue(f"datasetio_{fixture_name}") - impls = await resolve_impls_for_test_v2( + test_stack = await construct_stack_for_test( [Api.datasetio], {"datasetio": fixture.providers}, fixture.provider_data, ) - return impls[Api.datasetio], impls[Api.datasets] + return test_stack.impls[Api.datasetio], test_stack.impls[Api.datasets] diff --git a/llama_stack/providers/tests/eval/fixtures.py b/llama_stack/providers/tests/eval/fixtures.py index 4a359213b..a6b404d0c 100644 --- a/llama_stack/providers/tests/eval/fixtures.py +++ b/llama_stack/providers/tests/eval/fixtures.py @@ -9,7 +9,7 @@ import pytest_asyncio from llama_stack.distribution.datatypes import Api, Provider -from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 +from llama_stack.providers.tests.resolver import construct_stack_for_test from ..conftest import ProviderFixture, remote_stack_fixture @@ -46,10 +46,10 @@ async def eval_stack(request): if fixture.provider_data: provider_data.update(fixture.provider_data) - impls = await resolve_impls_for_test_v2( + test_stack = await construct_stack_for_test( [Api.eval, Api.datasetio, Api.inference, Api.scoring], providers, provider_data, ) - return impls + return test_stack.impls diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index f6f2a30e8..a53ddf639 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -21,7 +21,7 @@ from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig from llama_stack.providers.remote.inference.together import TogetherImplConfig from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig -from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 +from llama_stack.providers.tests.resolver import construct_stack_for_test from ..conftest import ProviderFixture, remote_stack_fixture from ..env import get_env_or_fail @@ -182,15 +182,11 @@ INFERENCE_FIXTURES = [ async def inference_stack(request, inference_model): fixture_name = request.param inference_fixture = request.getfixturevalue(f"inference_{fixture_name}") - impls = await resolve_impls_for_test_v2( + test_stack = await construct_stack_for_test( [Api.inference], {"inference": inference_fixture.providers}, inference_fixture.provider_data, - models=[ - ModelInput( - model_id=inference_model, - ) - ], + models=[ModelInput(model_id=inference_model)], ) - return (impls[Api.inference], impls[Api.models]) + return test_stack.impls[Api.inference], test_stack.impls[Api.models] diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 70047a61f..7b7aca5bd 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -147,9 +147,9 @@ class TestInference: user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003." response = await inference_impl.completion( + model_id=inference_model, content=user_input, stream=False, - model=inference_model, sampling_params=SamplingParams( max_tokens=50, ), diff --git a/llama_stack/providers/tests/inference/test_vision_inference.py b/llama_stack/providers/tests/inference/test_vision_inference.py index 3e785b757..c5db04cca 100644 --- a/llama_stack/providers/tests/inference/test_vision_inference.py +++ b/llama_stack/providers/tests/inference/test_vision_inference.py @@ -55,7 +55,7 @@ class TestVisionModelInference: ) response = await inference_impl.chat_completion( - model=inference_model, + model_id=inference_model, messages=[ UserMessage(content="You are a helpful assistant."), UserMessage(content=[image, "Describe this image in two sentences."]), @@ -102,7 +102,7 @@ class TestVisionModelInference: response = [ r async for r in await inference_impl.chat_completion( - model=inference_model, + model_id=inference_model, messages=[ UserMessage(content="You are a helpful assistant."), UserMessage( diff --git a/llama_stack/providers/tests/memory/fixtures.py b/llama_stack/providers/tests/memory/fixtures.py index 456e354b2..c9559b61c 100644 --- a/llama_stack/providers/tests/memory/fixtures.py +++ b/llama_stack/providers/tests/memory/fixtures.py @@ -14,7 +14,7 @@ from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConf from llama_stack.providers.inline.memory.faiss import FaissImplConfig from llama_stack.providers.remote.memory.pgvector import PGVectorConfig from llama_stack.providers.remote.memory.weaviate import WeaviateConfig -from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 +from llama_stack.providers.tests.resolver import construct_stack_for_test from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig from ..conftest import ProviderFixture, remote_stack_fixture from ..env import get_env_or_fail @@ -101,10 +101,10 @@ async def memory_stack(request): fixture_name = request.param fixture = request.getfixturevalue(f"memory_{fixture_name}") - impls = await resolve_impls_for_test_v2( + test_stack = await construct_stack_for_test( [Api.memory], {"memory": fixture.providers}, fixture.provider_data, ) - return impls[Api.memory], impls[Api.memory_banks] + return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks] diff --git a/llama_stack/providers/tests/resolver.py b/llama_stack/providers/tests/resolver.py index 1353fc71b..df927926e 100644 --- a/llama_stack/providers/tests/resolver.py +++ b/llama_stack/providers/tests/resolver.py @@ -5,33 +5,36 @@ # the root directory of this source tree. import json -import os import tempfile from datetime import datetime from typing import Any, Dict, List, Optional -import yaml - from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.request_headers import set_request_provider_data +from llama_stack.distribution.resolver import resolve_remote_stack_impls from llama_stack.distribution.stack import construct_stack from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig -async def resolve_impls_for_test_v2( +class TestStack(BaseModel): + impls: Dict[Api, Any] + run_config: StackRunConfig + + +async def construct_stack_for_test( apis: List[Api], providers: Dict[str, List[Provider]], provider_data: Optional[Dict[str, Any]] = None, - models: Optional[List[Model]] = None, - shields: Optional[List[Shield]] = None, - memory_banks: Optional[List[MemoryBank]] = None, - datasets: Optional[List[Dataset]] = None, - scoring_fns: Optional[List[ScoringFn]] = None, - eval_tasks: Optional[List[EvalTask]] = None, -): + models: Optional[List[ModelInput]] = None, + shields: Optional[List[ShieldInput]] = None, + memory_banks: Optional[List[MemoryBankInput]] = None, + datasets: Optional[List[DatasetInput]] = None, + scoring_fns: Optional[List[ScoringFnInput]] = None, + eval_tasks: Optional[List[EvalTaskInput]] = None, +) -> TestStack: sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db") run_config = dict( built_at=datetime.now(), @@ -48,7 +51,18 @@ async def resolve_impls_for_test_v2( ) run_config = parse_and_maybe_upgrade_config(run_config) try: - impls = await construct_stack(run_config) + remote_config = remote_provider_config(run_config) + if not remote_config: + # TODO: add to provider registry by creating interesting mocks or fakes + impls = await construct_stack(run_config, get_provider_registry()) + else: + # we don't register resources for a remote stack as part of the fixture setup + # because the stack is already "up". if a test needs to register resources, it + # can do so manually always. + + impls = await resolve_remote_stack_impls(remote_config, run_config.apis) + + test_stack = TestStack(impls=impls, run_config=run_config) except ModuleNotFoundError as e: print_pip_install_help(providers) raise e @@ -58,91 +72,22 @@ async def resolve_impls_for_test_v2( {"X-LlamaStack-ProviderData": json.dumps(provider_data)} ) - return impls + return test_stack -async def resolve_impls_for_test(api: Api, deps: List[Api] = None): - if "PROVIDER_CONFIG" not in os.environ: - raise ValueError( - "You must set PROVIDER_CONFIG to a YAML file containing provider config" - ) +def remote_provider_config( + run_config: StackRunConfig, +) -> Optional[RemoteProviderConfig]: + remote_config = None + has_non_remote = False + for api_providers in run_config.providers.values(): + for provider in api_providers: + if provider.provider_type == "test::remote": + remote_config = RemoteProviderConfig(**provider.config) + else: + has_non_remote = True - with open(os.environ["PROVIDER_CONFIG"], "r") as f: - config_dict = yaml.safe_load(f) + if remote_config: + assert not has_non_remote, "Remote stack cannot have non-remote providers" - providers = read_providers(api, config_dict) - - chosen = choose_providers(providers, api, deps) - run_config = dict( - built_at=datetime.now(), - image_name="test-fixture", - apis=[api] + (deps or []), - providers=chosen, - ) - run_config = parse_and_maybe_upgrade_config(run_config) - try: - impls = await resolve_impls(run_config, get_provider_registry()) - except ModuleNotFoundError as e: - print_pip_install_help(providers) - raise e - - if "provider_data" in config_dict: - provider_id = chosen[api.value][0].provider_id - provider_data = config_dict["provider_data"].get(provider_id, {}) - if provider_data: - set_request_provider_data( - {"X-LlamaStack-ProviderData": json.dumps(provider_data)} - ) - - return impls - - -def read_providers(api: Api, config_dict: Dict[str, Any]) -> Dict[str, Any]: - if "providers" not in config_dict: - raise ValueError("Config file should contain a `providers` key") - - providers = config_dict["providers"] - if isinstance(providers, dict): - return providers - elif isinstance(providers, list): - return { - api.value: providers, - } - else: - raise ValueError( - "Config file should contain a list of providers or dict(api to providers)" - ) - - -def choose_providers( - providers: Dict[str, Any], api: Api, deps: List[Api] = None -) -> Dict[str, Provider]: - chosen = {} - if api.value not in providers: - raise ValueError(f"No providers found for `{api}`?") - chosen[api.value] = [pick_provider(api, providers[api.value], "PROVIDER_ID")] - - for dep in deps or []: - if dep.value not in providers: - raise ValueError(f"No providers specified for `{dep}` in config?") - chosen[dep.value] = [Provider(**x) for x in providers[dep.value]] - - return chosen - - -def pick_provider(api: Api, providers: List[Any], key: str) -> Provider: - providers_by_id = {x["provider_id"]: x for x in providers} - if len(providers_by_id) == 0: - raise ValueError(f"No providers found for `{api}` in config file") - - if key in os.environ: - provider_id = os.environ[key] - if provider_id not in providers_by_id: - raise ValueError(f"Provider ID {provider_id} not found in config file") - provider = providers_by_id[provider_id] - else: - provider = list(providers_by_id.values())[0] - provider_id = provider["provider_id"] - print(f"No provider ID specified, picking first `{provider_id}`") - - return Provider(**provider) + return remote_config diff --git a/llama_stack/providers/tests/safety/conftest.py b/llama_stack/providers/tests/safety/conftest.py index cb380ce57..76eb418ea 100644 --- a/llama_stack/providers/tests/safety/conftest.py +++ b/llama_stack/providers/tests/safety/conftest.py @@ -66,14 +66,14 @@ def pytest_configure(config): def pytest_addoption(parser): parser.addoption( - "--safety-model", + "--safety-shield", action="store", default=None, - help="Specify the safety model to use for testing", + help="Specify the safety shield to use for testing", ) -SAFETY_MODEL_PARAMS = [ +SAFETY_SHIELD_PARAMS = [ pytest.param("Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"), ] @@ -83,13 +83,13 @@ def pytest_generate_tests(metafunc): # But a user can also pass in a custom combination via the CLI by doing # `--providers inference=together,safety=meta_reference` - if "safety_model" in metafunc.fixturenames: - model = metafunc.config.getoption("--safety-model") - if model: - params = [pytest.param(model, id="")] + if "safety_shield" in metafunc.fixturenames: + shield_id = metafunc.config.getoption("--safety-shield") + if shield_id: + params = [pytest.param(shield_id, id="")] else: - params = SAFETY_MODEL_PARAMS - for fixture in ["inference_model", "safety_model"]: + params = SAFETY_SHIELD_PARAMS + for fixture in ["inference_model", "safety_shield"]: metafunc.parametrize( fixture, params, diff --git a/llama_stack/providers/tests/safety/fixtures.py b/llama_stack/providers/tests/safety/fixtures.py index b73c2d798..a706316dd 100644 --- a/llama_stack/providers/tests/safety/fixtures.py +++ b/llama_stack/providers/tests/safety/fixtures.py @@ -16,7 +16,7 @@ from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig -from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 +from llama_stack.providers.tests.resolver import construct_stack_for_test from ..conftest import ProviderFixture, remote_stack_fixture from ..env import get_env_or_fail @@ -27,19 +27,38 @@ def safety_remote() -> ProviderFixture: return remote_stack_fixture() +def safety_model_from_shield(shield_id): + if shield_id in ("Bedrock", "CodeScanner", "CodeShield"): + return None + + return shield_id + + @pytest.fixture(scope="session") -def safety_model(request): +def safety_shield(request): if hasattr(request, "param"): - return request.param - return request.config.getoption("--safety-model", None) + shield_id = request.param + else: + shield_id = request.config.getoption("--safety-shield", None) + + if shield_id == "bedrock": + shield_id = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER") + params = {"guardrailVersion": get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")} + else: + params = {} + + return ShieldInput( + shield_id=shield_id, + params=params, + ) @pytest.fixture(scope="session") -def safety_llama_guard(safety_model) -> ProviderFixture: +def safety_llama_guard() -> ProviderFixture: return ProviderFixture( providers=[ Provider( - provider_id="inline::llama-guard", + provider_id="llama-guard", provider_type="inline::llama-guard", config=LlamaGuardConfig().model_dump(), ) @@ -55,7 +74,7 @@ def safety_prompt_guard() -> ProviderFixture: return ProviderFixture( providers=[ Provider( - provider_id="inline::prompt-guard", + provider_id="prompt-guard", provider_type="inline::prompt-guard", config=PromptGuardConfig().model_dump(), ) @@ -80,50 +99,25 @@ SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote"] @pytest_asyncio.fixture(scope="session") -async def safety_stack(inference_model, safety_model, request): +async def safety_stack(inference_model, safety_shield, request): # We need an inference + safety fixture to test safety fixture_dict = request.param - inference_fixture = request.getfixturevalue( - f"inference_{fixture_dict['inference']}" - ) - safety_fixture = request.getfixturevalue(f"safety_{fixture_dict['safety']}") - providers = { - "inference": inference_fixture.providers, - "safety": safety_fixture.providers, - } + providers = {} provider_data = {} - if inference_fixture.provider_data: - provider_data.update(inference_fixture.provider_data) - if safety_fixture.provider_data: - provider_data.update(safety_fixture.provider_data) + for key in ["inference", "safety"]: + fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") + providers[key] = fixture.providers + if fixture.provider_data: + provider_data.update(fixture.provider_data) - shield_provider_type = safety_fixture.providers[0].provider_type - shield_input = get_shield_to_register(shield_provider_type, safety_model) - - print(f"inference_model: {inference_model}") - print(f"shield_input = {shield_input}") - impls = await resolve_impls_for_test_v2( + test_stack = await construct_stack_for_test( [Api.safety, Api.shields, Api.inference], providers, provider_data, models=[ModelInput(model_id=inference_model)], - shields=[shield_input], + shields=[safety_shield], ) - shield = await impls[Api.shields].get_shield(shield_input.shield_id) - return impls[Api.safety], impls[Api.shields], shield - - -def get_shield_to_register(provider_type: str, safety_model: str) -> ShieldInput: - if provider_type == "remote::bedrock": - identifier = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER") - params = {"guardrailVersion": get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")} - else: - params = {} - identifier = safety_model - - return ShieldInput( - shield_id=identifier, - params=params, - ) + shield = await test_stack.impls[Api.shields].get_shield(safety_shield.shield_id) + return test_stack.impls[Api.safety], test_stack.impls[Api.shields], shield diff --git a/llama_stack/providers/tests/safety/test_safety.py b/llama_stack/providers/tests/safety/test_safety.py index 9daa7bf40..2b3e2d2f5 100644 --- a/llama_stack/providers/tests/safety/test_safety.py +++ b/llama_stack/providers/tests/safety/test_safety.py @@ -18,13 +18,6 @@ from llama_stack.distribution.datatypes import * # noqa: F403 class TestSafety: - @pytest.mark.asyncio - async def test_new_shield(self, safety_stack): - _, shields_impl, shield = safety_stack - assert shield is not None - assert shield.provider_resource_id == shield.identifier - assert shield.provider_id is not None - @pytest.mark.asyncio async def test_shield_list(self, safety_stack): _, shields_impl, _ = safety_stack diff --git a/llama_stack/providers/tests/scoring/fixtures.py b/llama_stack/providers/tests/scoring/fixtures.py index ee6999043..d89b211ef 100644 --- a/llama_stack/providers/tests/scoring/fixtures.py +++ b/llama_stack/providers/tests/scoring/fixtures.py @@ -11,7 +11,7 @@ from llama_stack.apis.models import ModelInput from llama_stack.distribution.datatypes import Api, Provider -from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 +from llama_stack.providers.tests.resolver import construct_stack_for_test from ..conftest import ProviderFixture, remote_stack_fixture @@ -74,7 +74,7 @@ async def scoring_stack(request, inference_model): if fixture.provider_data: provider_data.update(fixture.provider_data) - impls = await resolve_impls_for_test_v2( + test_stack = await construct_stack_for_test( [Api.scoring, Api.datasetio, Api.inference], providers, provider_data, @@ -88,4 +88,4 @@ async def scoring_stack(request, inference_model): ], ) - return impls + return test_stack.impls