diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index 3fc3b2d5d..6fc4545c7 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -9,7 +9,7 @@ from typing import Dict, List from pydantic import BaseModel -from llama_stack.providers.datatypes import Api, ProviderSpec, remote_provider_spec +from llama_stack.providers.datatypes import Api, ProviderSpec def stack_apis() -> List[Api]: @@ -62,9 +62,6 @@ def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]: for api in providable_apis(): name = api.name.lower() module = importlib.import_module(f"llama_stack.providers.registry.{name}") - ret[api] = { - "remote": remote_provider_spec(api), - **{a.provider_type: a for a in module.available_providers()}, - } + ret[api] = {a.provider_type: a for a in module.available_providers()} return ret diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 4e7fa0102..9aa202fff 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -273,17 +273,8 @@ async def instantiate_provider( config_type = instantiate_class_type(provider_spec.config_class) config = config_type(**provider.config) - if provider_spec.adapter: - method = "get_adapter_impl" - args = [config, deps] - else: - method = "get_client_impl" - protocol = protocols[provider_spec.api] - if provider_spec.api in additional_protocols: - _, additional_protocol = additional_protocols[provider_spec.api] - else: - additional_protocol = None - args = [protocol, additional_protocol, config, deps] + method = "get_adapter_impl" + args = [config, deps] elif isinstance(provider_spec, AutoRoutedProviderSpec): method = "get_auto_router_impl" diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 249d3a144..3ae030554 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -33,28 +33,83 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable api = get_impl_api(p) - if obj.provider_id == "remote": - # TODO: this is broken right now because we use the generic - # { identifier, provider_id, provider_resource_id } tuple here - # but the APIs expect things like ModelInput, ShieldInput, etc. - - # if this is just a passthrough, we want to let the remote - # end actually do the registration with the correct provider - obj = obj.model_copy(deep=True) - obj.provider_id = "" + is_remote = obj.provider_id == "remote" + if is_remote: + # TODO: these are incomplete fixes since (a) they are kind of adhoc and likely to break + # and (b) MemoryBankInput is missing BankParams + if isinstance(obj, Model): + obj = ModelInput( + model_id=obj.identifier, + metadata=obj.metadata, + provider_model_id=obj.provider_resource_id, + ) + elif isinstance(obj, Shield): + obj = ShieldInput( + shield_id=obj.identifier, + params=obj.params, + provider_shield_id=obj.provider_resource_id, + ) + elif isinstance(obj, MemoryBank): + # need to calculate params here + obj = MemoryBankInput( + memory_bank_id=obj.identifier, + provider_memory_bank_id=obj.provider_resource_id, + ) + elif isinstance(obj, ScoringFn): + obj = ScoringFnInput( + scoring_fn_id=obj.identifier, + provider_scoring_fn_id=obj.provider_resource_id, + description=obj.description, + metadata=obj.metadata, + return_type=obj.return_type, + params=obj.params, + ) + elif isinstance(obj, EvalTask): + obj = EvalTaskInput( + eval_task_id=obj.identifier, + provider_eval_task_id=obj.provider_resource_id, + dataset_id=obj.dataset_id, + scoring_function_id=obj.scoring_functions, + metadata=obj.metadata, + ) + elif isinstance(obj, Dataset): + obj = DatasetInput( + dataset_id=obj.identifier, + provider_dataset_id=obj.provider_resource_id, + schema=obj.schema, + url=obj.url, + metadata=obj.metadata, + ) + else: + raise ValueError(f"Unknown object type {type(obj)}") if api == Api.inference: return await p.register_model(obj) elif api == Api.safety: - await p.register_shield(obj) + if is_remote: + await p.register_shield(**obj.model_dump()) + else: + await p.register_shield(obj) elif api == Api.memory: - await p.register_memory_bank(obj) + if is_remote: + await p.register_memory_bank(**obj.model_dump()) + else: + await p.register_memory_bank(obj) elif api == Api.datasetio: - await p.register_dataset(obj) + if is_remote: + await p.register_dataset(**obj.model_dump()) + else: + await p.register_dataset(obj) elif api == Api.scoring: - await p.register_scoring_function(obj) + if is_remote: + await p.register_scoring_function(**obj.model_dump()) + else: + await p.register_scoring_function(obj) elif api == Api.eval: - await p.register_eval_task(obj) + if is_remote: + await p.register_eval_task(**obj.model_dump()) + else: + await p.register_eval_task(obj) else: raise ValueError(f"Unknown API {api} for registering object with provider") diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index bb57e2cc8..05927eef5 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -182,15 +182,6 @@ async def lifespan(app: FastAPI): await impl.shutdown() -def create_dynamic_passthrough( - downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None -): - async def endpoint(request: Request): - return await passthrough(request, downstream_url, downstream_headers) - - return endpoint - - def is_streaming_request(func_name: str, request: Request, **kwargs): # TODO: pass the api method and punt it to the Protocol definition directly return kwargs.get("stream", False) @@ -305,28 +296,19 @@ def main( endpoints = all_endpoints[api] impl = impls[api] - if is_passthrough(impl.__provider_spec__): - for endpoint in endpoints: - url = impl.__provider_config__.url.rstrip("/") + endpoint.route - getattr(app, endpoint.method)(endpoint.route)( - create_dynamic_passthrough(url) - ) - else: - for endpoint in endpoints: - if not hasattr(impl, endpoint.name): - # ideally this should be a typing violation already - raise ValueError( - f"Could not find method {endpoint.name} on {impl}!!" - ) + for endpoint in endpoints: + if not hasattr(impl, endpoint.name): + # ideally this should be a typing violation already + raise ValueError(f"Could not find method {endpoint.name} on {impl}!!") - impl_method = getattr(impl, endpoint.name) + impl_method = getattr(impl, endpoint.name) - getattr(app, endpoint.method)(endpoint.route, response_model=None)( - create_dynamic_typed_route( - impl_method, - endpoint.method, - ) + getattr(app, endpoint.method)(endpoint.route, response_model=None)( + create_dynamic_typed_route( + impl_method, + endpoint.method, ) + ) cprint(f"Serving API {api_str}", "white", attrs=["bold"]) for endpoint in endpoints: diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 5a259ae2d..51ff163ab 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -99,6 +99,7 @@ class RoutingTable(Protocol): def get_provider_impl(self, routing_key: str) -> Any: ... +# TODO: this can now be inlined into RemoteProviderSpec @json_schema_type class AdapterSpec(BaseModel): adapter_type: str = Field( @@ -171,12 +172,10 @@ class RemoteProviderConfig(BaseModel): @json_schema_type class RemoteProviderSpec(ProviderSpec): - adapter: Optional[AdapterSpec] = Field( - default=None, + adapter: AdapterSpec = Field( description=""" If some code is needed to convert the remote responses into Llama Stack compatible -API responses, specify the adapter here. If not specified, it indicates the remote -as being "Llama Stack compatible" +API responses, specify the adapter here. """, ) @@ -186,38 +185,21 @@ as being "Llama Stack compatible" @property def module(self) -> str: - if self.adapter: - return self.adapter.module - return "llama_stack.distribution.client" + return self.adapter.module @property def pip_packages(self) -> List[str]: - if self.adapter: - return self.adapter.pip_packages - return [] + return self.adapter.pip_packages @property def provider_data_validator(self) -> Optional[str]: - if self.adapter: - return self.adapter.provider_data_validator - return None + return self.adapter.provider_data_validator -def is_passthrough(spec: ProviderSpec) -> bool: - return isinstance(spec, RemoteProviderSpec) and spec.adapter is None - - -# Can avoid this by using Pydantic computed_field -def remote_provider_spec( - api: Api, adapter: Optional[AdapterSpec] = None -) -> RemoteProviderSpec: - config_class = ( - adapter.config_class - if adapter and adapter.config_class - else "llama_stack.distribution.datatypes.RemoteProviderConfig" - ) - provider_type = f"remote::{adapter.adapter_type}" if adapter else "remote" - +def remote_provider_spec(api: Api, adapter: AdapterSpec) -> RemoteProviderSpec: return RemoteProviderSpec( - api=api, provider_type=provider_type, config_class=config_class, adapter=adapter + api=api, + provider_type=f"remote::{adapter.adapter_type}", + config_class=adapter.config_class, + adapter=adapter, ) diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index f6f2a30e8..7db21ac2a 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -189,6 +189,7 @@ async def inference_stack(request, inference_model): models=[ ModelInput( model_id=inference_model, + provider_id=inference_fixture.providers[0].provider_id, ) ], )