Kill the notion of a "remote" / "passthrough" provider

This commit is contained in:
Ashwin Bharambe 2024-11-12 15:30:59 -08:00
parent 59a65e34d3
commit 743da9690b
6 changed files with 95 additions and 87 deletions

View file

@ -9,7 +9,7 @@ from typing import Dict, List
from pydantic import BaseModel
from llama_stack.providers.datatypes import Api, ProviderSpec, remote_provider_spec
from llama_stack.providers.datatypes import Api, ProviderSpec
def stack_apis() -> List[Api]:
@ -62,9 +62,6 @@ def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
for api in providable_apis():
name = api.name.lower()
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
ret[api] = {
"remote": remote_provider_spec(api),
**{a.provider_type: a for a in module.available_providers()},
}
ret[api] = {a.provider_type: a for a in module.available_providers()}
return ret

View file

@ -273,17 +273,8 @@ async def instantiate_provider(
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider.config)
if provider_spec.adapter:
method = "get_adapter_impl"
args = [config, deps]
else:
method = "get_client_impl"
protocol = protocols[provider_spec.api]
if provider_spec.api in additional_protocols:
_, additional_protocol = additional_protocols[provider_spec.api]
else:
additional_protocol = None
args = [protocol, additional_protocol, config, deps]
method = "get_adapter_impl"
args = [config, deps]
elif isinstance(provider_spec, AutoRoutedProviderSpec):
method = "get_auto_router_impl"

View file

@ -33,28 +33,83 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
api = get_impl_api(p)
if obj.provider_id == "remote":
# TODO: this is broken right now because we use the generic
# { identifier, provider_id, provider_resource_id } tuple here
# but the APIs expect things like ModelInput, ShieldInput, etc.
# if this is just a passthrough, we want to let the remote
# end actually do the registration with the correct provider
obj = obj.model_copy(deep=True)
obj.provider_id = ""
is_remote = obj.provider_id == "remote"
if is_remote:
# TODO: these are incomplete fixes since (a) they are kind of adhoc and likely to break
# and (b) MemoryBankInput is missing BankParams
if isinstance(obj, Model):
obj = ModelInput(
model_id=obj.identifier,
metadata=obj.metadata,
provider_model_id=obj.provider_resource_id,
)
elif isinstance(obj, Shield):
obj = ShieldInput(
shield_id=obj.identifier,
params=obj.params,
provider_shield_id=obj.provider_resource_id,
)
elif isinstance(obj, MemoryBank):
# need to calculate params here
obj = MemoryBankInput(
memory_bank_id=obj.identifier,
provider_memory_bank_id=obj.provider_resource_id,
)
elif isinstance(obj, ScoringFn):
obj = ScoringFnInput(
scoring_fn_id=obj.identifier,
provider_scoring_fn_id=obj.provider_resource_id,
description=obj.description,
metadata=obj.metadata,
return_type=obj.return_type,
params=obj.params,
)
elif isinstance(obj, EvalTask):
obj = EvalTaskInput(
eval_task_id=obj.identifier,
provider_eval_task_id=obj.provider_resource_id,
dataset_id=obj.dataset_id,
scoring_function_id=obj.scoring_functions,
metadata=obj.metadata,
)
elif isinstance(obj, Dataset):
obj = DatasetInput(
dataset_id=obj.identifier,
provider_dataset_id=obj.provider_resource_id,
schema=obj.schema,
url=obj.url,
metadata=obj.metadata,
)
else:
raise ValueError(f"Unknown object type {type(obj)}")
if api == Api.inference:
return await p.register_model(obj)
elif api == Api.safety:
await p.register_shield(obj)
if is_remote:
await p.register_shield(**obj.model_dump())
else:
await p.register_shield(obj)
elif api == Api.memory:
await p.register_memory_bank(obj)
if is_remote:
await p.register_memory_bank(**obj.model_dump())
else:
await p.register_memory_bank(obj)
elif api == Api.datasetio:
await p.register_dataset(obj)
if is_remote:
await p.register_dataset(**obj.model_dump())
else:
await p.register_dataset(obj)
elif api == Api.scoring:
await p.register_scoring_function(obj)
if is_remote:
await p.register_scoring_function(**obj.model_dump())
else:
await p.register_scoring_function(obj)
elif api == Api.eval:
await p.register_eval_task(obj)
if is_remote:
await p.register_eval_task(**obj.model_dump())
else:
await p.register_eval_task(obj)
else:
raise ValueError(f"Unknown API {api} for registering object with provider")

