Kill the notion of a "remote" / "passthrough" provider

This commit is contained in:
Ashwin Bharambe 2024-11-12 15:30:59 -08:00
parent 59a65e34d3
commit 743da9690b
6 changed files with 95 additions and 87 deletions

View file

@ -9,7 +9,7 @@ from typing import Dict, List
from pydantic import BaseModel
from llama_stack.providers.datatypes import Api, ProviderSpec, remote_provider_spec
from llama_stack.providers.datatypes import Api, ProviderSpec
def stack_apis() -> List[Api]:
@ -62,9 +62,6 @@ def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
for api in providable_apis():
name = api.name.lower()
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
ret[api] = {
"remote": remote_provider_spec(api),
**{a.provider_type: a for a in module.available_providers()},
}
ret[api] = {a.provider_type: a for a in module.available_providers()}
return ret

View file

@ -273,17 +273,8 @@ async def instantiate_provider(
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider.config)
if provider_spec.adapter:
method = "get_adapter_impl"
args = [config, deps]
else:
method = "get_client_impl"
protocol = protocols[provider_spec.api]
if provider_spec.api in additional_protocols:
_, additional_protocol = additional_protocols[provider_spec.api]
else:
additional_protocol = None
args = [protocol, additional_protocol, config, deps]
method = "get_adapter_impl"
args = [config, deps]
elif isinstance(provider_spec, AutoRoutedProviderSpec):
method = "get_auto_router_impl"

View file

@ -33,28 +33,83 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
api = get_impl_api(p)
if obj.provider_id == "remote":
# TODO: this is broken right now because we use the generic
# { identifier, provider_id, provider_resource_id } tuple here
# but the APIs expect things like ModelInput, ShieldInput, etc.
# if this is just a passthrough, we want to let the remote
# end actually do the registration with the correct provider
obj = obj.model_copy(deep=True)
obj.provider_id = ""
is_remote = obj.provider_id == "remote"
if is_remote:
# TODO: these are incomplete fixes since (a) they are kind of adhoc and likely to break
# and (b) MemoryBankInput is missing BankParams
if isinstance(obj, Model):
obj = ModelInput(
model_id=obj.identifier,
metadata=obj.metadata,
provider_model_id=obj.provider_resource_id,
)
elif isinstance(obj, Shield):
obj = ShieldInput(
shield_id=obj.identifier,
params=obj.params,
provider_shield_id=obj.provider_resource_id,
)
elif isinstance(obj, MemoryBank):
# need to calculate params here
obj = MemoryBankInput(
memory_bank_id=obj.identifier,
provider_memory_bank_id=obj.provider_resource_id,
)
elif isinstance(obj, ScoringFn):
obj = ScoringFnInput(
scoring_fn_id=obj.identifier,
provider_scoring_fn_id=obj.provider_resource_id,
description=obj.description,
metadata=obj.metadata,
return_type=obj.return_type,
params=obj.params,
)
elif isinstance(obj, EvalTask):
obj = EvalTaskInput(
eval_task_id=obj.identifier,
provider_eval_task_id=obj.provider_resource_id,
dataset_id=obj.dataset_id,
scoring_function_id=obj.scoring_functions,
metadata=obj.metadata,
)
elif isinstance(obj, Dataset):
obj = DatasetInput(
dataset_id=obj.identifier,
provider_dataset_id=obj.provider_resource_id,
schema=obj.schema,
url=obj.url,
metadata=obj.metadata,
)
else:
raise ValueError(f"Unknown object type {type(obj)}")
if api == Api.inference:
return await p.register_model(obj)
elif api == Api.safety:
await p.register_shield(obj)
if is_remote:
await p.register_shield(**obj.model_dump())
else:
await p.register_shield(obj)
elif api == Api.memory:
await p.register_memory_bank(obj)
if is_remote:
await p.register_memory_bank(**obj.model_dump())
else:
await p.register_memory_bank(obj)
elif api == Api.datasetio:
await p.register_dataset(obj)
if is_remote:
await p.register_dataset(**obj.model_dump())
else:
await p.register_dataset(obj)
elif api == Api.scoring:
await p.register_scoring_function(obj)
if is_remote:
await p.register_scoring_function(**obj.model_dump())
else:
await p.register_scoring_function(obj)
elif api == Api.eval:
await p.register_eval_task(obj)
if is_remote:
await p.register_eval_task(**obj.model_dump())
else:
await p.register_eval_task(obj)
else:
raise ValueError(f"Unknown API {api} for registering object with provider")

View file

@ -182,15 +182,6 @@ async def lifespan(app: FastAPI):
await impl.shutdown()
def create_dynamic_passthrough(
downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None
):
async def endpoint(request: Request):
return await passthrough(request, downstream_url, downstream_headers)
return endpoint
def is_streaming_request(func_name: str, request: Request, **kwargs):
# TODO: pass the api method and punt it to the Protocol definition directly
return kwargs.get("stream", False)
@ -305,28 +296,19 @@ def main(
endpoints = all_endpoints[api]
impl = impls[api]
if is_passthrough(impl.__provider_spec__):
for endpoint in endpoints:
url = impl.__provider_config__.url.rstrip("/") + endpoint.route
getattr(app, endpoint.method)(endpoint.route)(
create_dynamic_passthrough(url)
)
else:
for endpoint in endpoints:
if not hasattr(impl, endpoint.name):
# ideally this should be a typing violation already
raise ValueError(
f"Could not find method {endpoint.name} on {impl}!!"
)
for endpoint in endpoints:
if not hasattr(impl, endpoint.name):
# ideally this should be a typing violation already
raise ValueError(f"Could not find method {endpoint.name} on {impl}!!")
impl_method = getattr(impl, endpoint.name)
impl_method = getattr(impl, endpoint.name)
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
create_dynamic_typed_route(
impl_method,
endpoint.method,
)
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
create_dynamic_typed_route(
impl_method,
endpoint.method,
)
)
cprint(f"Serving API {api_str}", "white", attrs=["bold"])
for endpoint in endpoints: