Add an introspection "Api.inspect" API

This commit is contained in:
Ashwin Bharambe 2024-10-02 15:13:24 -07:00
parent 01d93be948
commit 8d049000e3
14 changed files with 619 additions and 174 deletions

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, Union
from typing import Any, List, Optional, Protocol
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@ -24,6 +24,9 @@ class Api(Enum):
shields = "shields"
memory_banks = "memory_banks"
# built-in API
inspect = "inspect"
@json_schema_type
class ProviderSpec(BaseModel):
@ -55,68 +58,6 @@ class RoutableProvider(Protocol):
async def validate_routing_keys(self, keys: List[str]) -> None: ...
class GenericProviderConfig(BaseModel):
provider_type: str
config: Dict[str, Any]
class PlaceholderProviderConfig(BaseModel):
"""Placeholder provider config for API whose provider are defined in routing_table"""
providers: List[str]
RoutingKey = Union[str, List[str]]
class RoutableProviderConfig(GenericProviderConfig):
routing_key: RoutingKey
# Example: /inference, /safety
@json_schema_type
class AutoRoutedProviderSpec(ProviderSpec):
provider_type: str = "router"
config_class: str = ""
docker_image: Optional[str] = None
routing_table_api: Api
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
""",
)
provider_data_validator: Optional[str] = Field(
default=None,
)
@property
def pip_packages(self) -> List[str]:
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
# Example: /models, /shields
@json_schema_type
class RoutingTableProviderSpec(ProviderSpec):
provider_type: str = "routing_table"
config_class: str = ""
docker_image: Optional[str] = None
inner_specs: List[ProviderSpec]
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
""",
)
pip_packages: List[str] = Field(default_factory=list)
@json_schema_type
class AdapterSpec(BaseModel):
adapter_type: str = Field(
@ -179,10 +120,6 @@ class RemoteProviderConfig(BaseModel):
return f"http://{self.host}:{self.port}"
def remote_provider_type(adapter_type: str) -> str:
return f"remote::{adapter_type}"
@json_schema_type
class RemoteProviderSpec(ProviderSpec):
adapter: Optional[AdapterSpec] = Field(
@ -226,7 +163,7 @@ def remote_provider_spec(
if adapter and adapter.config_class
else "llama_stack.distribution.datatypes.RemoteProviderConfig"
)
provider_type = remote_provider_type(adapter.adapter_type) if adapter else "remote"
provider_type = f"remote::{adapter.adapter_type}" if adapter else "remote"
return RemoteProviderSpec(
api=api, provider_type=provider_type, config_class=config_class, adapter=adapter