mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-18 23:28:53 +00:00
Kill the notion of a "remote" / "passthrough" provider
This commit is contained in:
parent
59a65e34d3
commit
743da9690b
6 changed files with 95 additions and 87 deletions
|
@ -99,6 +99,7 @@ class RoutingTable(Protocol):
|
|||
def get_provider_impl(self, routing_key: str) -> Any: ...
|
||||
|
||||
|
||||
# TODO: this can now be inlined into RemoteProviderSpec
|
||||
@json_schema_type
|
||||
class AdapterSpec(BaseModel):
|
||||
adapter_type: str = Field(
|
||||
|
@ -171,12 +172,10 @@ class RemoteProviderConfig(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class RemoteProviderSpec(ProviderSpec):
|
||||
adapter: Optional[AdapterSpec] = Field(
|
||||
default=None,
|
||||
adapter: AdapterSpec = Field(
|
||||
description="""
|
||||
If some code is needed to convert the remote responses into Llama Stack compatible
|
||||
API responses, specify the adapter here. If not specified, it indicates the remote
|
||||
as being "Llama Stack compatible"
|
||||
API responses, specify the adapter here.
|
||||
""",
|
||||
)
|
||||
|
||||
|
@ -186,38 +185,21 @@ as being "Llama Stack compatible"
|
|||
|
||||
@property
|
||||
def module(self) -> str:
|
||||
if self.adapter:
|
||||
return self.adapter.module
|
||||
return "llama_stack.distribution.client"
|
||||
return self.adapter.module
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> List[str]:
|
||||
if self.adapter:
|
||||
return self.adapter.pip_packages
|
||||
return []
|
||||
return self.adapter.pip_packages
|
||||
|
||||
@property
|
||||
def provider_data_validator(self) -> Optional[str]:
|
||||
if self.adapter:
|
||||
return self.adapter.provider_data_validator
|
||||
return None
|
||||
return self.adapter.provider_data_validator
|
||||
|
||||
|
||||
def is_passthrough(spec: ProviderSpec) -> bool:
|
||||
return isinstance(spec, RemoteProviderSpec) and spec.adapter is None
|
||||
|
||||
|
||||
# Can avoid this by using Pydantic computed_field
|
||||
def remote_provider_spec(
|
||||
api: Api, adapter: Optional[AdapterSpec] = None
|
||||
) -> RemoteProviderSpec:
|
||||
config_class = (
|
||||
adapter.config_class
|
||||
if adapter and adapter.config_class
|
||||
else "llama_stack.distribution.datatypes.RemoteProviderConfig"
|
||||
)
|
||||
provider_type = f"remote::{adapter.adapter_type}" if adapter else "remote"
|
||||
|
||||
def remote_provider_spec(api: Api, adapter: AdapterSpec) -> RemoteProviderSpec:
|
||||
return RemoteProviderSpec(
|
||||
api=api, provider_type=provider_type, config_class=config_class, adapter=adapter
|
||||
api=api,
|
||||
provider_type=f"remote::{adapter.adapter_type}",
|
||||
config_class=adapter.config_class,
|
||||
adapter=adapter,
|
||||
)
|
||||
|
|
|
@ -189,6 +189,7 @@ async def inference_stack(request, inference_model):
|
|||
models=[
|
||||
ModelInput(
|
||||
model_id=inference_model,
|
||||
provider_id=inference_fixture.providers[0].provider_id,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue