Add an introspection "Api.inspect" API

This commit is contained in:
Ashwin Bharambe 2024-10-02 15:13:24 -07:00
parent 01d93be948
commit 8d049000e3
14 changed files with 619 additions and 174 deletions

View file

@ -46,6 +46,7 @@ from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.inspect import * # noqa: F403
class LlamaStack(
@ -63,6 +64,7 @@ class LlamaStack(
Evaluations,
Models,
Shields,
Inspect,
):
pass

View file

@ -21,7 +21,7 @@
"info": {
"title": "[DRAFT] Llama Stack Specification",
"version": "0.0.1",
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-09-23 16:58:41.469308"
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-02 15:40:53.008257"
},
"servers": [
{
@ -1542,6 +1542,36 @@
]
}
},
"/health": {
"get": {
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HealthInfo"
}
}
}
}
},
"tags": [
"Inspect"
],
"parameters": [
{
"name": "X-LlamaStack-ProviderData",
"in": "header",
"description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
"required": false,
"schema": {
"type": "string"
}
}
]
}
},
"/memory/insert": {
"post": {
"responses": {
@ -1665,6 +1695,75 @@
]
}
},
"/providers/list": {
"get": {
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {
"type": "object",
"additionalProperties": {
"$ref": "#/components/schemas/ProviderInfo"
}
}
}
}
}
},
"tags": [
"Inspect"
],
"parameters": [
{
"name": "X-LlamaStack-ProviderData",
"in": "header",
"description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
"required": false,
"schema": {
"type": "string"
}
}
]
}
},
"/routes/list": {
"get": {
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {
"type": "object",
"additionalProperties": {
"type": "array",
"items": {
"$ref": "#/components/schemas/RouteInfo"
}
}
}
}
}
}
},
"tags": [
"Inspect"
],
"parameters": [
{
"name": "X-LlamaStack-ProviderData",
"in": "header",
"description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
"required": false,
"schema": {
"type": "string"
}
}
]
}
},
"/shields/list": {
"get": {
"responses": {
@ -5086,6 +5185,18 @@
"job_uuid"
]
},
"HealthInfo": {
"type": "object",
"properties": {
"status": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"status"
]
},
"InsertDocumentsRequest": {
"type": "object",
"properties": {
@ -5108,6 +5219,45 @@
"documents"
]
},
"ProviderInfo": {
"type": "object",
"properties": {
"provider_type": {
"type": "string"
},
"description": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"provider_type",
"description"
]
},
"RouteInfo": {
"type": "object",
"properties": {
"route": {
"type": "string"
},
"method": {
"type": "string"
},
"providers": {
"type": "array",
"items": {
"type": "string"
}
}
},
"additionalProperties": false,
"required": [
"route",
"method",
"providers"
]
},
"LogSeverity": {
"type": "string",
"enum": [
@ -6220,19 +6370,34 @@
],
"tags": [
{
"name": "Shields"
"name": "Datasets"
},
{
"name": "Inspect"
},
{
"name": "Memory"
},
{
"name": "BatchInference"
},
{
"name": "RewardScoring"
"name": "Agents"
},
{
"name": "Inference"
},
{
"name": "Shields"
},
{
"name": "SyntheticDataGeneration"
},
{
"name": "Agents"
"name": "Models"
},
{
"name": "RewardScoring"
},
{
"name": "MemoryBanks"
@ -6241,13 +6406,7 @@
"name": "Safety"
},
{
"name": "Models"
},
{
"name": "Inference"
},
{
"name": "Memory"
"name": "Evaluations"
},
{
"name": "Telemetry"
@ -6255,12 +6414,6 @@
{
"name": "PostTraining"
},
{
"name": "Datasets"
},
{
"name": "Evaluations"
},
{
"name": "BuiltinTool",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/BuiltinTool\" />"
@ -6653,10 +6806,22 @@
"name": "PostTrainingJob",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/PostTrainingJob\" />"
},
{
"name": "HealthInfo",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/HealthInfo\" />"
},
{
"name": "InsertDocumentsRequest",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/InsertDocumentsRequest\" />"
},
{
"name": "ProviderInfo",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/ProviderInfo\" />"
},
{
"name": "RouteInfo",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/RouteInfo\" />"
},
{
"name": "LogSeverity",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/LogSeverity\" />"
@ -6787,6 +6952,7 @@
"Datasets",
"Evaluations",
"Inference",
"Inspect",
"Memory",
"MemoryBanks",
"Models",
@ -6857,6 +7023,7 @@
"FunctionCallToolDefinition",
"GetAgentsSessionRequest",
"GetDocumentsRequest",
"HealthInfo",
"ImageMedia",
"InferenceStep",
"InsertDocumentsRequest",
@ -6880,6 +7047,7 @@
"PostTrainingJobStatus",
"PostTrainingJobStatusResponse",
"PreferenceOptimizeRequest",
"ProviderInfo",
"QLoraFinetuningConfig",
"QueryDocumentsRequest",
"QueryDocumentsResponse",
@ -6888,6 +7056,7 @@
"RestAPIMethod",
"RewardScoreRequest",
"RewardScoringResponse",
"RouteInfo",
"RunShieldRequest",
"RunShieldResponse",
"SafetyViolation",

View file

@ -908,6 +908,14 @@ components:
required:
- document_ids
type: object
HealthInfo:
additionalProperties: false
properties:
status:
type: string
required:
- status
type: object
ImageMedia:
additionalProperties: false
properties:
@ -1543,6 +1551,17 @@ components:
- hyperparam_search_config
- logger_config
type: object
ProviderInfo:
additionalProperties: false
properties:
description:
type: string
provider_type:
type: string
required:
- provider_type
- description
type: object
QLoraFinetuningConfig:
additionalProperties: false
properties:
@ -1704,6 +1723,22 @@ components:
title: Response from the reward scoring. Batch of (prompt, response, score)
tuples that pass the threshold.
type: object
RouteInfo:
additionalProperties: false
properties:
method:
type: string
providers:
items:
type: string
type: array
route:
type: string
required:
- route
- method
- providers
type: object
RunShieldRequest:
additionalProperties: false
properties:
@ -2569,7 +2604,7 @@ info:
description: "This is the specification of the llama stack that provides\n \
\ a set of endpoints and their corresponding interfaces that are tailored\
\ to\n best leverage Llama Models. The specification is still in\
\ draft and subject to change.\n Generated at 2024-09-23 16:58:41.469308"
\ draft and subject to change.\n Generated at 2024-10-02 15:40:53.008257"
title: '[DRAFT] Llama Stack Specification'
version: 0.0.1
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
@ -3093,6 +3128,25 @@ paths:
description: OK
tags:
- Evaluations
/health:
get:
parameters:
- description: JSON-encoded provider data which will be made available to the
adapter servicing the API
in: header
name: X-LlamaStack-ProviderData
required: false
schema:
type: string
responses:
'200':
content:
application/json:
schema:
$ref: '#/components/schemas/HealthInfo'
description: OK
tags:
- Inspect
/inference/chat_completion:
post:
parameters:
@ -3637,6 +3691,27 @@ paths:
description: OK
tags:
- PostTraining
/providers/list:
get:
parameters:
- description: JSON-encoded provider data which will be made available to the
adapter servicing the API
in: header
name: X-LlamaStack-ProviderData
required: false
schema:
type: string
responses:
'200':
content:
application/json:
schema:
additionalProperties:
$ref: '#/components/schemas/ProviderInfo'
type: object
description: OK
tags:
- Inspect
/reward_scoring/score:
post:
parameters:
@ -3662,6 +3737,29 @@ paths:
description: OK
tags:
- RewardScoring
/routes/list:
get:
parameters:
- description: JSON-encoded provider data which will be made available to the
adapter servicing the API
in: header
name: X-LlamaStack-ProviderData
required: false
schema:
type: string
responses:
'200':
content:
application/json:
schema:
additionalProperties:
items:
$ref: '#/components/schemas/RouteInfo'
type: array
type: object
description: OK
tags:
- Inspect
/safety/run_shield:
post:
parameters:
@ -3807,20 +3905,21 @@ security:
servers:
- url: http://any-hosted-llama-stack.com
tags:
- name: Shields
- name: Datasets
- name: Inspect
- name: Memory
- name: BatchInference
- name: RewardScoring
- name: SyntheticDataGeneration
- name: Agents
- name: Inference
- name: Shields
- name: SyntheticDataGeneration
- name: Models
- name: RewardScoring
- name: MemoryBanks
- name: Safety
- name: Models
- name: Inference
- name: Memory
- name: Evaluations
- name: Telemetry
- name: PostTraining
- name: Datasets
- name: Evaluations
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
name: BuiltinTool
- description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage"
@ -4135,9 +4234,15 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/PostTrainingJob"
/>
name: PostTrainingJob
- description: <SchemaDefinition schemaRef="#/components/schemas/HealthInfo" />
name: HealthInfo
- description: <SchemaDefinition schemaRef="#/components/schemas/InsertDocumentsRequest"
/>
name: InsertDocumentsRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/ProviderInfo" />
name: ProviderInfo
- description: <SchemaDefinition schemaRef="#/components/schemas/RouteInfo" />
name: RouteInfo
- description: <SchemaDefinition schemaRef="#/components/schemas/LogSeverity" />
name: LogSeverity
- description: <SchemaDefinition schemaRef="#/components/schemas/MetricEvent" />
@ -4236,6 +4341,7 @@ x-tagGroups:
- Datasets
- Evaluations
- Inference
- Inspect
- Memory
- MemoryBanks
- Models
@ -4303,6 +4409,7 @@ x-tagGroups:
- FunctionCallToolDefinition
- GetAgentsSessionRequest
- GetDocumentsRequest
- HealthInfo
- ImageMedia
- InferenceStep
- InsertDocumentsRequest
@ -4326,6 +4433,7 @@ x-tagGroups:
- PostTrainingJobStatus
- PostTrainingJobStatusResponse
- PreferenceOptimizeRequest
- ProviderInfo
- QLoraFinetuningConfig
- QueryDocumentsRequest
- QueryDocumentsResponse
@ -4334,6 +4442,7 @@ x-tagGroups:
- RestAPIMethod
- RewardScoreRequest
- RewardScoringResponse
- RouteInfo
- RunShieldRequest
- RunShieldResponse
- SafetyViolation

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .inspect import * # noqa: F401 F403

View file

@ -0,0 +1,82 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import List
import fire
import httpx
from termcolor import cprint
from .inspect import * # noqa: F403
class InspectClient(Inspect):
def __init__(self, base_url: str):
self.base_url = base_url
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def list_providers(self) -> Dict[str, ProviderInfo]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/providers/list",
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
print(response.json())
return {
k: [ProviderInfo(**vi) for vi in v] for k, v in response.json().items()
}
async def list_routes(self) -> Dict[str, List[RouteInfo]]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/routes/list",
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return {
k: [RouteInfo(**vi) for vi in v] for k, v in response.json().items()
}
async def health(self) -> HealthInfo:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/health",
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
j = response.json()
if j is None:
return None
return HealthInfo(**j)
async def run_main(host: str, port: int):
client = InspectClient(f"http://{host}:{port}")
response = await client.list_providers()
cprint(f"list_providers response={response}", "green")
response = await client.list_routes()
cprint(f"list_routes response={response}", "blue")
response = await client.health()
cprint(f"health response={response}", "yellow")
def main(host: str, port: int):
asyncio.run(run_main(host, port))
if __name__ == "__main__":
fire.Fire(main)

View file

@ -0,0 +1,40 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict, List, Protocol
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
@json_schema_type
class ProviderInfo(BaseModel):
provider_type: str
description: str
@json_schema_type
class RouteInfo(BaseModel):
route: str
method: str
providers: List[str]
@json_schema_type
class HealthInfo(BaseModel):
status: str
# TODO: add a provider level status
class Inspect(Protocol):
@webmethod(route="/providers/list", method="GET")
async def list_providers(self) -> Dict[str, ProviderInfo]: ...
@webmethod(route="/routes/list", method="GET")
async def list_routes(self) -> Dict[str, List[RouteInfo]]: ...
@webmethod(route="/health", method="GET")
async def health(self) -> HealthInfo: ...

View file

@ -17,6 +17,53 @@ LLAMA_STACK_BUILD_CONFIG_VERSION = "v1"
LLAMA_STACK_RUN_CONFIG_VERSION = "v1"
RoutingKey = Union[str, List[str]]
class GenericProviderConfig(BaseModel):
provider_type: str
config: Dict[str, Any]
class RoutableProviderConfig(GenericProviderConfig):
routing_key: RoutingKey
class PlaceholderProviderConfig(BaseModel):
"""Placeholder provider config for API whose provider are defined in routing_table"""
providers: List[str]
# Example: /inference, /safety
class AutoRoutedProviderSpec(ProviderSpec):
provider_type: str = "router"
config_class: str = ""
docker_image: Optional[str] = None
routing_table_api: Api
module: str
provider_data_validator: Optional[str] = Field(
default=None,
)
@property
def pip_packages(self) -> List[str]:
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
# Example: /models, /shields
@json_schema_type
class RoutingTableProviderSpec(ProviderSpec):
provider_type: str = "routing_table"
config_class: str = ""
docker_image: Optional[str] = None
inner_specs: List[ProviderSpec]
module: str
pip_packages: List[str] = Field(default_factory=list)
@json_schema_type
class DistributionSpec(BaseModel):
description: Optional[str] = Field(

View file

@ -46,6 +46,8 @@ def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
for api in stack_apis():
if api in routing_table_apis:
continue
if api == Api.inspect:
continue
name = api.name.lower()
module = importlib.import_module(f"llama_stack.providers.registry.{name}")

View file

@ -0,0 +1,54 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict, List
from llama_stack.apis.inspect import * # noqa: F403
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.providers.datatypes import * # noqa: F403
def is_passthrough(spec: ProviderSpec) -> bool:
return isinstance(spec, RemoteProviderSpec) and spec.adapter is None
class DistributionInspectImpl(Inspect):
def __init__(self):
pass
async def list_providers(self) -> Dict[str, List[ProviderInfo]]:
ret = {}
all_providers = get_provider_registry()
for api, providers in all_providers.items():
ret[api.value] = [
ProviderInfo(
provider_type=p.provider_type,
description="Passthrough" if is_passthrough(p) else "",
)
for p in providers.values()
]
return ret
async def list_routes(self) -> Dict[str, List[RouteInfo]]:
ret = {}
all_endpoints = get_all_api_endpoints()
for api, endpoints in all_endpoints.items():
ret[api.value] = [
RouteInfo(
route=e.route,
method=e.method,
providers=[],
)
for e in endpoints
]
return ret
async def health(self) -> HealthInfo:
return HealthInfo(status="OK")

View file

@ -3,6 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import importlib
from typing import Any, Dict, List, Set
@ -11,7 +12,8 @@ from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis,
get_provider_registry,
)
from llama_stack.distribution.utils.dynamic import instantiate_provider
from llama_stack.distribution.inspect import DistributionInspectImpl
from llama_stack.distribution.utils.dynamic import instantiate_class_type
async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]:
@ -57,7 +59,6 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
if info.router_api.value not in apis_to_serve:
continue
print("router_api", info.router_api)
if info.router_api.value not in run_config.routing_table:
raise ValueError(f"Routing table for `{source_api.value}` is not provided?")
@ -104,6 +105,14 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
impls[api] = impl
impls[Api.inspect] = DistributionInspectImpl()
specs[Api.inspect] = InlineProviderSpec(
api=Api.inspect,
provider_type="__distribution_builtin__",
config_class="",
module="",
)
return impls, specs
@ -127,3 +136,60 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
dfs(a, visited, stack)
return [by_id[x] for x in stack]
# returns a class implementing the protocol corresponding to the Api
async def instantiate_provider(
provider_spec: ProviderSpec,
deps: Dict[str, Any],
provider_config: Union[GenericProviderConfig, RoutingTable],
):
module = importlib.import_module(provider_spec.module)
args = []
if isinstance(provider_spec, RemoteProviderSpec):
if provider_spec.adapter:
method = "get_adapter_impl"
else:
method = "get_client_impl"
assert isinstance(provider_config, GenericProviderConfig)
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config)
args = [config, deps]
elif isinstance(provider_spec, AutoRoutedProviderSpec):
method = "get_auto_router_impl"
config = None
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps]
elif isinstance(provider_spec, RoutingTableProviderSpec):
method = "get_routing_table_impl"
assert isinstance(provider_config, List)
routing_table = provider_config
inner_specs = {x.provider_type: x for x in provider_spec.inner_specs}
inner_impls = []
for routing_entry in routing_table:
impl = await instantiate_provider(
inner_specs[routing_entry.provider_type],
deps,
routing_entry,
)
inner_impls.append((routing_entry.routing_key, impl))
config = None
args = [provider_spec.api, inner_impls, routing_table, deps]
else:
method = "get_provider_impl"
assert isinstance(provider_config, GenericProviderConfig)
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config)
args = [config, deps]
fn = getattr(module, method)
impl = await fn(*args)
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
return impl

View file

@ -11,12 +11,14 @@ from pydantic import BaseModel
from llama_stack.apis.agents import Agents
from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.models import Models
from llama_stack.apis.safety import Safety
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.providers.datatypes import Api
@ -38,6 +40,7 @@ def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
Api.models: Models,
Api.shields: Shields,
Api.memory_banks: MemoryBanks,
Api.inspect: Inspect,
}
for api, protocol in protocols.items():

