Add an introspection "Api.inspect" API

This commit is contained in:
Ashwin Bharambe 2024-10-02 15:13:24 -07:00
parent 01d93be948
commit 8d049000e3
14 changed files with 619 additions and 174 deletions

View file

@ -46,6 +46,7 @@ from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403 from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403 from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403 from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.inspect import * # noqa: F403
class LlamaStack( class LlamaStack(
@ -63,6 +64,7 @@ class LlamaStack(
Evaluations, Evaluations,
Models, Models,
Shields, Shields,
Inspect,
): ):
pass pass

View file

@ -21,7 +21,7 @@
"info": { "info": {
"title": "[DRAFT] Llama Stack Specification", "title": "[DRAFT] Llama Stack Specification",
"version": "0.0.1", "version": "0.0.1",
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-09-23 16:58:41.469308" "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-02 15:40:53.008257"
}, },
"servers": [ "servers": [
{ {
@ -1542,6 +1542,36 @@
] ]
} }
}, },
"/health": {
"get": {
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HealthInfo"
}
}
}
}
},
"tags": [
"Inspect"
],
"parameters": [
{
"name": "X-LlamaStack-ProviderData",
"in": "header",
"description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
"required": false,
"schema": {
"type": "string"
}
}
]
}
},
"/memory/insert": { "/memory/insert": {
"post": { "post": {
"responses": { "responses": {
@ -1665,6 +1695,75 @@
] ]
} }
}, },
"/providers/list": {
"get": {
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {
"type": "object",
"additionalProperties": {
"$ref": "#/components/schemas/ProviderInfo"
}
}
}
}
}
},
"tags": [
"Inspect"
],
"parameters": [
{
"name": "X-LlamaStack-ProviderData",
"in": "header",
"description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
"required": false,
"schema": {
"type": "string"
}
}
]
}
},
"/routes/list": {
"get": {
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {
"type": "object",
"additionalProperties": {
"type": "array",
"items": {
"$ref": "#/components/schemas/RouteInfo"
}
}
}
}
}
}
},
"tags": [
"Inspect"
],
"parameters": [
{
"name": "X-LlamaStack-ProviderData",
"in": "header",
"description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
"required": false,
"schema": {
"type": "string"
}
}
]
}
},
"/shields/list": { "/shields/list": {
"get": { "get": {
"responses": { "responses": {
@ -5086,6 +5185,18 @@
"job_uuid" "job_uuid"
] ]
}, },
"HealthInfo": {
"type": "object",
"properties": {
"status": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"status"
]
},
"InsertDocumentsRequest": { "InsertDocumentsRequest": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -5108,6 +5219,45 @@
"documents" "documents"
] ]
}, },
"ProviderInfo": {
"type": "object",
"properties": {
"provider_type": {
"type": "string"
},
"description": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"provider_type",
"description"
]
},
"RouteInfo": {
"type": "object",
"properties": {
"route": {
"type": "string"
},
"method": {
"type": "string"
},
"providers": {
"type": "array",
"items": {
"type": "string"
}
}
},
"additionalProperties": false,
"required": [
"route",
"method",
"providers"
]
},
"LogSeverity": { "LogSeverity": {
"type": "string", "type": "string",
"enum": [ "enum": [
@ -6220,19 +6370,34 @@
], ],
"tags": [ "tags": [
{ {
"name": "Shields" "name": "Datasets"
},
{
"name": "Inspect"
},
{
"name": "Memory"
}, },
{ {
"name": "BatchInference" "name": "BatchInference"
}, },
{ {
"name": "RewardScoring" "name": "Agents"
},
{
"name": "Inference"
},
{
"name": "Shields"
}, },
{ {
"name": "SyntheticDataGeneration" "name": "SyntheticDataGeneration"
}, },
{ {
"name": "Agents" "name": "Models"
},
{
"name": "RewardScoring"
}, },
{ {
"name": "MemoryBanks" "name": "MemoryBanks"
@ -6241,13 +6406,7 @@
"name": "Safety" "name": "Safety"
}, },
{ {
"name": "Models" "name": "Evaluations"
},
{
"name": "Inference"
},
{
"name": "Memory"
}, },
{ {
"name": "Telemetry" "name": "Telemetry"
@ -6255,12 +6414,6 @@
{ {
"name": "PostTraining" "name": "PostTraining"
}, },
{
"name": "Datasets"
},
{
"name": "Evaluations"
},
{ {
"name": "BuiltinTool", "name": "BuiltinTool",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/BuiltinTool\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/BuiltinTool\" />"
@ -6653,10 +6806,22 @@
"name": "PostTrainingJob", "name": "PostTrainingJob",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/PostTrainingJob\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/PostTrainingJob\" />"
}, },
{
"name": "HealthInfo",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/HealthInfo\" />"
},
{ {
"name": "InsertDocumentsRequest", "name": "InsertDocumentsRequest",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/InsertDocumentsRequest\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/InsertDocumentsRequest\" />"
}, },
{
"name": "ProviderInfo",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/ProviderInfo\" />"
},
{
"name": "RouteInfo",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/RouteInfo\" />"
},
{ {
"name": "LogSeverity", "name": "LogSeverity",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/LogSeverity\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/LogSeverity\" />"
@ -6787,6 +6952,7 @@
"Datasets", "Datasets",
"Evaluations", "Evaluations",
"Inference", "Inference",
"Inspect",
"Memory", "Memory",
"MemoryBanks", "MemoryBanks",
"Models", "Models",
@ -6857,6 +7023,7 @@
"FunctionCallToolDefinition", "FunctionCallToolDefinition",
"GetAgentsSessionRequest", "GetAgentsSessionRequest",
"GetDocumentsRequest", "GetDocumentsRequest",
"HealthInfo",
"ImageMedia", "ImageMedia",
"InferenceStep", "InferenceStep",
"InsertDocumentsRequest", "InsertDocumentsRequest",
@ -6880,6 +7047,7 @@
"PostTrainingJobStatus", "PostTrainingJobStatus",
"PostTrainingJobStatusResponse", "PostTrainingJobStatusResponse",
"PreferenceOptimizeRequest", "PreferenceOptimizeRequest",
"ProviderInfo",
"QLoraFinetuningConfig", "QLoraFinetuningConfig",
"QueryDocumentsRequest", "QueryDocumentsRequest",
"QueryDocumentsResponse", "QueryDocumentsResponse",
@ -6888,6 +7056,7 @@
"RestAPIMethod", "RestAPIMethod",
"RewardScoreRequest", "RewardScoreRequest",
"RewardScoringResponse", "RewardScoringResponse",
"RouteInfo",
"RunShieldRequest", "RunShieldRequest",
"RunShieldResponse", "RunShieldResponse",
"SafetyViolation", "SafetyViolation",

View file

@ -908,6 +908,14 @@ components:
required: required:
- document_ids - document_ids
type: object type: object
HealthInfo:
additionalProperties: false
properties:
status:
type: string
required:
- status
type: object
ImageMedia: ImageMedia:
additionalProperties: false additionalProperties: false
properties: properties:
@ -1543,6 +1551,17 @@ components:
- hyperparam_search_config - hyperparam_search_config
- logger_config - logger_config
type: object type: object
ProviderInfo:
additionalProperties: false
properties:
description:
type: string
provider_type:
type: string
required:
- provider_type
- description
type: object
QLoraFinetuningConfig: QLoraFinetuningConfig:
additionalProperties: false additionalProperties: false
properties: properties:
@ -1704,6 +1723,22 @@ components:
title: Response from the reward scoring. Batch of (prompt, response, score) title: Response from the reward scoring. Batch of (prompt, response, score)
tuples that pass the threshold. tuples that pass the threshold.
type: object type: object
RouteInfo:
additionalProperties: false
properties:
method:
type: string
providers:
items:
type: string
type: array
route:
type: string
required:
- route
- method
- providers
type: object
RunShieldRequest: RunShieldRequest:
additionalProperties: false additionalProperties: false
properties: properties:
@ -2569,7 +2604,7 @@ info:
description: "This is the specification of the llama stack that provides\n \ description: "This is the specification of the llama stack that provides\n \
\ a set of endpoints and their corresponding interfaces that are tailored\ \ a set of endpoints and their corresponding interfaces that are tailored\
\ to\n best leverage Llama Models. The specification is still in\ \ to\n best leverage Llama Models. The specification is still in\
\ draft and subject to change.\n Generated at 2024-09-23 16:58:41.469308" \ draft and subject to change.\n Generated at 2024-10-02 15:40:53.008257"
title: '[DRAFT] Llama Stack Specification' title: '[DRAFT] Llama Stack Specification'
version: 0.0.1 version: 0.0.1
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
@ -3093,6 +3128,25 @@ paths:
description: OK description: OK
tags: tags:
- Evaluations - Evaluations
/health:
get:
parameters:
- description: JSON-encoded provider data which will be made available to the
adapter servicing the API
in: header
name: X-LlamaStack-ProviderData
required: false
schema:
type: string
responses:
'200':
content:
application/json:
schema:
$ref: '#/components/schemas/HealthInfo'
description: OK
tags:
- Inspect
/inference/chat_completion: /inference/chat_completion:
post: post:
parameters: parameters:
@ -3637,6 +3691,27 @@ paths:
description: OK description: OK
tags: tags:
- PostTraining - PostTraining
/providers/list:
get:
parameters:
- description: JSON-encoded provider data which will be made available to the
adapter servicing the API
in: header
name: X-LlamaStack-ProviderData
required: false
schema:
type: string
responses:
'200':
content:
application/json:
schema:
additionalProperties:
$ref: '#/components/schemas/ProviderInfo'
type: object
description: OK
tags:
- Inspect
/reward_scoring/score: /reward_scoring/score:
post: post:
parameters: parameters:
@ -3662,6 +3737,29 @@ paths:
description: OK description: OK
tags: tags:
- RewardScoring - RewardScoring
/routes/list:
get:
parameters:
- description: JSON-encoded provider data which will be made available to the
adapter servicing the API
in: header
name: X-LlamaStack-ProviderData
required: false
schema:
type: string
responses:
'200':
content:
application/json:
schema:
additionalProperties:
items:
$ref: '#/components/schemas/RouteInfo'
type: array
type: object
description: OK
tags:
- Inspect
/safety/run_shield: /safety/run_shield:
post: post:
parameters: parameters:
@ -3807,20 +3905,21 @@ security:
servers: servers:
- url: http://any-hosted-llama-stack.com - url: http://any-hosted-llama-stack.com
tags: tags:
- name: Shields - name: Datasets
- name: Inspect
- name: Memory
- name: BatchInference - name: BatchInference
- name: RewardScoring
- name: SyntheticDataGeneration
- name: Agents - name: Agents
- name: Inference
- name: Shields
- name: SyntheticDataGeneration
- name: Models
- name: RewardScoring
- name: MemoryBanks - name: MemoryBanks
- name: Safety - name: Safety
- name: Models - name: Evaluations
- name: Inference
- name: Memory
- name: Telemetry - name: Telemetry
- name: PostTraining - name: PostTraining
- name: Datasets
- name: Evaluations
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" /> - description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
name: BuiltinTool name: BuiltinTool
- description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage" - description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage"
@ -4135,9 +4234,15 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/PostTrainingJob" - description: <SchemaDefinition schemaRef="#/components/schemas/PostTrainingJob"
/> />
name: PostTrainingJob name: PostTrainingJob
- description: <SchemaDefinition schemaRef="#/components/schemas/HealthInfo" />
name: HealthInfo
- description: <SchemaDefinition schemaRef="#/components/schemas/InsertDocumentsRequest" - description: <SchemaDefinition schemaRef="#/components/schemas/InsertDocumentsRequest"
/> />
name: InsertDocumentsRequest name: InsertDocumentsRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/ProviderInfo" />
name: ProviderInfo
- description: <SchemaDefinition schemaRef="#/components/schemas/RouteInfo" />
name: RouteInfo
- description: <SchemaDefinition schemaRef="#/components/schemas/LogSeverity" /> - description: <SchemaDefinition schemaRef="#/components/schemas/LogSeverity" />
name: LogSeverity name: LogSeverity
- description: <SchemaDefinition schemaRef="#/components/schemas/MetricEvent" /> - description: <SchemaDefinition schemaRef="#/components/schemas/MetricEvent" />
@ -4236,6 +4341,7 @@ x-tagGroups:
- Datasets - Datasets
- Evaluations - Evaluations
- Inference - Inference
- Inspect
- Memory - Memory
- MemoryBanks - MemoryBanks
- Models - Models
@ -4303,6 +4409,7 @@ x-tagGroups:
- FunctionCallToolDefinition - FunctionCallToolDefinition
- GetAgentsSessionRequest - GetAgentsSessionRequest
- GetDocumentsRequest - GetDocumentsRequest
- HealthInfo
- ImageMedia - ImageMedia
- InferenceStep - InferenceStep
- InsertDocumentsRequest - InsertDocumentsRequest
@ -4326,6 +4433,7 @@ x-tagGroups:
- PostTrainingJobStatus - PostTrainingJobStatus
- PostTrainingJobStatusResponse - PostTrainingJobStatusResponse
- PreferenceOptimizeRequest - PreferenceOptimizeRequest
- ProviderInfo
- QLoraFinetuningConfig - QLoraFinetuningConfig
- QueryDocumentsRequest - QueryDocumentsRequest
- QueryDocumentsResponse - QueryDocumentsResponse
@ -4334,6 +4442,7 @@ x-tagGroups:
- RestAPIMethod - RestAPIMethod
- RewardScoreRequest - RewardScoreRequest
- RewardScoringResponse - RewardScoringResponse
- RouteInfo
- RunShieldRequest - RunShieldRequest
- RunShieldResponse - RunShieldResponse
- SafetyViolation - SafetyViolation

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .inspect import * # noqa: F401 F403

View file

@ -0,0 +1,82 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import List
import fire
import httpx
from termcolor import cprint
from .inspect import * # noqa: F403
class InspectClient(Inspect):
def __init__(self, base_url: str):
self.base_url = base_url
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def list_providers(self) -> Dict[str, ProviderInfo]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/providers/list",
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
print(response.json())
return {
k: [ProviderInfo(**vi) for vi in v] for k, v in response.json().items()
}
async def list_routes(self) -> Dict[str, List[RouteInfo]]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/routes/list",
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return {
k: [RouteInfo(**vi) for vi in v] for k, v in response.json().items()
}
async def health(self) -> HealthInfo:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/health",
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
j = response.json()
if j is None:
return None
return HealthInfo(**j)
async def run_main(host: str, port: int):
client = InspectClient(f"http://{host}:{port}")
response = await client.list_providers()
cprint(f"list_providers response={response}", "green")
response = await client.list_routes()
cprint(f"list_routes response={response}", "blue")
response = await client.health()
cprint(f"health response={response}", "yellow")
def main(host: str, port: int):
asyncio.run(run_main(host, port))
if __name__ == "__main__":
fire.Fire(main)

View file

@ -0,0 +1,40 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict, List, Protocol
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
@json_schema_type
class ProviderInfo(BaseModel):
provider_type: str
description: str
@json_schema_type
class RouteInfo(BaseModel):
route: str
method: str
providers: List[str]
@json_schema_type
class HealthInfo(BaseModel):
status: str
# TODO: add a provider level status
class Inspect(Protocol):
@webmethod(route="/providers/list", method="GET")
async def list_providers(self) -> Dict[str, ProviderInfo]: ...
@webmethod(route="/routes/list", method="GET")
async def list_routes(self) -> Dict[str, List[RouteInfo]]: ...
@webmethod(route="/health", method="GET")
async def health(self) -> HealthInfo: ...

View file

@ -17,6 +17,53 @@ LLAMA_STACK_BUILD_CONFIG_VERSION = "v1"
LLAMA_STACK_RUN_CONFIG_VERSION = "v1" LLAMA_STACK_RUN_CONFIG_VERSION = "v1"
RoutingKey = Union[str, List[str]]
class GenericProviderConfig(BaseModel):
provider_type: str
config: Dict[str, Any]
class RoutableProviderConfig(GenericProviderConfig):
routing_key: RoutingKey
class PlaceholderProviderConfig(BaseModel):
"""Placeholder provider config for API whose provider are defined in routing_table"""
providers: List[str]
# Example: /inference, /safety
class AutoRoutedProviderSpec(ProviderSpec):
provider_type: str = "router"
config_class: str = ""
docker_image: Optional[str] = None
routing_table_api: Api
module: str
provider_data_validator: Optional[str] = Field(
default=None,
)
@property
def pip_packages(self) -> List[str]:
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
# Example: /models, /shields
@json_schema_type
class RoutingTableProviderSpec(ProviderSpec):
provider_type: str = "routing_table"
config_class: str = ""
docker_image: Optional[str] = None
inner_specs: List[ProviderSpec]
module: str
pip_packages: List[str] = Field(default_factory=list)
@json_schema_type @json_schema_type
class DistributionSpec(BaseModel): class DistributionSpec(BaseModel):
description: Optional[str] = Field( description: Optional[str] = Field(

View file

@ -46,6 +46,8 @@ def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
for api in stack_apis(): for api in stack_apis():
if api in routing_table_apis: if api in routing_table_apis:
continue continue
if api == Api.inspect:
continue
name = api.name.lower() name = api.name.lower()
module = importlib.import_module(f"llama_stack.providers.registry.{name}") module = importlib.import_module(f"llama_stack.providers.registry.{name}")

View file

@ -0,0 +1,54 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict, List
from llama_stack.apis.inspect import * # noqa: F403
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.providers.datatypes import * # noqa: F403
def is_passthrough(spec: ProviderSpec) -> bool:
return isinstance(spec, RemoteProviderSpec) and spec.adapter is None
class DistributionInspectImpl(Inspect):
def __init__(self):
pass
async def list_providers(self) -> Dict[str, List[ProviderInfo]]:
ret = {}
all_providers = get_provider_registry()
for api, providers in all_providers.items():
ret[api.value] = [
ProviderInfo(
provider_type=p.provider_type,
description="Passthrough" if is_passthrough(p) else "",
)
for p in providers.values()
]
return ret
async def list_routes(self) -> Dict[str, List[RouteInfo]]:
ret = {}
all_endpoints = get_all_api_endpoints()
for api, endpoints in all_endpoints.items():
ret[api.value] = [
RouteInfo(
route=e.route,
method=e.method,
providers=[],
)
for e in endpoints
]
return ret
async def health(self) -> HealthInfo:
return HealthInfo(status="OK")

View file

@ -3,6 +3,7 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import importlib
from typing import Any, Dict, List, Set from typing import Any, Dict, List, Set
@ -11,7 +12,8 @@ from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis, builtin_automatically_routed_apis,
get_provider_registry, get_provider_registry,
) )
from llama_stack.distribution.utils.dynamic import instantiate_provider from llama_stack.distribution.inspect import DistributionInspectImpl
from llama_stack.distribution.utils.dynamic import instantiate_class_type
async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]: async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]:
@ -57,7 +59,6 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
if info.router_api.value not in apis_to_serve: if info.router_api.value not in apis_to_serve:
continue continue
print("router_api", info.router_api)
if info.router_api.value not in run_config.routing_table: if info.router_api.value not in run_config.routing_table:
raise ValueError(f"Routing table for `{source_api.value}` is not provided?") raise ValueError(f"Routing table for `{source_api.value}` is not provided?")
@ -104,6 +105,14 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
impls[api] = impl impls[api] = impl
impls[Api.inspect] = DistributionInspectImpl()
specs[Api.inspect] = InlineProviderSpec(
api=Api.inspect,
provider_type="__distribution_builtin__",
config_class="",
module="",
)
return impls, specs return impls, specs
@ -127,3 +136,60 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
dfs(a, visited, stack) dfs(a, visited, stack)
return [by_id[x] for x in stack] return [by_id[x] for x in stack]
# returns a class implementing the protocol corresponding to the Api
async def instantiate_provider(
provider_spec: ProviderSpec,
deps: Dict[str, Any],
provider_config: Union[GenericProviderConfig, RoutingTable],
):
module = importlib.import_module(provider_spec.module)
args = []
if isinstance(provider_spec, RemoteProviderSpec):
if provider_spec.adapter:
method = "get_adapter_impl"
else:
method = "get_client_impl"
assert isinstance(provider_config, GenericProviderConfig)
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config)
args = [config, deps]
elif isinstance(provider_spec, AutoRoutedProviderSpec):
method = "get_auto_router_impl"
config = None
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps]
elif isinstance(provider_spec, RoutingTableProviderSpec):
method = "get_routing_table_impl"
assert isinstance(provider_config, List)
routing_table = provider_config
inner_specs = {x.provider_type: x for x in provider_spec.inner_specs}
inner_impls = []
for routing_entry in routing_table:
impl = await instantiate_provider(
inner_specs[routing_entry.provider_type],
deps,
routing_entry,
)
inner_impls.append((routing_entry.routing_key, impl))
config = None
args = [provider_spec.api, inner_impls, routing_table, deps]
else:
method = "get_provider_impl"
assert isinstance(provider_config, GenericProviderConfig)
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config)
args = [config, deps]
fn = getattr(module, method)
impl = await fn(*args)
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
return impl

View file

@ -11,12 +11,14 @@ from pydantic import BaseModel
from llama_stack.apis.agents import Agents from llama_stack.apis.agents import Agents
from llama_stack.apis.inference import Inference from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.memory import Memory from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.models import Models from llama_stack.apis.models import Models
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.shields import Shields from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry from llama_stack.apis.telemetry import Telemetry
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
@ -38,6 +40,7 @@ def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
Api.models: Models, Api.models: Models,
Api.shields: Shields, Api.shields: Shields,
Api.memory_banks: MemoryBanks, Api.memory_banks: MemoryBanks,
Api.inspect: Inspect,
} }
for api, protocol in protocols.items(): for api, protocol in protocols.items():

View file

@ -15,7 +15,6 @@ from collections.abc import (
AsyncIterator as AsyncIteratorABC, AsyncIterator as AsyncIteratorABC,
) )
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from http import HTTPStatus
from ssl import SSLError from ssl import SSLError
from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional
@ -26,7 +25,6 @@ import yaml
from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from termcolor import cprint from termcolor import cprint
from typing_extensions import Annotated from typing_extensions import Annotated
@ -287,15 +285,6 @@ def main(
app = FastAPI() app = FastAPI()
# Health check is added to enable deploying the docker container image on Kubernetes which require
# a health check that can return 200 for readiness and liveness check
class HealthCheck(BaseModel):
status: str = "OK"
@app.get("/healthcheck", status_code=HTTPStatus.OK, response_model=HealthCheck)
async def healthcheck():
return HealthCheck(status="OK")
impls, specs = asyncio.run(resolve_impls_with_routing(config)) impls, specs = asyncio.run(resolve_impls_with_routing(config))
if Api.telemetry in impls: if Api.telemetry in impls:
setup_logger(impls[Api.telemetry]) setup_logger(impls[Api.telemetry])
@ -307,6 +296,7 @@ def main(
else: else:
apis_to_serve = set(impls.keys()) apis_to_serve = set(impls.keys())
apis_to_serve.add(Api.inspect)
for api_str in apis_to_serve: for api_str in apis_to_serve:
api = Api(api_str) api = Api(api_str)
@ -340,14 +330,11 @@ def main(
) )
) )
for route in app.routes: cprint(f"Serving API {api_str}", "white", attrs=["bold"])
if isinstance(route, APIRoute): for endpoint in endpoints:
cprint( cprint(f" {endpoint.method.upper()} {endpoint.route}", "white")
f"Serving {next(iter(route.methods))} {route.path}",
"white",
attrs=["bold"],
)
print("")
app.exception_handler(RequestValidationError)(global_exception_handler) app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler)
signal.signal(signal.SIGINT, handle_sigint) signal.signal(signal.SIGINT, handle_sigint)

