diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 9aa202fff..6f677f268 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -59,12 +59,16 @@ def api_protocol_map() -> Dict[Api, Any]: def additional_protocols_map() -> Dict[Api, Any]: return { - Api.inference: (ModelsProtocolPrivate, Models), - Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks), - Api.safety: (ShieldsProtocolPrivate, Shields), - Api.datasetio: (DatasetsProtocolPrivate, Datasets), - Api.scoring: (ScoringFunctionsProtocolPrivate, ScoringFunctions), - Api.eval_tasks: (EvalTasksProtocolPrivate, EvalTasks), + Api.inference: (ModelsProtocolPrivate, Models, Api.models), + Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks, Api.memory_banks), + Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields), + Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets), + Api.scoring: ( + ScoringFunctionsProtocolPrivate, + ScoringFunctions, + Api.scoring_functions, + ), + Api.eval: (EvalTasksProtocolPrivate, EvalTasks, Api.eval_tasks), } diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 3ae030554..393581b41 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -33,83 +33,20 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable api = get_impl_api(p) - is_remote = obj.provider_id == "remote" - if is_remote: - # TODO: these are incomplete fixes since (a) they are kind of adhoc and likely to break - # and (b) MemoryBankInput is missing BankParams - if isinstance(obj, Model): - obj = ModelInput( - model_id=obj.identifier, - metadata=obj.metadata, - provider_model_id=obj.provider_resource_id, - ) - elif isinstance(obj, Shield): - obj = ShieldInput( - shield_id=obj.identifier, - params=obj.params, - provider_shield_id=obj.provider_resource_id, - ) - elif isinstance(obj, MemoryBank): - # need to calculate params here - obj = MemoryBankInput( - memory_bank_id=obj.identifier, - provider_memory_bank_id=obj.provider_resource_id, - ) - elif isinstance(obj, ScoringFn): - obj = ScoringFnInput( - scoring_fn_id=obj.identifier, - provider_scoring_fn_id=obj.provider_resource_id, - description=obj.description, - metadata=obj.metadata, - return_type=obj.return_type, - params=obj.params, - ) - elif isinstance(obj, EvalTask): - obj = EvalTaskInput( - eval_task_id=obj.identifier, - provider_eval_task_id=obj.provider_resource_id, - dataset_id=obj.dataset_id, - scoring_function_id=obj.scoring_functions, - metadata=obj.metadata, - ) - elif isinstance(obj, Dataset): - obj = DatasetInput( - dataset_id=obj.identifier, - provider_dataset_id=obj.provider_resource_id, - schema=obj.schema, - url=obj.url, - metadata=obj.metadata, - ) - else: - raise ValueError(f"Unknown object type {type(obj)}") + assert obj.provider_id != "remote", "Remote provider should not be registered" if api == Api.inference: return await p.register_model(obj) elif api == Api.safety: - if is_remote: - await p.register_shield(**obj.model_dump()) - else: - await p.register_shield(obj) + await p.register_shield(**obj.model_dump()) elif api == Api.memory: - if is_remote: - await p.register_memory_bank(**obj.model_dump()) - else: - await p.register_memory_bank(obj) + await p.register_memory_bank(**obj.model_dump()) elif api == Api.datasetio: - if is_remote: - await p.register_dataset(**obj.model_dump()) - else: - await p.register_dataset(obj) + await p.register_dataset(**obj.model_dump()) elif api == Api.scoring: - if is_remote: - await p.register_scoring_function(**obj.model_dump()) - else: - await p.register_scoring_function(obj) + await p.register_scoring_function(**obj.model_dump()) elif api == Api.eval: - if is_remote: - await p.register_eval_task(**obj.model_dump()) - else: - await p.register_eval_task(obj) + await p.register_eval_task(**obj.model_dump()) else: raise ValueError(f"Unknown API {api} for registering object with provider") @@ -137,15 +74,10 @@ class CommonRoutingTableImpl(RoutingTable): if cls is None: obj.provider_id = provider_id else: - if provider_id == "remote": - # if this is just a passthrough, we got the *WithProvider object - # so we should just override the provider in-place - obj.provider_id = provider_id - else: - # Create a copy of the model data and explicitly set provider_id - model_data = obj.model_dump() - model_data["provider_id"] = provider_id - obj = cls(**model_data) + # Create a copy of the model data and explicitly set provider_id + model_data = obj.model_dump() + model_data["provider_id"] = provider_id + obj = cls(**model_data) await self.dist_registry.register(obj) # Register all objects from providers diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index 1c7325eee..1aca27d99 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -28,11 +28,16 @@ from llama_stack.apis.shields import * # noqa: F403 from llama_stack.apis.inspect import * # noqa: F403 from llama_stack.apis.eval_tasks import * # noqa: F403 +from llama_stack.distribution.client import get_client_impl from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.distribution import get_provider_registry -from llama_stack.distribution.resolver import resolve_impls +from llama_stack.distribution.resolver import ( + additional_protocols_map, + api_protocol_map, + resolve_impls, +) from llama_stack.distribution.store.registry import create_dist_registry -from llama_stack.providers.datatypes import Api +from llama_stack.providers.datatypes import Api, RemoteProviderConfig class LlamaStack( @@ -65,7 +70,9 @@ async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]: run_config.metadata_store, run_config.image_name ) - impls = await resolve_impls(run_config, get_provider_registry(), dist_registry) + impls = await maybe_get_remote_stack_impls(run_config) + if impls is None: + impls = await resolve_impls(run_config, get_provider_registry(), dist_registry) resources = [ ("models", Api.models, "register_model", "list_models"), @@ -97,3 +104,54 @@ async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]: print("") return impls + + +# NOTE: this code path is really for the tests so you can send HTTP requests +# to the remote stack without needing to use llama-stack-client +async def maybe_get_remote_stack_impls( + run_config: StackRunConfig, +) -> Optional[Dict[Api, Any]]: + remote_config = remote_provider_config(run_config) + if not remote_config: + return None + + protocols = api_protocol_map() + additional_protocols = additional_protocols_map() + + impls = {} + for api_str in run_config.apis: + api = Api(api_str) + impls[api] = await get_client_impl( + protocols[api], + None, + remote_config, + {}, + ) + if api in additional_protocols: + _, additional_protocol, additional_api = additional_protocols[api] + impls[additional_api] = await get_client_impl( + additional_protocol, + None, + remote_config, + {}, + ) + + return impls + + +def remote_provider_config( + run_config: StackRunConfig, +) -> Optional[RemoteProviderConfig]: + remote_config = None + has_non_remote = False + for api_providers in run_config.providers.values(): + for provider in api_providers: + if provider.provider_type == "remote": + remote_config = RemoteProviderConfig(**provider.config) + else: + has_non_remote = True + + if remote_config: + assert not has_non_remote, "Remote stack cannot have non-remote providers" + + return remote_config diff --git a/llama_stack/providers/registry/memory.py b/llama_stack/providers/registry/memory.py index 0b98f3368..ff0926108 100644 --- a/llama_stack/providers/registry/memory.py +++ b/llama_stack/providers/registry/memory.py @@ -53,6 +53,7 @@ def available_providers() -> List[ProviderSpec]: adapter_type="chromadb", pip_packages=EMBEDDING_DEPS + ["chromadb-client"], module="llama_stack.providers.remote.memory.chroma", + config_class="llama_stack.distribution.datatypes.RemoteProviderConfig", ), ), remote_provider_spec( diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index c58741f62..322cae798 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -85,4 +85,5 @@ async def agents_stack(request, inference_model, safety_shield): ], shields=[safety_shield], ) + return impls[Api.agents], impls[Api.memory] diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index 7db21ac2a..01d7e4892 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -186,12 +186,7 @@ async def inference_stack(request, inference_model): [Api.inference], {"inference": inference_fixture.providers}, inference_fixture.provider_data, - models=[ - ModelInput( - model_id=inference_model, - provider_id=inference_fixture.providers[0].provider_id, - ) - ], + models=[ModelInput(model_id=inference_model)], ) return (impls[Api.inference], impls[Api.models]) diff --git a/llama_stack/providers/tests/resolver.py b/llama_stack/providers/tests/resolver.py index 1353fc71b..84f0520dc 100644 --- a/llama_stack/providers/tests/resolver.py +++ b/llama_stack/providers/tests/resolver.py @@ -17,6 +17,7 @@ from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.request_headers import set_request_provider_data +from llama_stack.distribution.resolver import resolve_impls from llama_stack.distribution.stack import construct_stack from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig @@ -25,12 +26,12 @@ async def resolve_impls_for_test_v2( apis: List[Api], providers: Dict[str, List[Provider]], provider_data: Optional[Dict[str, Any]] = None, - models: Optional[List[Model]] = None, - shields: Optional[List[Shield]] = None, - memory_banks: Optional[List[MemoryBank]] = None, - datasets: Optional[List[Dataset]] = None, - scoring_fns: Optional[List[ScoringFn]] = None, - eval_tasks: Optional[List[EvalTask]] = None, + models: Optional[List[ModelInput]] = None, + shields: Optional[List[ShieldInput]] = None, + memory_banks: Optional[List[MemoryBankInput]] = None, + datasets: Optional[List[DatasetInput]] = None, + scoring_fns: Optional[List[ScoringFnInput]] = None, + eval_tasks: Optional[List[EvalTaskInput]] = None, ): sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db") run_config = dict(