View file

@ -15,7 +15,6 @@ from collections.abc import (
AsyncIterator as AsyncIteratorABC,
)
from contextlib import asynccontextmanager
from http import HTTPStatus
from ssl import SSLError
from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional
@ -26,7 +25,6 @@ import yaml
from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
@ -287,15 +285,6 @@ def main(
app = FastAPI()
# Health check is added to enable deploying the docker container image on Kubernetes which require
# a health check that can return 200 for readiness and liveness check
class HealthCheck(BaseModel):
status: str = "OK"
@app.get("/healthcheck", status_code=HTTPStatus.OK, response_model=HealthCheck)
async def healthcheck():
return HealthCheck(status="OK")
impls, specs = asyncio.run(resolve_impls_with_routing(config))
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
@ -307,6 +296,7 @@ def main(
else:
apis_to_serve = set(impls.keys())
apis_to_serve.add(Api.inspect)
for api_str in apis_to_serve:
api = Api(api_str)
@ -340,14 +330,11 @@ def main(
)
)
for route in app.routes:
if isinstance(route, APIRoute):
cprint(
f"Serving {next(iter(route.methods))} {route.path}",
"white",
attrs=["bold"],
)
cprint(f"Serving API {api_str}", "white", attrs=["bold"])
for endpoint in endpoints:
cprint(f" {endpoint.method.upper()} {endpoint.route}", "white")
print("")
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler)
signal.signal(signal.SIGINT, handle_sigint)