View file

@ -182,15 +182,6 @@ async def lifespan(app: FastAPI):
await impl.shutdown()
def create_dynamic_passthrough(
downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None
):
async def endpoint(request: Request):
return await passthrough(request, downstream_url, downstream_headers)
return endpoint
def is_streaming_request(func_name: str, request: Request, **kwargs):
# TODO: pass the api method and punt it to the Protocol definition directly
return kwargs.get("stream", False)
@ -305,28 +296,19 @@ def main(
endpoints = all_endpoints[api]
impl = impls[api]
if is_passthrough(impl.__provider_spec__):
for endpoint in endpoints:
url = impl.__provider_config__.url.rstrip("/") + endpoint.route
getattr(app, endpoint.method)(endpoint.route)(
create_dynamic_passthrough(url)
)
else:
for endpoint in endpoints:
if not hasattr(impl, endpoint.name):
# ideally this should be a typing violation already
raise ValueError(
f"Could not find method {endpoint.name} on {impl}!!"
)
for endpoint in endpoints:
if not hasattr(impl, endpoint.name):
# ideally this should be a typing violation already
raise ValueError(f"Could not find method {endpoint.name} on {impl}!!")
impl_method = getattr(impl, endpoint.name)
impl_method = getattr(impl, endpoint.name)
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
create_dynamic_typed_route(
impl_method,
endpoint.method,
)
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
create_dynamic_typed_route(
impl_method,
endpoint.method,
)
)
cprint(f"Serving API {api_str}", "white", attrs=["bold"])
for endpoint in endpoints:

View file

@ -99,6 +99,7 @@ class RoutingTable(Protocol):
def get_provider_impl(self, routing_key: str) -> Any: ...
# TODO: this can now be inlined into RemoteProviderSpec
@json_schema_type
class AdapterSpec(BaseModel):
adapter_type: str = Field(
@ -171,12 +172,10 @@ class RemoteProviderConfig(BaseModel):
@json_schema_type
class RemoteProviderSpec(ProviderSpec):
adapter: Optional[AdapterSpec] = Field(
default=None,
adapter: AdapterSpec = Field(
description="""
If some code is needed to convert the remote responses into Llama Stack compatible
API responses, specify the adapter here. If not specified, it indicates the remote
as being "Llama Stack compatible"
API responses, specify the adapter here.
""",
)
@ -186,38 +185,21 @@ as being "Llama Stack compatible"
@property
def module(self) -> str:
if self.adapter:
return self.adapter.module
return "llama_stack.distribution.client"
return self.adapter.module
@property
def pip_packages(self) -> List[str]:
if self.adapter:
return self.adapter.pip_packages
return []
return self.adapter.pip_packages
@property
def provider_data_validator(self) -> Optional[str]:
if self.adapter:
return self.adapter.provider_data_validator
return None
return self.adapter.provider_data_validator
def is_passthrough(spec: ProviderSpec) -> bool:
return isinstance(spec, RemoteProviderSpec) and spec.adapter is None
# Can avoid this by using Pydantic computed_field
def remote_provider_spec(
api: Api, adapter: Optional[AdapterSpec] = None
) -> RemoteProviderSpec:
config_class = (
adapter.config_class
if adapter and adapter.config_class
else "llama_stack.distribution.datatypes.RemoteProviderConfig"
)
provider_type = f"remote::{adapter.adapter_type}" if adapter else "remote"
def remote_provider_spec(api: Api, adapter: AdapterSpec) -> RemoteProviderSpec:
return RemoteProviderSpec(
api=api, provider_type=provider_type, config_class=config_class, adapter=adapter
api=api,
provider_type=f"remote::{adapter.adapter_type}",
config_class=adapter.config_class,
adapter=adapter,
)

View file

@ -189,6 +189,7 @@ async def inference_stack(request, inference_model):
models=[
ModelInput(
model_id=inference_model,
provider_id=inference_fixture.providers[0].provider_id,
)
],
)