llama-stack-mirror/llama_stack/distribution/server/server.py
Charlie Doern d994305f0a
fix: remove disabled providers from model dump (#2784)
# What does this PR do?

currently when running `llama stack run --template starter...` the
__disabled__ providers, their models, etc are printed alongside the
enabled ones making the output really confusing

in server.py add a utility `remove_disabled_providers` which
post-processes the model_dump output to remove any dict with
`provider_id: __disabled__`

we also have `debug` logs printing the disabled providers, so I think
its safe to say that is the only indicator we need when using starter.

<!-- If resolving an issue, uncomment and update the line below -->
<!-- Closes #[issue-number] -->

## Test Plan

before (output truncated because it was huge):


```
...
           model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/sambanova/Llama-3.2-11B-Vision-Instruct
           model_type: llm
           provider_id: __disabled__
           provider_model_id: sambanova/Llama-3.2-11B-Vision-Instruct
         - metadata: {}
           model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/meta-llama/Llama-3.2-11B-Vision-Instruct
           model_type: llm
           provider_id: __disabled__
           provider_model_id: sambanova/Llama-3.2-11B-Vision-Instruct
         - metadata: {}
           model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/sambanova/Llama-3.2-90B-Vision-Instruct
           model_type: llm
           provider_id: __disabled__
           provider_model_id: sambanova/Llama-3.2-90B-Vision-Instruct
         - metadata: {}
           model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/meta-llama/Llama-3.2-90B-Vision-Instruct
           model_type: llm
           provider_id: __disabled__
           provider_model_id: sambanova/Llama-3.2-90B-Vision-Instruct
         - metadata: {}
           model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/sambanova/Llama-4-Scout-17B-16E-Instruct
           model_type: llm
           provider_id: __disabled__
           provider_model_id: sambanova/Llama-4-Scout-17B-16E-Instruct
         - metadata: {}
           model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/meta-llama/Llama-4-Scout-17B-16E-Instruct
           model_type: llm
           provider_id: __disabled__
           provider_model_id: sambanova/Llama-4-Scout-17B-16E-Instruct
         - metadata: {}
           model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/sambanova/Llama-4-Maverick-17B-128E-Instruct
           model_type: llm
           provider_id: __disabled__
           provider_model_id: sambanova/Llama-4-Maverick-17B-128E-Instruct
         - metadata: {}
           model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/meta-llama/Llama-4-Maverick-17B-128E-Instruct
           model_type: llm
           provider_id: __disabled__
           provider_model_id: sambanova/Llama-4-Maverick-17B-128E-Instruct
         - metadata: {}
           model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/sambanova/Meta-Llama-Guard-3-8B
           model_type: llm
           provider_id: __disabled__
           provider_model_id: sambanova/Meta-Llama-Guard-3-8B
         - metadata: {}
           model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/meta-llama/Llama-Guard-3-8B
           model_type: llm
           provider_id: __disabled__
           provider_model_id: sambanova/Meta-Llama-Guard-3-8B
         - metadata:
             embedding_dimension: 384
           model_id: all-MiniLM-L6-v2
           model_type: embedding
           provider_id: sentence-transformers
           provider_model_id: null
         providers:
           agents:
           - config:
               persistence_store:
                 db_path: /Users/charliedoern/.llama/distributions/starter/agents_store.db
                 type: sqlite
               responses_store:
                 db_path: /Users/charliedoern/.llama/distributions/starter/responses_store.db
                 type: sqlite
             provider_id: meta-reference
             provider_type: inline::meta-reference
           datasetio:
           - config:
               kvstore:
                 db_path: /Users/charliedoern/.llama/distributions/starter/huggingface_datasetio.db
                 type: sqlite
             provider_id: huggingface
             provider_type: remote::huggingface
           - config:
               kvstore:
                 db_path: /Users/charliedoern/.llama/distributions/starter/localfs_datasetio.db
                 type: sqlite
             provider_id: localfs
             provider_type: inline::localfs
           eval:
           - config:
               kvstore:
                 db_path: /Users/charliedoern/.llama/distributions/starter/meta_reference_eval.db
                 type: sqlite
             provider_id: meta-reference
             provider_type: inline::meta-reference
           files:
           - config:
               metadata_store:
                 db_path: /Users/charliedoern/.llama/distributions/starter/files_metadata.db
                 type: sqlite
               storage_dir: /Users/charliedoern/.llama/distributions/starter/files
             provider_id: meta-reference-files
             provider_type: inline::localfs
           inference:
           - config:
               api_key: '********'
               base_url: https://api.cerebras.ai
             provider_id: __disabled__
             provider_type: remote::cerebras
           - config:
               url: http://localhost:11434
             provider_id: ollama
             provider_type: remote::ollama
           - config:
               api_token: '********'
               max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
               tls_verify: ${env.VLLM_TLS_VERIFY:=true}
               url: ${env.VLLM_URL}
             provider_id: __disabled__
             provider_type: remote::vllm
           - config:
               url: ${env.TGI_URL}
             provider_id: __disabled__
             provider_type: remote::tgi
           - config:
               api_token: '********'
               huggingface_repo: ${env.INFERENCE_MODEL}
             provider_id: __disabled__
             provider_type: remote::hf::serverless
           - config:
               api_token: '********'
               endpoint_name: ${env.INFERENCE_ENDPOINT_NAME}
             provider_id: __disabled__
             provider_type: remote::hf::endpoint
           - config:
               api_key: '********'
               url: https://api.fireworks.ai/inference/v1
             provider_id: __disabled__
             provider_type: remote::fireworks
           - config:
               api_key: '********'
               url: https://api.together.xyz/v1
             provider_id: __disabled__
             provider_type: remote::together
           - config: {}
             provider_id: __disabled__
             provider_type: remote::bedrock
           - config:
               api_token: '********'
               url: ${env.DATABRICKS_URL}
             provider_id: __disabled__
             provider_type: remote::databricks
           - config:
               api_key: '********'
               append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
               url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
             provider_id: __disabled__
             provider_type: remote::nvidia
           - config:
               api_token: '********'
               url: ${env.RUNPOD_URL:=}
             provider_id: __disabled__
             provider_type: remote::runpod
           - config:
               api_key: '********'
             provider_id: __disabled__
             provider_type: remote::openai
           - config:
               api_key: '********'
             provider_id: __disabled__
             provider_type: remote::anthropic
           - config:
               api_key: '********'
             provider_id: __disabled__
             provider_type: remote::gemini
           - config:
               api_key: '********'
               url: https://api.groq.com
             provider_id: __disabled__
             provider_type: remote::groq
           - config:
               api_key: '********'
               openai_compat_api_base: https://api.fireworks.ai/inference/v1
             provider_id: __disabled__
             provider_type: remote::fireworks-openai-compat
           - config:
               api_key: '********'
               openai_compat_api_base: https://api.llama.com/compat/v1/
             provider_id: __disabled__
             provider_type: remote::llama-openai-compat
           - config:
               api_key: '********'
               openai_compat_api_base: https://api.together.xyz/v1
             provider_id: __disabled__
             provider_type: remote::together-openai-compat
           - config:
               api_key: '********'
               openai_compat_api_base: https://api.groq.com/openai/v1
             provider_id: __disabled__
             provider_type: remote::groq-openai-compat
           - config:
               api_key: '********'
               openai_compat_api_base: https://api.sambanova.ai/v1
             provider_id: __disabled__
             provider_type: remote::sambanova-openai-compat
           - config:
               api_key: '********'
               openai_compat_api_base: https://api.cerebras.ai/v1
             provider_id: __disabled__
             provider_type: remote::cerebras-openai-compat
           - config:
               api_key: '********'
               url: https://api.sambanova.ai/v1
             provider_id: __disabled__
             provider_type: remote::sambanova
           - config:
               api_key: '********'
               url: ${env.PASSTHROUGH_URL}
             provider_id: __disabled__
             provider_type: remote::passthrough
           - config: {}
             provider_id: sentence-transformers
             provider_type: inline::sentence-transformers
           post_training:
           - config:
               checkpoint_format: huggingface
               device: cpu
               distributed_backend: null
             provider_id: huggingface
             provider_type: inline::huggingface
           safety:
           - config:
               excluded_categories: []
             provider_id: llama-guard
             provider_type: inline::llama-guard
           scoring:
           - config: {}
             provider_id: basic
             provider_type: inline::basic
           - config: {}
             provider_id: llm-as-judge
             provider_type: inline::llm-as-judge
           - config:
               openai_api_key: '********'
             provider_id: braintrust
             provider_type: inline::braintrust
           telemetry:
           - config:
               otel_exporter_otlp_endpoint: null
               service_name: "\u200B"
               sinks: console,sqlite
               sqlite_db_path: /Users/charliedoern/.llama/distributions/starter/trace_store.db
             provider_id: meta-reference
             provider_type: inline::meta-reference
           tool_runtime:
           - config:
               api_key: '********'
               max_results: 3
             provider_id: brave-search
             provider_type: remote::brave-search
           - config:
               api_key: '********'
               max_results: 3
             provider_id: tavily-search
             provider_type: remote::tavily-search
           - config: {}
             provider_id: rag-runtime
             provider_type: inline::rag-runtime
           - config: {}
             provider_id: model-context-protocol
             provider_type: remote::model-context-protocol
           vector_io:
           - config:
               kvstore:
                 db_path: /Users/charliedoern/.llama/distributions/starter/faiss_store.db
                 type: sqlite
             provider_id: faiss
             provider_type: inline::faiss
           - config:
               db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db
               kvstore:
                 db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db
                 type: sqlite
             provider_id: __disabled__
             provider_type: inline::sqlite-vec
           - config:
               db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter}/milvus.db
               kvstore:
                 db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/milvus_registry.db
                 type: sqlite
             provider_id: __disabled__
             provider_type: inline::milvus
           - config:
               url: ${env.CHROMADB_URL:=}
             provider_id: __disabled__
             provider_type: remote::chromadb
           - config:
               db: ${env.PGVECTOR_DB:=}
               host: ${env.PGVECTOR_HOST:=localhost}
               kvstore:
                 db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db
                 type: sqlite
               password: '********'
               port: ${env.PGVECTOR_PORT:=5432}
               user: ${env.PGVECTOR_USER:=}
             provider_id: __disabled__
             provider_type: remote::pgvector
         scoring_fns: []
         server:
           auth: null
           host: null
           port: 8321
           quota: null
           tls_cafile: null
           tls_certfile: null
           tls_keyfile: null
         shields:
         - params: null
           provider_id: null
           provider_shield_id: ollama/__disabled__
           shield_id: __disabled__
         tool_groups:
         - args: null
           mcp_endpoint: null
           provider_id: tavily-search
           toolgroup_id: builtin::websearch
         - args: null
           mcp_endpoint: null
           provider_id: rag-runtime
           toolgroup_id: builtin::rag
         vector_dbs: []
         version: 2

```

after:

```
INFO     2025-07-16 13:00:32,604 __main__:448 server: Run configuration:
INFO     2025-07-16 13:00:32,606 __main__:450 server: apis:
         - agents
         - datasetio
         - eval
         - files
         - inference
         - post_training
         - safety
         - scoring
         - telemetry
         - tool_runtime
         - vector_io
         benchmarks: []
         datasets: []
         image_name: starter
         inference_store:
           db_path: /Users/charliedoern/.llama/distributions/starter/inference_store.db
           type: sqlite
         metadata_store:
           db_path: /Users/charliedoern/.llama/distributions/starter/registry.db
           type: sqlite
         models:
         - metadata: {}
           model_id: ollama/llama3.2:3b
           model_type: llm
           provider_id: ollama
           provider_model_id: llama3.2:3b
         - metadata:
             embedding_dimension: 384
           model_id: all-MiniLM-L6-v2
           model_type: embedding
           provider_id: sentence-transformers
         providers:
           agents:
           - config:
               persistence_store:
                 db_path: /Users/charliedoern/.llama/distributions/starter/agents_store.db
                 type: sqlite
               responses_store:
                 db_path: /Users/charliedoern/.llama/distributions/starter/responses_store.db
                 type: sqlite
             provider_id: meta-reference
             provider_type: inline::meta-reference
           datasetio:
           - config:
               kvstore:
                 db_path: /Users/charliedoern/.llama/distributions/starter/huggingface_datasetio.db
                 type: sqlite
             provider_id: huggingface
             provider_type: remote::huggingface
           - config:
               kvstore:
                 db_path: /Users/charliedoern/.llama/distributions/starter/localfs_datasetio.db
                 type: sqlite
             provider_id: localfs
             provider_type: inline::localfs
           eval:
           - config:
               kvstore:
                 db_path: /Users/charliedoern/.llama/distributions/starter/meta_reference_eval.db
                 type: sqlite
             provider_id: meta-reference
             provider_type: inline::meta-reference
           files:
           - config:
               metadata_store:
                 db_path: /Users/charliedoern/.llama/distributions/starter/files_metadata.db
                 type: sqlite
               storage_dir: /Users/charliedoern/.llama/distributions/starter/files
             provider_id: meta-reference-files
             provider_type: inline::localfs
           inference:
           - config:
               url: http://localhost:11434
             provider_id: ollama
             provider_type: remote::ollama
           - config: {}
             provider_id: sentence-transformers
             provider_type: inline::sentence-transformers
           post_training:
           - config:
               checkpoint_format: huggingface
               device: cpu
             provider_id: huggingface
             provider_type: inline::huggingface
           safety:
           - config:
               excluded_categories: []
             provider_id: llama-guard
             provider_type: inline::llama-guard
           scoring:
           - config: {}
             provider_id: basic
             provider_type: inline::basic
           - config: {}
             provider_id: llm-as-judge
             provider_type: inline::llm-as-judge
           - config:
               openai_api_key: '********'
             provider_id: braintrust
             provider_type: inline::braintrust
           telemetry:
           - config:
               service_name: "\u200B"
               sinks: console,sqlite
               sqlite_db_path: /Users/charliedoern/.llama/distributions/starter/trace_store.db
             provider_id: meta-reference
             provider_type: inline::meta-reference
           tool_runtime:
           - config:
               api_key: '********'
               max_results: 3
             provider_id: brave-search
             provider_type: remote::brave-search
           - config:
               api_key: '********'
               max_results: 3
             provider_id: tavily-search
             provider_type: remote::tavily-search
           - config: {}
             provider_id: rag-runtime
             provider_type: inline::rag-runtime
           - config: {}
             provider_id: model-context-protocol
             provider_type: remote::model-context-protocol
           vector_io:
           - config:
               kvstore:
                 db_path: /Users/charliedoern/.llama/distributions/starter/faiss_store.db
                 type: sqlite
             provider_id: faiss
             provider_type: inline::faiss
         scoring_fns: []
         server:
           port: 8321
         shields: []
         tool_groups:
         - provider_id: tavily-search
           toolgroup_id: builtin::websearch
         - provider_id: rag-runtime
           toolgroup_id: builtin::rag
         vector_dbs: []
         version: 2
```

Signed-off-by: Charlie Doern <cdoern@redhat.com>
2025-07-18 10:44:35 -07:00

635 lines
24 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import asyncio
import functools
import inspect
import json
import logging
import os
import ssl
import sys
import traceback
import warnings
from collections.abc import Callable
from contextlib import asynccontextmanager
from importlib.metadata import version as parse_version
from pathlib import Path
from typing import Annotated, Any, get_origin
import rich.pretty
import yaml
from aiohttp import hdrs
from fastapi import Body, FastAPI, HTTPException, Request
from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from openai import BadRequestError
from pydantic import BaseModel, ValidationError
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.distribution.access_control.access_control import AccessDeniedError
from llama_stack.distribution.datatypes import (
AuthenticationRequiredError,
LoggingConfig,
StackRunConfig,
)
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.server.routes import (
find_matching_route,
get_all_api_routes,
initialize_route_impls,
)
from llama_stack.distribution.stack import (
cast_image_name_to_string,
construct_stack,
replace_env_vars,
validate_env_pair,
)
from llama_stack.distribution.utils.config import redact_sensitive_fields
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
TelemetryAdapter,
)
from llama_stack.providers.utils.telemetry.tracing import (
CURRENT_TRACE_CONTEXT,
end_trace,
setup_logger,
start_trace,
)
from .auth import AuthenticationMiddleware
from .quota import QuotaMiddleware
REPO_ROOT = Path(__file__).parent.parent.parent.parent
logger = get_logger(name=__name__, category="server")
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
log = file if hasattr(file, "write") else sys.stderr
traceback.print_stack(file=log)
log.write(warnings.formatwarning(message, category, filename, lineno, line))
if os.environ.get("LLAMA_STACK_TRACE_WARNINGS"):
warnings.showwarning = warn_with_traceback
def create_sse_event(data: Any) -> str:
if isinstance(data, BaseModel):
data = data.model_dump_json()
else:
data = json.dumps(data)
return f"data: {data}\n\n"
async def global_exception_handler(request: Request, exc: Exception):
traceback.print_exception(exc)
http_exc = translate_exception(exc)
return JSONResponse(status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}})
def translate_exception(exc: Exception) -> HTTPException | RequestValidationError:
if isinstance(exc, ValidationError):
exc = RequestValidationError(exc.errors())
if isinstance(exc, RequestValidationError):
return HTTPException(
status_code=400,
detail={
"errors": [
{
"loc": list(error["loc"]),
"msg": error["msg"],
"type": error["type"],
}
for error in exc.errors()
]
},
)
elif isinstance(exc, ValueError):
return HTTPException(status_code=400, detail=f"Invalid value: {str(exc)}")
elif isinstance(exc, BadRequestError):
return HTTPException(status_code=400, detail=str(exc))
elif isinstance(exc, PermissionError | AccessDeniedError):
return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}")
elif isinstance(exc, asyncio.TimeoutError | TimeoutError):
return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}")
elif isinstance(exc, NotImplementedError):
return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}")
elif isinstance(exc, AuthenticationRequiredError):
return HTTPException(status_code=401, detail=f"Authentication required: {str(exc)}")
else:
return HTTPException(
status_code=500,
detail="Internal server error: An unexpected error occurred.",
)
async def shutdown(app):
"""Initiate a graceful shutdown of the application.
Handled by the lifespan context manager. The shutdown process involves
shutting down all implementations registered in the application.
"""
for impl in app.__llama_stack_impls__.values():
impl_name = impl.__class__.__name__
logger.info("Shutting down %s", impl_name)
try:
if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logger.warning("No shutdown method for %s", impl_name)
except TimeoutError:
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
except (Exception, asyncio.CancelledError) as e:
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Starting up")
yield
logger.info("Shutting down")
await shutdown(app)
def is_streaming_request(func_name: str, request: Request, **kwargs):
# TODO: pass the api method and punt it to the Protocol definition directly
return kwargs.get("stream", False)
async def maybe_await(value):
if inspect.iscoroutine(value):
return await value
return value
async def sse_generator(event_gen_coroutine):
event_gen = None
try:
event_gen = await event_gen_coroutine
async for item in event_gen:
yield create_sse_event(item)
await asyncio.sleep(0.01)
except asyncio.CancelledError:
logger.info("Generator cancelled")
if event_gen:
await event_gen.aclose()
except Exception as e:
logger.exception("Error in sse_generator")
yield create_sse_event(
{
"error": {
"message": str(translate_exception(e)),
},
}
)
async def log_request_pre_validation(request: Request):
if request.method in ("POST", "PUT", "PATCH"):
try:
body_bytes = await request.body()
if body_bytes:
try:
parsed_body = json.loads(body_bytes.decode())
log_output = rich.pretty.pretty_repr(parsed_body)
except (json.JSONDecodeError, UnicodeDecodeError):
log_output = repr(body_bytes)
logger.debug(f"Incoming raw request body for {request.method} {request.url.path}:\n{log_output}")
else:
logger.debug(f"Incoming {request.method} {request.url.path} request with empty body.")
except Exception as e:
logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
@functools.wraps(func)
async def route_handler(request: Request, **kwargs):
# Get auth attributes from the request scope
user_attributes = request.scope.get("user_attributes", {})
principal = request.scope.get("principal", "")
user = User(principal=principal, attributes=user_attributes)
await log_request_pre_validation(request)
# Use context manager with both provider data and auth attributes
with request_provider_data_context(request.headers, user):
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try:
if is_streaming:
gen = preserve_contexts_async_generator(
sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
)
return StreamingResponse(gen, media_type="text/event-stream")
else:
value = func(**kwargs)
result = await maybe_await(value)
if isinstance(result, PaginatedResponse) and result.url is None:
result.url = route
return result
except Exception as e:
if logger.isEnabledFor(logging.DEBUG):
logger.exception(f"Error executing endpoint {route=} {method=}")
else:
logger.error(f"Error executing endpoint {route=} {method=}: {str(e)}")
raise translate_exception(e) from e
sig = inspect.signature(func)
new_params = [inspect.Parameter("request", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request)]
new_params.extend(sig.parameters.values())
path_params = extract_path_params(route)
if method == "post":
# Annotate parameters that are in the path with Path(...) and others with Body(...),
# but preserve existing File() and Form() annotations for multipart form data
new_params = (
[new_params[0]]
+ [
(
param.replace(annotation=Annotated[param.annotation, FastapiPath(..., title=param.name)])
if param.name in path_params
else (
param # Keep original annotation if it's already an Annotated type
if get_origin(param.annotation) is Annotated
else param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
)
)
for param in new_params[1:]
]
)
route_handler.__signature__ = sig.replace(parameters=new_params)
return route_handler
class TracingMiddleware:
def __init__(self, app, impls):
self.app = app
self.impls = impls
# FastAPI built-in paths that should bypass custom routing
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
async def __call__(self, scope, receive, send):
if scope.get("type") == "lifespan":
return await self.app(scope, receive, send)
path = scope.get("path", "")
# Check if the path is a FastAPI built-in path
if path.startswith(self.fastapi_paths):
# Pass through to FastAPI's built-in handlers
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
return await self.app(scope, receive, send)
if not hasattr(self, "route_impls"):
self.route_impls = initialize_route_impls(self.impls)
try:
_, _, trace_path = find_matching_route(scope.get("method", hdrs.METH_GET), path, self.route_impls)
except ValueError:
# If no matching endpoint is found, pass through to FastAPI
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
return await self.app(scope, receive, send)
trace_attributes = {"__location__": "server", "raw_path": path}
# Extract W3C trace context headers and store as trace attributes
headers = dict(scope.get("headers", []))
traceparent = headers.get(b"traceparent", b"").decode()
if traceparent:
trace_attributes["traceparent"] = traceparent
tracestate = headers.get(b"tracestate", b"").decode()
if tracestate:
trace_attributes["tracestate"] = tracestate
trace_context = await start_trace(trace_path, trace_attributes)
async def send_with_trace_id(message):
if message["type"] == "http.response.start":
headers = message.get("headers", [])
headers.append([b"x-trace-id", str(trace_context.trace_id).encode()])
message["headers"] = headers
await send(message)
try:
return await self.app(scope, receive, send_with_trace_id)
finally:
await end_trace()
class ClientVersionMiddleware:
def __init__(self, app):
self.app = app
self.server_version = parse_version("llama-stack")
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
headers = dict(scope.get("headers", []))
client_version = headers.get(b"x-llamastack-client-version", b"").decode()
if client_version:
try:
client_version_parts = tuple(map(int, client_version.split(".")[:2]))
server_version_parts = tuple(map(int, self.server_version.split(".")[:2]))
if client_version_parts != server_version_parts:
async def send_version_error(send):
await send(
{
"type": "http.response.start",
"status": 426,
"headers": [[b"content-type", b"application/json"]],
}
)
error_msg = json.dumps(
{
"error": {
"message": f"Client version {client_version} is not compatible with server version {self.server_version}. Please update your client."
}
}
).encode()
await send({"type": "http.response.body", "body": error_msg})
return await send_version_error(send)
except (ValueError, IndexError):
# If version parsing fails, let the request through
pass
return await self.app(scope, receive, send)
def main(args: argparse.Namespace | None = None):
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
parser.add_argument(
"--yaml-config",
dest="config",
help="(Deprecated) Path to YAML configuration file - use --config instead",
)
parser.add_argument(
"--config",
dest="config",
help="Path to YAML configuration file",
)
parser.add_argument(
"--template",
help="One of the template names in llama_stack/templates (e.g., tgi, fireworks, remote-vllm, etc.)",
)
parser.add_argument(
"--port",
type=int,
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
help="Port to listen on",
)
parser.add_argument(
"--env",
action="append",
help="Environment variables in KEY=value format. Can be specified multiple times.",
)
# Determine whether the server args are being passed by the "run" command, if this is the case
# the args will be passed as a Namespace object to the main function, otherwise they will be
# parsed from the command line
if args is None:
args = parser.parse_args()
log_line = ""
if hasattr(args, "config") and args.config:
# if the user provided a config file, use it, even if template was specified
config_file = Path(args.config)
if not config_file.exists():
raise ValueError(f"Config file {config_file} does not exist")
log_line = f"Using config file: {config_file}"
elif hasattr(args, "template") and args.template:
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
if not config_file.exists():
raise ValueError(f"Template {args.template} does not exist")
log_line = f"Using template {args.template} config file: {config_file}"
else:
raise ValueError("Either --config or --template must be provided")
logger_config = None
with open(config_file) as fp:
config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg)
logger = get_logger(name=__name__, category="server", config=logger_config)
if args.env:
for env_pair in args.env:
try:
key, value = validate_env_pair(env_pair)
logger.info(f"Setting CLI environment variable {key} => {value}")
os.environ[key] = value
except ValueError as e:
logger.error(f"Error: {str(e)}")
sys.exit(1)
config = replace_env_vars(config_contents)
config = StackRunConfig(**cast_image_name_to_string(config))
# now that the logger is initialized, print the line about which type of config we are using.
logger.info(log_line)
_log_run_config(run_config=config)
app = FastAPI(
lifespan=lifespan,
docs_url="/docs",
redoc_url="/redoc",
openapi_url="/openapi.json",
)
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
app.add_middleware(ClientVersionMiddleware)
# Add authentication middleware if configured
if config.server.auth:
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
else:
if config.server.quota:
quota = config.server.quota
logger.warning(
"Configured authenticated_max_requests (%d) but no auth is enabled; "
"falling back to anonymous_max_requests (%d) for all the requests",
quota.authenticated_max_requests,
quota.anonymous_max_requests,
)
if config.server.quota:
logger.info("Enabling quota middleware for authenticated and anonymous clients")
quota = config.server.quota
anonymous_max_requests = quota.anonymous_max_requests
# if auth is disabled, use the anonymous max requests
authenticated_max_requests = quota.authenticated_max_requests if config.server.auth else anonymous_max_requests
kv_config = quota.kvstore
window_map = {"day": 86400}
window_seconds = window_map[quota.period.value]
app.add_middleware(
QuotaMiddleware,
kv_config=kv_config,
anonymous_max_requests=anonymous_max_requests,
authenticated_max_requests=authenticated_max_requests,
window_seconds=window_seconds,
)
try:
# Create and set the event loop that will be used for both construction and server runtime
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Construct the stack in the persistent event loop
impls = loop.run_until_complete(construct_stack(config))
except InvalidProviderError as e:
logger.error(f"Error: {str(e)}")
sys.exit(1)
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
else:
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
all_routes = get_all_api_routes()
if config.apis:
apis_to_serve = set(config.apis)
else:
apis_to_serve = set(impls.keys())
for inf in builtin_automatically_routed_apis():
# if we do not serve the corresponding router API, we should not serve the routing table API
if inf.router_api.value not in apis_to_serve:
continue
apis_to_serve.add(inf.routing_table_api.value)
apis_to_serve.add("inspect")
apis_to_serve.add("providers")
for api_str in apis_to_serve:
api = Api(api_str)
routes = all_routes[api]
impl = impls[api]
for route in routes:
if not hasattr(impl, route.name):
# ideally this should be a typing violation already
raise ValueError(f"Could not find method {route.name} on {impl}!")
impl_method = getattr(impl, route.name)
# Filter out HEAD method since it's automatically handled by FastAPI for GET routes
available_methods = [m for m in route.methods if m != "HEAD"]
if not available_methods:
raise ValueError(f"No methods found for {route.name} on {impl}")
method = available_methods[0]
logger.debug(f"{method} {route.path}")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
getattr(app, method.lower())(route.path, response_model=None)(
create_dynamic_typed_route(
impl_method,
method.lower(),
route.path,
)
)
logger.debug(f"serving APIs: {apis_to_serve}")
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler)
app.__llama_stack_impls__ = impls
app.add_middleware(TracingMiddleware, impls=impls)
import uvicorn
# Configure SSL if certificates are provided
port = args.port or config.server.port
ssl_config = None
keyfile = config.server.tls_keyfile
certfile = config.server.tls_certfile
if keyfile and certfile:
ssl_config = {
"ssl_keyfile": keyfile,
"ssl_certfile": certfile,
}
if config.server.tls_cafile:
ssl_config["ssl_ca_certs"] = config.server.tls_cafile
ssl_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED
logger.info(
f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}\n CA: {config.server.tls_cafile}"
)
else:
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
listen_host = config.server.host or ["::", "0.0.0.0"]
logger.info(f"Listening on {listen_host}:{port}")
uvicorn_config = {
"app": app,
"host": listen_host,
"port": port,
"lifespan": "on",
"log_level": logger.getEffectiveLevel(),
}
if ssl_config:
uvicorn_config.update(ssl_config)
# Run uvicorn in the existing event loop to preserve background tasks
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
def _log_run_config(run_config: StackRunConfig):
"""Logs the run config with redacted fields and disabled providers removed."""
logger.info("Run configuration:")
safe_config = redact_sensitive_fields(run_config.model_dump(mode="json"))
clean_config = remove_disabled_providers(safe_config)
logger.info(yaml.dump(clean_config, indent=2))
def extract_path_params(route: str) -> list[str]:
segments = route.split("/")
params = [seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")]
# to handle path params like {param:path}
params = [param.split(":")[0] for param in params]
return params
def remove_disabled_providers(obj):
if isinstance(obj, dict):
if (
obj.get("provider_id") == "__disabled__"
or obj.get("shield_id") == "__disabled__"
or obj.get("provider_model_id") == "__disabled__"
):
return None
return {k: v for k, v in ((k, remove_disabled_providers(v)) for k, v in obj.items()) if v is not None}
elif isinstance(obj, list):
return [item for item in (remove_disabled_providers(i) for i in obj) if item is not None]
else:
return obj
if __name__ == "__main__":
main()