refactor: Add ProviderContext for a flexible storage directory

- Introduce ProviderContext class to decouple provider storage paths from absolute paths
- Add storage_dir attribute to StackRunConfig to accept CLI options
- Implement storage directory resolution with prioritized fallbacks:
  1. CLI option (--state-directory)
  2. Environment variable (LLAMA_STACK_STATE_DIR)
  3. Default distribution directory
- Standardize provider signatures to follow context, config, deps pattern
- Update provider implementations to use the new context-based approach
- Add comprehensive tests to verify state directory resolution
This commit is contained in:
Roland Huß 2025-05-12 11:44:21 +02:00
parent dd07c7a5b5
commit e6c9aebe47
41 changed files with 242 additions and 81 deletions

View file

@ -5,6 +5,8 @@
# the root directory of this source tree.
import importlib
import inspect
import os
from pathlib import Path
from typing import Any
from llama_stack.apis.agents import Agents
@ -42,6 +44,7 @@ from llama_stack.providers.datatypes import (
BenchmarksProtocolPrivate,
DatasetsProtocolPrivate,
ModelsProtocolPrivate,
ProviderContext,
ProviderSpec,
RemoteProviderConfig,
RemoteProviderSpec,
@ -334,7 +337,15 @@ async def instantiate_provider(
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider.config)
args = [config, deps]
# Build ProviderContext for every provider
distro_name = (
dist_registry.run_config.image_name
if dist_registry and hasattr(dist_registry, "run_config")
else provider.spec.api.value
)
storage_dir = resolve_storage_dir(config, distro_name)
context = ProviderContext(storage_dir=storage_dir)
args = [context, config, deps]
fn = getattr(module, method)
impl = await fn(*args)
@ -413,3 +424,31 @@ async def resolve_remote_stack_impls(
)
return impls
def resolve_storage_dir(config, distro_name: str) -> Path:
"""
Resolves the storage directory for a provider in the following order of precedence:
1. CLI option (config.storage_dir)
2. Environment variable (LLAMA_STACK_STORAGE_DIR)
3. Fallback to <DISTRIBS_BASE_DIR>/<distro_name>
Args:
config: Provider configuration object
distro_name: Distribution name used for the fallback path
Returns:
Path: Resolved storage directory path
"""
# Import here to avoid circular imports
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
# 1. CLI option
storage_dir = getattr(config, "storage_dir", None)
# 2. Environment variable
if not storage_dir:
storage_dir = os.environ.get("LLAMA_STACK_STORAGE_DIR")
# 3. Fallback
if not storage_dir:
storage_dir = str(DISTRIBS_BASE_DIR / distro_name)
return Path(storage_dir)