mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 23:51:00 +00:00
Kill the notion of a "remote" / "passthrough" provider
This commit is contained in:
parent
59a65e34d3
commit
743da9690b
6 changed files with 95 additions and 87 deletions
|
@ -9,7 +9,7 @@ from typing import Dict, List
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec, remote_provider_spec
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
|
||||
def stack_apis() -> List[Api]:
|
||||
|
@ -62,9 +62,6 @@ def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
|
|||
for api in providable_apis():
|
||||
name = api.name.lower()
|
||||
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
||||
ret[api] = {
|
||||
"remote": remote_provider_spec(api),
|
||||
**{a.provider_type: a for a in module.available_providers()},
|
||||
}
|
||||
ret[api] = {a.provider_type: a for a in module.available_providers()}
|
||||
|
||||
return ret
|
||||
|
|
|
@ -273,17 +273,8 @@ async def instantiate_provider(
|
|||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
config = config_type(**provider.config)
|
||||
|
||||
if provider_spec.adapter:
|
||||
method = "get_adapter_impl"
|
||||
args = [config, deps]
|
||||
else:
|
||||
method = "get_client_impl"
|
||||
protocol = protocols[provider_spec.api]
|
||||
if provider_spec.api in additional_protocols:
|
||||
_, additional_protocol = additional_protocols[provider_spec.api]
|
||||
else:
|
||||
additional_protocol = None
|
||||
args = [protocol, additional_protocol, config, deps]
|
||||
method = "get_adapter_impl"
|
||||
args = [config, deps]
|
||||
|
||||
elif isinstance(provider_spec, AutoRoutedProviderSpec):
|
||||
method = "get_auto_router_impl"
|
||||
|
|
|
@ -33,28 +33,83 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
|||
|
||||
api = get_impl_api(p)
|
||||
|
||||
if obj.provider_id == "remote":
|
||||
# TODO: this is broken right now because we use the generic
|
||||
# { identifier, provider_id, provider_resource_id } tuple here
|
||||
# but the APIs expect things like ModelInput, ShieldInput, etc.
|
||||
|
||||
# if this is just a passthrough, we want to let the remote
|
||||
# end actually do the registration with the correct provider
|
||||
obj = obj.model_copy(deep=True)
|
||||
obj.provider_id = ""
|
||||
is_remote = obj.provider_id == "remote"
|
||||
if is_remote:
|
||||
# TODO: these are incomplete fixes since (a) they are kind of adhoc and likely to break
|
||||
# and (b) MemoryBankInput is missing BankParams
|
||||
if isinstance(obj, Model):
|
||||
obj = ModelInput(
|
||||
model_id=obj.identifier,
|
||||
metadata=obj.metadata,
|
||||
provider_model_id=obj.provider_resource_id,
|
||||
)
|
||||
elif isinstance(obj, Shield):
|
||||
obj = ShieldInput(
|
||||
shield_id=obj.identifier,
|
||||
params=obj.params,
|
||||
provider_shield_id=obj.provider_resource_id,
|
||||
)
|
||||
elif isinstance(obj, MemoryBank):
|
||||
# need to calculate params here
|
||||
obj = MemoryBankInput(
|
||||
memory_bank_id=obj.identifier,
|
||||
provider_memory_bank_id=obj.provider_resource_id,
|
||||
)
|
||||
elif isinstance(obj, ScoringFn):
|
||||
obj = ScoringFnInput(
|
||||
scoring_fn_id=obj.identifier,
|
||||
provider_scoring_fn_id=obj.provider_resource_id,
|
||||
description=obj.description,
|
||||
metadata=obj.metadata,
|
||||
return_type=obj.return_type,
|
||||
params=obj.params,
|
||||
)
|
||||
elif isinstance(obj, EvalTask):
|
||||
obj = EvalTaskInput(
|
||||
eval_task_id=obj.identifier,
|
||||
provider_eval_task_id=obj.provider_resource_id,
|
||||
dataset_id=obj.dataset_id,
|
||||
scoring_function_id=obj.scoring_functions,
|
||||
metadata=obj.metadata,
|
||||
)
|
||||
elif isinstance(obj, Dataset):
|
||||
obj = DatasetInput(
|
||||
dataset_id=obj.identifier,
|
||||
provider_dataset_id=obj.provider_resource_id,
|
||||
schema=obj.schema,
|
||||
url=obj.url,
|
||||
metadata=obj.metadata,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown object type {type(obj)}")
|
||||
|
||||
if api == Api.inference:
|
||||
return await p.register_model(obj)
|
||||
elif api == Api.safety:
|
||||
await p.register_shield(obj)
|
||||
if is_remote:
|
||||
await p.register_shield(**obj.model_dump())
|
||||
else:
|
||||
await p.register_shield(obj)
|
||||
elif api == Api.memory:
|
||||
await p.register_memory_bank(obj)
|
||||
if is_remote:
|
||||
await p.register_memory_bank(**obj.model_dump())
|
||||
else:
|
||||
await p.register_memory_bank(obj)
|
||||
elif api == Api.datasetio:
|
||||
await p.register_dataset(obj)
|
||||
if is_remote:
|
||||
await p.register_dataset(**obj.model_dump())
|
||||
else:
|
||||
await p.register_dataset(obj)
|
||||
elif api == Api.scoring:
|
||||
await p.register_scoring_function(obj)
|
||||
if is_remote:
|
||||
await p.register_scoring_function(**obj.model_dump())
|
||||
else:
|
||||
await p.register_scoring_function(obj)
|
||||
elif api == Api.eval:
|
||||
await p.register_eval_task(obj)
|
||||
if is_remote:
|
||||
await p.register_eval_task(**obj.model_dump())
|
||||
else:
|
||||
await p.register_eval_task(obj)
|
||||
else:
|
||||
raise ValueError(f"Unknown API {api} for registering object with provider")
|
||||
|
||||
|
|
|
@ -182,15 +182,6 @@ async def lifespan(app: FastAPI):
|
|||
await impl.shutdown()
|
||||
|
||||
|
||||
def create_dynamic_passthrough(
|
||||
downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None
|
||||
):
|
||||
async def endpoint(request: Request):
|
||||
return await passthrough(request, downstream_url, downstream_headers)
|
||||
|
||||
return endpoint
|
||||
|
||||
|
||||
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||
# TODO: pass the api method and punt it to the Protocol definition directly
|
||||
return kwargs.get("stream", False)
|
||||
|
@ -305,28 +296,19 @@ def main(
|
|||
endpoints = all_endpoints[api]
|
||||
impl = impls[api]
|
||||
|
||||
if is_passthrough(impl.__provider_spec__):
|
||||
for endpoint in endpoints:
|
||||
url = impl.__provider_config__.url.rstrip("/") + endpoint.route
|
||||
getattr(app, endpoint.method)(endpoint.route)(
|
||||
create_dynamic_passthrough(url)
|
||||
)
|
||||
else:
|
||||
for endpoint in endpoints:
|
||||
if not hasattr(impl, endpoint.name):
|
||||
# ideally this should be a typing violation already
|
||||
raise ValueError(
|
||||
f"Could not find method {endpoint.name} on {impl}!!"
|
||||
)
|
||||
for endpoint in endpoints:
|
||||
if not hasattr(impl, endpoint.name):
|
||||
# ideally this should be a typing violation already
|
||||
raise ValueError(f"Could not find method {endpoint.name} on {impl}!!")
|
||||
|
||||
impl_method = getattr(impl, endpoint.name)
|
||||
impl_method = getattr(impl, endpoint.name)
|
||||
|
||||
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
||||
create_dynamic_typed_route(
|
||||
impl_method,
|
||||
endpoint.method,
|
||||
)
|
||||
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
||||
create_dynamic_typed_route(
|
||||
impl_method,
|
||||
endpoint.method,
|
||||
)
|
||||
)
|
||||
|
||||
cprint(f"Serving API {api_str}", "white", attrs=["bold"])
|
||||
for endpoint in endpoints:
|
||||
|
|
|
@ -99,6 +99,7 @@ class RoutingTable(Protocol):
|
|||
def get_provider_impl(self, routing_key: str) -> Any: ...
|
||||
|
||||
|
||||
# TODO: this can now be inlined into RemoteProviderSpec
|
||||
@json_schema_type
|
||||
class AdapterSpec(BaseModel):
|
||||
adapter_type: str = Field(
|
||||
|
@ -171,12 +172,10 @@ class RemoteProviderConfig(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class RemoteProviderSpec(ProviderSpec):
|
||||
adapter: Optional[AdapterSpec] = Field(
|
||||
default=None,
|
||||
adapter: AdapterSpec = Field(
|
||||
description="""
|
||||
If some code is needed to convert the remote responses into Llama Stack compatible
|
||||
API responses, specify the adapter here. If not specified, it indicates the remote
|
||||
as being "Llama Stack compatible"
|
||||
API responses, specify the adapter here.
|
||||
""",
|
||||
)
|
||||
|
||||
|
@ -186,38 +185,21 @@ as being "Llama Stack compatible"
|
|||
|
||||
@property
|
||||
def module(self) -> str:
|
||||
if self.adapter:
|
||||
return self.adapter.module
|
||||
return "llama_stack.distribution.client"
|
||||
return self.adapter.module
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> List[str]:
|
||||
if self.adapter:
|
||||
return self.adapter.pip_packages
|
||||
return []
|
||||
return self.adapter.pip_packages
|
||||
|
||||
@property
|
||||
def provider_data_validator(self) -> Optional[str]:
|
||||
if self.adapter:
|
||||
return self.adapter.provider_data_validator
|
||||
return None
|
||||
return self.adapter.provider_data_validator
|
||||
|
||||
|
||||
def is_passthrough(spec: ProviderSpec) -> bool:
|
||||
return isinstance(spec, RemoteProviderSpec) and spec.adapter is None
|
||||
|
||||
|
||||
# Can avoid this by using Pydantic computed_field
|
||||
def remote_provider_spec(
|
||||
api: Api, adapter: Optional[AdapterSpec] = None
|
||||
) -> RemoteProviderSpec:
|
||||
config_class = (
|
||||
adapter.config_class
|
||||
if adapter and adapter.config_class
|
||||
else "llama_stack.distribution.datatypes.RemoteProviderConfig"
|
||||
)
|
||||
provider_type = f"remote::{adapter.adapter_type}" if adapter else "remote"
|
||||
|
||||
def remote_provider_spec(api: Api, adapter: AdapterSpec) -> RemoteProviderSpec:
|
||||
return RemoteProviderSpec(
|
||||
api=api, provider_type=provider_type, config_class=config_class, adapter=adapter
|
||||
api=api,
|
||||
provider_type=f"remote::{adapter.adapter_type}",
|
||||
config_class=adapter.config_class,
|
||||
adapter=adapter,
|
||||
)
|
||||
|
|
|
@ -189,6 +189,7 @@ async def inference_stack(request, inference_model):
|
|||
models=[
|
||||
ModelInput(
|
||||
model_id=inference_model,
|
||||
provider_id=inference_fixture.providers[0].provider_id,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue