From 0121114a5d79cb28550ff4bf18ee510956a3af44 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 12 Nov 2024 19:17:58 -0800 Subject: [PATCH] remote tests are functional! --- llama_stack/distribution/resolver.py | 31 ++++- .../distribution/routers/routing_tables.py | 3 - llama_stack/distribution/stack.py | 108 +++++------------- llama_stack/providers/tests/conftest.py | 4 +- llama_stack/providers/tests/resolver.py | 36 +++++- 5 files changed, 93 insertions(+), 89 deletions(-) diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 6f677f268..d00aedb5c 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -28,6 +28,7 @@ from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring_functions import ScoringFunctions from llama_stack.apis.shields import Shields from llama_stack.apis.telemetry import Telemetry +from llama_stack.distribution.client import get_client_impl from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.utils.dynamic import instantiate_class_type @@ -308,7 +309,7 @@ async def instantiate_provider( not isinstance(provider_spec, AutoRoutedProviderSpec) and provider_spec.api in additional_protocols ): - additional_api, _ = additional_protocols[provider_spec.api] + additional_api, _, _ = additional_protocols[provider_spec.api] check_protocol_compliance(impl, additional_api) return impl @@ -354,3 +355,31 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None: raise ValueError( f"Provider `{obj.__provider_id__} ({obj.__provider_spec__.api})` does not implement the following methods:\n{missing_methods}" ) + + +async def resolve_remote_stack_impls( + config: RemoteProviderConfig, + apis: List[str], +) -> Dict[Api, Any]: + protocols = api_protocol_map() + additional_protocols = additional_protocols_map() + + impls = {} + for api_str in apis: + api = Api(api_str) + impls[api] = await get_client_impl( + protocols[api], + None, + config, + {}, + ) + if api in additional_protocols: + _, additional_protocol, additional_api = additional_protocols[api] + impls[additional_api] = await get_client_impl( + additional_protocol, + None, + config, + {}, + ) + + return impls diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 393581b41..4bdeb608a 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -87,13 +87,10 @@ class CommonRoutingTableImpl(RoutingTable): p.model_store = self elif api == Api.safety: p.shield_store = self - elif api == Api.memory: p.memory_bank_store = self - elif api == Api.datasetio: p.dataset_store = self - elif api == Api.scoring: p.scoring_function_store = self scoring_functions = await p.list_scoring_functions() diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index 1aca27d99..6a80d7a48 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -28,16 +28,11 @@ from llama_stack.apis.shields import * # noqa: F403 from llama_stack.apis.inspect import * # noqa: F403 from llama_stack.apis.eval_tasks import * # noqa: F403 -from llama_stack.distribution.client import get_client_impl from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.distribution import get_provider_registry -from llama_stack.distribution.resolver import ( - additional_protocols_map, - api_protocol_map, - resolve_impls, -) +from llama_stack.distribution.resolver import resolve_impls from llama_stack.distribution.store.registry import create_dist_registry -from llama_stack.providers.datatypes import Api, RemoteProviderConfig +from llama_stack.providers.datatypes import Api class LlamaStack( @@ -63,31 +58,23 @@ class LlamaStack( pass -# Produces a stack of providers for the given run config. Not all APIs may be -# asked for in the run config. -async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]: - dist_registry, _ = await create_dist_registry( - run_config.metadata_store, run_config.image_name - ) +RESOURCES = [ + ("models", Api.models, "register_model", "list_models"), + ("shields", Api.shields, "register_shield", "list_shields"), + ("memory_banks", Api.memory_banks, "register_memory_bank", "list_memory_banks"), + ("datasets", Api.datasets, "register_dataset", "list_datasets"), + ( + "scoring_fns", + Api.scoring_functions, + "register_scoring_function", + "list_scoring_functions", + ), + ("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"), +] - impls = await maybe_get_remote_stack_impls(run_config) - if impls is None: - impls = await resolve_impls(run_config, get_provider_registry(), dist_registry) - resources = [ - ("models", Api.models, "register_model", "list_models"), - ("shields", Api.shields, "register_shield", "list_shields"), - ("memory_banks", Api.memory_banks, "register_memory_bank", "list_memory_banks"), - ("datasets", Api.datasets, "register_dataset", "list_datasets"), - ( - "scoring_fns", - Api.scoring_functions, - "register_scoring_function", - "list_scoring_functions", - ), - ("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"), - ] - for rsrc, api, register_method, list_method in resources: +async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]): + for rsrc, api, register_method, list_method in RESOURCES: objects = getattr(run_config, rsrc) if api not in impls: continue @@ -103,55 +90,14 @@ async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]: ) print("") + + +# Produces a stack of providers for the given run config. Not all APIs may be +# asked for in the run config. +async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]: + dist_registry, _ = await create_dist_registry( + run_config.metadata_store, run_config.image_name + ) + impls = await resolve_impls(run_config, get_provider_registry(), dist_registry) + await register_resources(run_config, impls) return impls - - -# NOTE: this code path is really for the tests so you can send HTTP requests -# to the remote stack without needing to use llama-stack-client -async def maybe_get_remote_stack_impls( - run_config: StackRunConfig, -) -> Optional[Dict[Api, Any]]: - remote_config = remote_provider_config(run_config) - if not remote_config: - return None - - protocols = api_protocol_map() - additional_protocols = additional_protocols_map() - - impls = {} - for api_str in run_config.apis: - api = Api(api_str) - impls[api] = await get_client_impl( - protocols[api], - None, - remote_config, - {}, - ) - if api in additional_protocols: - _, additional_protocol, additional_api = additional_protocols[api] - impls[additional_api] = await get_client_impl( - additional_protocol, - None, - remote_config, - {}, - ) - - return impls - - -def remote_provider_config( - run_config: StackRunConfig, -) -> Optional[RemoteProviderConfig]: - remote_config = None - has_non_remote = False - for api_providers in run_config.providers.values(): - for provider in api_providers: - if provider.provider_type == "remote": - remote_config = RemoteProviderConfig(**provider.config) - else: - has_non_remote = True - - if remote_config: - assert not has_non_remote, "Remote stack cannot have non-remote providers" - - return remote_config diff --git a/llama_stack/providers/tests/conftest.py b/llama_stack/providers/tests/conftest.py index 3bec2d11d..8b73500d0 100644 --- a/llama_stack/providers/tests/conftest.py +++ b/llama_stack/providers/tests/conftest.py @@ -35,8 +35,8 @@ def remote_stack_fixture() -> ProviderFixture: return ProviderFixture( providers=[ Provider( - provider_id="remote", - provider_type="remote", + provider_id="test::remote", + provider_type="test::remote", config=config.model_dump(), ) ], diff --git a/llama_stack/providers/tests/resolver.py b/llama_stack/providers/tests/resolver.py index 84f0520dc..46e0435af 100644 --- a/llama_stack/providers/tests/resolver.py +++ b/llama_stack/providers/tests/resolver.py @@ -17,11 +17,25 @@ from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.request_headers import set_request_provider_data -from llama_stack.distribution.resolver import resolve_impls +from llama_stack.distribution.resolver import resolve_impls, resolve_remote_stack_impls from llama_stack.distribution.stack import construct_stack from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig +async def construct_stack_for_test(run_config: StackRunConfig): + remote_config = remote_provider_config(run_config) + if not remote_config: + return await construct_stack(run_config) + + impls = await resolve_remote_stack_impls(remote_config, run_config.apis) + + # we don't register resources for a remote stack as part of the fixture setup + # because the stack is already "up". if a test needs to register resources, it + # can do so manually always. + + return impls + + async def resolve_impls_for_test_v2( apis: List[Api], providers: Dict[str, List[Provider]], @@ -49,7 +63,7 @@ async def resolve_impls_for_test_v2( ) run_config = parse_and_maybe_upgrade_config(run_config) try: - impls = await construct_stack(run_config) + impls = await construct_stack_for_test(run_config) except ModuleNotFoundError as e: print_pip_install_help(providers) raise e @@ -62,6 +76,24 @@ async def resolve_impls_for_test_v2( return impls +def remote_provider_config( + run_config: StackRunConfig, +) -> Optional[RemoteProviderConfig]: + remote_config = None + has_non_remote = False + for api_providers in run_config.providers.values(): + for provider in api_providers: + if provider.provider_type == "test::remote": + remote_config = RemoteProviderConfig(**provider.config) + else: + has_non_remote = True + + if remote_config: + assert not has_non_remote, "Remote stack cannot have non-remote providers" + + return remote_config + + async def resolve_impls_for_test(api: Api, deps: List[Api] = None): if "PROVIDER_CONFIG" not in os.environ: raise ValueError(