mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
Kill the notion of a "remote" / "passthrough" provider
This commit is contained in:
parent
59a65e34d3
commit
743da9690b
6 changed files with 95 additions and 87 deletions
|
@ -9,7 +9,7 @@ from typing import Dict, List
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import Api, ProviderSpec, remote_provider_spec
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
|
|
||||||
def stack_apis() -> List[Api]:
|
def stack_apis() -> List[Api]:
|
||||||
|
@ -62,9 +62,6 @@ def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||||
for api in providable_apis():
|
for api in providable_apis():
|
||||||
name = api.name.lower()
|
name = api.name.lower()
|
||||||
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
||||||
ret[api] = {
|
ret[api] = {a.provider_type: a for a in module.available_providers()}
|
||||||
"remote": remote_provider_spec(api),
|
|
||||||
**{a.provider_type: a for a in module.available_providers()},
|
|
||||||
}
|
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
|
@ -273,17 +273,8 @@ async def instantiate_provider(
|
||||||
config_type = instantiate_class_type(provider_spec.config_class)
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
config = config_type(**provider.config)
|
config = config_type(**provider.config)
|
||||||
|
|
||||||
if provider_spec.adapter:
|
|
||||||
method = "get_adapter_impl"
|
method = "get_adapter_impl"
|
||||||
args = [config, deps]
|
args = [config, deps]
|
||||||
else:
|
|
||||||
method = "get_client_impl"
|
|
||||||
protocol = protocols[provider_spec.api]
|
|
||||||
if provider_spec.api in additional_protocols:
|
|
||||||
_, additional_protocol = additional_protocols[provider_spec.api]
|
|
||||||
else:
|
|
||||||
additional_protocol = None
|
|
||||||
args = [protocol, additional_protocol, config, deps]
|
|
||||||
|
|
||||||
elif isinstance(provider_spec, AutoRoutedProviderSpec):
|
elif isinstance(provider_spec, AutoRoutedProviderSpec):
|
||||||
method = "get_auto_router_impl"
|
method = "get_auto_router_impl"
|
||||||
|
|
|
@ -33,27 +33,82 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
||||||
|
|
||||||
api = get_impl_api(p)
|
api = get_impl_api(p)
|
||||||
|
|
||||||
if obj.provider_id == "remote":
|
is_remote = obj.provider_id == "remote"
|
||||||
# TODO: this is broken right now because we use the generic
|
if is_remote:
|
||||||
# { identifier, provider_id, provider_resource_id } tuple here
|
# TODO: these are incomplete fixes since (a) they are kind of adhoc and likely to break
|
||||||
# but the APIs expect things like ModelInput, ShieldInput, etc.
|
# and (b) MemoryBankInput is missing BankParams
|
||||||
|
if isinstance(obj, Model):
|
||||||
# if this is just a passthrough, we want to let the remote
|
obj = ModelInput(
|
||||||
# end actually do the registration with the correct provider
|
model_id=obj.identifier,
|
||||||
obj = obj.model_copy(deep=True)
|
metadata=obj.metadata,
|
||||||
obj.provider_id = ""
|
provider_model_id=obj.provider_resource_id,
|
||||||
|
)
|
||||||
|
elif isinstance(obj, Shield):
|
||||||
|
obj = ShieldInput(
|
||||||
|
shield_id=obj.identifier,
|
||||||
|
params=obj.params,
|
||||||
|
provider_shield_id=obj.provider_resource_id,
|
||||||
|
)
|
||||||
|
elif isinstance(obj, MemoryBank):
|
||||||
|
# need to calculate params here
|
||||||
|
obj = MemoryBankInput(
|
||||||
|
memory_bank_id=obj.identifier,
|
||||||
|
provider_memory_bank_id=obj.provider_resource_id,
|
||||||
|
)
|
||||||
|
elif isinstance(obj, ScoringFn):
|
||||||
|
obj = ScoringFnInput(
|
||||||
|
scoring_fn_id=obj.identifier,
|
||||||
|
provider_scoring_fn_id=obj.provider_resource_id,
|
||||||
|
description=obj.description,
|
||||||
|
metadata=obj.metadata,
|
||||||
|
return_type=obj.return_type,
|
||||||
|
params=obj.params,
|
||||||
|
)
|
||||||
|
elif isinstance(obj, EvalTask):
|
||||||
|
obj = EvalTaskInput(
|
||||||
|
eval_task_id=obj.identifier,
|
||||||
|
provider_eval_task_id=obj.provider_resource_id,
|
||||||
|
dataset_id=obj.dataset_id,
|
||||||
|
scoring_function_id=obj.scoring_functions,
|
||||||
|
metadata=obj.metadata,
|
||||||
|
)
|
||||||
|
elif isinstance(obj, Dataset):
|
||||||
|
obj = DatasetInput(
|
||||||
|
dataset_id=obj.identifier,
|
||||||
|
provider_dataset_id=obj.provider_resource_id,
|
||||||
|
schema=obj.schema,
|
||||||
|
url=obj.url,
|
||||||
|
metadata=obj.metadata,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown object type {type(obj)}")
|
||||||
|
|
||||||
if api == Api.inference:
|
if api == Api.inference:
|
||||||
return await p.register_model(obj)
|
return await p.register_model(obj)
|
||||||
elif api == Api.safety:
|
elif api == Api.safety:
|
||||||
|
if is_remote:
|
||||||
|
await p.register_shield(**obj.model_dump())
|
||||||
|
else:
|
||||||
await p.register_shield(obj)
|
await p.register_shield(obj)
|
||||||
elif api == Api.memory:
|
elif api == Api.memory:
|
||||||
|
if is_remote:
|
||||||
|
await p.register_memory_bank(**obj.model_dump())
|
||||||
|
else:
|
||||||
await p.register_memory_bank(obj)
|
await p.register_memory_bank(obj)
|
||||||
elif api == Api.datasetio:
|
elif api == Api.datasetio:
|
||||||
|
if is_remote:
|
||||||
|
await p.register_dataset(**obj.model_dump())
|
||||||
|
else:
|
||||||
await p.register_dataset(obj)
|
await p.register_dataset(obj)
|
||||||
elif api == Api.scoring:
|
elif api == Api.scoring:
|
||||||
|
if is_remote:
|
||||||
|
await p.register_scoring_function(**obj.model_dump())
|
||||||
|
else:
|
||||||
await p.register_scoring_function(obj)
|
await p.register_scoring_function(obj)
|
||||||
elif api == Api.eval:
|
elif api == Api.eval:
|
||||||
|
if is_remote:
|
||||||
|
await p.register_eval_task(**obj.model_dump())
|
||||||
|
else:
|
||||||
await p.register_eval_task(obj)
|
await p.register_eval_task(obj)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown API {api} for registering object with provider")
|
raise ValueError(f"Unknown API {api} for registering object with provider")
|
||||||
|
|
|
@ -182,15 +182,6 @@ async def lifespan(app: FastAPI):
|
||||||
await impl.shutdown()
|
await impl.shutdown()
|
||||||
|
|
||||||
|
|
||||||
def create_dynamic_passthrough(
|
|
||||||
downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None
|
|
||||||
):
|
|
||||||
async def endpoint(request: Request):
|
|
||||||
return await passthrough(request, downstream_url, downstream_headers)
|
|
||||||
|
|
||||||
return endpoint
|
|
||||||
|
|
||||||
|
|
||||||
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||||
# TODO: pass the api method and punt it to the Protocol definition directly
|
# TODO: pass the api method and punt it to the Protocol definition directly
|
||||||
return kwargs.get("stream", False)
|
return kwargs.get("stream", False)
|
||||||
|
@ -305,19 +296,10 @@ def main(
|
||||||
endpoints = all_endpoints[api]
|
endpoints = all_endpoints[api]
|
||||||
impl = impls[api]
|
impl = impls[api]
|
||||||
|
|
||||||
if is_passthrough(impl.__provider_spec__):
|
|
||||||
for endpoint in endpoints:
|
|
||||||
url = impl.__provider_config__.url.rstrip("/") + endpoint.route
|
|
||||||
getattr(app, endpoint.method)(endpoint.route)(
|
|
||||||
create_dynamic_passthrough(url)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
for endpoint in endpoints:
|
for endpoint in endpoints:
|
||||||
if not hasattr(impl, endpoint.name):
|
if not hasattr(impl, endpoint.name):
|
||||||
# ideally this should be a typing violation already
|
# ideally this should be a typing violation already
|
||||||
raise ValueError(
|
raise ValueError(f"Could not find method {endpoint.name} on {impl}!!")
|
||||||
f"Could not find method {endpoint.name} on {impl}!!"
|
|
||||||
)
|
|
||||||
|
|
||||||
impl_method = getattr(impl, endpoint.name)
|
impl_method = getattr(impl, endpoint.name)
|
||||||
|
|
||||||
|
|
|
@ -99,6 +99,7 @@ class RoutingTable(Protocol):
|
||||||
def get_provider_impl(self, routing_key: str) -> Any: ...
|
def get_provider_impl(self, routing_key: str) -> Any: ...
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: this can now be inlined into RemoteProviderSpec
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AdapterSpec(BaseModel):
|
class AdapterSpec(BaseModel):
|
||||||
adapter_type: str = Field(
|
adapter_type: str = Field(
|
||||||
|
@ -171,12 +172,10 @@ class RemoteProviderConfig(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RemoteProviderSpec(ProviderSpec):
|
class RemoteProviderSpec(ProviderSpec):
|
||||||
adapter: Optional[AdapterSpec] = Field(
|
adapter: AdapterSpec = Field(
|
||||||
default=None,
|
|
||||||
description="""
|
description="""
|
||||||
If some code is needed to convert the remote responses into Llama Stack compatible
|
If some code is needed to convert the remote responses into Llama Stack compatible
|
||||||
API responses, specify the adapter here. If not specified, it indicates the remote
|
API responses, specify the adapter here.
|
||||||
as being "Llama Stack compatible"
|
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -186,38 +185,21 @@ as being "Llama Stack compatible"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def module(self) -> str:
|
def module(self) -> str:
|
||||||
if self.adapter:
|
|
||||||
return self.adapter.module
|
return self.adapter.module
|
||||||
return "llama_stack.distribution.client"
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pip_packages(self) -> List[str]:
|
def pip_packages(self) -> List[str]:
|
||||||
if self.adapter:
|
|
||||||
return self.adapter.pip_packages
|
return self.adapter.pip_packages
|
||||||
return []
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider_data_validator(self) -> Optional[str]:
|
def provider_data_validator(self) -> Optional[str]:
|
||||||
if self.adapter:
|
|
||||||
return self.adapter.provider_data_validator
|
return self.adapter.provider_data_validator
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def is_passthrough(spec: ProviderSpec) -> bool:
|
def remote_provider_spec(api: Api, adapter: AdapterSpec) -> RemoteProviderSpec:
|
||||||
return isinstance(spec, RemoteProviderSpec) and spec.adapter is None
|
|
||||||
|
|
||||||
|
|
||||||
# Can avoid this by using Pydantic computed_field
|
|
||||||
def remote_provider_spec(
|
|
||||||
api: Api, adapter: Optional[AdapterSpec] = None
|
|
||||||
) -> RemoteProviderSpec:
|
|
||||||
config_class = (
|
|
||||||
adapter.config_class
|
|
||||||
if adapter and adapter.config_class
|
|
||||||
else "llama_stack.distribution.datatypes.RemoteProviderConfig"
|
|
||||||
)
|
|
||||||
provider_type = f"remote::{adapter.adapter_type}" if adapter else "remote"
|
|
||||||
|
|
||||||
return RemoteProviderSpec(
|
return RemoteProviderSpec(
|
||||||
api=api, provider_type=provider_type, config_class=config_class, adapter=adapter
|
api=api,
|
||||||
|
provider_type=f"remote::{adapter.adapter_type}",
|
||||||
|
config_class=adapter.config_class,
|
||||||
|
adapter=adapter,
|
||||||
)
|
)
|
||||||
|
|
|
@ -189,6 +189,7 @@ async def inference_stack(request, inference_model):
|
||||||
models=[
|
models=[
|
||||||
ModelInput(
|
ModelInput(
|
||||||
model_id=inference_model,
|
model_id=inference_model,
|
||||||
|
provider_id=inference_fixture.providers[0].provider_id,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue