mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-28 04:21:58 +00:00
refactor: Add ProviderContext for a flexible storage directory
- Introduce ProviderContext class to decouple provider storage paths from absolute paths - Add storage_dir attribute to StackRunConfig to accept CLI options - Implement storage directory resolution with prioritized fallbacks: 1. CLI option (--state-directory) 2. Environment variable (LLAMA_STACK_STATE_DIR) 3. Default distribution directory - Standardize provider signatures to follow context, config, deps pattern - Update provider implementations to use the new context-based approach - Add comprehensive tests to verify state directory resolution
This commit is contained in:
parent
dd07c7a5b5
commit
e6c9aebe47
41 changed files with 242 additions and 81 deletions
|
|
@ -317,6 +317,11 @@ a default SQLite store will be used.""",
|
|||
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
|
||||
)
|
||||
|
||||
storage_dir: str | None = Field(
|
||||
default=None,
|
||||
description="Directory to use for provider state. Can be set by CLI, environment variable and default to the distribution directory",
|
||||
)
|
||||
|
||||
|
||||
class BuildConfig(BaseModel):
|
||||
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ class DistributionInspectConfig(BaseModel):
|
|||
run_config: StackRunConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config, deps):
|
||||
async def get_provider_impl(context, config, deps):
|
||||
impl = DistributionInspectImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ class ProviderImplConfig(BaseModel):
|
|||
run_config: StackRunConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config, deps):
|
||||
async def get_provider_impl(context, config, deps):
|
||||
impl = ProviderImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
|
|
@ -42,6 +44,7 @@ from llama_stack.providers.datatypes import (
|
|||
BenchmarksProtocolPrivate,
|
||||
DatasetsProtocolPrivate,
|
||||
ModelsProtocolPrivate,
|
||||
ProviderContext,
|
||||
ProviderSpec,
|
||||
RemoteProviderConfig,
|
||||
RemoteProviderSpec,
|
||||
|
|
@ -334,7 +337,15 @@ async def instantiate_provider(
|
|||
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
config = config_type(**provider.config)
|
||||
args = [config, deps]
|
||||
# Build ProviderContext for every provider
|
||||
distro_name = (
|
||||
dist_registry.run_config.image_name
|
||||
if dist_registry and hasattr(dist_registry, "run_config")
|
||||
else provider.spec.api.value
|
||||
)
|
||||
storage_dir = resolve_storage_dir(config, distro_name)
|
||||
context = ProviderContext(storage_dir=storage_dir)
|
||||
args = [context, config, deps]
|
||||
|
||||
fn = getattr(module, method)
|
||||
impl = await fn(*args)
|
||||
|
|
@ -413,3 +424,31 @@ async def resolve_remote_stack_impls(
|
|||
)
|
||||
|
||||
return impls
|
||||
|
||||
|
||||
def resolve_storage_dir(config, distro_name: str) -> Path:
|
||||
"""
|
||||
Resolves the storage directory for a provider in the following order of precedence:
|
||||
1. CLI option (config.storage_dir)
|
||||
2. Environment variable (LLAMA_STACK_STORAGE_DIR)
|
||||
3. Fallback to <DISTRIBS_BASE_DIR>/<distro_name>
|
||||
|
||||
Args:
|
||||
config: Provider configuration object
|
||||
distro_name: Distribution name used for the fallback path
|
||||
|
||||
Returns:
|
||||
Path: Resolved storage directory path
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
|
||||
# 1. CLI option
|
||||
storage_dir = getattr(config, "storage_dir", None)
|
||||
# 2. Environment variable
|
||||
if not storage_dir:
|
||||
storage_dir = os.environ.get("LLAMA_STACK_STORAGE_DIR")
|
||||
# 3. Fallback
|
||||
if not storage_dir:
|
||||
storage_dir = str(DISTRIBS_BASE_DIR / distro_name)
|
||||
return Path(storage_dir)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ from importlib.metadata import version as parse_version
|
|||
from pathlib import Path
|
||||
from typing import Annotated, Any
|
||||
|
||||
import rich.pretty
|
||||
import yaml
|
||||
from fastapi import Body, FastAPI, HTTPException, Request
|
||||
from fastapi import Path as FastapiPath
|
||||
|
|
@ -33,7 +32,7 @@ from llama_stack.distribution.request_headers import (
|
|||
PROVIDER_DATA_VAR,
|
||||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.resolver import InvalidProviderError, resolve_storage_dir
|
||||
from llama_stack.distribution.server.endpoints import (
|
||||
find_matching_endpoint,
|
||||
initialize_endpoint_impls,
|
||||
|
|
@ -46,7 +45,7 @@ from llama_stack.distribution.stack import (
|
|||
from llama_stack.distribution.utils.config import redact_sensitive_fields
|
||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.providers.datatypes import Api, ProviderContext
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
|
||||
TelemetryAdapter,
|
||||
|
|
@ -188,30 +187,11 @@ async def sse_generator(event_gen_coroutine):
|
|||
)
|
||||
|
||||
|
||||
async def log_request_pre_validation(request: Request):
|
||||
if request.method in ("POST", "PUT", "PATCH"):
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
if body_bytes:
|
||||
try:
|
||||
parsed_body = json.loads(body_bytes.decode())
|
||||
log_output = rich.pretty.pretty_repr(parsed_body)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
log_output = repr(body_bytes)
|
||||
logger.debug(f"Incoming raw request body for {request.method} {request.url.path}:\n{log_output}")
|
||||
else:
|
||||
logger.debug(f"Incoming {request.method} {request.url.path} request with empty body.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
|
||||
|
||||
|
||||
def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
# Get auth attributes from the request scope
|
||||
user_attributes = request.scope.get("user_attributes", {})
|
||||
|
||||
await log_request_pre_validation(request)
|
||||
|
||||
# Use context manager with both provider data and auth attributes
|
||||
with request_provider_data_context(request.headers, user_attributes):
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
|
|
@ -442,7 +422,10 @@ def main(args: argparse.Namespace | None = None):
|
|||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
else:
|
||||
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
|
||||
# Resolve storage directory using the same logic as other providers
|
||||
storage_dir = resolve_storage_dir(config, config.image_name)
|
||||
context = ProviderContext(storage_dir=storage_dir)
|
||||
setup_logger(TelemetryAdapter(context, TelemetryConfig(), {}))
|
||||
|
||||
all_endpoints = get_all_api_endpoints()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue