diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index c00ea3040..1c85436c4 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -51,3 +51,9 @@ repos:
# hooks:
# - id: pydoclint
# args: [--config=pyproject.toml]
+
+# - repo: https://github.com/tcort/markdown-link-check
+# rev: v3.11.2
+# hooks:
+# - id: markdown-link-check
+# args: ['--quiet']
diff --git a/README.md b/README.md
index 936876708..a5172ce5c 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,8 @@
# Llama Stack
+[](https://pypi.org/project/llama_stack/)
[](https://pypi.org/project/llama-stack/)
-[](https://discord.gg/TZAAYNVtrU)
+[](https://discord.gg/llama-stack)
This repository contains the Llama Stack API specifications as well as API Providers and Llama Stack Distributions.
diff --git a/docs/cli_reference.md b/docs/cli_reference.md
index 28874641f..3541d0b4e 100644
--- a/docs/cli_reference.md
+++ b/docs/cli_reference.md
@@ -5,7 +5,7 @@ The `llama` CLI tool helps you setup and use the Llama toolchain & agentic syste
### Subcommands
1. `download`: `llama` cli tools supports downloading the model from Meta or Hugging Face.
2. `model`: Lists available models and their properties.
-3. `stack`: Allows you to build and run a Llama Stack server. You can read more about this [here](/docs/cli_reference.md#step-3-building-configuring-and-running-llama-stack-servers).
+3. `stack`: Allows you to build and run a Llama Stack server. You can read more about this [here](cli_reference.md#step-3-building-and-configuring-llama-stack-distributions).
### Sample Usage
diff --git a/docs/openapi_generator/generate.py b/docs/openapi_generator/generate.py
index c5ba23b14..c5b156bb8 100644
--- a/docs/openapi_generator/generate.py
+++ b/docs/openapi_generator/generate.py
@@ -46,6 +46,7 @@ from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
+from llama_stack.apis.inspect import * # noqa: F403
class LlamaStack(
@@ -63,6 +64,7 @@ class LlamaStack(
Evaluations,
Models,
Shields,
+ Inspect,
):
pass
diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html
index c77ebe2a7..0d06ce03d 100644
--- a/docs/resources/llama-stack-spec.html
+++ b/docs/resources/llama-stack-spec.html
@@ -21,7 +21,7 @@
"info": {
"title": "[DRAFT] Llama Stack Specification",
"version": "0.0.1",
- "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-09-23 16:58:41.469308"
+ "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-02 15:40:53.008257"
},
"servers": [
{
@@ -1542,6 +1542,36 @@
]
}
},
+ "/health": {
+ "get": {
+ "responses": {
+ "200": {
+ "description": "OK",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/HealthInfo"
+ }
+ }
+ }
+ }
+ },
+ "tags": [
+ "Inspect"
+ ],
+ "parameters": [
+ {
+ "name": "X-LlamaStack-ProviderData",
+ "in": "header",
+ "description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
+ "required": false,
+ "schema": {
+ "type": "string"
+ }
+ }
+ ]
+ }
+ },
"/memory/insert": {
"post": {
"responses": {
@@ -1665,6 +1695,75 @@
]
}
},
+ "/providers/list": {
+ "get": {
+ "responses": {
+ "200": {
+ "description": "OK",
+ "content": {
+ "application/json": {
+ "schema": {
+ "type": "object",
+ "additionalProperties": {
+ "$ref": "#/components/schemas/ProviderInfo"
+ }
+ }
+ }
+ }
+ }
+ },
+ "tags": [
+ "Inspect"
+ ],
+ "parameters": [
+ {
+ "name": "X-LlamaStack-ProviderData",
+ "in": "header",
+ "description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
+ "required": false,
+ "schema": {
+ "type": "string"
+ }
+ }
+ ]
+ }
+ },
+ "/routes/list": {
+ "get": {
+ "responses": {
+ "200": {
+ "description": "OK",
+ "content": {
+ "application/json": {
+ "schema": {
+ "type": "object",
+ "additionalProperties": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/RouteInfo"
+ }
+ }
+ }
+ }
+ }
+ }
+ },
+ "tags": [
+ "Inspect"
+ ],
+ "parameters": [
+ {
+ "name": "X-LlamaStack-ProviderData",
+ "in": "header",
+ "description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
+ "required": false,
+ "schema": {
+ "type": "string"
+ }
+ }
+ ]
+ }
+ },
"/shields/list": {
"get": {
"responses": {
@@ -4783,7 +4882,7 @@
"provider_config": {
"type": "object",
"properties": {
- "provider_id": {
+ "provider_type": {
"type": "string"
},
"config": {
@@ -4814,7 +4913,7 @@
},
"additionalProperties": false,
"required": [
- "provider_id",
+ "provider_type",
"config"
]
}
@@ -4843,7 +4942,7 @@
"provider_config": {
"type": "object",
"properties": {
- "provider_id": {
+ "provider_type": {
"type": "string"
},
"config": {
@@ -4874,7 +4973,7 @@
},
"additionalProperties": false,
"required": [
- "provider_id",
+ "provider_type",
"config"
]
}
@@ -4894,7 +4993,7 @@
"provider_config": {
"type": "object",
"properties": {
- "provider_id": {
+ "provider_type": {
"type": "string"
},
"config": {
@@ -4925,7 +5024,7 @@
},
"additionalProperties": false,
"required": [
- "provider_id",
+ "provider_type",
"config"
]
}
@@ -5086,6 +5185,18 @@
"job_uuid"
]
},
+ "HealthInfo": {
+ "type": "object",
+ "properties": {
+ "status": {
+ "type": "string"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "status"
+ ]
+ },
"InsertDocumentsRequest": {
"type": "object",
"properties": {
@@ -5108,6 +5219,45 @@
"documents"
]
},
+ "ProviderInfo": {
+ "type": "object",
+ "properties": {
+ "provider_type": {
+ "type": "string"
+ },
+ "description": {
+ "type": "string"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "provider_type",
+ "description"
+ ]
+ },
+ "RouteInfo": {
+ "type": "object",
+ "properties": {
+ "route": {
+ "type": "string"
+ },
+ "method": {
+ "type": "string"
+ },
+ "providers": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "route",
+ "method",
+ "providers"
+ ]
+ },
"LogSeverity": {
"type": "string",
"enum": [
@@ -6220,19 +6370,34 @@
],
"tags": [
{
- "name": "Shields"
+ "name": "Datasets"
+ },
+ {
+ "name": "Inspect"
+ },
+ {
+ "name": "Memory"
},
{
"name": "BatchInference"
},
{
- "name": "RewardScoring"
+ "name": "Agents"
+ },
+ {
+ "name": "Inference"
+ },
+ {
+ "name": "Shields"
},
{
"name": "SyntheticDataGeneration"
},
{
- "name": "Agents"
+ "name": "Models"
+ },
+ {
+ "name": "RewardScoring"
},
{
"name": "MemoryBanks"
@@ -6241,13 +6406,7 @@
"name": "Safety"
},
{
- "name": "Models"
- },
- {
- "name": "Inference"
- },
- {
- "name": "Memory"
+ "name": "Evaluations"
},
{
"name": "Telemetry"
@@ -6255,12 +6414,6 @@
{
"name": "PostTraining"
},
- {
- "name": "Datasets"
- },
- {
- "name": "Evaluations"
- },
{
"name": "BuiltinTool",
"description": ""
@@ -6653,10 +6806,22 @@
"name": "PostTrainingJob",
"description": ""
},
+ {
+ "name": "HealthInfo",
+ "description": ""
+ },
{
"name": "InsertDocumentsRequest",
"description": ""
},
+ {
+ "name": "ProviderInfo",
+ "description": ""
+ },
+ {
+ "name": "RouteInfo",
+ "description": ""
+ },
{
"name": "LogSeverity",
"description": ""
@@ -6787,6 +6952,7 @@
"Datasets",
"Evaluations",
"Inference",
+ "Inspect",
"Memory",
"MemoryBanks",
"Models",
@@ -6857,6 +7023,7 @@
"FunctionCallToolDefinition",
"GetAgentsSessionRequest",
"GetDocumentsRequest",
+ "HealthInfo",
"ImageMedia",
"InferenceStep",
"InsertDocumentsRequest",
@@ -6880,6 +7047,7 @@
"PostTrainingJobStatus",
"PostTrainingJobStatusResponse",
"PreferenceOptimizeRequest",
+ "ProviderInfo",
"QLoraFinetuningConfig",
"QueryDocumentsRequest",
"QueryDocumentsResponse",
@@ -6888,6 +7056,7 @@
"RestAPIMethod",
"RewardScoreRequest",
"RewardScoringResponse",
+ "RouteInfo",
"RunShieldRequest",
"RunShieldResponse",
"SafetyViolation",
diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml
index 83b415649..317d1ee33 100644
--- a/docs/resources/llama-stack-spec.yaml
+++ b/docs/resources/llama-stack-spec.yaml
@@ -908,6 +908,14 @@ components:
required:
- document_ids
type: object
+ HealthInfo:
+ additionalProperties: false
+ properties:
+ status:
+ type: string
+ required:
+ - status
+ type: object
ImageMedia:
additionalProperties: false
properties:
@@ -1117,10 +1125,10 @@ components:
- type: array
- type: object
type: object
- provider_id:
+ provider_type:
type: string
required:
- - provider_id
+ - provider_type
- config
type: object
required:
@@ -1362,10 +1370,10 @@ components:
- type: array
- type: object
type: object
- provider_id:
+ provider_type:
type: string
required:
- - provider_id
+ - provider_type
- config
type: object
required:
@@ -1543,6 +1551,17 @@ components:
- hyperparam_search_config
- logger_config
type: object
+ ProviderInfo:
+ additionalProperties: false
+ properties:
+ description:
+ type: string
+ provider_type:
+ type: string
+ required:
+ - provider_type
+ - description
+ type: object
QLoraFinetuningConfig:
additionalProperties: false
properties:
@@ -1704,6 +1723,22 @@ components:
title: Response from the reward scoring. Batch of (prompt, response, score)
tuples that pass the threshold.
type: object
+ RouteInfo:
+ additionalProperties: false
+ properties:
+ method:
+ type: string
+ providers:
+ items:
+ type: string
+ type: array
+ route:
+ type: string
+ required:
+ - route
+ - method
+ - providers
+ type: object
RunShieldRequest:
additionalProperties: false
properties:
@@ -1916,10 +1951,10 @@ components:
- type: array
- type: object
type: object
- provider_id:
+ provider_type:
type: string
required:
- - provider_id
+ - provider_type
- config
type: object
shield_type:
@@ -2569,7 +2604,7 @@ info:
description: "This is the specification of the llama stack that provides\n \
\ a set of endpoints and their corresponding interfaces that are tailored\
\ to\n best leverage Llama Models. The specification is still in\
- \ draft and subject to change.\n Generated at 2024-09-23 16:58:41.469308"
+ \ draft and subject to change.\n Generated at 2024-10-02 15:40:53.008257"
title: '[DRAFT] Llama Stack Specification'
version: 0.0.1
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
@@ -3093,6 +3128,25 @@ paths:
description: OK
tags:
- Evaluations
+ /health:
+ get:
+ parameters:
+ - description: JSON-encoded provider data which will be made available to the
+ adapter servicing the API
+ in: header
+ name: X-LlamaStack-ProviderData
+ required: false
+ schema:
+ type: string
+ responses:
+ '200':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HealthInfo'
+ description: OK
+ tags:
+ - Inspect
/inference/chat_completion:
post:
parameters:
@@ -3637,6 +3691,27 @@ paths:
description: OK
tags:
- PostTraining
+ /providers/list:
+ get:
+ parameters:
+ - description: JSON-encoded provider data which will be made available to the
+ adapter servicing the API
+ in: header
+ name: X-LlamaStack-ProviderData
+ required: false
+ schema:
+ type: string
+ responses:
+ '200':
+ content:
+ application/json:
+ schema:
+ additionalProperties:
+ $ref: '#/components/schemas/ProviderInfo'
+ type: object
+ description: OK
+ tags:
+ - Inspect
/reward_scoring/score:
post:
parameters:
@@ -3662,6 +3737,29 @@ paths:
description: OK
tags:
- RewardScoring
+ /routes/list:
+ get:
+ parameters:
+ - description: JSON-encoded provider data which will be made available to the
+ adapter servicing the API
+ in: header
+ name: X-LlamaStack-ProviderData
+ required: false
+ schema:
+ type: string
+ responses:
+ '200':
+ content:
+ application/json:
+ schema:
+ additionalProperties:
+ items:
+ $ref: '#/components/schemas/RouteInfo'
+ type: array
+ type: object
+ description: OK
+ tags:
+ - Inspect
/safety/run_shield:
post:
parameters:
@@ -3807,20 +3905,21 @@ security:
servers:
- url: http://any-hosted-llama-stack.com
tags:
-- name: Shields
+- name: Datasets
+- name: Inspect
+- name: Memory
- name: BatchInference
-- name: RewardScoring
-- name: SyntheticDataGeneration
- name: Agents
+- name: Inference
+- name: Shields
+- name: SyntheticDataGeneration
+- name: Models
+- name: RewardScoring
- name: MemoryBanks
- name: Safety
-- name: Models
-- name: Inference
-- name: Memory
+- name: Evaluations
- name: Telemetry
- name: PostTraining
-- name: Datasets
-- name: Evaluations
- description:
name: BuiltinTool
- description:
name: PostTrainingJob
+- description:
+ name: HealthInfo
- description:
name: InsertDocumentsRequest
+- description:
+ name: ProviderInfo
+- description:
+ name: RouteInfo
- description:
name: LogSeverity
- description:
@@ -4236,6 +4341,7 @@ x-tagGroups:
- Datasets
- Evaluations
- Inference
+ - Inspect
- Memory
- MemoryBanks
- Models
@@ -4303,6 +4409,7 @@ x-tagGroups:
- FunctionCallToolDefinition
- GetAgentsSessionRequest
- GetDocumentsRequest
+ - HealthInfo
- ImageMedia
- InferenceStep
- InsertDocumentsRequest
@@ -4326,6 +4433,7 @@ x-tagGroups:
- PostTrainingJobStatus
- PostTrainingJobStatusResponse
- PreferenceOptimizeRequest
+ - ProviderInfo
- QLoraFinetuningConfig
- QueryDocumentsRequest
- QueryDocumentsResponse
@@ -4334,6 +4442,7 @@ x-tagGroups:
- RestAPIMethod
- RewardScoreRequest
- RewardScoringResponse
+ - RouteInfo
- RunShieldRequest
- RunShieldResponse
- SafetyViolation
diff --git a/llama_stack/apis/inspect/__init__.py b/llama_stack/apis/inspect/__init__.py
new file mode 100644
index 000000000..88ba8e908
--- /dev/null
+++ b/llama_stack/apis/inspect/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from .inspect import * # noqa: F401 F403
diff --git a/llama_stack/apis/inspect/client.py b/llama_stack/apis/inspect/client.py
new file mode 100644
index 000000000..65d8b83ed
--- /dev/null
+++ b/llama_stack/apis/inspect/client.py
@@ -0,0 +1,82 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+import asyncio
+
+from typing import List
+
+import fire
+import httpx
+from termcolor import cprint
+
+from .inspect import * # noqa: F403
+
+
+class InspectClient(Inspect):
+ def __init__(self, base_url: str):
+ self.base_url = base_url
+
+ async def initialize(self) -> None:
+ pass
+
+ async def shutdown(self) -> None:
+ pass
+
+ async def list_providers(self) -> Dict[str, ProviderInfo]:
+ async with httpx.AsyncClient() as client:
+ response = await client.get(
+ f"{self.base_url}/providers/list",
+ headers={"Content-Type": "application/json"},
+ )
+ response.raise_for_status()
+ print(response.json())
+ return {
+ k: [ProviderInfo(**vi) for vi in v] for k, v in response.json().items()
+ }
+
+ async def list_routes(self) -> Dict[str, List[RouteInfo]]:
+ async with httpx.AsyncClient() as client:
+ response = await client.get(
+ f"{self.base_url}/routes/list",
+ headers={"Content-Type": "application/json"},
+ )
+ response.raise_for_status()
+ return {
+ k: [RouteInfo(**vi) for vi in v] for k, v in response.json().items()
+ }
+
+ async def health(self) -> HealthInfo:
+ async with httpx.AsyncClient() as client:
+ response = await client.get(
+ f"{self.base_url}/health",
+ headers={"Content-Type": "application/json"},
+ )
+ response.raise_for_status()
+ j = response.json()
+ if j is None:
+ return None
+ return HealthInfo(**j)
+
+
+async def run_main(host: str, port: int):
+ client = InspectClient(f"http://{host}:{port}")
+
+ response = await client.list_providers()
+ cprint(f"list_providers response={response}", "green")
+
+ response = await client.list_routes()
+ cprint(f"list_routes response={response}", "blue")
+
+ response = await client.health()
+ cprint(f"health response={response}", "yellow")
+
+
+def main(host: str, port: int):
+ asyncio.run(run_main(host, port))
+
+
+if __name__ == "__main__":
+ fire.Fire(main)
diff --git a/llama_stack/apis/inspect/inspect.py b/llama_stack/apis/inspect/inspect.py
new file mode 100644
index 000000000..ca444098c
--- /dev/null
+++ b/llama_stack/apis/inspect/inspect.py
@@ -0,0 +1,40 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from typing import Dict, List, Protocol
+
+from llama_models.schema_utils import json_schema_type, webmethod
+from pydantic import BaseModel
+
+
+@json_schema_type
+class ProviderInfo(BaseModel):
+ provider_type: str
+ description: str
+
+
+@json_schema_type
+class RouteInfo(BaseModel):
+ route: str
+ method: str
+ providers: List[str]
+
+
+@json_schema_type
+class HealthInfo(BaseModel):
+ status: str
+ # TODO: add a provider level status
+
+
+class Inspect(Protocol):
+ @webmethod(route="/providers/list", method="GET")
+ async def list_providers(self) -> Dict[str, ProviderInfo]: ...
+
+ @webmethod(route="/routes/list", method="GET")
+ async def list_routes(self) -> Dict[str, List[RouteInfo]]: ...
+
+ @webmethod(route="/health", method="GET")
+ async def health(self) -> HealthInfo: ...
diff --git a/llama_stack/apis/memory_banks/memory_banks.py b/llama_stack/apis/memory_banks/memory_banks.py
index b4e35fb0c..53ca83e84 100644
--- a/llama_stack/apis/memory_banks/memory_banks.py
+++ b/llama_stack/apis/memory_banks/memory_banks.py
@@ -18,7 +18,7 @@ from llama_stack.distribution.datatypes import GenericProviderConfig
class MemoryBankSpec(BaseModel):
bank_type: MemoryBankType
provider_config: GenericProviderConfig = Field(
- description="Provider config for the model, including provider_id, and corresponding config. ",
+ description="Provider config for the model, including provider_type, and corresponding config. ",
)
diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py
index d542517ba..2952a8dee 100644
--- a/llama_stack/apis/models/models.py
+++ b/llama_stack/apis/models/models.py
@@ -20,7 +20,7 @@ class ModelServingSpec(BaseModel):
description="All metadatas associated with llama model (defined in llama_models.models.sku_list).",
)
provider_config: GenericProviderConfig = Field(
- description="Provider config for the model, including provider_id, and corresponding config. ",
+ description="Provider config for the model, including provider_type, and corresponding config. ",
)
diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py
index 006178b5d..2b8242263 100644
--- a/llama_stack/apis/shields/shields.py
+++ b/llama_stack/apis/shields/shields.py
@@ -16,7 +16,7 @@ from llama_stack.distribution.datatypes import GenericProviderConfig
class ShieldSpec(BaseModel):
shield_type: str
provider_config: GenericProviderConfig = Field(
- description="Provider config for the model, including provider_id, and corresponding config. ",
+ description="Provider config for the model, including provider_type, and corresponding config. ",
)
diff --git a/llama_stack/cli/model/prompt_format.py b/llama_stack/cli/model/prompt_format.py
index e6fd8aac7..67f456175 100644
--- a/llama_stack/cli/model/prompt_format.py
+++ b/llama_stack/cli/model/prompt_format.py
@@ -5,7 +5,6 @@
# the root directory of this source tree.
import argparse
-import subprocess
import textwrap
from io import StringIO
@@ -110,7 +109,4 @@ def render_markdown_to_pager(markdown_content: str):
console = Console(file=output, force_terminal=True, width=100) # Set a fixed width
console.print(md)
rendered_content = output.getvalue()
-
- # Pipe to pager
- pager = subprocess.Popen(["less", "-R"], stdin=subprocess.PIPE)
- pager.communicate(input=rendered_content.encode())
+ print(rendered_content)
diff --git a/llama_stack/cli/stack/build.py b/llama_stack/cli/stack/build.py
index 0324b068a..cfb296f2f 100644
--- a/llama_stack/cli/stack/build.py
+++ b/llama_stack/cli/stack/build.py
@@ -179,12 +179,7 @@ class StackBuild(Subcommand):
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
import yaml
- from llama_stack.distribution.distribution import (
- Api,
- api_providers,
- builtin_automatically_routed_apis,
- )
- from llama_stack.distribution.utils.dynamic import instantiate_class_type
+ from llama_stack.distribution.distribution import get_provider_registry
from prompt_toolkit import prompt
from prompt_toolkit.validation import Validator
from termcolor import cprint
@@ -249,22 +244,12 @@ class StackBuild(Subcommand):
)
cprint(
- f"\n Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs.",
+ "\n Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs.",
color="green",
)
providers = dict()
- all_providers = api_providers()
- routing_table_apis = set(
- x.routing_table_api for x in builtin_automatically_routed_apis()
- )
-
- for api in Api:
- if api in routing_table_apis:
- continue
-
- providers_for_api = all_providers[api]
-
+ for api, providers_for_api in get_provider_registry().items():
api_provider = prompt(
"> Enter provider for the {} API: (default=meta-reference): ".format(
api.value
diff --git a/llama_stack/cli/stack/list_providers.py b/llama_stack/cli/stack/list_providers.py
index 18c4de201..96e978826 100644
--- a/llama_stack/cli/stack/list_providers.py
+++ b/llama_stack/cli/stack/list_providers.py
@@ -34,9 +34,9 @@ class StackListProviders(Subcommand):
def _run_providers_list_cmd(self, args: argparse.Namespace) -> None:
from llama_stack.cli.table import print_table
- from llama_stack.distribution.distribution import Api, api_providers
+ from llama_stack.distribution.distribution import Api, get_provider_registry
- all_providers = api_providers()
+ all_providers = get_provider_registry()
providers_for_api = all_providers[Api(args.api)]
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
@@ -47,11 +47,11 @@ class StackListProviders(Subcommand):
rows = []
for spec in providers_for_api.values():
- if spec.provider_id == "sample":
+ if spec.provider_type == "sample":
continue
rows.append(
[
- spec.provider_id,
+ spec.provider_type,
",".join(spec.pip_packages),
]
)
diff --git a/llama_stack/distribution/build.py b/llama_stack/distribution/build.py
index b616fcf6d..56186a5aa 100644
--- a/llama_stack/distribution/build.py
+++ b/llama_stack/distribution/build.py
@@ -19,6 +19,17 @@ from pathlib import Path
from llama_stack.distribution.distribution import api_providers, SERVER_DEPENDENCIES
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
+from llama_stack.distribution.distribution import get_provider_registry
+
+
+# These are the dependencies needed by the distribution server.
+# `llama-stack` is automatically installed by the installation script.
+SERVER_DEPENDENCIES = [
+ "fastapi",
+ "fire",
+ "httpx",
+ "uvicorn",
+]
class ImageType(Enum):
@@ -43,7 +54,7 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
)
# extend package dependencies based on providers spec
- all_providers = api_providers()
+ all_providers = get_provider_registry()
for (
api_str,
provider_or_providers,
diff --git a/llama_stack/distribution/configure.py b/llama_stack/distribution/configure.py
index d3b807d4a..e03b201ec 100644
--- a/llama_stack/distribution/configure.py
+++ b/llama_stack/distribution/configure.py
@@ -15,8 +15,8 @@ from termcolor import cprint
from llama_stack.apis.memory.memory import MemoryBankType
from llama_stack.distribution.distribution import (
- api_providers,
builtin_automatically_routed_apis,
+ get_provider_registry,
stack_apis,
)
from llama_stack.distribution.utils.dynamic import instantiate_class_type
@@ -62,7 +62,7 @@ def configure_api_providers(
config.apis_to_serve = list(set([a for a in apis if a != "telemetry"]))
apis = [v.value for v in stack_apis()]
- all_providers = api_providers()
+ all_providers = get_provider_registry()
# configure simple case for with non-routing providers to api_providers
for api_str in spec.providers.keys():
@@ -109,7 +109,7 @@ def configure_api_providers(
routing_entries.append(
RoutableProviderConfig(
routing_key=routing_key,
- provider_id=p,
+ provider_type=p,
config=cfg.dict(),
)
)
@@ -120,7 +120,7 @@ def configure_api_providers(
routing_entries.append(
RoutableProviderConfig(
routing_key=[s.value for s in MetaReferenceShieldType],
- provider_id=p,
+ provider_type=p,
config=cfg.dict(),
)
)
@@ -133,7 +133,7 @@ def configure_api_providers(
routing_entries.append(
RoutableProviderConfig(
routing_key=routing_key,
- provider_id=p,
+ provider_type=p,
config=cfg.dict(),
)
)
@@ -153,7 +153,7 @@ def configure_api_providers(
routing_entries.append(
RoutableProviderConfig(
routing_key=routing_key,
- provider_id=p,
+ provider_type=p,
config=cfg.dict(),
)
)
@@ -164,7 +164,7 @@ def configure_api_providers(
)
else:
config.api_providers[api_str] = GenericProviderConfig(
- provider_id=p,
+ provider_type=p,
config=cfg.dict(),
)
diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py
index fa88ad5cf..2be6ede26 100644
--- a/llama_stack/distribution/datatypes.py
+++ b/llama_stack/distribution/datatypes.py
@@ -17,6 +17,53 @@ LLAMA_STACK_BUILD_CONFIG_VERSION = "v1"
LLAMA_STACK_RUN_CONFIG_VERSION = "v1"
+RoutingKey = Union[str, List[str]]
+
+
+class GenericProviderConfig(BaseModel):
+ provider_type: str
+ config: Dict[str, Any]
+
+
+class RoutableProviderConfig(GenericProviderConfig):
+ routing_key: RoutingKey
+
+
+class PlaceholderProviderConfig(BaseModel):
+ """Placeholder provider config for API whose provider are defined in routing_table"""
+
+ providers: List[str]
+
+
+# Example: /inference, /safety
+class AutoRoutedProviderSpec(ProviderSpec):
+ provider_type: str = "router"
+ config_class: str = ""
+
+ docker_image: Optional[str] = None
+ routing_table_api: Api
+ module: str
+ provider_data_validator: Optional[str] = Field(
+ default=None,
+ )
+
+ @property
+ def pip_packages(self) -> List[str]:
+ raise AssertionError("Should not be called on AutoRoutedProviderSpec")
+
+
+# Example: /models, /shields
+@json_schema_type
+class RoutingTableProviderSpec(ProviderSpec):
+ provider_type: str = "routing_table"
+ config_class: str = ""
+ docker_image: Optional[str] = None
+
+ inner_specs: List[ProviderSpec]
+ module: str
+ pip_packages: List[str] = Field(default_factory=list)
+
+
@json_schema_type
class DistributionSpec(BaseModel):
description: Optional[str] = Field(
@@ -71,7 +118,7 @@ Provider configurations for each of the APIs provided by this package.
E.g. The following is a ProviderRoutingEntry for models:
- routing_key: Meta-Llama3.1-8B-Instruct
- provider_id: meta-reference
+ provider_type: meta-reference
config:
model: Meta-Llama3.1-8B-Instruct
quantization: null
diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py
index 035febb80..999646cc0 100644
--- a/llama_stack/distribution/distribution.py
+++ b/llama_stack/distribution/distribution.py
@@ -5,30 +5,11 @@
# the root directory of this source tree.
import importlib
-import inspect
from typing import Dict, List
from pydantic import BaseModel
-from llama_stack.apis.agents import Agents
-from llama_stack.apis.inference import Inference
-from llama_stack.apis.memory import Memory
-from llama_stack.apis.memory_banks import MemoryBanks
-from llama_stack.apis.models import Models
-from llama_stack.apis.safety import Safety
-from llama_stack.apis.shields import Shields
-from llama_stack.apis.telemetry import Telemetry
-
-from .datatypes import Api, ApiEndpoint, ProviderSpec, remote_provider_spec
-
-# These are the dependencies needed by the distribution server.
-# `llama-stack` is automatically installed by the installation script.
-SERVER_DEPENDENCIES = [
- "fastapi",
- "fire",
- "httpx",
- "uvicorn",
-]
+from llama_stack.providers.datatypes import Api, ProviderSpec, remote_provider_spec
def stack_apis() -> List[Api]:
@@ -57,58 +38,21 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
]
-def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
- apis = {}
-
- protocols = {
- Api.inference: Inference,
- Api.safety: Safety,
- Api.agents: Agents,
- Api.memory: Memory,
- Api.telemetry: Telemetry,
- Api.models: Models,
- Api.shields: Shields,
- Api.memory_banks: MemoryBanks,
- }
-
- for api, protocol in protocols.items():
- endpoints = []
- protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
-
- for name, method in protocol_methods:
- if not hasattr(method, "__webmethod__"):
- continue
-
- webmethod = method.__webmethod__
- route = webmethod.route
-
- if webmethod.method == "GET":
- method = "get"
- elif webmethod.method == "DELETE":
- method = "delete"
- else:
- method = "post"
- endpoints.append(ApiEndpoint(route=route, method=method, name=name))
-
- apis[api] = endpoints
-
- return apis
-
-
-def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
- ret = {}
+def providable_apis() -> List[Api]:
routing_table_apis = set(
x.routing_table_api for x in builtin_automatically_routed_apis()
)
- for api in stack_apis():
- if api in routing_table_apis:
- continue
+ return [api for api in Api if api not in routing_table_apis and api != Api.inspect]
+
+def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
+ ret = {}
+ for api in providable_apis():
name = api.name.lower()
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
ret[api] = {
"remote": remote_provider_spec(api),
- **{a.provider_id: a for a in module.available_providers()},
+ **{a.provider_type: a for a in module.available_providers()},
}
return ret
diff --git a/llama_stack/distribution/inspect.py b/llama_stack/distribution/inspect.py
new file mode 100644
index 000000000..acd7ab7f8
--- /dev/null
+++ b/llama_stack/distribution/inspect.py
@@ -0,0 +1,54 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from typing import Dict, List
+from llama_stack.apis.inspect import * # noqa: F403
+
+
+from llama_stack.distribution.distribution import get_provider_registry
+from llama_stack.distribution.server.endpoints import get_all_api_endpoints
+from llama_stack.providers.datatypes import * # noqa: F403
+
+
+def is_passthrough(spec: ProviderSpec) -> bool:
+ return isinstance(spec, RemoteProviderSpec) and spec.adapter is None
+
+
+class DistributionInspectImpl(Inspect):
+ def __init__(self):
+ pass
+
+ async def list_providers(self) -> Dict[str, List[ProviderInfo]]:
+ ret = {}
+ all_providers = get_provider_registry()
+ for api, providers in all_providers.items():
+ ret[api.value] = [
+ ProviderInfo(
+ provider_type=p.provider_type,
+ description="Passthrough" if is_passthrough(p) else "",
+ )
+ for p in providers.values()
+ ]
+
+ return ret
+
+ async def list_routes(self) -> Dict[str, List[RouteInfo]]:
+ ret = {}
+ all_endpoints = get_all_api_endpoints()
+
+ for api, endpoints in all_endpoints.items():
+ ret[api.value] = [
+ RouteInfo(
+ route=e.route,
+ method=e.method,
+ providers=[],
+ )
+ for e in endpoints
+ ]
+ return ret
+
+ async def health(self) -> HealthInfo:
+ return HealthInfo(status="OK")
diff --git a/llama_stack/distribution/request_headers.py b/llama_stack/distribution/request_headers.py
index 990fa66d5..bbb1fff9d 100644
--- a/llama_stack/distribution/request_headers.py
+++ b/llama_stack/distribution/request_headers.py
@@ -18,10 +18,10 @@ class NeedsRequestProviderData:
spec = self.__provider_spec__
assert spec, f"Provider spec not set on {self.__class__}"
- provider_id = spec.provider_id
+ provider_type = spec.provider_type
validator_class = spec.provider_data_validator
if not validator_class:
- raise ValueError(f"Provider {provider_id} does not have a validator")
+ raise ValueError(f"Provider {provider_type} does not have a validator")
val = getattr(_THREAD_LOCAL, "provider_data_header_value", None)
if not val:
diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py
index f7d51c64a..ae7d9ab40 100644
--- a/llama_stack/distribution/resolver.py
+++ b/llama_stack/distribution/resolver.py
@@ -3,15 +3,17 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
+import importlib
from typing import Any, Dict, List, Set
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import (
- api_providers,
builtin_automatically_routed_apis,
+ get_provider_registry,
)
-from llama_stack.distribution.utils.dynamic import instantiate_provider
+from llama_stack.distribution.inspect import DistributionInspectImpl
+from llama_stack.distribution.utils.dynamic import instantiate_class_type
async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]:
@@ -20,7 +22,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
- flatmaps, sorts and resolves the providers in dependency order
- for each API, produces either a (local, passthrough or router) implementation
"""
- all_providers = api_providers()
+ all_providers = get_provider_registry()
specs = {}
configs = {}
@@ -34,11 +36,11 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
if isinstance(config, PlaceholderProviderConfig):
continue
- if config.provider_id not in providers:
+ if config.provider_type not in providers:
raise ValueError(
- f"Unknown provider `{config.provider_id}` is not available for API `{api}`"
+ f"Provider `{config.provider_type}` is not available for API `{api}`"
)
- specs[api] = providers[config.provider_id]
+ specs[api] = providers[config.provider_type]
configs[api] = config
apis_to_serve = run_config.apis_to_serve or set(
@@ -57,7 +59,6 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
if info.router_api.value not in apis_to_serve:
continue
- print("router_api", info.router_api)
if info.router_api.value not in run_config.routing_table:
raise ValueError(f"Routing table for `{source_api.value}` is not provided?")
@@ -68,12 +69,12 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
inner_specs = []
inner_deps = []
for rt_entry in routing_table:
- if rt_entry.provider_id not in providers:
+ if rt_entry.provider_type not in providers:
raise ValueError(
- f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
+ f"Provider `{rt_entry.provider_type}` is not available for API `{api}`"
)
- inner_specs.append(providers[rt_entry.provider_id])
- inner_deps.extend(providers[rt_entry.provider_id].api_dependencies)
+ inner_specs.append(providers[rt_entry.provider_type])
+ inner_deps.extend(providers[rt_entry.provider_type].api_dependencies)
specs[source_api] = RoutingTableProviderSpec(
api=source_api,
@@ -94,7 +95,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
sorted_specs = topological_sort(specs.values())
print(f"Resolved {len(sorted_specs)} providers in topological order")
for spec in sorted_specs:
- print(f" {spec.api}: {spec.provider_id}")
+ print(f" {spec.api}: {spec.provider_type}")
print("")
impls = {}
for spec in sorted_specs:
@@ -104,6 +105,14 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
impls[api] = impl
+ impls[Api.inspect] = DistributionInspectImpl()
+ specs[Api.inspect] = InlineProviderSpec(
+ api=Api.inspect,
+ provider_type="__distribution_builtin__",
+ config_class="",
+ module="",
+ )
+
return impls, specs
@@ -127,3 +136,60 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
dfs(a, visited, stack)
return [by_id[x] for x in stack]
+
+
+# returns a class implementing the protocol corresponding to the Api
+async def instantiate_provider(
+ provider_spec: ProviderSpec,
+ deps: Dict[str, Any],
+ provider_config: Union[GenericProviderConfig, RoutingTable],
+):
+ module = importlib.import_module(provider_spec.module)
+
+ args = []
+ if isinstance(provider_spec, RemoteProviderSpec):
+ if provider_spec.adapter:
+ method = "get_adapter_impl"
+ else:
+ method = "get_client_impl"
+
+ assert isinstance(provider_config, GenericProviderConfig)
+ config_type = instantiate_class_type(provider_spec.config_class)
+ config = config_type(**provider_config.config)
+ args = [config, deps]
+ elif isinstance(provider_spec, AutoRoutedProviderSpec):
+ method = "get_auto_router_impl"
+
+ config = None
+ args = [provider_spec.api, deps[provider_spec.routing_table_api], deps]
+ elif isinstance(provider_spec, RoutingTableProviderSpec):
+ method = "get_routing_table_impl"
+
+ assert isinstance(provider_config, List)
+ routing_table = provider_config
+
+ inner_specs = {x.provider_type: x for x in provider_spec.inner_specs}
+ inner_impls = []
+ for routing_entry in routing_table:
+ impl = await instantiate_provider(
+ inner_specs[routing_entry.provider_type],
+ deps,
+ routing_entry,
+ )
+ inner_impls.append((routing_entry.routing_key, impl))
+
+ config = None
+ args = [provider_spec.api, inner_impls, routing_table, deps]
+ else:
+ method = "get_provider_impl"
+
+ assert isinstance(provider_config, GenericProviderConfig)
+ config_type = instantiate_class_type(provider_spec.config_class)
+ config = config_type(**provider_config.config)
+ args = [config, deps]
+
+ fn = getattr(module, method)
+ impl = await fn(*args)
+ impl.__provider_spec__ = provider_spec
+ impl.__provider_config__ = config
+ return impl
diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py
index 02dc942e8..e5db17edc 100644
--- a/llama_stack/distribution/routers/routing_tables.py
+++ b/llama_stack/distribution/routers/routing_tables.py
@@ -94,12 +94,21 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> List[ShieldSpec]:
specs = []
for entry in self.routing_table_config:
- specs.append(
- ShieldSpec(
- shield_type=entry.routing_key,
- provider_config=entry,
+ if isinstance(entry.routing_key, list):
+ for k in entry.routing_key:
+ specs.append(
+ ShieldSpec(
+ shield_type=k,
+ provider_config=entry,
+ )
+ )
+ else:
+ specs.append(
+ ShieldSpec(
+ shield_type=entry.routing_key,
+ provider_config=entry,
+ )
)
- )
return specs
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
diff --git a/llama_stack/distribution/server/endpoints.py b/llama_stack/distribution/server/endpoints.py
new file mode 100644
index 000000000..601e80e5d
--- /dev/null
+++ b/llama_stack/distribution/server/endpoints.py
@@ -0,0 +1,67 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+import inspect
+from typing import Dict, List
+
+from pydantic import BaseModel
+
+from llama_stack.apis.agents import Agents
+from llama_stack.apis.inference import Inference
+from llama_stack.apis.inspect import Inspect
+from llama_stack.apis.memory import Memory
+from llama_stack.apis.memory_banks import MemoryBanks
+from llama_stack.apis.models import Models
+from llama_stack.apis.safety import Safety
+from llama_stack.apis.shields import Shields
+from llama_stack.apis.telemetry import Telemetry
+
+from llama_stack.providers.datatypes import Api
+
+
+class ApiEndpoint(BaseModel):
+ route: str
+ method: str
+ name: str
+
+
+def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
+ apis = {}
+
+ protocols = {
+ Api.inference: Inference,
+ Api.safety: Safety,
+ Api.agents: Agents,
+ Api.memory: Memory,
+ Api.telemetry: Telemetry,
+ Api.models: Models,
+ Api.shields: Shields,
+ Api.memory_banks: MemoryBanks,
+ Api.inspect: Inspect,
+ }
+
+ for api, protocol in protocols.items():
+ endpoints = []
+ protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
+
+ for name, method in protocol_methods:
+ if not hasattr(method, "__webmethod__"):
+ continue
+
+ webmethod = method.__webmethod__
+ route = webmethod.route
+
+ if webmethod.method == "GET":
+ method = "get"
+ elif webmethod.method == "DELETE":
+ method = "delete"
+ else:
+ method = "post"
+ endpoints.append(ApiEndpoint(route=route, method=method, name=name))
+
+ apis[api] = endpoints
+
+ return apis
diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py
index 16b1fb619..4013264df 100644
--- a/llama_stack/distribution/server/server.py
+++ b/llama_stack/distribution/server/server.py
@@ -15,7 +15,6 @@ from collections.abc import (
AsyncIterator as AsyncIteratorABC,
)
from contextlib import asynccontextmanager
-from http import HTTPStatus
from ssl import SSLError
from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional
@@ -26,7 +25,6 @@ import yaml
from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
-from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
@@ -39,10 +37,11 @@ from llama_stack.providers.utils.telemetry.tracing import (
)
from llama_stack.distribution.datatypes import * # noqa: F403
-from llama_stack.distribution.distribution import api_endpoints
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls_with_routing
+from .endpoints import get_all_api_endpoints
+
def is_async_iterator_type(typ):
if hasattr(typ, "__origin__"):
@@ -286,26 +285,18 @@ def main(
app = FastAPI()
- # Health check is added to enable deploying the docker container image on Kubernetes which require
- # a health check that can return 200 for readiness and liveness check
- class HealthCheck(BaseModel):
- status: str = "OK"
-
- @app.get("/healthcheck", status_code=HTTPStatus.OK, response_model=HealthCheck)
- async def healthcheck():
- return HealthCheck(status="OK")
-
impls, specs = asyncio.run(resolve_impls_with_routing(config))
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
- all_endpoints = api_endpoints()
+ all_endpoints = get_all_api_endpoints()
if config.apis_to_serve:
apis_to_serve = set(config.apis_to_serve)
else:
apis_to_serve = set(impls.keys())
+ apis_to_serve.add(Api.inspect)
for api_str in apis_to_serve:
api = Api(api_str)
@@ -339,14 +330,11 @@ def main(
)
)
- for route in app.routes:
- if isinstance(route, APIRoute):
- cprint(
- f"Serving {next(iter(route.methods))} {route.path}",
- "white",
- attrs=["bold"],
- )
+ cprint(f"Serving API {api_str}", "white", attrs=["bold"])
+ for endpoint in endpoints:
+ cprint(f" {endpoint.method.upper()} {endpoint.route}", "white")
+ print("")
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler)
signal.signal(signal.SIGINT, handle_sigint)
diff --git a/llama_stack/distribution/templates/docker/llamastack-local-cpu/run.yaml b/llama_stack/distribution/templates/docker/llamastack-local-cpu/run.yaml
index 0a845582c..aa5bb916f 100644
--- a/llama_stack/distribution/templates/docker/llamastack-local-cpu/run.yaml
+++ b/llama_stack/distribution/templates/docker/llamastack-local-cpu/run.yaml
@@ -18,7 +18,7 @@ api_providers:
providers:
- meta-reference
agents:
- provider_id: meta-reference
+ provider_type: meta-reference
config:
persistence_store:
namespace: null
@@ -28,22 +28,22 @@ api_providers:
providers:
- meta-reference
telemetry:
- provider_id: meta-reference
+ provider_type: meta-reference
config: {}
routing_table:
inference:
- - provider_id: remote::ollama
+ - provider_type: remote::ollama
config:
host: localhost
port: 6000
routing_key: Meta-Llama3.1-8B-Instruct
safety:
- - provider_id: meta-reference
+ - provider_type: meta-reference
config:
llama_guard_shield: null
prompt_guard_shield: null
routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"]
memory:
- - provider_id: meta-reference
+ - provider_type: meta-reference
config: {}
routing_key: vector
diff --git a/llama_stack/distribution/templates/docker/llamastack-local-gpu/run.yaml b/llama_stack/distribution/templates/docker/llamastack-local-gpu/run.yaml
index 66f6cfcef..bb7a2cc0d 100644
--- a/llama_stack/distribution/templates/docker/llamastack-local-gpu/run.yaml
+++ b/llama_stack/distribution/templates/docker/llamastack-local-gpu/run.yaml
@@ -18,7 +18,7 @@ api_providers:
providers:
- meta-reference
agents:
- provider_id: meta-reference
+ provider_type: meta-reference
config:
persistence_store:
namespace: null
@@ -28,11 +28,11 @@ api_providers:
providers:
- meta-reference
telemetry:
- provider_id: meta-reference
+ provider_type: meta-reference
config: {}
routing_table:
inference:
- - provider_id: meta-reference
+ - provider_type: meta-reference
config:
model: Llama3.1-8B-Instruct
quantization: null
@@ -41,12 +41,12 @@ routing_table:
max_batch_size: 1
routing_key: Llama3.1-8B-Instruct
safety:
- - provider_id: meta-reference
+ - provider_type: meta-reference
config:
llama_guard_shield: null
prompt_guard_shield: null
routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"]
memory:
- - provider_id: meta-reference
+ - provider_type: meta-reference
config: {}
routing_key: vector
diff --git a/llama_stack/distribution/utils/dynamic.py b/llama_stack/distribution/utils/dynamic.py
index 7c2ac2e6a..53b861fe4 100644
--- a/llama_stack/distribution/utils/dynamic.py
+++ b/llama_stack/distribution/utils/dynamic.py
@@ -5,69 +5,9 @@
# the root directory of this source tree.
import importlib
-from typing import Any, Dict
-
-from llama_stack.distribution.datatypes import * # noqa: F403
def instantiate_class_type(fully_qualified_name):
module_name, class_name = fully_qualified_name.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, class_name)
-
-
-# returns a class implementing the protocol corresponding to the Api
-async def instantiate_provider(
- provider_spec: ProviderSpec,
- deps: Dict[str, Any],
- provider_config: Union[GenericProviderConfig, RoutingTable],
-):
- module = importlib.import_module(provider_spec.module)
-
- args = []
- if isinstance(provider_spec, RemoteProviderSpec):
- if provider_spec.adapter:
- method = "get_adapter_impl"
- else:
- method = "get_client_impl"
-
- assert isinstance(provider_config, GenericProviderConfig)
- config_type = instantiate_class_type(provider_spec.config_class)
- config = config_type(**provider_config.config)
- args = [config, deps]
- elif isinstance(provider_spec, AutoRoutedProviderSpec):
- method = "get_auto_router_impl"
-
- config = None
- args = [provider_spec.api, deps[provider_spec.routing_table_api], deps]
- elif isinstance(provider_spec, RoutingTableProviderSpec):
- method = "get_routing_table_impl"
-
- assert isinstance(provider_config, List)
- routing_table = provider_config
-
- inner_specs = {x.provider_id: x for x in provider_spec.inner_specs}
- inner_impls = []
- for routing_entry in routing_table:
- impl = await instantiate_provider(
- inner_specs[routing_entry.provider_id],
- deps,
- routing_entry,
- )
- inner_impls.append((routing_entry.routing_key, impl))
-
- config = None
- args = [provider_spec.api, inner_impls, routing_table, deps]
- else:
- method = "get_provider_impl"
-
- assert isinstance(provider_config, GenericProviderConfig)
- config_type = instantiate_class_type(provider_spec.config_class)
- config = config_type(**provider_config.config)
- args = [config, deps]
-
- fn = getattr(module, method)
- impl = await fn(*args)
- impl.__provider_spec__ = provider_spec
- impl.__provider_config__ = config
- return impl
diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py
index a9a3d86e9..a2e8851a2 100644
--- a/llama_stack/providers/datatypes.py
+++ b/llama_stack/providers/datatypes.py
@@ -5,7 +5,7 @@
# the root directory of this source tree.
from enum import Enum
-from typing import Any, Dict, List, Optional, Protocol, Union
+from typing import Any, List, Optional, Protocol
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@@ -24,18 +24,14 @@ class Api(Enum):
shields = "shields"
memory_banks = "memory_banks"
-
-@json_schema_type
-class ApiEndpoint(BaseModel):
- route: str
- method: str
- name: str
+ # built-in API
+ inspect = "inspect"
@json_schema_type
class ProviderSpec(BaseModel):
api: Api
- provider_id: str
+ provider_type: str
config_class: str = Field(
...,
description="Fully-qualified classname of the config for this provider",
@@ -62,71 +58,9 @@ class RoutableProvider(Protocol):
async def validate_routing_keys(self, keys: List[str]) -> None: ...
-class GenericProviderConfig(BaseModel):
- provider_id: str
- config: Dict[str, Any]
-
-
-class PlaceholderProviderConfig(BaseModel):
- """Placeholder provider config for API whose provider are defined in routing_table"""
-
- providers: List[str]
-
-
-RoutingKey = Union[str, List[str]]
-
-
-class RoutableProviderConfig(GenericProviderConfig):
- routing_key: RoutingKey
-
-
-# Example: /inference, /safety
-@json_schema_type
-class AutoRoutedProviderSpec(ProviderSpec):
- provider_id: str = "router"
- config_class: str = ""
-
- docker_image: Optional[str] = None
- routing_table_api: Api
- module: str = Field(
- ...,
- description="""
- Fully-qualified name of the module to import. The module is expected to have:
-
- - `get_router_impl(config, provider_specs, deps)`: returns the router implementation
- """,
- )
- provider_data_validator: Optional[str] = Field(
- default=None,
- )
-
- @property
- def pip_packages(self) -> List[str]:
- raise AssertionError("Should not be called on AutoRoutedProviderSpec")
-
-
-# Example: /models, /shields
-@json_schema_type
-class RoutingTableProviderSpec(ProviderSpec):
- provider_id: str = "routing_table"
- config_class: str = ""
- docker_image: Optional[str] = None
-
- inner_specs: List[ProviderSpec]
- module: str = Field(
- ...,
- description="""
- Fully-qualified name of the module to import. The module is expected to have:
-
- - `get_router_impl(config, provider_specs, deps)`: returns the router implementation
- """,
- )
- pip_packages: List[str] = Field(default_factory=list)
-
-
@json_schema_type
class AdapterSpec(BaseModel):
- adapter_id: str = Field(
+ adapter_type: str = Field(
...,
description="Unique identifier for this adapter",
)
@@ -186,10 +120,6 @@ class RemoteProviderConfig(BaseModel):
return f"http://{self.host}:{self.port}"
-def remote_provider_id(adapter_id: str) -> str:
- return f"remote::{adapter_id}"
-
-
@json_schema_type
class RemoteProviderSpec(ProviderSpec):
adapter: Optional[AdapterSpec] = Field(
@@ -233,8 +163,8 @@ def remote_provider_spec(
if adapter and adapter.config_class
else "llama_stack.distribution.datatypes.RemoteProviderConfig"
)
- provider_id = remote_provider_id(adapter.adapter_id) if adapter else "remote"
+ provider_type = f"remote::{adapter.adapter_type}" if adapter else "remote"
return RemoteProviderSpec(
- api=api, provider_id=provider_id, config_class=config_class, adapter=adapter
+ api=api, provider_type=provider_type, config_class=config_class, adapter=adapter
)
diff --git a/llama_stack/providers/impls/meta_reference/safety/config.py b/llama_stack/providers/impls/meta_reference/safety/config.py
index 734103412..36428078d 100644
--- a/llama_stack/providers/impls/meta_reference/safety/config.py
+++ b/llama_stack/providers/impls/meta_reference/safety/config.py
@@ -50,20 +50,6 @@ class LlamaGuardShieldConfig(BaseModel):
class PromptGuardShieldConfig(BaseModel):
model: str = "Prompt-Guard-86M"
- @validator("model")
- @classmethod
- def validate_model(cls, model: str) -> str:
- permitted_models = [
- m.descriptor()
- for m in safety_models()
- if m.core_model_id == CoreModelId.prompt_guard_86m
- ]
- if model not in permitted_models:
- raise ValueError(
- f"Invalid model: {model}. Must be one of {permitted_models}"
- )
- return model
-
class SafetyConfig(BaseModel):
llama_guard_shield: Optional[LlamaGuardShieldConfig] = None
diff --git a/llama_stack/providers/registry/agents.py b/llama_stack/providers/registry/agents.py
index 16a872572..2603b5faf 100644
--- a/llama_stack/providers/registry/agents.py
+++ b/llama_stack/providers/registry/agents.py
@@ -14,7 +14,7 @@ def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.agents,
- provider_id="meta-reference",
+ provider_type="meta-reference",
pip_packages=[
"matplotlib",
"pillow",
@@ -33,7 +33,7 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
api=Api.agents,
adapter=AdapterSpec(
- adapter_id="sample",
+ adapter_type="sample",
pip_packages=[],
module="llama_stack.providers.adapters.agents.sample",
config_class="llama_stack.providers.adapters.agents.sample.SampleConfig",
diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py
index 8f9786a95..47e142201 100644
--- a/llama_stack/providers/registry/inference.py
+++ b/llama_stack/providers/registry/inference.py
@@ -13,7 +13,7 @@ def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.inference,
- provider_id="meta-reference",
+ provider_type="meta-reference",
pip_packages=[
"accelerate",
"blobfile",
@@ -30,7 +30,7 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
- adapter_id="sample",
+ adapter_type="sample",
pip_packages=[],
module="llama_stack.providers.adapters.inference.sample",
config_class="llama_stack.providers.adapters.inference.sample.SampleConfig",
@@ -39,7 +39,7 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
- adapter_id="ollama",
+ adapter_type="ollama",
pip_packages=["ollama"],
module="llama_stack.providers.adapters.inference.ollama",
),
@@ -47,7 +47,7 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
- adapter_id="tgi",
+ adapter_type="tgi",
pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.adapters.inference.tgi",
config_class="llama_stack.providers.adapters.inference.tgi.TGIImplConfig",
@@ -56,7 +56,7 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
- adapter_id="hf::serverless",
+ adapter_type="hf::serverless",
pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.adapters.inference.tgi",
config_class="llama_stack.providers.adapters.inference.tgi.InferenceAPIImplConfig",
@@ -65,7 +65,7 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
- adapter_id="hf::endpoint",
+ adapter_type="hf::endpoint",
pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.adapters.inference.tgi",
config_class="llama_stack.providers.adapters.inference.tgi.InferenceEndpointImplConfig",
@@ -74,7 +74,7 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
- adapter_id="fireworks",
+ adapter_type="fireworks",
pip_packages=[
"fireworks-ai",
],
@@ -85,7 +85,7 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
- adapter_id="together",
+ adapter_type="together",
pip_packages=[
"together",
],
@@ -97,10 +97,8 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
- adapter_id="bedrock",
- pip_packages=[
- "boto3"
- ],
+ adapter_type="bedrock",
+ pip_packages=["boto3"],
module="llama_stack.providers.adapters.inference.bedrock",
config_class="llama_stack.providers.adapters.inference.bedrock.BedrockConfig",
),
diff --git a/llama_stack/providers/registry/memory.py b/llama_stack/providers/registry/memory.py
index d6776ff69..4687e262c 100644
--- a/llama_stack/providers/registry/memory.py
+++ b/llama_stack/providers/registry/memory.py
@@ -34,7 +34,7 @@ def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.memory,
- provider_id="meta-reference",
+ provider_type="meta-reference",
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
module="llama_stack.providers.impls.meta_reference.memory",
config_class="llama_stack.providers.impls.meta_reference.memory.FaissImplConfig",
@@ -42,7 +42,7 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
Api.memory,
AdapterSpec(
- adapter_id="chromadb",
+ adapter_type="chromadb",
pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
module="llama_stack.providers.adapters.memory.chroma",
),
@@ -50,7 +50,7 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
Api.memory,
AdapterSpec(
- adapter_id="pgvector",
+ adapter_type="pgvector",
pip_packages=EMBEDDING_DEPS + ["psycopg2-binary"],
module="llama_stack.providers.adapters.memory.pgvector",
config_class="llama_stack.providers.adapters.memory.pgvector.PGVectorConfig",
@@ -59,7 +59,7 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
api=Api.memory,
adapter=AdapterSpec(
- adapter_id="sample",
+ adapter_type="sample",
pip_packages=[],
module="llama_stack.providers.adapters.memory.sample",
config_class="llama_stack.providers.adapters.memory.sample.SampleConfig",
diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py
index e0022f02b..58307be11 100644
--- a/llama_stack/providers/registry/safety.py
+++ b/llama_stack/providers/registry/safety.py
@@ -19,7 +19,7 @@ def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.safety,
- provider_id="meta-reference",
+ provider_type="meta-reference",
pip_packages=[
"codeshield",
"transformers",
@@ -34,7 +34,7 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
api=Api.safety,
adapter=AdapterSpec(
- adapter_id="sample",
+ adapter_type="sample",
pip_packages=[],
module="llama_stack.providers.adapters.safety.sample",
config_class="llama_stack.providers.adapters.safety.sample.SampleConfig",
@@ -43,7 +43,7 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
api=Api.safety,
adapter=AdapterSpec(
- adapter_id="bedrock",
+ adapter_type="bedrock",
pip_packages=["boto3"],
module="llama_stack.providers.adapters.safety.bedrock",
config_class="llama_stack.providers.adapters.safety.bedrock.BedrockSafetyConfig",
@@ -52,7 +52,7 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
api=Api.safety,
adapter=AdapterSpec(
- adapter_id="together",
+ adapter_type="together",
pip_packages=[
"together",
],
diff --git a/llama_stack/providers/registry/telemetry.py b/llama_stack/providers/registry/telemetry.py
index 02b71077e..39bcb75d8 100644
--- a/llama_stack/providers/registry/telemetry.py
+++ b/llama_stack/providers/registry/telemetry.py
@@ -13,7 +13,7 @@ def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.telemetry,
- provider_id="meta-reference",
+ provider_type="meta-reference",
pip_packages=[],
module="llama_stack.providers.impls.meta_reference.telemetry",
config_class="llama_stack.providers.impls.meta_reference.telemetry.ConsoleConfig",
@@ -21,7 +21,7 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
api=Api.telemetry,
adapter=AdapterSpec(
- adapter_id="sample",
+ adapter_type="sample",
pip_packages=[],
module="llama_stack.providers.adapters.telemetry.sample",
config_class="llama_stack.providers.adapters.telemetry.sample.SampleConfig",
@@ -30,7 +30,7 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
api=Api.telemetry,
adapter=AdapterSpec(
- adapter_id="opentelemetry-jaeger",
+ adapter_type="opentelemetry-jaeger",
pip_packages=[
"opentelemetry-api",
"opentelemetry-sdk",
diff --git a/llama_stack/providers/utils/inference/augment_messages.py b/llama_stack/providers/utils/inference/augment_messages.py
index 10375cf0e..613a39525 100644
--- a/llama_stack/providers/utils/inference/augment_messages.py
+++ b/llama_stack/providers/utils/inference/augment_messages.py
@@ -34,7 +34,8 @@ def augment_messages_for_tools(request: ChatCompletionRequest) -> List[Message]:
return request.messages
if model.model_family == ModelFamily.llama3_1 or (
- model.model_family == ModelFamily.llama3_2 and is_multimodal(model.core_model_id)
+ model.model_family == ModelFamily.llama3_2
+ and is_multimodal(model.core_model_id)
):
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
return augment_messages_for_tools_llama_3_1(request)
diff --git a/requirements.txt b/requirements.txt
index 327b2ee82..df3221371 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,7 +2,7 @@ blobfile
fire
httpx
huggingface-hub
-llama-models>=0.0.37
+llama-models>=0.0.38
prompt-toolkit
python-dotenv
pydantic>=2
diff --git a/setup.py b/setup.py
index 3c26c9a84..804c9ba3d 100644
--- a/setup.py
+++ b/setup.py
@@ -16,7 +16,7 @@ def read_requirements():
setup(
name="llama_stack",
- version="0.0.37",
+ version="0.0.38",
author="Meta Llama",
author_email="llama-oss@meta.com",
description="Llama Stack",
diff --git a/tests/examples/local-run.yaml b/tests/examples/local-run.yaml
index 98d105233..94340c4d1 100644
--- a/tests/examples/local-run.yaml
+++ b/tests/examples/local-run.yaml
@@ -18,7 +18,7 @@ api_providers:
providers:
- meta-reference
agents:
- provider_id: meta-reference
+ provider_type: meta-reference
config:
persistence_store:
namespace: null
@@ -28,11 +28,11 @@ api_providers:
providers:
- meta-reference
telemetry:
- provider_id: meta-reference
+ provider_type: meta-reference
config: {}
routing_table:
inference:
- - provider_id: meta-reference
+ - provider_type: meta-reference
config:
model: Meta-Llama3.1-8B-Instruct
quantization: null
@@ -41,7 +41,7 @@ routing_table:
max_batch_size: 1
routing_key: Meta-Llama3.1-8B-Instruct
safety:
- - provider_id: meta-reference
+ - provider_type: meta-reference
config:
llama_guard_shield:
model: Llama-Guard-3-1B
@@ -52,6 +52,6 @@ routing_table:
model: Prompt-Guard-86M
routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"]
memory:
- - provider_id: meta-reference
+ - provider_type: meta-reference
config: {}
routing_key: vector