Merge branch 'main' into fix_configure

This commit is contained in:
Xi Yan 2024-10-03 10:38:20 -07:00 committed by GitHub
commit a7250a1e33
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
39 changed files with 814 additions and 376 deletions

View file

@ -51,3 +51,9 @@ repos:
# hooks:
# - id: pydoclint
# args: [--config=pyproject.toml]
# - repo: https://github.com/tcort/markdown-link-check
# rev: v3.11.2
# hooks:
# - id: markdown-link-check
# args: ['--quiet']

View file

@ -1,7 +1,8 @@
# Llama Stack
[![PyPI version](https://img.shields.io/pypi/v/llama_stack.svg)](https://pypi.org/project/llama_stack/)
[![PyPI - Downloads](https://img.shields.io/pypi/dm/llama-stack)](https://pypi.org/project/llama-stack/)
[![Discord](https://img.shields.io/discord/1257833999603335178)](https://discord.gg/TZAAYNVtrU)
[![Discord](https://img.shields.io/discord/1257833999603335178)](https://discord.gg/llama-stack)
This repository contains the Llama Stack API specifications as well as API Providers and Llama Stack Distributions.

View file

@ -5,7 +5,7 @@ The `llama` CLI tool helps you setup and use the Llama toolchain & agentic syste
### Subcommands
1. `download`: `llama` cli tools supports downloading the model from Meta or Hugging Face.
2. `model`: Lists available models and their properties.
3. `stack`: Allows you to build and run a Llama Stack server. You can read more about this [here](/docs/cli_reference.md#step-3-building-configuring-and-running-llama-stack-servers).
3. `stack`: Allows you to build and run a Llama Stack server. You can read more about this [here](cli_reference.md#step-3-building-and-configuring-llama-stack-distributions).
### Sample Usage

View file

@ -46,6 +46,7 @@ from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.inspect import * # noqa: F403
class LlamaStack(
@ -63,6 +64,7 @@ class LlamaStack(
Evaluations,
Models,
Shields,
Inspect,
):
pass

View file

@ -21,7 +21,7 @@
"info": {
"title": "[DRAFT] Llama Stack Specification",
"version": "0.0.1",
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-09-23 16:58:41.469308"
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-02 15:40:53.008257"
},
"servers": [
{
@ -1542,6 +1542,36 @@
]
}
},
"/health": {
"get": {
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HealthInfo"
}
}
}
}
},
"tags": [
"Inspect"
],
"parameters": [
{
"name": "X-LlamaStack-ProviderData",
"in": "header",
"description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
"required": false,
"schema": {
"type": "string"
}
}
]
}
},
"/memory/insert": {
"post": {
"responses": {
@ -1665,6 +1695,75 @@
]
}
},
"/providers/list": {
"get": {
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {
"type": "object",
"additionalProperties": {
"$ref": "#/components/schemas/ProviderInfo"
}
}
}
}
}
},
"tags": [
"Inspect"
],
"parameters": [
{
"name": "X-LlamaStack-ProviderData",
"in": "header",
"description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
"required": false,
"schema": {
"type": "string"
}
}
]
}
},
"/routes/list": {
"get": {
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {
"type": "object",
"additionalProperties": {
"type": "array",
"items": {
"$ref": "#/components/schemas/RouteInfo"
}
}
}
}
}
}
},
"tags": [
"Inspect"
],
"parameters": [
{
"name": "X-LlamaStack-ProviderData",
"in": "header",
"description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
"required": false,
"schema": {
"type": "string"
}
}
]
}
},
"/shields/list": {
"get": {
"responses": {
@ -4783,7 +4882,7 @@
"provider_config": {
"type": "object",
"properties": {
"provider_id": {
"provider_type": {
"type": "string"
},
"config": {
@ -4814,7 +4913,7 @@
},
"additionalProperties": false,
"required": [
"provider_id",
"provider_type",
"config"
]
}
@ -4843,7 +4942,7 @@
"provider_config": {
"type": "object",
"properties": {
"provider_id": {
"provider_type": {
"type": "string"
},
"config": {
@ -4874,7 +4973,7 @@
},
"additionalProperties": false,
"required": [
"provider_id",
"provider_type",
"config"
]
}
@ -4894,7 +4993,7 @@
"provider_config": {
"type": "object",
"properties": {
"provider_id": {
"provider_type": {
"type": "string"
},
"config": {
@ -4925,7 +5024,7 @@
},
"additionalProperties": false,
"required": [
"provider_id",
"provider_type",
"config"
]
}
@ -5086,6 +5185,18 @@
"job_uuid"
]
},
"HealthInfo": {
"type": "object",
"properties": {
"status": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"status"
]
},
"InsertDocumentsRequest": {
"type": "object",
"properties": {
@ -5108,6 +5219,45 @@
"documents"
]
},
"ProviderInfo": {
"type": "object",
"properties": {
"provider_type": {
"type": "string"
},
"description": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"provider_type",
"description"
]
},
"RouteInfo": {
"type": "object",
"properties": {
"route": {
"type": "string"
},
"method": {
"type": "string"
},
"providers": {
"type": "array",
"items": {
"type": "string"
}
}
},
"additionalProperties": false,
"required": [
"route",
"method",
"providers"
]
},
"LogSeverity": {
"type": "string",
"enum": [
@ -6220,19 +6370,34 @@
],
"tags": [
{
"name": "Shields"
"name": "Datasets"
},
{
"name": "Inspect"
},
{
"name": "Memory"
},
{
"name": "BatchInference"
},
{
"name": "RewardScoring"
"name": "Agents"
},
{
"name": "Inference"
},
{
"name": "Shields"
},
{
"name": "SyntheticDataGeneration"
},
{
"name": "Agents"
"name": "Models"
},
{
"name": "RewardScoring"
},
{
"name": "MemoryBanks"
@ -6241,13 +6406,7 @@
"name": "Safety"
},
{
"name": "Models"
},
{
"name": "Inference"
},
{
"name": "Memory"
"name": "Evaluations"
},
{
"name": "Telemetry"
@ -6255,12 +6414,6 @@
{
"name": "PostTraining"
},
{
"name": "Datasets"
},
{
"name": "Evaluations"
},
{
"name": "BuiltinTool",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/BuiltinTool\" />"
@ -6653,10 +6806,22 @@
"name": "PostTrainingJob",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/PostTrainingJob\" />"
},
{
"name": "HealthInfo",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/HealthInfo\" />"
},
{
"name": "InsertDocumentsRequest",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/InsertDocumentsRequest\" />"
},
{
"name": "ProviderInfo",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/ProviderInfo\" />"
},
{
"name": "RouteInfo",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/RouteInfo\" />"
},
{
"name": "LogSeverity",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/LogSeverity\" />"
@ -6787,6 +6952,7 @@
"Datasets",
"Evaluations",
"Inference",
"Inspect",
"Memory",
"MemoryBanks",
"Models",
@ -6857,6 +7023,7 @@
"FunctionCallToolDefinition",
"GetAgentsSessionRequest",
"GetDocumentsRequest",
"HealthInfo",
"ImageMedia",
"InferenceStep",
"InsertDocumentsRequest",
@ -6880,6 +7047,7 @@
"PostTrainingJobStatus",
"PostTrainingJobStatusResponse",
"PreferenceOptimizeRequest",
"ProviderInfo",
"QLoraFinetuningConfig",
"QueryDocumentsRequest",
"QueryDocumentsResponse",
@ -6888,6 +7056,7 @@
"RestAPIMethod",
"RewardScoreRequest",
"RewardScoringResponse",
"RouteInfo",
"RunShieldRequest",
"RunShieldResponse",
"SafetyViolation",

View file

@ -908,6 +908,14 @@ components:
required:
- document_ids
type: object
HealthInfo:
additionalProperties: false
properties:
status:
type: string
required:
- status
type: object
ImageMedia:
additionalProperties: false
properties:
@ -1117,10 +1125,10 @@ components:
- type: array
- type: object
type: object
provider_id:
provider_type:
type: string
required:
- provider_id
- provider_type
- config
type: object
required:
@ -1362,10 +1370,10 @@ components:
- type: array
- type: object
type: object
provider_id:
provider_type:
type: string
required:
- provider_id
- provider_type
- config
type: object
required:
@ -1543,6 +1551,17 @@ components:
- hyperparam_search_config
- logger_config
type: object
ProviderInfo:
additionalProperties: false
properties:
description:
type: string
provider_type:
type: string
required:
- provider_type
- description
type: object
QLoraFinetuningConfig:
additionalProperties: false
properties:
@ -1704,6 +1723,22 @@ components:
title: Response from the reward scoring. Batch of (prompt, response, score)
tuples that pass the threshold.
type: object
RouteInfo:
additionalProperties: false
properties:
method:
type: string
providers:
items:
type: string
type: array
route:
type: string
required:
- route
- method
- providers
type: object
RunShieldRequest:
additionalProperties: false
properties:
@ -1916,10 +1951,10 @@ components:
- type: array
- type: object
type: object
provider_id:
provider_type:
type: string
required:
- provider_id
- provider_type
- config
type: object
shield_type:
@ -2569,7 +2604,7 @@ info:
description: "This is the specification of the llama stack that provides\n \
\ a set of endpoints and their corresponding interfaces that are tailored\
\ to\n best leverage Llama Models. The specification is still in\
\ draft and subject to change.\n Generated at 2024-09-23 16:58:41.469308"
\ draft and subject to change.\n Generated at 2024-10-02 15:40:53.008257"
title: '[DRAFT] Llama Stack Specification'
version: 0.0.1
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
@ -3093,6 +3128,25 @@ paths:
description: OK
tags:
- Evaluations
/health:
get:
parameters:
- description: JSON-encoded provider data which will be made available to the
adapter servicing the API
in: header
name: X-LlamaStack-ProviderData
required: false
schema:
type: string
responses:
'200':
content:
application/json:
schema:
$ref: '#/components/schemas/HealthInfo'
description: OK
tags:
- Inspect
/inference/chat_completion:
post:
parameters:
@ -3637,6 +3691,27 @@ paths:
description: OK
tags:
- PostTraining
/providers/list:
get:
parameters:
- description: JSON-encoded provider data which will be made available to the
adapter servicing the API
in: header
name: X-LlamaStack-ProviderData
required: false
schema:
type: string
responses:
'200':
content:
application/json:
schema:
additionalProperties:
$ref: '#/components/schemas/ProviderInfo'
type: object
description: OK
tags:
- Inspect
/reward_scoring/score:
post:
parameters:
@ -3662,6 +3737,29 @@ paths:
description: OK
tags:
- RewardScoring
/routes/list:
get:
parameters:
- description: JSON-encoded provider data which will be made available to the
adapter servicing the API
in: header
name: X-LlamaStack-ProviderData
required: false
schema:
type: string
responses:
'200':
content:
application/json:
schema:
additionalProperties:
items:
$ref: '#/components/schemas/RouteInfo'
type: array
type: object
description: OK
tags:
- Inspect
/safety/run_shield:
post:
parameters:
@ -3807,20 +3905,21 @@ security:
servers:
- url: http://any-hosted-llama-stack.com
tags:
- name: Shields
- name: Datasets
- name: Inspect
- name: Memory
- name: BatchInference
- name: RewardScoring
- name: SyntheticDataGeneration
- name: Agents
- name: Inference
- name: Shields
- name: SyntheticDataGeneration
- name: Models
- name: RewardScoring
- name: MemoryBanks
- name: Safety
- name: Models
- name: Inference
- name: Memory
- name: Evaluations
- name: Telemetry
- name: PostTraining
- name: Datasets
- name: Evaluations
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
name: BuiltinTool
- description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage"
@ -4135,9 +4234,15 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/PostTrainingJob"
/>
name: PostTrainingJob
- description: <SchemaDefinition schemaRef="#/components/schemas/HealthInfo" />
name: HealthInfo
- description: <SchemaDefinition schemaRef="#/components/schemas/InsertDocumentsRequest"
/>
name: InsertDocumentsRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/ProviderInfo" />
name: ProviderInfo
- description: <SchemaDefinition schemaRef="#/components/schemas/RouteInfo" />
name: RouteInfo
- description: <SchemaDefinition schemaRef="#/components/schemas/LogSeverity" />
name: LogSeverity
- description: <SchemaDefinition schemaRef="#/components/schemas/MetricEvent" />
@ -4236,6 +4341,7 @@ x-tagGroups:
- Datasets
- Evaluations
- Inference
- Inspect
- Memory
- MemoryBanks
- Models
@ -4303,6 +4409,7 @@ x-tagGroups:
- FunctionCallToolDefinition
- GetAgentsSessionRequest
- GetDocumentsRequest
- HealthInfo
- ImageMedia
- InferenceStep
- InsertDocumentsRequest
@ -4326,6 +4433,7 @@ x-tagGroups:
- PostTrainingJobStatus
- PostTrainingJobStatusResponse
- PreferenceOptimizeRequest
- ProviderInfo
- QLoraFinetuningConfig
- QueryDocumentsRequest
- QueryDocumentsResponse
@ -4334,6 +4442,7 @@ x-tagGroups:
- RestAPIMethod
- RewardScoreRequest
- RewardScoringResponse
- RouteInfo
- RunShieldRequest
- RunShieldResponse
- SafetyViolation

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .inspect import * # noqa: F401 F403

View file

@ -0,0 +1,82 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import List
import fire
import httpx
from termcolor import cprint
from .inspect import * # noqa: F403
class InspectClient(Inspect):
def __init__(self, base_url: str):
self.base_url = base_url
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def list_providers(self) -> Dict[str, ProviderInfo]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/providers/list",
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
print(response.json())
return {
k: [ProviderInfo(**vi) for vi in v] for k, v in response.json().items()
}
async def list_routes(self) -> Dict[str, List[RouteInfo]]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/routes/list",
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return {
k: [RouteInfo(**vi) for vi in v] for k, v in response.json().items()
}
async def health(self) -> HealthInfo:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/health",
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
j = response.json()
if j is None:
return None
return HealthInfo(**j)
async def run_main(host: str, port: int):
client = InspectClient(f"http://{host}:{port}")
response = await client.list_providers()
cprint(f"list_providers response={response}", "green")
response = await client.list_routes()
cprint(f"list_routes response={response}", "blue")
response = await client.health()
cprint(f"health response={response}", "yellow")
def main(host: str, port: int):
asyncio.run(run_main(host, port))
if __name__ == "__main__":
fire.Fire(main)

View file

@ -0,0 +1,40 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict, List, Protocol
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
@json_schema_type
class ProviderInfo(BaseModel):
provider_type: str
description: str
@json_schema_type
class RouteInfo(BaseModel):
route: str
method: str
providers: List[str]
@json_schema_type
class HealthInfo(BaseModel):
status: str
# TODO: add a provider level status
class Inspect(Protocol):
@webmethod(route="/providers/list", method="GET")
async def list_providers(self) -> Dict[str, ProviderInfo]: ...
@webmethod(route="/routes/list", method="GET")
async def list_routes(self) -> Dict[str, List[RouteInfo]]: ...
@webmethod(route="/health", method="GET")
async def health(self) -> HealthInfo: ...

View file

@ -18,7 +18,7 @@ from llama_stack.distribution.datatypes import GenericProviderConfig
class MemoryBankSpec(BaseModel):
bank_type: MemoryBankType
provider_config: GenericProviderConfig = Field(
description="Provider config for the model, including provider_id, and corresponding config. ",
description="Provider config for the model, including provider_type, and corresponding config. ",
)

View file

@ -20,7 +20,7 @@ class ModelServingSpec(BaseModel):
description="All metadatas associated with llama model (defined in llama_models.models.sku_list).",
)
provider_config: GenericProviderConfig = Field(
description="Provider config for the model, including provider_id, and corresponding config. ",
description="Provider config for the model, including provider_type, and corresponding config. ",
)

View file

@ -16,7 +16,7 @@ from llama_stack.distribution.datatypes import GenericProviderConfig
class ShieldSpec(BaseModel):
shield_type: str
provider_config: GenericProviderConfig = Field(
description="Provider config for the model, including provider_id, and corresponding config. ",
description="Provider config for the model, including provider_type, and corresponding config. ",
)

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import argparse
import subprocess
import textwrap
from io import StringIO
@ -110,7 +109,4 @@ def render_markdown_to_pager(markdown_content: str):
console = Console(file=output, force_terminal=True, width=100) # Set a fixed width
console.print(md)
rendered_content = output.getvalue()
# Pipe to pager
pager = subprocess.Popen(["less", "-R"], stdin=subprocess.PIPE)
pager.communicate(input=rendered_content.encode())
print(rendered_content)

View file

@ -179,12 +179,7 @@ class StackBuild(Subcommand):
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
import yaml
from llama_stack.distribution.distribution import (
Api,
api_providers,
builtin_automatically_routed_apis,
)
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.distribution import get_provider_registry
from prompt_toolkit import prompt
from prompt_toolkit.validation import Validator
from termcolor import cprint
@ -249,22 +244,12 @@ class StackBuild(Subcommand):
)
cprint(
f"\n Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs.",
"\n Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs.",
color="green",
)
providers = dict()
all_providers = api_providers()
routing_table_apis = set(
x.routing_table_api for x in builtin_automatically_routed_apis()
)
for api in Api:
if api in routing_table_apis:
continue
providers_for_api = all_providers[api]
for api, providers_for_api in get_provider_registry().items():
api_provider = prompt(
"> Enter provider for the {} API: (default=meta-reference): ".format(
api.value

View file

@ -34,9 +34,9 @@ class StackListProviders(Subcommand):
def _run_providers_list_cmd(self, args: argparse.Namespace) -> None:
from llama_stack.cli.table import print_table
from llama_stack.distribution.distribution import Api, api_providers
from llama_stack.distribution.distribution import Api, get_provider_registry
all_providers = api_providers()
all_providers = get_provider_registry()
providers_for_api = all_providers[Api(args.api)]
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
@ -47,11 +47,11 @@ class StackListProviders(Subcommand):
rows = []
for spec in providers_for_api.values():
if spec.provider_id == "sample":
if spec.provider_type == "sample":
continue
rows.append(
[
spec.provider_id,
spec.provider_type,
",".join(spec.pip_packages),
]
)

View file

@ -19,6 +19,17 @@ from pathlib import Path
from llama_stack.distribution.distribution import api_providers, SERVER_DEPENDENCIES
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
from llama_stack.distribution.distribution import get_provider_registry
# These are the dependencies needed by the distribution server.
# `llama-stack` is automatically installed by the installation script.
SERVER_DEPENDENCIES = [
"fastapi",
"fire",
"httpx",
"uvicorn",
]
class ImageType(Enum):
@ -43,7 +54,7 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
)
# extend package dependencies based on providers spec
all_providers = api_providers()
all_providers = get_provider_registry()
for (
api_str,
provider_or_providers,

View file

@ -15,8 +15,8 @@ from termcolor import cprint
from llama_stack.apis.memory.memory import MemoryBankType
from llama_stack.distribution.distribution import (
api_providers,
builtin_automatically_routed_apis,
get_provider_registry,
stack_apis,
)
from llama_stack.distribution.utils.dynamic import instantiate_class_type
@ -62,7 +62,7 @@ def configure_api_providers(
config.apis_to_serve = list(set([a for a in apis if a != "telemetry"]))
apis = [v.value for v in stack_apis()]
all_providers = api_providers()
all_providers = get_provider_registry()
# configure simple case for with non-routing providers to api_providers
for api_str in spec.providers.keys():
@ -109,7 +109,7 @@ def configure_api_providers(
routing_entries.append(
RoutableProviderConfig(
routing_key=routing_key,
provider_id=p,
provider_type=p,
config=cfg.dict(),
)
)
@ -120,7 +120,7 @@ def configure_api_providers(
routing_entries.append(
RoutableProviderConfig(
routing_key=[s.value for s in MetaReferenceShieldType],
provider_id=p,
provider_type=p,
config=cfg.dict(),
)
)
@ -133,7 +133,7 @@ def configure_api_providers(
routing_entries.append(
RoutableProviderConfig(
routing_key=routing_key,
provider_id=p,
provider_type=p,
config=cfg.dict(),
)
)
@ -153,7 +153,7 @@ def configure_api_providers(
routing_entries.append(
RoutableProviderConfig(
routing_key=routing_key,
provider_id=p,
provider_type=p,
config=cfg.dict(),
)
)
@ -164,7 +164,7 @@ def configure_api_providers(
)
else:
config.api_providers[api_str] = GenericProviderConfig(
provider_id=p,
provider_type=p,
config=cfg.dict(),
)

View file

@ -17,6 +17,53 @@ LLAMA_STACK_BUILD_CONFIG_VERSION = "v1"
LLAMA_STACK_RUN_CONFIG_VERSION = "v1"
RoutingKey = Union[str, List[str]]
class GenericProviderConfig(BaseModel):
provider_type: str
config: Dict[str, Any]
class RoutableProviderConfig(GenericProviderConfig):
routing_key: RoutingKey
class PlaceholderProviderConfig(BaseModel):
"""Placeholder provider config for API whose provider are defined in routing_table"""
providers: List[str]
# Example: /inference, /safety
class AutoRoutedProviderSpec(ProviderSpec):
provider_type: str = "router"
config_class: str = ""
docker_image: Optional[str] = None
routing_table_api: Api
module: str
provider_data_validator: Optional[str] = Field(
default=None,
)
@property
def pip_packages(self) -> List[str]:
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
# Example: /models, /shields
@json_schema_type
class RoutingTableProviderSpec(ProviderSpec):
provider_type: str = "routing_table"
config_class: str = ""
docker_image: Optional[str] = None
inner_specs: List[ProviderSpec]
module: str
pip_packages: List[str] = Field(default_factory=list)
@json_schema_type
class DistributionSpec(BaseModel):
description: Optional[str] = Field(
@ -71,7 +118,7 @@ Provider configurations for each of the APIs provided by this package.
E.g. The following is a ProviderRoutingEntry for models:
- routing_key: Meta-Llama3.1-8B-Instruct
provider_id: meta-reference
provider_type: meta-reference
config:
model: Meta-Llama3.1-8B-Instruct
quantization: null

View file

@ -5,30 +5,11 @@
# the root directory of this source tree.
import importlib
import inspect
from typing import Dict, List
from pydantic import BaseModel
from llama_stack.apis.agents import Agents
from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.models import Models
from llama_stack.apis.safety import Safety
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from .datatypes import Api, ApiEndpoint, ProviderSpec, remote_provider_spec
# These are the dependencies needed by the distribution server.
# `llama-stack` is automatically installed by the installation script.
SERVER_DEPENDENCIES = [
"fastapi",
"fire",
"httpx",
"uvicorn",
]
from llama_stack.providers.datatypes import Api, ProviderSpec, remote_provider_spec
def stack_apis() -> List[Api]:
@ -57,58 +38,21 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
]
def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
apis = {}
protocols = {
Api.inference: Inference,
Api.safety: Safety,
Api.agents: Agents,
Api.memory: Memory,
Api.telemetry: Telemetry,
Api.models: Models,
Api.shields: Shields,
Api.memory_banks: MemoryBanks,
}
for api, protocol in protocols.items():
endpoints = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
for name, method in protocol_methods:
if not hasattr(method, "__webmethod__"):
continue
webmethod = method.__webmethod__
route = webmethod.route
if webmethod.method == "GET":
method = "get"
elif webmethod.method == "DELETE":
method = "delete"
else:
method = "post"
endpoints.append(ApiEndpoint(route=route, method=method, name=name))
apis[api] = endpoints
return apis
def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
ret = {}
def providable_apis() -> List[Api]:
routing_table_apis = set(
x.routing_table_api for x in builtin_automatically_routed_apis()
)
for api in stack_apis():
if api in routing_table_apis:
continue
return [api for api in Api if api not in routing_table_apis and api != Api.inspect]
def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
ret = {}
for api in providable_apis():
name = api.name.lower()
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
ret[api] = {
"remote": remote_provider_spec(api),
**{a.provider_id: a for a in module.available_providers()},
**{a.provider_type: a for a in module.available_providers()},
}
return ret

View file

@ -0,0 +1,54 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict, List
from llama_stack.apis.inspect import * # noqa: F403
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.providers.datatypes import * # noqa: F403
def is_passthrough(spec: ProviderSpec) -> bool:
return isinstance(spec, RemoteProviderSpec) and spec.adapter is None
class DistributionInspectImpl(Inspect):
def __init__(self):
pass
async def list_providers(self) -> Dict[str, List[ProviderInfo]]:
ret = {}
all_providers = get_provider_registry()
for api, providers in all_providers.items():
ret[api.value] = [
ProviderInfo(
provider_type=p.provider_type,
description="Passthrough" if is_passthrough(p) else "",
)
for p in providers.values()
]
return ret
async def list_routes(self) -> Dict[str, List[RouteInfo]]:
ret = {}
all_endpoints = get_all_api_endpoints()
for api, endpoints in all_endpoints.items():
ret[api.value] = [
RouteInfo(
route=e.route,
method=e.method,
providers=[],
)
for e in endpoints
]
return ret
async def health(self) -> HealthInfo:
return HealthInfo(status="OK")

View file

@ -18,10 +18,10 @@ class NeedsRequestProviderData:
spec = self.__provider_spec__
assert spec, f"Provider spec not set on {self.__class__}"
provider_id = spec.provider_id
provider_type = spec.provider_type
validator_class = spec.provider_data_validator
if not validator_class:
raise ValueError(f"Provider {provider_id} does not have a validator")
raise ValueError(f"Provider {provider_type} does not have a validator")
val = getattr(_THREAD_LOCAL, "provider_data_header_value", None)
if not val:

View file

@ -3,15 +3,17 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import importlib
from typing import Any, Dict, List, Set
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import (
api_providers,
builtin_automatically_routed_apis,
get_provider_registry,
)
from llama_stack.distribution.utils.dynamic import instantiate_provider
from llama_stack.distribution.inspect import DistributionInspectImpl
from llama_stack.distribution.utils.dynamic import instantiate_class_type
async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]:
@ -20,7 +22,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
- flatmaps, sorts and resolves the providers in dependency order
- for each API, produces either a (local, passthrough or router) implementation
"""
all_providers = api_providers()
all_providers = get_provider_registry()
specs = {}
configs = {}
@ -34,11 +36,11 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
if isinstance(config, PlaceholderProviderConfig):
continue
if config.provider_id not in providers:
if config.provider_type not in providers:
raise ValueError(
f"Unknown provider `{config.provider_id}` is not available for API `{api}`"
f"Provider `{config.provider_type}` is not available for API `{api}`"
)
specs[api] = providers[config.provider_id]
specs[api] = providers[config.provider_type]
configs[api] = config
apis_to_serve = run_config.apis_to_serve or set(
@ -57,7 +59,6 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
if info.router_api.value not in apis_to_serve:
continue
print("router_api", info.router_api)
if info.router_api.value not in run_config.routing_table:
raise ValueError(f"Routing table for `{source_api.value}` is not provided?")
@ -68,12 +69,12 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
inner_specs = []
inner_deps = []
for rt_entry in routing_table:
if rt_entry.provider_id not in providers:
if rt_entry.provider_type not in providers:
raise ValueError(
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
f"Provider `{rt_entry.provider_type}` is not available for API `{api}`"
)
inner_specs.append(providers[rt_entry.provider_id])
inner_deps.extend(providers[rt_entry.provider_id].api_dependencies)
inner_specs.append(providers[rt_entry.provider_type])
inner_deps.extend(providers[rt_entry.provider_type].api_dependencies)
specs[source_api] = RoutingTableProviderSpec(
api=source_api,
@ -94,7 +95,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
sorted_specs = topological_sort(specs.values())
print(f"Resolved {len(sorted_specs)} providers in topological order")
for spec in sorted_specs:
print(f" {spec.api}: {spec.provider_id}")
print(f" {spec.api}: {spec.provider_type}")
print("")
impls = {}
for spec in sorted_specs:
@ -104,6 +105,14 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
impls[api] = impl
impls[Api.inspect] = DistributionInspectImpl()
specs[Api.inspect] = InlineProviderSpec(
api=Api.inspect,
provider_type="__distribution_builtin__",
config_class="",
module="",
)
return impls, specs
@ -127,3 +136,60 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
dfs(a, visited, stack)
return [by_id[x] for x in stack]
# returns a class implementing the protocol corresponding to the Api
async def instantiate_provider(
provider_spec: ProviderSpec,
deps: Dict[str, Any],
provider_config: Union[GenericProviderConfig, RoutingTable],
):
module = importlib.import_module(provider_spec.module)
args = []
if isinstance(provider_spec, RemoteProviderSpec):
if provider_spec.adapter:
method = "get_adapter_impl"
else:
method = "get_client_impl"
assert isinstance(provider_config, GenericProviderConfig)
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config)
args = [config, deps]
elif isinstance(provider_spec, AutoRoutedProviderSpec):
method = "get_auto_router_impl"
config = None
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps]
elif isinstance(provider_spec, RoutingTableProviderSpec):
method = "get_routing_table_impl"
assert isinstance(provider_config, List)
routing_table = provider_config
inner_specs = {x.provider_type: x for x in provider_spec.inner_specs}
inner_impls = []
for routing_entry in routing_table:
impl = await instantiate_provider(
inner_specs[routing_entry.provider_type],
deps,
routing_entry,
)
inner_impls.append((routing_entry.routing_key, impl))
config = None
args = [provider_spec.api, inner_impls, routing_table, deps]
else:
method = "get_provider_impl"
assert isinstance(provider_config, GenericProviderConfig)
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config)
args = [config, deps]
fn = getattr(module, method)
impl = await fn(*args)
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
return impl

View file

@ -94,12 +94,21 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> List[ShieldSpec]:
specs = []
for entry in self.routing_table_config:
specs.append(
ShieldSpec(
shield_type=entry.routing_key,
provider_config=entry,
if isinstance(entry.routing_key, list):
for k in entry.routing_key:
specs.append(
ShieldSpec(
shield_type=k,
provider_config=entry,
)
)
else:
specs.append(
ShieldSpec(
shield_type=entry.routing_key,
provider_config=entry,
)
)
)
return specs
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:

View file

@ -0,0 +1,67 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import inspect
from typing import Dict, List
from pydantic import BaseModel
from llama_stack.apis.agents import Agents
from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.models import Models
from llama_stack.apis.safety import Safety
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.providers.datatypes import Api
class ApiEndpoint(BaseModel):
route: str
method: str
name: str
def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
apis = {}
protocols = {
Api.inference: Inference,
Api.safety: Safety,
Api.agents: Agents,
Api.memory: Memory,
Api.telemetry: Telemetry,
Api.models: Models,
Api.shields: Shields,
Api.memory_banks: MemoryBanks,
Api.inspect: Inspect,
}
for api, protocol in protocols.items():
endpoints = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
for name, method in protocol_methods:
if not hasattr(method, "__webmethod__"):
continue
webmethod = method.__webmethod__
route = webmethod.route
if webmethod.method == "GET":
method = "get"
elif webmethod.method == "DELETE":
method = "delete"
else:
method = "post"
endpoints.append(ApiEndpoint(route=route, method=method, name=name))
apis[api] = endpoints
return apis

View file

@ -15,7 +15,6 @@ from collections.abc import (
AsyncIterator as AsyncIteratorABC,
)
from contextlib import asynccontextmanager
from http import HTTPStatus
from ssl import SSLError
from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional
@ -26,7 +25,6 @@ import yaml
from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
@ -39,10 +37,11 @@ from llama_stack.providers.utils.telemetry.tracing import (
)
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import api_endpoints
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls_with_routing
from .endpoints import get_all_api_endpoints
def is_async_iterator_type(typ):
if hasattr(typ, "__origin__"):
@ -286,26 +285,18 @@ def main(
app = FastAPI()
# Health check is added to enable deploying the docker container image on Kubernetes which require
# a health check that can return 200 for readiness and liveness check
class HealthCheck(BaseModel):
status: str = "OK"
@app.get("/healthcheck", status_code=HTTPStatus.OK, response_model=HealthCheck)
async def healthcheck():
return HealthCheck(status="OK")
impls, specs = asyncio.run(resolve_impls_with_routing(config))
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
all_endpoints = api_endpoints()
all_endpoints = get_all_api_endpoints()
if config.apis_to_serve:
apis_to_serve = set(config.apis_to_serve)
else:
apis_to_serve = set(impls.keys())
apis_to_serve.add(Api.inspect)
for api_str in apis_to_serve:
api = Api(api_str)
@ -339,14 +330,11 @@ def main(
)
)
for route in app.routes:
if isinstance(route, APIRoute):
cprint(
f"Serving {next(iter(route.methods))} {route.path}",
"white",
attrs=["bold"],
)
cprint(f"Serving API {api_str}", "white", attrs=["bold"])
for endpoint in endpoints:
cprint(f" {endpoint.method.upper()} {endpoint.route}", "white")
print("")
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler)
signal.signal(signal.SIGINT, handle_sigint)

View file

@ -18,7 +18,7 @@ api_providers:
providers:
- meta-reference
agents:
provider_id: meta-reference
provider_type: meta-reference
config:
persistence_store:
namespace: null
@ -28,22 +28,22 @@ api_providers:
providers:
- meta-reference
telemetry:
provider_id: meta-reference
provider_type: meta-reference
config: {}
routing_table:
inference:
- provider_id: remote::ollama
- provider_type: remote::ollama
config:
host: localhost
port: 6000
routing_key: Meta-Llama3.1-8B-Instruct
safety:
- provider_id: meta-reference
- provider_type: meta-reference
config:
llama_guard_shield: null
prompt_guard_shield: null
routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"]
memory:
- provider_id: meta-reference
- provider_type: meta-reference
config: {}
routing_key: vector

View file

@ -18,7 +18,7 @@ api_providers:
providers:
- meta-reference
agents:
provider_id: meta-reference
provider_type: meta-reference
config:
persistence_store:
namespace: null
@ -28,11 +28,11 @@ api_providers:
providers:
- meta-reference
telemetry:
provider_id: meta-reference
provider_type: meta-reference
config: {}
routing_table:
inference:
- provider_id: meta-reference
- provider_type: meta-reference
config:
model: Llama3.1-8B-Instruct
quantization: null
@ -41,12 +41,12 @@ routing_table:
max_batch_size: 1
routing_key: Llama3.1-8B-Instruct
safety:
- provider_id: meta-reference
- provider_type: meta-reference
config:
llama_guard_shield: null
prompt_guard_shield: null
routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"]
memory:
- provider_id: meta-reference
- provider_type: meta-reference
config: {}
routing_key: vector

View file

@ -5,69 +5,9 @@
# the root directory of this source tree.
import importlib
from typing import Any, Dict
from llama_stack.distribution.datatypes import * # noqa: F403
def instantiate_class_type(fully_qualified_name):
module_name, class_name = fully_qualified_name.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, class_name)
# returns a class implementing the protocol corresponding to the Api
async def instantiate_provider(
provider_spec: ProviderSpec,
deps: Dict[str, Any],
provider_config: Union[GenericProviderConfig, RoutingTable],
):
module = importlib.import_module(provider_spec.module)
args = []
if isinstance(provider_spec, RemoteProviderSpec):
if provider_spec.adapter:
method = "get_adapter_impl"
else:
method = "get_client_impl"
assert isinstance(provider_config, GenericProviderConfig)
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config)
args = [config, deps]
elif isinstance(provider_spec, AutoRoutedProviderSpec):
method = "get_auto_router_impl"
config = None
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps]
elif isinstance(provider_spec, RoutingTableProviderSpec):
method = "get_routing_table_impl"
assert isinstance(provider_config, List)
routing_table = provider_config
inner_specs = {x.provider_id: x for x in provider_spec.inner_specs}
inner_impls = []
for routing_entry in routing_table:
impl = await instantiate_provider(
inner_specs[routing_entry.provider_id],
deps,
routing_entry,
)
inner_impls.append((routing_entry.routing_key, impl))
config = None
args = [provider_spec.api, inner_impls, routing_table, deps]
else:
method = "get_provider_impl"
assert isinstance(provider_config, GenericProviderConfig)
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config)
args = [config, deps]
fn = getattr(module, method)
impl = await fn(*args)
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
return impl

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, Union
from typing import Any, List, Optional, Protocol
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@ -24,18 +24,14 @@ class Api(Enum):
shields = "shields"
memory_banks = "memory_banks"
@json_schema_type
class ApiEndpoint(BaseModel):
route: str
method: str
name: str
# built-in API
inspect = "inspect"
@json_schema_type
class ProviderSpec(BaseModel):
api: Api
provider_id: str
provider_type: str
config_class: str = Field(
...,
description="Fully-qualified classname of the config for this provider",
@ -62,71 +58,9 @@ class RoutableProvider(Protocol):
async def validate_routing_keys(self, keys: List[str]) -> None: ...
class GenericProviderConfig(BaseModel):
provider_id: str
config: Dict[str, Any]
class PlaceholderProviderConfig(BaseModel):
"""Placeholder provider config for API whose provider are defined in routing_table"""
providers: List[str]
RoutingKey = Union[str, List[str]]
class RoutableProviderConfig(GenericProviderConfig):
routing_key: RoutingKey
# Example: /inference, /safety
@json_schema_type
class AutoRoutedProviderSpec(ProviderSpec):
provider_id: str = "router"
config_class: str = ""
docker_image: Optional[str] = None
routing_table_api: Api
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
""",
)
provider_data_validator: Optional[str] = Field(
default=None,
)
@property
def pip_packages(self) -> List[str]:
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
# Example: /models, /shields
@json_schema_type
class RoutingTableProviderSpec(ProviderSpec):
provider_id: str = "routing_table"
config_class: str = ""
docker_image: Optional[str] = None
inner_specs: List[ProviderSpec]
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
""",
)
pip_packages: List[str] = Field(default_factory=list)
@json_schema_type
class AdapterSpec(BaseModel):
adapter_id: str = Field(
adapter_type: str = Field(
...,
description="Unique identifier for this adapter",
)
@ -186,10 +120,6 @@ class RemoteProviderConfig(BaseModel):
return f"http://{self.host}:{self.port}"
def remote_provider_id(adapter_id: str) -> str:
return f"remote::{adapter_id}"
@json_schema_type
class RemoteProviderSpec(ProviderSpec):
adapter: Optional[AdapterSpec] = Field(
@ -233,8 +163,8 @@ def remote_provider_spec(
if adapter and adapter.config_class
else "llama_stack.distribution.datatypes.RemoteProviderConfig"
)
provider_id = remote_provider_id(adapter.adapter_id) if adapter else "remote"
provider_type = f"remote::{adapter.adapter_type}" if adapter else "remote"
return RemoteProviderSpec(
api=api, provider_id=provider_id, config_class=config_class, adapter=adapter
api=api, provider_type=provider_type, config_class=config_class, adapter=adapter
)

View file

@ -50,20 +50,6 @@ class LlamaGuardShieldConfig(BaseModel):
class PromptGuardShieldConfig(BaseModel):
model: str = "Prompt-Guard-86M"
@validator("model")
@classmethod
def validate_model(cls, model: str) -> str:
permitted_models = [
m.descriptor()
for m in safety_models()
if m.core_model_id == CoreModelId.prompt_guard_86m
]
if model not in permitted_models:
raise ValueError(
f"Invalid model: {model}. Must be one of {permitted_models}"
)
return model
class SafetyConfig(BaseModel):
llama_guard_shield: Optional[LlamaGuardShieldConfig] = None

View file

@ -14,7 +14,7 @@ def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.agents,
provider_id="meta-reference",
provider_type="meta-reference",
pip_packages=[
"matplotlib",
"pillow",
@ -33,7 +33,7 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
api=Api.agents,
adapter=AdapterSpec(
adapter_id="sample",
adapter_type="sample",
pip_packages=[],
module="llama_stack.providers.adapters.agents.sample",
config_class="llama_stack.providers.adapters.agents.sample.SampleConfig",

View file

@ -13,7 +13,7 @@ def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.inference,
provider_id="meta-reference",
provider_type="meta-reference",
pip_packages=[
"accelerate",
"blobfile",
@ -30,7 +30,7 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_id="sample",
adapter_type="sample",
pip_packages=[],
module="llama_stack.providers.adapters.inference.sample",
config_class="llama_stack.providers.adapters.inference.sample.SampleConfig",
@ -39,7 +39,7 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_id="ollama",
adapter_type="ollama",
pip_packages=["ollama"],
module="llama_stack.providers.adapters.inference.ollama",
),
@ -47,7 +47,7 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_id="tgi",
adapter_type="tgi",
pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.adapters.inference.tgi",
config_class="llama_stack.providers.adapters.inference.tgi.TGIImplConfig",
@ -56,7 +56,7 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_id="hf::serverless",
adapter_type="hf::serverless",
pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.adapters.inference.tgi",
config_class="llama_stack.providers.adapters.inference.tgi.InferenceAPIImplConfig",
@ -65,7 +65,7 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_id="hf::endpoint",
adapter_type="hf::endpoint",
pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.adapters.inference.tgi",
config_class="llama_stack.providers.adapters.inference.tgi.InferenceEndpointImplConfig",
@ -74,7 +74,7 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_id="fireworks",
adapter_type="fireworks",
pip_packages=[
"fireworks-ai",
],
@ -85,7 +85,7 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_id="together",
adapter_type="together",
pip_packages=[
"together",
],
@ -97,10 +97,8 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_id="bedrock",
pip_packages=[
"boto3"
],
adapter_type="bedrock",
pip_packages=["boto3"],
module="llama_stack.providers.adapters.inference.bedrock",
config_class="llama_stack.providers.adapters.inference.bedrock.BedrockConfig",
),

View file

@ -34,7 +34,7 @@ def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.memory,
provider_id="meta-reference",
provider_type="meta-reference",
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
module="llama_stack.providers.impls.meta_reference.memory",
config_class="llama_stack.providers.impls.meta_reference.memory.FaissImplConfig",
@ -42,7 +42,7 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
Api.memory,
AdapterSpec(
adapter_id="chromadb",
adapter_type="chromadb",
pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
module="llama_stack.providers.adapters.memory.chroma",
),
@ -50,7 +50,7 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
Api.memory,
AdapterSpec(
adapter_id="pgvector",
adapter_type="pgvector",
pip_packages=EMBEDDING_DEPS + ["psycopg2-binary"],
module="llama_stack.providers.adapters.memory.pgvector",
config_class="llama_stack.providers.adapters.memory.pgvector.PGVectorConfig",
@ -59,7 +59,7 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
api=Api.memory,
adapter=AdapterSpec(
adapter_id="sample",
adapter_type="sample",
pip_packages=[],
module="llama_stack.providers.adapters.memory.sample",
config_class="llama_stack.providers.adapters.memory.sample.SampleConfig",

View file

@ -19,7 +19,7 @@ def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.safety,
provider_id="meta-reference",
provider_type="meta-reference",
pip_packages=[
"codeshield",
"transformers",
@ -34,7 +34,7 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
api=Api.safety,
adapter=AdapterSpec(
adapter_id="sample",
adapter_type="sample",
pip_packages=[],
module="llama_stack.providers.adapters.safety.sample",
config_class="llama_stack.providers.adapters.safety.sample.SampleConfig",
@ -43,7 +43,7 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
api=Api.safety,
adapter=AdapterSpec(
adapter_id="bedrock",
adapter_type="bedrock",
pip_packages=["boto3"],
module="llama_stack.providers.adapters.safety.bedrock",
config_class="llama_stack.providers.adapters.safety.bedrock.BedrockSafetyConfig",
@ -52,7 +52,7 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
api=Api.safety,
adapter=AdapterSpec(
adapter_id="together",
adapter_type="together",
pip_packages=[
"together",
],

View file

@ -13,7 +13,7 @@ def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.telemetry,
provider_id="meta-reference",
provider_type="meta-reference",
pip_packages=[],
module="llama_stack.providers.impls.meta_reference.telemetry",
config_class="llama_stack.providers.impls.meta_reference.telemetry.ConsoleConfig",
@ -21,7 +21,7 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
api=Api.telemetry,
adapter=AdapterSpec(
adapter_id="sample",
adapter_type="sample",
pip_packages=[],
module="llama_stack.providers.adapters.telemetry.sample",
config_class="llama_stack.providers.adapters.telemetry.sample.SampleConfig",
@ -30,7 +30,7 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
api=Api.telemetry,
adapter=AdapterSpec(
adapter_id="opentelemetry-jaeger",
adapter_type="opentelemetry-jaeger",
pip_packages=[
"opentelemetry-api",
"opentelemetry-sdk",

View file

@ -34,7 +34,8 @@ def augment_messages_for_tools(request: ChatCompletionRequest) -> List[Message]:
return request.messages
if model.model_family == ModelFamily.llama3_1 or (
model.model_family == ModelFamily.llama3_2 and is_multimodal(model.core_model_id)
model.model_family == ModelFamily.llama3_2
and is_multimodal(model.core_model_id)
):
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
return augment_messages_for_tools_llama_3_1(request)

View file

@ -2,7 +2,7 @@ blobfile
fire
httpx
huggingface-hub
llama-models>=0.0.37
llama-models>=0.0.38
prompt-toolkit
python-dotenv
pydantic>=2

View file

@ -16,7 +16,7 @@ def read_requirements():
setup(
name="llama_stack",
version="0.0.37",
version="0.0.38",
author="Meta Llama",
author_email="llama-oss@meta.com",
description="Llama Stack",

View file

@ -18,7 +18,7 @@ api_providers:
providers:
- meta-reference
agents:
provider_id: meta-reference
provider_type: meta-reference
config:
persistence_store:
namespace: null
@ -28,11 +28,11 @@ api_providers:
providers:
- meta-reference
telemetry:
provider_id: meta-reference
provider_type: meta-reference
config: {}
routing_table:
inference:
- provider_id: meta-reference
- provider_type: meta-reference
config:
model: Meta-Llama3.1-8B-Instruct
quantization: null
@ -41,7 +41,7 @@ routing_table:
max_batch_size: 1
routing_key: Meta-Llama3.1-8B-Instruct
safety:
- provider_id: meta-reference
- provider_type: meta-reference
config:
llama_guard_shield:
model: Llama-Guard-3-1B
@ -52,6 +52,6 @@ routing_table:
model: Prompt-Guard-86M
routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"]
memory:
- provider_id: meta-reference
- provider_type: meta-reference
config: {}
routing_key: vector