View file

@ -5,69 +5,9 @@
# the root directory of this source tree. # the root directory of this source tree.
import importlib import importlib
from typing import Any, Dict
from llama_stack.distribution.datatypes import * # noqa: F403
def instantiate_class_type(fully_qualified_name): def instantiate_class_type(fully_qualified_name):
module_name, class_name = fully_qualified_name.rsplit(".", 1) module_name, class_name = fully_qualified_name.rsplit(".", 1)
module = importlib.import_module(module_name) module = importlib.import_module(module_name)
return getattr(module, class_name) return getattr(module, class_name)
# returns a class implementing the protocol corresponding to the Api
async def instantiate_provider(
provider_spec: ProviderSpec,
deps: Dict[str, Any],
provider_config: Union[GenericProviderConfig, RoutingTable],
):
module = importlib.import_module(provider_spec.module)
args = []
if isinstance(provider_spec, RemoteProviderSpec):
if provider_spec.adapter:
method = "get_adapter_impl"
else:
method = "get_client_impl"
assert isinstance(provider_config, GenericProviderConfig)
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config)
args = [config, deps]
elif isinstance(provider_spec, AutoRoutedProviderSpec):
method = "get_auto_router_impl"
config = None
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps]
elif isinstance(provider_spec, RoutingTableProviderSpec):
method = "get_routing_table_impl"
assert isinstance(provider_config, List)
routing_table = provider_config
inner_specs = {x.provider_type: x for x in provider_spec.inner_specs}
inner_impls = []
for routing_entry in routing_table:
impl = await instantiate_provider(
inner_specs[routing_entry.provider_type],
deps,
routing_entry,
)
inner_impls.append((routing_entry.routing_key, impl))
config = None
args = [provider_spec.api, inner_impls, routing_table, deps]
else:
method = "get_provider_impl"
assert isinstance(provider_config, GenericProviderConfig)
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config)
args = [config, deps]
fn = getattr(module, method)
impl = await fn(*args)
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
return impl

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, Union from typing import Any, List, Optional, Protocol
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -24,6 +24,9 @@ class Api(Enum):
shields = "shields" shields = "shields"
memory_banks = "memory_banks" memory_banks = "memory_banks"
# built-in API
inspect = "inspect"
@json_schema_type @json_schema_type
class ProviderSpec(BaseModel): class ProviderSpec(BaseModel):
@ -55,68 +58,6 @@ class RoutableProvider(Protocol):
async def validate_routing_keys(self, keys: List[str]) -> None: ... async def validate_routing_keys(self, keys: List[str]) -> None: ...
class GenericProviderConfig(BaseModel):
provider_type: str
config: Dict[str, Any]
class PlaceholderProviderConfig(BaseModel):
"""Placeholder provider config for API whose provider are defined in routing_table"""
providers: List[str]
RoutingKey = Union[str, List[str]]
class RoutableProviderConfig(GenericProviderConfig):
routing_key: RoutingKey
# Example: /inference, /safety
@json_schema_type
class AutoRoutedProviderSpec(ProviderSpec):
provider_type: str = "router"
config_class: str = ""
docker_image: Optional[str] = None
routing_table_api: Api
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
""",
)
provider_data_validator: Optional[str] = Field(
default=None,
)
@property
def pip_packages(self) -> List[str]:
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
# Example: /models, /shields
@json_schema_type
class RoutingTableProviderSpec(ProviderSpec):
provider_type: str = "routing_table"
config_class: str = ""
docker_image: Optional[str] = None
inner_specs: List[ProviderSpec]
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
""",
)
pip_packages: List[str] = Field(default_factory=list)
@json_schema_type @json_schema_type
class AdapterSpec(BaseModel): class AdapterSpec(BaseModel):
adapter_type: str = Field( adapter_type: str = Field(
@ -179,10 +120,6 @@ class RemoteProviderConfig(BaseModel):
return f"http://{self.host}:{self.port}" return f"http://{self.host}:{self.port}"
def remote_provider_type(adapter_type: str) -> str:
return f"remote::{adapter_type}"
@json_schema_type @json_schema_type
class RemoteProviderSpec(ProviderSpec): class RemoteProviderSpec(ProviderSpec):
adapter: Optional[AdapterSpec] = Field( adapter: Optional[AdapterSpec] = Field(
@ -226,7 +163,7 @@ def remote_provider_spec(
if adapter and adapter.config_class if adapter and adapter.config_class
else "llama_stack.distribution.datatypes.RemoteProviderConfig" else "llama_stack.distribution.datatypes.RemoteProviderConfig"
) )
provider_type = remote_provider_type(adapter.adapter_type) if adapter else "remote" provider_type = f"remote::{adapter.adapter_type}" if adapter else "remote"
return RemoteProviderSpec( return RemoteProviderSpec(
api=api, provider_type=provider_type, config_class=config_class, adapter=adapter api=api, provider_type=provider_type, config_class=config_class, adapter=adapter