Kill the notion of a "remote" / "passthrough" provider

This commit is contained in:
Ashwin Bharambe 2024-11-12 15:30:59 -08:00
parent 59a65e34d3
commit 743da9690b
6 changed files with 95 additions and 87 deletions

View file

@ -9,7 +9,7 @@ from typing import Dict, List
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.providers.datatypes import Api, ProviderSpec, remote_provider_spec from llama_stack.providers.datatypes import Api, ProviderSpec
def stack_apis() -> List[Api]: def stack_apis() -> List[Api]:
@ -62,9 +62,6 @@ def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
for api in providable_apis(): for api in providable_apis():
name = api.name.lower() name = api.name.lower()
module = importlib.import_module(f"llama_stack.providers.registry.{name}") module = importlib.import_module(f"llama_stack.providers.registry.{name}")
ret[api] = { ret[api] = {a.provider_type: a for a in module.available_providers()}
"remote": remote_provider_spec(api),
**{a.provider_type: a for a in module.available_providers()},
}
return ret return ret

View file

@ -273,17 +273,8 @@ async def instantiate_provider(
config_type = instantiate_class_type(provider_spec.config_class) config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider.config) config = config_type(**provider.config)
if provider_spec.adapter: method = "get_adapter_impl"
method = "get_adapter_impl" args = [config, deps]
args = [config, deps]
else:
method = "get_client_impl"
protocol = protocols[provider_spec.api]
if provider_spec.api in additional_protocols:
_, additional_protocol = additional_protocols[provider_spec.api]
else:
additional_protocol = None
args = [protocol, additional_protocol, config, deps]
elif isinstance(provider_spec, AutoRoutedProviderSpec): elif isinstance(provider_spec, AutoRoutedProviderSpec):
method = "get_auto_router_impl" method = "get_auto_router_impl"

View file

@ -33,28 +33,83 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
api = get_impl_api(p) api = get_impl_api(p)
if obj.provider_id == "remote": is_remote = obj.provider_id == "remote"
# TODO: this is broken right now because we use the generic if is_remote:
# { identifier, provider_id, provider_resource_id } tuple here # TODO: these are incomplete fixes since (a) they are kind of adhoc and likely to break
# but the APIs expect things like ModelInput, ShieldInput, etc. # and (b) MemoryBankInput is missing BankParams
if isinstance(obj, Model):
# if this is just a passthrough, we want to let the remote obj = ModelInput(
# end actually do the registration with the correct provider model_id=obj.identifier,
obj = obj.model_copy(deep=True) metadata=obj.metadata,
obj.provider_id = "" provider_model_id=obj.provider_resource_id,
)
elif isinstance(obj, Shield):
obj = ShieldInput(
shield_id=obj.identifier,
params=obj.params,
provider_shield_id=obj.provider_resource_id,
)
elif isinstance(obj, MemoryBank):
# need to calculate params here
obj = MemoryBankInput(
memory_bank_id=obj.identifier,
provider_memory_bank_id=obj.provider_resource_id,
)
elif isinstance(obj, ScoringFn):
obj = ScoringFnInput(
scoring_fn_id=obj.identifier,
provider_scoring_fn_id=obj.provider_resource_id,
description=obj.description,
metadata=obj.metadata,
return_type=obj.return_type,
params=obj.params,
)
elif isinstance(obj, EvalTask):
obj = EvalTaskInput(
eval_task_id=obj.identifier,
provider_eval_task_id=obj.provider_resource_id,
dataset_id=obj.dataset_id,
scoring_function_id=obj.scoring_functions,
metadata=obj.metadata,
)
elif isinstance(obj, Dataset):
obj = DatasetInput(
dataset_id=obj.identifier,
provider_dataset_id=obj.provider_resource_id,
schema=obj.schema,
url=obj.url,
metadata=obj.metadata,
)
else:
raise ValueError(f"Unknown object type {type(obj)}")
if api == Api.inference: if api == Api.inference:
return await p.register_model(obj) return await p.register_model(obj)
elif api == Api.safety: elif api == Api.safety:
await p.register_shield(obj) if is_remote:
await p.register_shield(**obj.model_dump())
else:
await p.register_shield(obj)
elif api == Api.memory: elif api == Api.memory:
await p.register_memory_bank(obj) if is_remote:
await p.register_memory_bank(**obj.model_dump())
else:
await p.register_memory_bank(obj)
elif api == Api.datasetio: elif api == Api.datasetio:
await p.register_dataset(obj) if is_remote:
await p.register_dataset(**obj.model_dump())
else:
await p.register_dataset(obj)
elif api == Api.scoring: elif api == Api.scoring:
await p.register_scoring_function(obj) if is_remote:
await p.register_scoring_function(**obj.model_dump())
else:
await p.register_scoring_function(obj)
elif api == Api.eval: elif api == Api.eval:
await p.register_eval_task(obj) if is_remote:
await p.register_eval_task(**obj.model_dump())
else:
await p.register_eval_task(obj)
else: else:
raise ValueError(f"Unknown API {api} for registering object with provider") raise ValueError(f"Unknown API {api} for registering object with provider")