View file

@ -5,69 +5,9 @@
# the root directory of this source tree.
import importlib
from typing import Any, Dict
from llama_stack.distribution.datatypes import * # noqa: F403
def instantiate_class_type(fully_qualified_name):
module_name, class_name = fully_qualified_name.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, class_name)
# returns a class implementing the protocol corresponding to the Api
async def instantiate_provider(
provider_spec: ProviderSpec,
deps: Dict[str, Any],
provider_config: Union[GenericProviderConfig, RoutingTable],
):
module = importlib.import_module(provider_spec.module)
args = []
if isinstance(provider_spec, RemoteProviderSpec):
if provider_spec.adapter:
method = "get_adapter_impl"
else:
method = "get_client_impl"
assert isinstance(provider_config, GenericProviderConfig)
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config)
args = [config, deps]
elif isinstance(provider_spec, AutoRoutedProviderSpec):
method = "get_auto_router_impl"
config = None
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps]
elif isinstance(provider_spec, RoutingTableProviderSpec):
method = "get_routing_table_impl"
assert isinstance(provider_config, List)
routing_table = provider_config
inner_specs = {x.provider_type: x for x in provider_spec.inner_specs}
inner_impls = []
for routing_entry in routing_table:
impl = await instantiate_provider(
inner_specs[routing_entry.provider_type],
deps,
routing_entry,
)
inner_impls.append((routing_entry.routing_key, impl))
config = None
args = [provider_spec.api, inner_impls, routing_table, deps]
else:
method = "get_provider_impl"
assert isinstance(provider_config, GenericProviderConfig)
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config)
args = [config, deps]
fn = getattr(module, method)
impl = await fn(*args)
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
return impl

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, Union
from typing import Any, List, Optional, Protocol
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@ -24,6 +24,9 @@ class Api(Enum):
shields = "shields"
memory_banks = "memory_banks"
# built-in API
inspect = "inspect"
@json_schema_type
class ProviderSpec(BaseModel):
@ -55,68 +58,6 @@ class RoutableProvider(Protocol):
async def validate_routing_keys(self, keys: List[str]) -> None: ...
class GenericProviderConfig(BaseModel):
provider_type: str
config: Dict[str, Any]
class PlaceholderProviderConfig(BaseModel):
"""Placeholder provider config for API whose provider are defined in routing_table"""
providers: List[str]
RoutingKey = Union[str, List[str]]
class RoutableProviderConfig(GenericProviderConfig):
routing_key: RoutingKey
# Example: /inference, /safety
@json_schema_type
class AutoRoutedProviderSpec(ProviderSpec):
provider_type: str = "router"
config_class: str = ""
docker_image: Optional[str] = None
routing_table_api: Api
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
""",
)
provider_data_validator: Optional[str] = Field(
default=None,
)
@property
def pip_packages(self) -> List[str]:
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
# Example: /models, /shields
@json_schema_type
class RoutingTableProviderSpec(ProviderSpec):
provider_type: str = "routing_table"
config_class: str = ""
docker_image: Optional[str] = None
inner_specs: List[ProviderSpec]
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
""",
)
pip_packages: List[str] = Field(default_factory=list)
@json_schema_type
class AdapterSpec(BaseModel):
adapter_type: str = Field(
@ -179,10 +120,6 @@ class RemoteProviderConfig(BaseModel):
return f"http://{self.host}:{self.port}"
def remote_provider_type(adapter_type: str) -> str:
return f"remote::{adapter_type}"
@json_schema_type
class RemoteProviderSpec(ProviderSpec):
adapter: Optional[AdapterSpec] = Field(
@ -226,7 +163,7 @@ def remote_provider_spec(
if adapter and adapter.config_class
else "llama_stack.distribution.datatypes.RemoteProviderConfig"
)
provider_type = remote_provider_type(adapter.adapter_type) if adapter else "remote"
provider_type = f"remote::{adapter.adapter_type}" if adapter else "remote"
return RemoteProviderSpec(
api=api, provider_type=provider_type, config_class=config_class, adapter=adapter