llama-stack/llama_stack/distribution/server/server.py
Sébastien Han 418645696a
fix: improve signal handling and update dependencies (#1044)
# What does this PR do?
This commit enhances the signal handling mechanism in the server by
improving the `handle_signal` (previously handle_sigint) function. It
now properly retrieves the signal name, ensuring clearer logging when a
termination signal is received. Additionally, it cancels all running
tasks and waits for their completion before stopping the event loop,
allowing for a more graceful shutdown. Support for handling
SIGTERM has also been added alongside SIGINT.

Before the changes, handle_sigint used asyncio.run(run_shutdown()).
However, asyncio.run() is meant to start a new event loop, and calling
it inside an existing one (like when running Uvicorn) raises an error.
The fix replaces asyncio.run(run_shutdown()) with an async function
scheduled on the existing loop using loop.create_task(shutdown()). This
ensures that the shutdown coroutine runs within the current event loop
instead of trying to create a new one.

Furthermore, this commit updates the project dependencies. `fastapi` and
`uvicorn` have been added to the development dependencies in
`pyproject.toml` and `uv.lock`, ensuring that the necessary packages are
available for development and execution.

Closes: https://github.com/meta-llama/llama-stack/issues/1043
Signed-off-by: Sébastien Han <seb@redhat.com>

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan

Run a server and send SIGINT:

```
INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" python -m llama_stack.distribution.server.server --yaml-config ./llama_stack/templates/ollama/run.yaml
Using config file: llama_stack/templates/ollama/run.yaml
Run configuration:
apis:
- agents
- datasetio
- eval
- inference
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
container_image: null
datasets: []
eval_tasks: []
image_name: ollama
metadata_store:
  db_path: /Users/leseb/.llama/distributions/ollama/registry.db
  namespace: null
  type: sqlite
models:
- metadata: {}
  model_id: meta-llama/Llama-3.2-3B-Instruct
  model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType
  - llm
  provider_id: ollama
  provider_model_id: null
- metadata:
    embedding_dimension: 384
  model_id: all-MiniLM-L6-v2
  model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType
  - embedding
  provider_id: sentence-transformers
  provider_model_id: null
providers:
  agents:
  - config:
      persistence_store:
        db_path: /Users/leseb/.llama/distributions/ollama/agents_store.db
        namespace: null
        type: sqlite
    provider_id: meta-reference
    provider_type: inline::meta-reference
  datasetio:
  - config: {}
    provider_id: huggingface
    provider_type: remote::huggingface
  - config: {}
    provider_id: localfs
    provider_type: inline::localfs
  eval:
  - config: {}
    provider_id: meta-reference
    provider_type: inline::meta-reference
  inference:
  - config:
      url: http://localhost:11434
    provider_id: ollama
    provider_type: remote::ollama
  - config: {}
    provider_id: sentence-transformers
    provider_type: inline::sentence-transformers
  safety:
  - config: {}
    provider_id: llama-guard
    provider_type: inline::llama-guard
  scoring:
  - config: {}
    provider_id: basic
    provider_type: inline::basic
  - config: {}
    provider_id: llm-as-judge
    provider_type: inline::llm-as-judge
  - config:
      openai_api_key: '********'
    provider_id: braintrust
    provider_type: inline::braintrust
  telemetry:
  - config:
      service_name: llama-stack
      sinks: console,sqlite
      sqlite_db_path: /Users/leseb/.llama/distributions/ollama/trace_store.db
    provider_id: meta-reference
    provider_type: inline::meta-reference
  tool_runtime:
  - config:
      api_key: '********'
      max_results: 3
    provider_id: brave-search
    provider_type: remote::brave-search
  - config:
      api_key: '********'
      max_results: 3
    provider_id: tavily-search
    provider_type: remote::tavily-search
  - config: {}
    provider_id: code-interpreter
    provider_type: inline::code-interpreter
  - config: {}
    provider_id: rag-runtime
    provider_type: inline::rag-runtime
  vector_io:
  - config:
      kvstore:
        db_path: /Users/leseb/.llama/distributions/ollama/faiss_store.db
        namespace: null
        type: sqlite
    provider_id: faiss
    provider_type: inline::faiss
scoring_fns: []
server:
  port: 8321
  tls_certfile: null
  tls_keyfile: null
shields: []
tool_groups:
- args: null
  mcp_endpoint: null
  provider_id: tavily-search
  toolgroup_id: builtin::websearch
- args: null
  mcp_endpoint: null
  provider_id: rag-runtime
  toolgroup_id: builtin::rag
- args: null
  mcp_endpoint: null
  provider_id: code-interpreter
  toolgroup_id: builtin::code_interpreter
vector_dbs: []
version: '2'

INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:213: Resolved 31 providers
INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215:  inner-inference => ollama
INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215:  inner-inference => sentence-transformers
INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215:  models => __routing_table__
INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215:  inference => __autorouted__
INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215:  inner-vector_io => faiss
INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215:  inner-safety => llama-guard
INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215:  shields => __routing_table__
INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215:  safety => __autorouted__
INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215:  vector_dbs => __routing_table__
INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215:  vector_io => __autorouted__
INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215:  inner-tool_runtime => brave-search
INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215:  inner-tool_runtime => tavily-search
INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215:  inner-tool_runtime => code-interpreter
INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215:  inner-tool_runtime => rag-runtime
INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215:  tool_groups => __routing_table__
INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215:  tool_runtime => __autorouted__
INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215:  agents => meta-reference
INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215:  inner-datasetio => huggingface
INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215:  inner-datasetio => localfs
INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215:  datasets => __routing_table__
INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215:  datasetio => __autorouted__
INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215:  telemetry => meta-reference
INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215:  inner-scoring => basic
INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215:  inner-scoring => llm-as-judge
INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215:  inner-scoring => braintrust
INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215:  scoring_functions => __routing_table__
INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215:  scoring => __autorouted__
INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215:  inner-eval => meta-reference
INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215:  eval_tasks => __routing_table__
INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215:  eval => __autorouted__
INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215:  inspect => __builtin__
INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:216: 
INFO 2025-02-12 10:21:03,723 llama_stack.providers.remote.inference.ollama.ollama:148: checking connectivity to Ollama at `http://localhost:11434`...
INFO 2025-02-12 10:21:03,734 httpx:1740: HTTP Request: GET http://localhost:11434/api/ps "HTTP/1.1 200 OK"
INFO 2025-02-12 10:21:03,843 faiss.loader:148: Loading faiss.
INFO 2025-02-12 10:21:03,865 faiss.loader:150: Successfully loaded faiss.
INFO 2025-02-12 10:21:03,868 faiss:173: Failed to load GPU Faiss: name 'GpuIndexIVFFlat' is not defined. Will not load constructor refs for GPU indexes.
Warning: `bwrap` is not available. Code interpreter tool will not work correctly.
INFO 2025-02-12 10:21:04,315 datasets:54: PyTorch version 2.6.0 available.
INFO 2025-02-12 10:21:04,556 httpx:1740: HTTP Request: GET http://localhost:11434/api/ps "HTTP/1.1 200 OK"
INFO 2025-02-12 10:21:04,557 llama_stack.providers.utils.inference.embedding_mixin:42: Loading sentence transformer for all-MiniLM-L6-v2...
INFO 2025-02-12 10:21:07,202 sentence_transformers.SentenceTransformer:210: Use pytorch device_name: mps
INFO 2025-02-12 10:21:07,202 sentence_transformers.SentenceTransformer:218: Load pretrained SentenceTransformer: all-MiniLM-L6-v2
INFO 2025-02-12 10:21:09,500 llama_stack.distribution.stack:102: Models: all-MiniLM-L6-v2 served by sentence-transformers
INFO 2025-02-12 10:21:09,500 llama_stack.distribution.stack:102: Models: meta-llama/Llama-3.2-3B-Instruct served by ollama
INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: basic::equality served by basic
INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: basic::regex_parser_multiple_choice_answer served by basic
INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: basic::subset_of served by basic
INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: braintrust::answer-correctness served by braintrust
INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: braintrust::answer-relevancy served by braintrust
INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: braintrust::answer-similarity served by braintrust
INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: braintrust::context-entity-recall served by braintrust
INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: braintrust::context-precision served by braintrust
INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: braintrust::context-recall served by braintrust
INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: braintrust::context-relevancy served by braintrust
INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: braintrust::factuality served by braintrust
INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: braintrust::faithfulness served by braintrust
INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: llm-as-judge::405b-simpleqa served by llm-as-judge
INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: llm-as-judge::base served by llm-as-judge
INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Tool_groups: builtin::code_interpreter served by code-interpreter
INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Tool_groups: builtin::rag served by rag-runtime
INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Tool_groups: builtin::websearch served by tavily-search
INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:106: 
Serving API eval
 POST /v1/eval/tasks/{task_id}/evaluations
 DELETE /v1/eval/tasks/{task_id}/jobs/{job_id}
 GET /v1/eval/tasks/{task_id}/jobs/{job_id}/result
 GET /v1/eval/tasks/{task_id}/jobs/{job_id}
 POST /v1/eval/tasks/{task_id}/jobs
Serving API agents
 POST /v1/agents
 POST /v1/agents/{agent_id}/session
 POST /v1/agents/{agent_id}/session/{session_id}/turn
 DELETE /v1/agents/{agent_id}
 DELETE /v1/agents/{agent_id}/session/{session_id}
 GET /v1/agents/{agent_id}/session/{session_id}
 GET /v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}
 GET /v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}
Serving API scoring_functions
 GET /v1/scoring-functions/{scoring_fn_id}
 GET /v1/scoring-functions
 POST /v1/scoring-functions
Serving API safety
 POST /v1/safety/run-shield
Serving API inspect
 GET /v1/health
 GET /v1/inspect/providers
 GET /v1/inspect/routes
 GET /v1/version
Serving API tool_runtime
 POST /v1/tool-runtime/invoke
 GET /v1/tool-runtime/list-tools
 POST /v1/tool-runtime/rag-tool/insert
 POST /v1/tool-runtime/rag-tool/query
Serving API datasetio
 POST /v1/datasetio/rows
 GET /v1/datasetio/rows
Serving API shields
 GET /v1/shields/{identifier}
 GET /v1/shields
 POST /v1/shields
Serving API eval_tasks
 GET /v1/eval-tasks/{eval_task_id}
 GET /v1/eval-tasks
 POST /v1/eval-tasks
Serving API models
 GET /v1/models/{model_id}
 GET /v1/models
 POST /v1/models
 DELETE /v1/models/{model_id}
Serving API datasets
 GET /v1/datasets/{dataset_id}
 GET /v1/datasets
 POST /v1/datasets
 DELETE /v1/datasets/{dataset_id}
Serving API vector_io
 POST /v1/vector-io/insert
 POST /v1/vector-io/query
