mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
Add an introspection "Api.inspect" API
This commit is contained in:
parent
01d93be948
commit
8d049000e3
14 changed files with 619 additions and 174 deletions
|
@ -46,6 +46,7 @@ from llama_stack.apis.safety import * # noqa: F403
|
|||
from llama_stack.apis.models import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
from llama_stack.apis.shields import * # noqa: F403
|
||||
from llama_stack.apis.inspect import * # noqa: F403
|
||||
|
||||
|
||||
class LlamaStack(
|
||||
|
@ -63,6 +64,7 @@ class LlamaStack(
|
|||
Evaluations,
|
||||
Models,
|
||||
Shields,
|
||||
Inspect,
|
||||
):
|
||||
pass
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
"info": {
|
||||
"title": "[DRAFT] Llama Stack Specification",
|
||||
"version": "0.0.1",
|
||||
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-09-23 16:58:41.469308"
|
||||
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-02 15:40:53.008257"
|
||||
},
|
||||
"servers": [
|
||||
{
|
||||
|
@ -1542,6 +1542,36 @@
|
|||
]
|
||||
}
|
||||
},
|
||||
"/health": {
|
||||
"get": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/HealthInfo"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"tags": [
|
||||
"Inspect"
|
||||
],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "X-LlamaStack-ProviderData",
|
||||
"in": "header",
|
||||
"description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/memory/insert": {
|
||||
"post": {
|
||||
"responses": {
|
||||
|
@ -1665,6 +1695,75 @@
|
|||
]
|
||||
}
|
||||
},
|
||||
"/providers/list": {
|
||||
"get": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"$ref": "#/components/schemas/ProviderInfo"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"tags": [
|
||||
"Inspect"
|
||||
],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "X-LlamaStack-ProviderData",
|
||||
"in": "header",
|
||||
"description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/routes/list": {
|
||||
"get": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/RouteInfo"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"tags": [
|
||||
"Inspect"
|
||||
],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "X-LlamaStack-ProviderData",
|
||||
"in": "header",
|
||||
"description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/shields/list": {
|
||||
"get": {
|
||||
"responses": {
|
||||
|
@ -5086,6 +5185,18 @@
|
|||
"job_uuid"
|
||||
]
|
||||
},
|
||||
"HealthInfo": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"status"
|
||||
]
|
||||
},
|
||||
"InsertDocumentsRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -5108,6 +5219,45 @@
|
|||
"documents"
|
||||
]
|
||||
},
|
||||
"ProviderInfo": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"provider_type": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"provider_type",
|
||||
"description"
|
||||
]
|
||||
},
|
||||
"RouteInfo": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"route": {
|
||||
"type": "string"
|
||||
},
|
||||
"method": {
|
||||
"type": "string"
|
||||
},
|
||||
"providers": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"route",
|
||||
"method",
|
||||
"providers"
|
||||
]
|
||||
},
|
||||
"LogSeverity": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
|
@ -6220,19 +6370,34 @@
|
|||
],
|
||||
"tags": [
|
||||
{
|
||||
"name": "Shields"
|
||||
"name": "Datasets"
|
||||
},
|
||||
{
|
||||
"name": "Inspect"
|
||||
},
|
||||
{
|
||||
"name": "Memory"
|
||||
},
|
||||
{
|
||||
"name": "BatchInference"
|
||||
},
|
||||
{
|
||||
"name": "RewardScoring"
|
||||
"name": "Agents"
|
||||
},
|
||||
{
|
||||
"name": "Inference"
|
||||
},
|
||||
{
|
||||
"name": "Shields"
|
||||
},
|
||||
{
|
||||
"name": "SyntheticDataGeneration"
|
||||
},
|
||||
{
|
||||
"name": "Agents"
|
||||
"name": "Models"
|
||||
},
|
||||
{
|
||||
"name": "RewardScoring"
|
||||
},
|
||||
{
|
||||
"name": "MemoryBanks"
|
||||
|
@ -6241,13 +6406,7 @@
|
|||
"name": "Safety"
|
||||
},
|
||||
{
|
||||
"name": "Models"
|
||||
},
|
||||
{
|
||||
"name": "Inference"
|
||||
},
|
||||
{
|
||||
"name": "Memory"
|
||||
"name": "Evaluations"
|
||||
},
|
||||
{
|
||||
"name": "Telemetry"
|
||||
|
@ -6255,12 +6414,6 @@
|
|||
{
|
||||
"name": "PostTraining"
|
||||
},
|
||||
{
|
||||
"name": "Datasets"
|
||||
},
|
||||
{
|
||||
"name": "Evaluations"
|
||||
},
|
||||
{
|
||||
"name": "BuiltinTool",
|
||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/BuiltinTool\" />"
|
||||
|
@ -6653,10 +6806,22 @@
|
|||
"name": "PostTrainingJob",
|
||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/PostTrainingJob\" />"
|
||||
},
|
||||
{
|
||||
"name": "HealthInfo",
|
||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/HealthInfo\" />"
|
||||
},
|
||||
{
|
||||
"name": "InsertDocumentsRequest",
|
||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/InsertDocumentsRequest\" />"
|
||||
},
|
||||
{
|
||||
"name": "ProviderInfo",
|
||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/ProviderInfo\" />"
|
||||
},
|
||||
{
|
||||
"name": "RouteInfo",
|
||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/RouteInfo\" />"
|
||||
},
|
||||
{
|
||||
"name": "LogSeverity",
|
||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/LogSeverity\" />"
|
||||
|
@ -6787,6 +6952,7 @@
|
|||
"Datasets",
|
||||
"Evaluations",
|
||||
"Inference",
|
||||
"Inspect",
|
||||
"Memory",
|
||||
"MemoryBanks",
|
||||
"Models",
|
||||
|
@ -6857,6 +7023,7 @@
|
|||
"FunctionCallToolDefinition",
|
||||
"GetAgentsSessionRequest",
|
||||
"GetDocumentsRequest",
|
||||
"HealthInfo",
|
||||
"ImageMedia",
|
||||
"InferenceStep",
|
||||
"InsertDocumentsRequest",
|
||||
|
@ -6880,6 +7047,7 @@
|
|||
"PostTrainingJobStatus",
|
||||
"PostTrainingJobStatusResponse",
|
||||
"PreferenceOptimizeRequest",
|
||||
"ProviderInfo",
|
||||
"QLoraFinetuningConfig",
|
||||
"QueryDocumentsRequest",
|
||||
"QueryDocumentsResponse",
|
||||
|
@ -6888,6 +7056,7 @@
|
|||
"RestAPIMethod",
|
||||
"RewardScoreRequest",
|
||||
"RewardScoringResponse",
|
||||
"RouteInfo",
|
||||
"RunShieldRequest",
|
||||
"RunShieldResponse",
|
||||
"SafetyViolation",
|
||||
|
|
|
@ -908,6 +908,14 @@ components:
|
|||
required:
|
||||
- document_ids
|
||||
type: object
|
||||
HealthInfo:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
status:
|
||||
type: string
|
||||
required:
|
||||
- status
|
||||
type: object
|
||||
ImageMedia:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
|
@ -1543,6 +1551,17 @@ components:
|
|||
- hyperparam_search_config
|
||||
- logger_config
|
||||
type: object
|
||||
ProviderInfo:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
description:
|
||||
type: string
|
||||
provider_type:
|
||||
type: string
|
||||
required:
|
||||
- provider_type
|
||||
- description
|
||||
type: object
|
||||
QLoraFinetuningConfig:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
|
@ -1704,6 +1723,22 @@ components:
|
|||
title: Response from the reward scoring. Batch of (prompt, response, score)
|
||||
tuples that pass the threshold.
|
||||
type: object
|
||||
RouteInfo:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
method:
|
||||
type: string
|
||||
providers:
|
||||
items:
|
||||
type: string
|
||||
type: array
|
||||
route:
|
||||
type: string
|
||||
required:
|
||||
- route
|
||||
- method
|
||||
- providers
|
||||
type: object
|
||||
RunShieldRequest:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
|
@ -2569,7 +2604,7 @@ info:
|
|||
description: "This is the specification of the llama stack that provides\n \
|
||||
\ a set of endpoints and their corresponding interfaces that are tailored\
|
||||
\ to\n best leverage Llama Models. The specification is still in\
|
||||
\ draft and subject to change.\n Generated at 2024-09-23 16:58:41.469308"
|
||||
\ draft and subject to change.\n Generated at 2024-10-02 15:40:53.008257"
|
||||
title: '[DRAFT] Llama Stack Specification'
|
||||
version: 0.0.1
|
||||
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
|
||||
|
@ -3093,6 +3128,25 @@ paths:
|
|||
description: OK
|
||||
tags:
|
||||
- Evaluations
|
||||
/health:
|
||||
get:
|
||||
parameters:
|
||||
- description: JSON-encoded provider data which will be made available to the
|
||||
adapter servicing the API
|
||||
in: header
|
||||
name: X-LlamaStack-ProviderData
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
'200':
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/HealthInfo'
|
||||
description: OK
|
||||
tags:
|
||||
- Inspect
|
||||
/inference/chat_completion:
|
||||
post:
|
||||
parameters:
|
||||
|
@ -3637,6 +3691,27 @@ paths:
|
|||
description: OK
|
||||
tags:
|
||||
- PostTraining
|
||||
/providers/list:
|
||||
get:
|
||||
parameters:
|
||||
- description: JSON-encoded provider data which will be made available to the
|
||||
adapter servicing the API
|
||||
in: header
|
||||
name: X-LlamaStack-ProviderData
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
'200':
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
additionalProperties:
|
||||
$ref: '#/components/schemas/ProviderInfo'
|
||||
type: object
|
||||
description: OK
|
||||
tags:
|
||||
- Inspect
|
||||
/reward_scoring/score:
|
||||
post:
|
||||
parameters:
|
||||
|
@ -3662,6 +3737,29 @@ paths:
|
|||
description: OK
|
||||
tags:
|
||||
- RewardScoring
|
||||
/routes/list:
|
||||
get:
|
||||
parameters:
|
||||
- description: JSON-encoded provider data which will be made available to the
|
||||
adapter servicing the API
|
||||
in: header
|
||||
name: X-LlamaStack-ProviderData
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
'200':
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
additionalProperties:
|
||||
items:
|
||||
$ref: '#/components/schemas/RouteInfo'
|
||||
type: array
|
||||
type: object
|
||||
description: OK
|
||||
tags:
|
||||
- Inspect
|
||||
/safety/run_shield:
|
||||
post:
|
||||
parameters:
|
||||
|
@ -3807,20 +3905,21 @@ security:
|
|||
servers:
|
||||
- url: http://any-hosted-llama-stack.com
|
||||
tags:
|
||||
- name: Shields
|
||||
- name: Datasets
|
||||
- name: Inspect
|
||||
- name: Memory
|
||||
- name: BatchInference
|
||||
- name: RewardScoring
|
||||
- name: SyntheticDataGeneration
|
||||
- name: Agents
|
||||
- name: Inference
|
||||
- name: Shields
|
||||
- name: SyntheticDataGeneration
|
||||
- name: Models
|
||||
- name: RewardScoring
|
||||
- name: MemoryBanks
|
||||
- name: Safety
|
||||
- name: Models
|
||||
- name: Inference
|
||||
- name: Memory
|
||||
- name: Evaluations
|
||||
- name: Telemetry
|
||||
- name: PostTraining
|
||||
- name: Datasets
|
||||
- name: Evaluations
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
|
||||
name: BuiltinTool
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage"
|
||||
|
@ -4135,9 +4234,15 @@ tags:
|
|||
- description: <SchemaDefinition schemaRef="#/components/schemas/PostTrainingJob"
|
||||
/>
|
||||
name: PostTrainingJob
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/HealthInfo" />
|
||||
name: HealthInfo
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/InsertDocumentsRequest"
|
||||
/>
|
||||
name: InsertDocumentsRequest
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/ProviderInfo" />
|
||||
name: ProviderInfo
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/RouteInfo" />
|
||||
name: RouteInfo
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/LogSeverity" />
|
||||
name: LogSeverity
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/MetricEvent" />
|
||||
|
@ -4236,6 +4341,7 @@ x-tagGroups:
|
|||
- Datasets
|
||||
- Evaluations
|
||||
- Inference
|
||||
- Inspect
|
||||
- Memory
|
||||
- MemoryBanks
|
||||
- Models
|
||||
|
@ -4303,6 +4409,7 @@ x-tagGroups:
|
|||
- FunctionCallToolDefinition
|
||||
- GetAgentsSessionRequest
|
||||
- GetDocumentsRequest
|
||||
- HealthInfo
|
||||
- ImageMedia
|
||||
- InferenceStep
|
||||
- InsertDocumentsRequest
|
||||
|
@ -4326,6 +4433,7 @@ x-tagGroups:
|
|||
- PostTrainingJobStatus
|
||||
- PostTrainingJobStatusResponse
|
||||
- PreferenceOptimizeRequest
|
||||
- ProviderInfo
|
||||
- QLoraFinetuningConfig
|
||||
- QueryDocumentsRequest
|
||||
- QueryDocumentsResponse
|
||||
|
@ -4334,6 +4442,7 @@ x-tagGroups:
|
|||
- RestAPIMethod
|
||||
- RewardScoreRequest
|
||||
- RewardScoringResponse
|
||||
- RouteInfo
|
||||
- RunShieldRequest
|
||||
- RunShieldResponse
|
||||
- SafetyViolation
|
||||
|
|
7
llama_stack/apis/inspect/__init__.py
Normal file
7
llama_stack/apis/inspect/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .inspect import * # noqa: F401 F403
|
82
llama_stack/apis/inspect/client.py
Normal file
82
llama_stack/apis/inspect/client.py
Normal file
|
@ -0,0 +1,82 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
|
||||
from typing import List
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
from termcolor import cprint
|
||||
|
||||
from .inspect import * # noqa: F403
|
||||
|
||||
|
||||
class InspectClient(Inspect):
|
||||
def __init__(self, base_url: str):
|
||||
self.base_url = base_url
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_providers(self) -> Dict[str, ProviderInfo]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/providers/list",
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
print(response.json())
|
||||
return {
|
||||
k: [ProviderInfo(**vi) for vi in v] for k, v in response.json().items()
|
||||
}
|
||||
|
||||
async def list_routes(self) -> Dict[str, List[RouteInfo]]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/routes/list",
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return {
|
||||
k: [RouteInfo(**vi) for vi in v] for k, v in response.json().items()
|
||||
}
|
||||
|
||||
async def health(self) -> HealthInfo:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/health",
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
j = response.json()
|
||||
if j is None:
|
||||
return None
|
||||
return HealthInfo(**j)
|
||||
|
||||
|
||||
async def run_main(host: str, port: int):
|
||||
client = InspectClient(f"http://{host}:{port}")
|
||||
|
||||
response = await client.list_providers()
|
||||
cprint(f"list_providers response={response}", "green")
|
||||
|
||||
response = await client.list_routes()
|
||||
cprint(f"list_routes response={response}", "blue")
|
||||
|
||||
response = await client.health()
|
||||
cprint(f"health response={response}", "yellow")
|
||||
|
||||
|
||||
def main(host: str, port: int):
|
||||
asyncio.run(run_main(host, port))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
40
llama_stack/apis/inspect/inspect.py
Normal file
40
llama_stack/apis/inspect/inspect.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict, List, Protocol
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ProviderInfo(BaseModel):
|
||||
provider_type: str
|
||||
description: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RouteInfo(BaseModel):
|
||||
route: str
|
||||
method: str
|
||||
providers: List[str]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class HealthInfo(BaseModel):
|
||||
status: str
|
||||
# TODO: add a provider level status
|
||||
|
||||
|
||||
class Inspect(Protocol):
|
||||
@webmethod(route="/providers/list", method="GET")
|
||||
async def list_providers(self) -> Dict[str, ProviderInfo]: ...
|
||||
|
||||
@webmethod(route="/routes/list", method="GET")
|
||||
async def list_routes(self) -> Dict[str, List[RouteInfo]]: ...
|
||||
|
||||
@webmethod(route="/health", method="GET")
|
||||
async def health(self) -> HealthInfo: ...
|
|
@ -17,6 +17,53 @@ LLAMA_STACK_BUILD_CONFIG_VERSION = "v1"
|
|||
LLAMA_STACK_RUN_CONFIG_VERSION = "v1"
|
||||
|
||||
|
||||
RoutingKey = Union[str, List[str]]
|
||||
|
||||
|
||||
class GenericProviderConfig(BaseModel):
|
||||
provider_type: str
|
||||
config: Dict[str, Any]
|
||||
|
||||
|
||||
class RoutableProviderConfig(GenericProviderConfig):
|
||||
routing_key: RoutingKey
|
||||
|
||||
|
||||
class PlaceholderProviderConfig(BaseModel):
|
||||
"""Placeholder provider config for API whose provider are defined in routing_table"""
|
||||
|
||||
providers: List[str]
|
||||
|
||||
|
||||
# Example: /inference, /safety
|
||||
class AutoRoutedProviderSpec(ProviderSpec):
|
||||
provider_type: str = "router"
|
||||
config_class: str = ""
|
||||
|
||||
docker_image: Optional[str] = None
|
||||
routing_table_api: Api
|
||||
module: str
|
||||
provider_data_validator: Optional[str] = Field(
|
||||
default=None,
|
||||
)
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> List[str]:
|
||||
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
|
||||
|
||||
|
||||
# Example: /models, /shields
|
||||
@json_schema_type
|
||||
class RoutingTableProviderSpec(ProviderSpec):
|
||||
provider_type: str = "routing_table"
|
||||
config_class: str = ""
|
||||
docker_image: Optional[str] = None
|
||||
|
||||
inner_specs: List[ProviderSpec]
|
||||
module: str
|
||||
pip_packages: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DistributionSpec(BaseModel):
|
||||
description: Optional[str] = Field(
|
||||
|
|
|
@ -46,6 +46,8 @@ def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
|
|||
for api in stack_apis():
|
||||
if api in routing_table_apis:
|
||||
continue
|
||||
if api == Api.inspect:
|
||||
continue
|
||||
|
||||
name = api.name.lower()
|
||||
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
||||
|
|
54
llama_stack/distribution/inspect.py
Normal file
54
llama_stack/distribution/inspect.py
Normal file
|
@ -0,0 +1,54 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict, List
|
||||
from llama_stack.apis.inspect import * # noqa: F403
|
||||
|
||||
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||
from llama_stack.providers.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
def is_passthrough(spec: ProviderSpec) -> bool:
|
||||
return isinstance(spec, RemoteProviderSpec) and spec.adapter is None
|
||||
|
||||
|
||||
class DistributionInspectImpl(Inspect):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def list_providers(self) -> Dict[str, List[ProviderInfo]]:
|
||||
ret = {}
|
||||
all_providers = get_provider_registry()
|
||||
for api, providers in all_providers.items():
|
||||
ret[api.value] = [
|
||||
ProviderInfo(
|
||||
provider_type=p.provider_type,
|
||||
description="Passthrough" if is_passthrough(p) else "",
|
||||
)
|
||||
for p in providers.values()
|
||||
]
|
||||
|
||||
return ret
|
||||
|
||||
async def list_routes(self) -> Dict[str, List[RouteInfo]]:
|
||||
ret = {}
|
||||
all_endpoints = get_all_api_endpoints()
|
||||
|
||||
for api, endpoints in all_endpoints.items():
|
||||
ret[api.value] = [
|
||||
RouteInfo(
|
||||
route=e.route,
|
||||
method=e.method,
|
||||
providers=[],
|
||||
)
|
||||
for e in endpoints
|
||||
]
|
||||
return ret
|
||||
|
||||
async def health(self) -> HealthInfo:
|
||||
return HealthInfo(status="OK")
|
|
@ -3,6 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import importlib
|
||||
|
||||
from typing import Any, Dict, List, Set
|
||||
|
||||
|
@ -11,7 +12,8 @@ from llama_stack.distribution.distribution import (
|
|||
builtin_automatically_routed_apis,
|
||||
get_provider_registry,
|
||||
)
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_provider
|
||||
from llama_stack.distribution.inspect import DistributionInspectImpl
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
|
||||
|
||||
async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]:
|
||||
|
@ -57,7 +59,6 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
|||
if info.router_api.value not in apis_to_serve:
|
||||
continue
|
||||
|
||||
print("router_api", info.router_api)
|
||||
if info.router_api.value not in run_config.routing_table:
|
||||
raise ValueError(f"Routing table for `{source_api.value}` is not provided?")
|
||||
|
||||
|
@ -104,6 +105,14 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
|||
|
||||
impls[api] = impl
|
||||
|
||||
impls[Api.inspect] = DistributionInspectImpl()
|
||||
specs[Api.inspect] = InlineProviderSpec(
|
||||
api=Api.inspect,
|
||||
provider_type="__distribution_builtin__",
|
||||
config_class="",
|
||||
module="",
|
||||
)
|
||||
|
||||
return impls, specs
|
||||
|
||||
|
||||
|
@ -127,3 +136,60 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
|
|||
dfs(a, visited, stack)
|
||||
|
||||
return [by_id[x] for x in stack]
|
||||
|
||||
|
||||
# returns a class implementing the protocol corresponding to the Api
|
||||
async def instantiate_provider(
|
||||
provider_spec: ProviderSpec,
|
||||
deps: Dict[str, Any],
|
||||
provider_config: Union[GenericProviderConfig, RoutingTable],
|
||||
):
|
||||
module = importlib.import_module(provider_spec.module)
|
||||
|
||||
args = []
|
||||
if isinstance(provider_spec, RemoteProviderSpec):
|
||||
if provider_spec.adapter:
|
||||
method = "get_adapter_impl"
|
||||
else:
|
||||
method = "get_client_impl"
|
||||
|
||||
assert isinstance(provider_config, GenericProviderConfig)
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
config = config_type(**provider_config.config)
|
||||
args = [config, deps]
|
||||
elif isinstance(provider_spec, AutoRoutedProviderSpec):
|
||||
method = "get_auto_router_impl"
|
||||
|
||||
config = None
|
||||
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps]
|
||||
elif isinstance(provider_spec, RoutingTableProviderSpec):
|
||||
method = "get_routing_table_impl"
|
||||
|
||||
assert isinstance(provider_config, List)
|
||||
routing_table = provider_config
|
||||
|
||||
inner_specs = {x.provider_type: x for x in provider_spec.inner_specs}
|
||||
inner_impls = []
|
||||
for routing_entry in routing_table:
|
||||
impl = await instantiate_provider(
|
||||
inner_specs[routing_entry.provider_type],
|
||||
deps,
|
||||
routing_entry,
|
||||
)
|
||||
inner_impls.append((routing_entry.routing_key, impl))
|
||||
|
||||
config = None
|
||||
args = [provider_spec.api, inner_impls, routing_table, deps]
|
||||
else:
|
||||
method = "get_provider_impl"
|
||||
|
||||
assert isinstance(provider_config, GenericProviderConfig)
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
config = config_type(**provider_config.config)
|
||||
args = [config, deps]
|
||||
|
||||
fn = getattr(module, method)
|
||||
impl = await fn(*args)
|
||||
impl.__provider_spec__ = provider_spec
|
||||
impl.__provider_config__ = config
|
||||
return impl
|
||||
|
|
|
@ -11,12 +11,14 @@ from pydantic import BaseModel
|
|||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.memory import Memory
|
||||
from llama_stack.apis.memory_banks import MemoryBanks
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
|
||||
|
@ -38,6 +40,7 @@ def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
|||
Api.models: Models,
|
||||
Api.shields: Shields,
|
||||
Api.memory_banks: MemoryBanks,
|
||||
Api.inspect: Inspect,
|
||||
}
|
||||
|
||||
for api, protocol in protocols.items():
|
||||
|
|
|
@ -15,7 +15,6 @@ from collections.abc import (
|
|||
AsyncIterator as AsyncIteratorABC,
|
||||
)
|
||||
from contextlib import asynccontextmanager
|
||||
from http import HTTPStatus
|
||||
from ssl import SSLError
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional
|
||||
|
||||
|
@ -26,7 +25,6 @@ import yaml
|
|||
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.routing import APIRoute
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
|
@ -287,15 +285,6 @@ def main(
|
|||
|
||||
app = FastAPI()
|
||||
|
||||
# Health check is added to enable deploying the docker container image on Kubernetes which require
|
||||
# a health check that can return 200 for readiness and liveness check
|
||||
class HealthCheck(BaseModel):
|
||||
status: str = "OK"
|
||||
|
||||
@app.get("/healthcheck", status_code=HTTPStatus.OK, response_model=HealthCheck)
|
||||
async def healthcheck():
|
||||
return HealthCheck(status="OK")
|
||||
|
||||
impls, specs = asyncio.run(resolve_impls_with_routing(config))
|
||||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
|
@ -307,6 +296,7 @@ def main(
|
|||
else:
|
||||
apis_to_serve = set(impls.keys())
|
||||
|
||||
apis_to_serve.add(Api.inspect)
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
|
||||
|
@ -340,14 +330,11 @@ def main(
|
|||
)
|
||||
)
|
||||
|
||||
for route in app.routes:
|
||||
if isinstance(route, APIRoute):
|
||||
cprint(
|
||||
f"Serving {next(iter(route.methods))} {route.path}",
|
||||
"white",
|
||||
attrs=["bold"],
|
||||
)
|
||||
cprint(f"Serving API {api_str}", "white", attrs=["bold"])
|
||||
for endpoint in endpoints:
|
||||
cprint(f" {endpoint.method.upper()} {endpoint.route}", "white")
|
||||
|
||||
print("")
|
||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||
app.exception_handler(Exception)(global_exception_handler)
|
||||
signal.signal(signal.SIGINT, handle_sigint)
|
||||
|
|
|
@ -5,69 +5,9 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import importlib
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
def instantiate_class_type(fully_qualified_name):
|
||||
module_name, class_name = fully_qualified_name.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
return getattr(module, class_name)
|
||||
|
||||
|
||||
# returns a class implementing the protocol corresponding to the Api
|
||||
async def instantiate_provider(
|
||||
provider_spec: ProviderSpec,
|
||||
deps: Dict[str, Any],
|
||||
provider_config: Union[GenericProviderConfig, RoutingTable],
|
||||
):
|
||||
module = importlib.import_module(provider_spec.module)
|
||||
|
||||
args = []
|
||||
if isinstance(provider_spec, RemoteProviderSpec):
|
||||
if provider_spec.adapter:
|
||||
method = "get_adapter_impl"
|
||||
else:
|
||||
method = "get_client_impl"
|
||||
|
||||
assert isinstance(provider_config, GenericProviderConfig)
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
config = config_type(**provider_config.config)
|
||||
args = [config, deps]
|
||||
elif isinstance(provider_spec, AutoRoutedProviderSpec):
|
||||
method = "get_auto_router_impl"
|
||||
|
||||
config = None
|
||||
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps]
|
||||
elif isinstance(provider_spec, RoutingTableProviderSpec):
|
||||
method = "get_routing_table_impl"
|
||||
|
||||
assert isinstance(provider_config, List)
|
||||
routing_table = provider_config
|
||||
|
||||
inner_specs = {x.provider_type: x for x in provider_spec.inner_specs}
|
||||
inner_impls = []
|
||||
for routing_entry in routing_table:
|
||||
impl = await instantiate_provider(
|
||||
inner_specs[routing_entry.provider_type],
|
||||
deps,
|
||||
routing_entry,
|
||||
)
|
||||
inner_impls.append((routing_entry.routing_key, impl))
|
||||
|
||||
config = None
|
||||
args = [provider_spec.api, inner_impls, routing_table, deps]
|
||||
else:
|
||||
method = "get_provider_impl"
|
||||
|
||||
assert isinstance(provider_config, GenericProviderConfig)
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
config = config_type(**provider_config.config)
|
||||
args = [config, deps]
|
||||
|
||||
fn = getattr(module, method)
|
||||
impl = await fn(*args)
|
||||
impl.__provider_spec__ = provider_spec
|
||||
impl.__provider_config__ = config
|
||||
return impl
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Protocol, Union
|
||||
from typing import Any, List, Optional, Protocol
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
@ -24,6 +24,9 @@ class Api(Enum):
|
|||
shields = "shields"
|
||||
memory_banks = "memory_banks"
|
||||
|
||||
# built-in API
|
||||
inspect = "inspect"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ProviderSpec(BaseModel):
|
||||
|
@ -55,68 +58,6 @@ class RoutableProvider(Protocol):
|
|||
async def validate_routing_keys(self, keys: List[str]) -> None: ...
|
||||
|
||||
|
||||
class GenericProviderConfig(BaseModel):
|
||||
provider_type: str
|
||||
config: Dict[str, Any]
|
||||
|
||||
|
||||
class PlaceholderProviderConfig(BaseModel):
|
||||
"""Placeholder provider config for API whose provider are defined in routing_table"""
|
||||
|
||||
providers: List[str]
|
||||
|
||||
|
||||
RoutingKey = Union[str, List[str]]
|
||||
|
||||
|
||||
class RoutableProviderConfig(GenericProviderConfig):
|
||||
routing_key: RoutingKey
|
||||
|
||||
|
||||
# Example: /inference, /safety
|
||||
@json_schema_type
|
||||
class AutoRoutedProviderSpec(ProviderSpec):
|
||||
provider_type: str = "router"
|
||||
config_class: str = ""
|
||||
|
||||
docker_image: Optional[str] = None
|
||||
routing_table_api: Api
|
||||
module: str = Field(
|
||||
...,
|
||||
description="""
|
||||
Fully-qualified name of the module to import. The module is expected to have:
|
||||
|
||||
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
|
||||
""",
|
||||
)
|
||||
provider_data_validator: Optional[str] = Field(
|
||||
default=None,
|
||||
)
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> List[str]:
|
||||
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
|
||||
|
||||
|
||||
# Example: /models, /shields
|
||||
@json_schema_type
|
||||
class RoutingTableProviderSpec(ProviderSpec):
|
||||
provider_type: str = "routing_table"
|
||||
config_class: str = ""
|
||||
docker_image: Optional[str] = None
|
||||
|
||||
inner_specs: List[ProviderSpec]
|
||||
module: str = Field(
|
||||
...,
|
||||
description="""
|
||||
Fully-qualified name of the module to import. The module is expected to have:
|
||||
|
||||
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
|
||||
""",
|
||||
)
|
||||
pip_packages: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AdapterSpec(BaseModel):
|
||||
adapter_type: str = Field(
|
||||
|
@ -179,10 +120,6 @@ class RemoteProviderConfig(BaseModel):
|
|||
return f"http://{self.host}:{self.port}"
|
||||
|
||||
|
||||
def remote_provider_type(adapter_type: str) -> str:
|
||||
return f"remote::{adapter_type}"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RemoteProviderSpec(ProviderSpec):
|
||||
adapter: Optional[AdapterSpec] = Field(
|
||||
|
@ -226,7 +163,7 @@ def remote_provider_spec(
|
|||
if adapter and adapter.config_class
|
||||
else "llama_stack.distribution.datatypes.RemoteProviderConfig"
|
||||
)
|
||||
provider_type = remote_provider_type(adapter.adapter_type) if adapter else "remote"
|
||||
provider_type = f"remote::{adapter.adapter_type}" if adapter else "remote"
|
||||
|
||||
return RemoteProviderSpec(
|
||||
api=api, provider_type=provider_type, config_class=config_class, adapter=adapter
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue