refactor: Add ProviderContext for a flexible storage directory

- Introduce ProviderContext class to decouple provider storage paths from absolute paths
- Add storage_dir attribute to StackRunConfig to accept CLI options
- Implement storage directory resolution with prioritized fallbacks:
  1. CLI option (--state-directory)
  2. Environment variable (LLAMA_STACK_STATE_DIR)
  3. Default distribution directory
- Standardize provider signatures to follow context, config, deps pattern
- Update provider implementations to use the new context-based approach
- Add comprehensive tests to verify state directory resolution
This commit is contained in:
Roland Huß 2025-05-12 11:44:21 +02:00
parent dd07c7a5b5
commit e6c9aebe47
41 changed files with 242 additions and 81 deletions

View file

@ -18,7 +18,6 @@ from importlib.metadata import version as parse_version
from pathlib import Path
from typing import Annotated, Any
import rich.pretty
import yaml
from fastapi import Body, FastAPI, HTTPException, Request
from fastapi import Path as FastapiPath
@ -33,7 +32,7 @@ from llama_stack.distribution.request_headers import (
PROVIDER_DATA_VAR,
request_provider_data_context,
)
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.resolver import InvalidProviderError, resolve_storage_dir
from llama_stack.distribution.server.endpoints import (
find_matching_endpoint,
initialize_endpoint_impls,
@ -46,7 +45,7 @@ from llama_stack.distribution.stack import (
from llama_stack.distribution.utils.config import redact_sensitive_fields
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
from llama_stack.providers.datatypes import Api, ProviderContext
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
TelemetryAdapter,
@ -188,30 +187,11 @@ async def sse_generator(event_gen_coroutine):
)
async def log_request_pre_validation(request: Request):
if request.method in ("POST", "PUT", "PATCH"):
try:
body_bytes = await request.body()
if body_bytes:
try:
parsed_body = json.loads(body_bytes.decode())
log_output = rich.pretty.pretty_repr(parsed_body)
except (json.JSONDecodeError, UnicodeDecodeError):
log_output = repr(body_bytes)
logger.debug(f"Incoming raw request body for {request.method} {request.url.path}:\n{log_output}")
else:
logger.debug(f"Incoming {request.method} {request.url.path} request with empty body.")
except Exception as e:
logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
def create_dynamic_typed_route(func: Any, method: str, route: str):
async def endpoint(request: Request, **kwargs):
# Get auth attributes from the request scope
user_attributes = request.scope.get("user_attributes", {})
await log_request_pre_validation(request)
# Use context manager with both provider data and auth attributes
with request_provider_data_context(request.headers, user_attributes):
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
@ -442,7 +422,10 @@ def main(args: argparse.Namespace | None = None):
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
else:
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
# Resolve storage directory using the same logic as other providers
storage_dir = resolve_storage_dir(config, config.image_name)
context = ProviderContext(storage_dir=storage_dir)
setup_logger(TelemetryAdapter(context, TelemetryConfig(), {}))
all_endpoints = get_all_api_endpoints()