Serving API inference
 POST /v1/inference/chat-completion
 POST /v1/inference/completion
 POST /v1/inference/embeddings
Serving API tool_groups
 GET /v1/tools/{tool_name}
 GET /v1/toolgroups/{toolgroup_id}
 GET /v1/toolgroups
 GET /v1/tools
 POST /v1/toolgroups
 DELETE /v1/toolgroups/{toolgroup_id}
Serving API vector_dbs
 GET /v1/vector-dbs/{vector_db_id}
 GET /v1/vector-dbs
 POST /v1/vector-dbs
 DELETE /v1/vector-dbs/{vector_db_id}
Serving API scoring
 POST /v1/scoring/score
 POST /v1/scoring/score-batch
Serving API telemetry
 GET /v1/telemetry/traces/{trace_id}/spans/{span_id}
 GET /v1/telemetry/spans/{span_id}/tree
 GET /v1/telemetry/traces/{trace_id}
 POST /v1/telemetry/events
 GET /v1/telemetry/spans
 GET /v1/telemetry/traces
 POST /v1/telemetry/spans/export

Listening on ['::', '0.0.0.0']:5001
INFO:     Started server process [65372]
INFO:     Waiting for application startup.
INFO:     ASGI 'lifespan' protocol appears unsupported.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://['::', '0.0.0.0']:5001 (Press CTRL+C to quit)
^CINFO:     Shutting down
INFO:     Finished server process [65372]
Received signal SIGINT (2). Exiting gracefully...
INFO 2025-02-12 10:21:11,215 __main__:151: Shutting down ModelsRoutingTable
INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down InferenceRouter
INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down ShieldsRoutingTable
INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down SafetyRouter
INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down VectorDBsRoutingTable
INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down VectorIORouter
INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down ToolGroupsRoutingTable
INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down ToolRuntimeRouter
INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down MetaReferenceAgentsImpl
INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down DatasetsRoutingTable
INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down DatasetIORouter
INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down TelemetryAdapter
INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down ScoringFunctionsRoutingTable
INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down ScoringRouter
INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down EvalTasksRoutingTable
INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down EvalRouter
INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down DistributionInspectImpl
```

[//]: # (## Documentation)
[//]: # (- [ ] Added a Changelog entry if the change is significant)

Signed-off-by: Sébastien Han <seb@redhat.com>
2025-02-13 08:07:59 -08:00

487 lines
17 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import asyncio
import functools
import logging
import inspect
import json
import os
import signal
import sys
import traceback
import warnings
from contextlib import asynccontextmanager
from importlib.metadata import version as parse_version
from pathlib import Path
from typing import Any, List, Union
import yaml
from fastapi import Body, FastAPI, HTTPException, Path as FastapiPath, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.stack import (
construct_stack,
redact_sensitive_fields,
replace_env_vars,
validate_env_pair,
)
from llama_stack.providers.datatypes import Api
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
TelemetryAdapter,
)
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
setup_logger,
start_trace,
)
from .endpoints import get_all_api_endpoints
REPO_ROOT = Path(__file__).parent.parent.parent.parent
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s")
logger = logging.getLogger(__name__)
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
log = file if hasattr(file, "write") else sys.stderr
traceback.print_stack(file=log)
log.write(warnings.formatwarning(message, category, filename, lineno, line))
if os.environ.get("LLAMA_STACK_TRACE_WARNINGS"):
warnings.showwarning = warn_with_traceback
def create_sse_event(data: Any) -> str:
if isinstance(data, BaseModel):
data = data.model_dump_json()
else:
data = json.dumps(data)
return f"data: {data}\n\n"
async def global_exception_handler(request: Request, exc: Exception):
traceback.print_exception(exc)
http_exc = translate_exception(exc)
return JSONResponse(status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}})
def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidationError]:
if isinstance(exc, ValidationError):
exc = RequestValidationError(exc.raw_errors)
if isinstance(exc, RequestValidationError):
return HTTPException(
status_code=400,
detail={
"errors": [
{
"loc": list(error["loc"]),
"msg": error["msg"],
"type": error["type"],
}
for error in exc.errors()
]
},
)
elif isinstance(exc, ValueError):
return HTTPException(status_code=400, detail=f"Invalid value: {str(exc)}")
elif isinstance(exc, PermissionError):
return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}")
elif isinstance(exc, TimeoutError):
return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}")
elif isinstance(exc, NotImplementedError):
return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}")
else:
return HTTPException(
status_code=500,
detail="Internal server error: An unexpected error occurred.",
)
def handle_signal(app, signum, _) -> None:
"""
Handle incoming signals and initiate a graceful shutdown of the application.
This function is intended to be used as a signal handler for various signals
(e.g., SIGINT, SIGTERM). Upon receiving a signal, it will print a message
indicating the received signal and initiate a shutdown process.
Args:
app: The application instance containing implementations to be shut down.
signum (int): The signal number received.
frame: The current stack frame (not used in this function).
The shutdown process involves:
- Shutting down all implementations registered in the application.
- Gathering all running asyncio tasks.
- Cancelling all gathered tasks.
- Waiting for all tasks to finish.
- Stopping the event loop.
Note:
This function schedules the shutdown process as an asyncio task and does
not block the current execution.
"""
signame = signal.Signals(signum).name
print(f"Received signal {signame} ({signum}). Exiting gracefully...")
async def shutdown():
try:
# Gracefully shut down implementations
for impl in app.__llama_stack_impls__.values():
impl_name = impl.__class__.__name__
logger.info("Shutting down %s", impl_name)
try:
if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logger.warning("No shutdown method for %s", impl_name)
except asyncio.TimeoutError:
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
except Exception as e:
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
# Gather all running tasks
loop = asyncio.get_running_loop()
tasks = [task for task in asyncio.all_tasks(loop) if task is not asyncio.current_task()]
# Cancel all tasks
for task in tasks:
task.cancel()
# Wait for all tasks to finish
try:
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
except asyncio.TimeoutError:
logger.exception("Timeout while waiting for tasks to finish")
except asyncio.CancelledError:
pass
finally:
loop.stop()
loop = asyncio.get_running_loop()
loop.create_task(shutdown())
@asynccontextmanager
async def lifespan(app: FastAPI):
print("Starting up")
yield
print("Shutting down")
for impl in app.__llama_stack_impls__.values():
await impl.shutdown()
def is_streaming_request(func_name: str, request: Request, **kwargs):
# TODO: pass the api method and punt it to the Protocol definition directly
return kwargs.get("stream", False)
async def maybe_await(value):
if inspect.iscoroutine(value):
return await value
return value
async def sse_generator(event_gen):
try:
event_gen = await event_gen
async for item in event_gen:
yield create_sse_event(item)
await asyncio.sleep(0.01)
except asyncio.CancelledError:
print("Generator cancelled")
await event_gen.aclose()
except Exception as e:
traceback.print_exception(e)
yield create_sse_event(
{
"error": {
"message": str(translate_exception(e)),
},
}
)
def create_dynamic_typed_route(func: Any, method: str, route: str):
async def endpoint(request: Request, **kwargs):
set_request_provider_data(request.headers)
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try:
if is_streaming:
return StreamingResponse(sse_generator(func(**kwargs)), media_type="text/event-stream")
else:
value = func(**kwargs)
return await maybe_await(value)
except Exception as e:
traceback.print_exception(e)
raise translate_exception(e) from e
sig = inspect.signature(func)
new_params = [inspect.Parameter("request", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request)]
new_params.extend(sig.parameters.values())
path_params = extract_path_params(route)
if method == "post":
# Annotate parameters that are in the path with Path(...) and others with Body(...)
new_params = [new_params[0]] + [
(
param.replace(annotation=Annotated[param.annotation, FastapiPath(..., title=param.name)])
if param.name in path_params
else param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
)
for param in new_params[1:]
]
endpoint.__signature__ = sig.replace(parameters=new_params)
return endpoint
class TracingMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
path = scope["path"]
await start_trace(path, {"__location__": "server"})
try:
return await self.app(scope, receive, send)
finally:
await end_trace()
class ClientVersionMiddleware:
def __init__(self, app):
self.app = app
self.server_version = parse_version("llama-stack")
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
headers = dict(scope.get("headers", []))
client_version = headers.get(b"x-llamastack-client-version", b"").decode()
if client_version:
try:
client_version_parts = tuple(map(int, client_version.split(".")[:2]))
server_version_parts = tuple(map(int, self.server_version.split(".")[:2]))
if client_version_parts != server_version_parts:
async def send_version_error(send):
await send(
{
"type": "http.response.start",
"status": 426,
"headers": [[b"content-type", b"application/json"]],
}
)
error_msg = json.dumps(
{
"error": {
"message": f"Client version {client_version} is not compatible with server version {self.server_version}. Please update your client."
}
}
).encode()
await send({"type": "http.response.body", "body": error_msg})
return await send_version_error(send)
except (ValueError, IndexError):
# If version parsing fails, let the request through
pass
return await self.app(scope, receive, send)
def main():
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
parser.add_argument(
"--yaml-config",
help="Path to YAML configuration file",
)
parser.add_argument(
"--template",
help="One of the template names in llama_stack/templates (e.g., tgi, fireworks, remote-vllm, etc.)",
)
parser.add_argument(
"--port",
type=int,
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
help="Port to listen on",
)
parser.add_argument("--disable-ipv6", action="store_true", help="Whether to disable IPv6 support")
parser.add_argument(
"--env",
action="append",
help="Environment variables in KEY=value format. Can be specified multiple times.",
)
parser.add_argument(
"--tls-keyfile",
help="Path to TLS key file for HTTPS",
required="--tls-certfile" in sys.argv,
)
parser.add_argument(
"--tls-certfile",
help="Path to TLS certificate file for HTTPS",
required="--tls-keyfile" in sys.argv,
)
args = parser.parse_args()
if args.env:
for env_pair in args.env:
try:
key, value = validate_env_pair(env_pair)
print(f"Setting CLI environment variable {key} => {value}")
os.environ[key] = value
except ValueError as e:
print(f"Error: {str(e)}")
sys.exit(1)
if args.yaml_config:
# if the user provided a config file, use it, even if template was specified
config_file = Path(args.yaml_config)
if not config_file.exists():
raise ValueError(f"Config file {config_file} does not exist")
print(f"Using config file: {config_file}")
elif args.template:
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
if not config_file.exists():
raise ValueError(f"Template {args.template} does not exist")
print(f"Using template {args.template} config file: {config_file}")
else:
raise ValueError("Either --yaml-config or --template must be provided")
with open(config_file, "r") as fp:
config = replace_env_vars(yaml.safe_load(fp))
config = StackRunConfig(**config)
print("Run configuration:")
safe_config = redact_sensitive_fields(config.model_dump())
print(yaml.dump(safe_config, indent=2))
app = FastAPI(lifespan=lifespan)
app.add_middleware(TracingMiddleware)
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
app.add_middleware(ClientVersionMiddleware)
try:
impls = asyncio.run(construct_stack(config))
except InvalidProviderError:
sys.exit(1)
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
else:
setup_logger(TelemetryAdapter(TelemetryConfig()))
all_endpoints = get_all_api_endpoints()
if config.apis:
apis_to_serve = set(config.apis)
else:
apis_to_serve = set(impls.keys())
for inf in builtin_automatically_routed_apis():
# if we do not serve the corresponding router API, we should not serve the routing table API
if inf.router_api.value not in apis_to_serve:
continue
apis_to_serve.add(inf.routing_table_api.value)
apis_to_serve.add("inspect")
for api_str in apis_to_serve:
api = Api(api_str)
endpoints = all_endpoints[api]
impl = impls[api]
for endpoint in endpoints:
if not hasattr(impl, endpoint.name):
# ideally this should be a typing violation already
raise ValueError(f"Could not find method {endpoint.name} on {impl}!!")
impl_method = getattr(impl, endpoint.name)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
create_dynamic_typed_route(
impl_method,
endpoint.method,
endpoint.route,
)
)
cprint(f"Serving API {api_str}", "white", attrs=["bold"])
for endpoint in endpoints:
cprint(f" {endpoint.method.upper()} {endpoint.route}", "white")
print("")
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler)
signal.signal(signal.SIGINT, functools.partial(handle_signal, app))
signal.signal(signal.SIGTERM, functools.partial(handle_signal, app))
app.__llama_stack_impls__ = impls
import uvicorn
# Configure SSL if certificates are provided
port = args.port or config.server.port
ssl_config = None
if args.tls_keyfile:
keyfile = args.tls_keyfile
certfile = args.tls_certfile
else:
keyfile = config.server.tls_keyfile
certfile = config.server.tls_certfile
if keyfile and certfile:
ssl_config = {
"ssl_keyfile": keyfile,
"ssl_certfile": certfile,
}
print(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0"
print(f"Listening on {listen_host}:{port}")
uvicorn_config = {
"app": app,
"host": listen_host,
"port": port,
}
if ssl_config:
uvicorn_config.update(ssl_config)
uvicorn.run(**uvicorn_config)
def extract_path_params(route: str) -> List[str]:
segments = route.split("/")
params = [seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")]
return params
if __name__ == "__main__":
main()