View file

@ -182,15 +182,6 @@ async def lifespan(app: FastAPI):
await impl.shutdown() await impl.shutdown()
def create_dynamic_passthrough(
downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None
):
async def endpoint(request: Request):
return await passthrough(request, downstream_url, downstream_headers)
return endpoint
def is_streaming_request(func_name: str, request: Request, **kwargs): def is_streaming_request(func_name: str, request: Request, **kwargs):
# TODO: pass the api method and punt it to the Protocol definition directly # TODO: pass the api method and punt it to the Protocol definition directly
return kwargs.get("stream", False) return kwargs.get("stream", False)
@ -305,28 +296,19 @@ def main(
endpoints = all_endpoints[api] endpoints = all_endpoints[api]
impl = impls[api] impl = impls[api]
if is_passthrough(impl.__provider_spec__): for endpoint in endpoints:
for endpoint in endpoints: if not hasattr(impl, endpoint.name):
url = impl.__provider_config__.url.rstrip("/") + endpoint.route # ideally this should be a typing violation already
getattr(app, endpoint.method)(endpoint.route)( raise ValueError(f"Could not find method {endpoint.name} on {impl}!!")
create_dynamic_passthrough(url)
)
else:
for endpoint in endpoints:
if not hasattr(impl, endpoint.name):
# ideally this should be a typing violation already
raise ValueError(
f"Could not find method {endpoint.name} on {impl}!!"
)
impl_method = getattr(impl, endpoint.name) impl_method = getattr(impl, endpoint.name)
getattr(app, endpoint.method)(endpoint.route, response_model=None)( getattr(app, endpoint.method)(endpoint.route, response_model=None)(
create_dynamic_typed_route( create_dynamic_typed_route(
impl_method, impl_method,
endpoint.method, endpoint.method,
)
) )
)
cprint(f"Serving API {api_str}", "white", attrs=["bold"]) cprint(f"Serving API {api_str}", "white", attrs=["bold"])
for endpoint in endpoints: for endpoint in endpoints:

View file

@ -99,6 +99,7 @@ class RoutingTable(Protocol):
def get_provider_impl(self, routing_key: str) -> Any: ... def get_provider_impl(self, routing_key: str) -> Any: ...
# TODO: this can now be inlined into RemoteProviderSpec
@json_schema_type @json_schema_type
class AdapterSpec(BaseModel): class AdapterSpec(BaseModel):
adapter_type: str = Field( adapter_type: str = Field(
@ -171,12 +172,10 @@ class RemoteProviderConfig(BaseModel):
@json_schema_type @json_schema_type
class RemoteProviderSpec(ProviderSpec): class RemoteProviderSpec(ProviderSpec):
adapter: Optional[AdapterSpec] = Field( adapter: AdapterSpec = Field(
default=None,
description=""" description="""
If some code is needed to convert the remote responses into Llama Stack compatible If some code is needed to convert the remote responses into Llama Stack compatible
API responses, specify the adapter here. If not specified, it indicates the remote API responses, specify the adapter here.
as being "Llama Stack compatible"
""", """,
) )
@ -186,38 +185,21 @@ as being "Llama Stack compatible"
@property @property
def module(self) -> str: def module(self) -> str:
if self.adapter: return self.adapter.module
return self.adapter.module
return "llama_stack.distribution.client"
@property @property
def pip_packages(self) -> List[str]: def pip_packages(self) -> List[str]:
if self.adapter: return self.adapter.pip_packages
return self.adapter.pip_packages
return []
@property @property
def provider_data_validator(self) -> Optional[str]: def provider_data_validator(self) -> Optional[str]:
if self.adapter: return self.adapter.provider_data_validator
return self.adapter.provider_data_validator
return None
def is_passthrough(spec: ProviderSpec) -> bool: def remote_provider_spec(api: Api, adapter: AdapterSpec) -> RemoteProviderSpec:
return isinstance(spec, RemoteProviderSpec) and spec.adapter is None
# Can avoid this by using Pydantic computed_field
def remote_provider_spec(
api: Api, adapter: Optional[AdapterSpec] = None
) -> RemoteProviderSpec:
config_class = (
adapter.config_class
if adapter and adapter.config_class
else "llama_stack.distribution.datatypes.RemoteProviderConfig"
)
provider_type = f"remote::{adapter.adapter_type}" if adapter else "remote"
return RemoteProviderSpec( return RemoteProviderSpec(
api=api, provider_type=provider_type, config_class=config_class, adapter=adapter api=api,
provider_type=f"remote::{adapter.adapter_type}",
config_class=adapter.config_class,
adapter=adapter,
) )

View file

@ -189,6 +189,7 @@ async def inference_stack(request, inference_model):
models=[ models=[
ModelInput( ModelInput(
model_id=inference_model, model_id=inference_model,
provider_id=inference_fixture.providers[0].provider_id,
) )
], ],
) )