From e6c9aebe473a7ce9a56a38f0f98da7043830fb5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roland=20Hu=C3=9F?= Date: Mon, 12 May 2025 11:44:21 +0200 Subject: [PATCH] refactor: Add ProviderContext for a flexible storage directory - Introduce ProviderContext class to decouple provider storage paths from absolute paths - Add storage_dir attribute to StackRunConfig to accept CLI options - Implement storage directory resolution with prioritized fallbacks: 1. CLI option (--state-directory) 2. Environment variable (LLAMA_STACK_STATE_DIR) 3. Default distribution directory - Standardize provider signatures to follow context, config, deps pattern - Update provider implementations to use the new context-based approach - Add comprehensive tests to verify state directory resolution --- llama_stack/cli/stack/run.py | 8 ++++ llama_stack/distribution/datatypes.py | 5 +++ llama_stack/distribution/inspect.py | 2 +- llama_stack/distribution/providers.py | 2 +- llama_stack/distribution/resolver.py | 41 +++++++++++++++++- llama_stack/distribution/server/server.py | 29 +++---------- llama_stack/providers/datatypes.py | 20 ++++++++- .../inline/agents/meta_reference/__init__.py | 4 +- .../inline/agents/meta_reference/agents.py | 3 ++ .../inline/datasetio/localfs/__init__.py | 5 ++- .../inline/datasetio/localfs/datasetio.py | 5 ++- .../inline/eval/meta_reference/__init__.py | 2 + .../inference/meta_reference/__init__.py | 9 ++-- .../inference/meta_reference/inference.py | 5 ++- .../sentence_transformers/__init__.py | 2 + .../inline/inference/vllm/__init__.py | 6 ++- .../post_training/torchtune/__init__.py | 3 ++ .../post_training/torchtune/post_training.py | 3 ++ .../inline/safety/code_scanner/__init__.py | 6 ++- .../safety/code_scanner/code_scanner.py | 5 ++- .../inline/safety/llama_guard/__init__.py | 6 ++- .../inline/safety/llama_guard/llama_guard.py | 5 ++- .../inline/safety/prompt_guard/__init__.py | 6 ++- .../safety/prompt_guard/prompt_guard.py | 6 ++- .../inline/scoring/basic/__init__.py | 3 ++ .../providers/inline/scoring/basic/scoring.py | 4 +- .../inline/scoring/braintrust/__init__.py | 2 + .../inline/scoring/llm_as_judge/__init__.py | 2 + .../telemetry/meta_reference/__init__.py | 5 ++- .../telemetry/meta_reference/telemetry.py | 7 ++- .../inline/tool_runtime/rag/__init__.py | 4 +- .../inline/vector_io/chroma/__init__.py | 5 ++- .../inline/vector_io/faiss/__init__.py | 6 +-- .../providers/inline/vector_io/faiss/faiss.py | 6 ++- .../inline/vector_io/milvus/__init__.py | 5 ++- .../inline/vector_io/sqlite_vec/__init__.py | 6 +-- .../inline/vector_io/sqlite_vec/sqlite_vec.py | 25 +++++++---- .../remote/inference/vllm/__init__.py | 2 +- .../agent/test_meta_reference_agent.py | 4 ++ .../providers/inference/test_remote_vllm.py | 6 ++- tests/unit/test_state_dir_resolution.py | 43 +++++++++++++++++++ 41 files changed, 242 insertions(+), 81 deletions(-) create mode 100644 tests/unit/test_state_dir_resolution.py diff --git a/llama_stack/cli/stack/run.py b/llama_stack/cli/stack/run.py index f3a6a9865..c7be5a86e 100644 --- a/llama_stack/cli/stack/run.py +++ b/llama_stack/cli/stack/run.py @@ -59,6 +59,11 @@ class StackRun(Subcommand): help="Image Type used during the build. This can be either conda or container or venv.", choices=[e.value for e in ImageType], ) + self.parser.add_argument( + "--storage-directory", + type=str, + help="Directory to use for provider state (overrides environment variable and default).", + ) # If neither image type nor image name is provided, but at the same time # the current environment has conda breadcrumbs, then assume what the user @@ -118,6 +123,9 @@ class StackRun(Subcommand): except AttributeError as e: self.parser.error(f"failed to parse config file '{config_file}':\n {e}") + # Pass the CLI storage directory option to run_config for resolver use + config.storage_dir = args.storage_directory + image_type, image_name = self._get_image_type_and_name(args) # If neither image type nor image name is provided, assume the server should be run directly diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index d36e21c6d..36ab4677f 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -317,6 +317,11 @@ a default SQLite store will be used.""", description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.", ) + storage_dir: str | None = Field( + default=None, + description="Directory to use for provider state. Can be set by CLI, environment variable and default to the distribution directory", + ) + class BuildConfig(BaseModel): version: str = LLAMA_STACK_BUILD_CONFIG_VERSION diff --git a/llama_stack/distribution/inspect.py b/llama_stack/distribution/inspect.py index 23f644ec6..e45cc7dd7 100644 --- a/llama_stack/distribution/inspect.py +++ b/llama_stack/distribution/inspect.py @@ -24,7 +24,7 @@ class DistributionInspectConfig(BaseModel): run_config: StackRunConfig -async def get_provider_impl(config, deps): +async def get_provider_impl(context, config, deps): impl = DistributionInspectImpl(config, deps) await impl.initialize() return impl diff --git a/llama_stack/distribution/providers.py b/llama_stack/distribution/providers.py index 29b7109dd..ad7aac5e8 100644 --- a/llama_stack/distribution/providers.py +++ b/llama_stack/distribution/providers.py @@ -23,7 +23,7 @@ class ProviderImplConfig(BaseModel): run_config: StackRunConfig -async def get_provider_impl(config, deps): +async def get_provider_impl(context, config, deps): impl = ProviderImpl(config, deps) await impl.initialize() return impl diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 37588ea64..f137f66e6 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -5,6 +5,8 @@ # the root directory of this source tree. import importlib import inspect +import os +from pathlib import Path from typing import Any from llama_stack.apis.agents import Agents @@ -42,6 +44,7 @@ from llama_stack.providers.datatypes import ( BenchmarksProtocolPrivate, DatasetsProtocolPrivate, ModelsProtocolPrivate, + ProviderContext, ProviderSpec, RemoteProviderConfig, RemoteProviderSpec, @@ -334,7 +337,15 @@ async def instantiate_provider( config_type = instantiate_class_type(provider_spec.config_class) config = config_type(**provider.config) - args = [config, deps] + # Build ProviderContext for every provider + distro_name = ( + dist_registry.run_config.image_name + if dist_registry and hasattr(dist_registry, "run_config") + else provider.spec.api.value + ) + storage_dir = resolve_storage_dir(config, distro_name) + context = ProviderContext(storage_dir=storage_dir) + args = [context, config, deps] fn = getattr(module, method) impl = await fn(*args) @@ -413,3 +424,31 @@ async def resolve_remote_stack_impls( ) return impls + + +def resolve_storage_dir(config, distro_name: str) -> Path: + """ + Resolves the storage directory for a provider in the following order of precedence: + 1. CLI option (config.storage_dir) + 2. Environment variable (LLAMA_STACK_STORAGE_DIR) + 3. Fallback to / + + Args: + config: Provider configuration object + distro_name: Distribution name used for the fallback path + + Returns: + Path: Resolved storage directory path + """ + # Import here to avoid circular imports + from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR + + # 1. CLI option + storage_dir = getattr(config, "storage_dir", None) + # 2. Environment variable + if not storage_dir: + storage_dir = os.environ.get("LLAMA_STACK_STORAGE_DIR") + # 3. Fallback + if not storage_dir: + storage_dir = str(DISTRIBS_BASE_DIR / distro_name) + return Path(storage_dir) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index f4d323607..202a1c866 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -18,7 +18,6 @@ from importlib.metadata import version as parse_version from pathlib import Path from typing import Annotated, Any -import rich.pretty import yaml from fastapi import Body, FastAPI, HTTPException, Request from fastapi import Path as FastapiPath @@ -33,7 +32,7 @@ from llama_stack.distribution.request_headers import ( PROVIDER_DATA_VAR, request_provider_data_context, ) -from llama_stack.distribution.resolver import InvalidProviderError +from llama_stack.distribution.resolver import InvalidProviderError, resolve_storage_dir from llama_stack.distribution.server.endpoints import ( find_matching_endpoint, initialize_endpoint_impls, @@ -46,7 +45,7 @@ from llama_stack.distribution.stack import ( from llama_stack.distribution.utils.config import redact_sensitive_fields from llama_stack.distribution.utils.context import preserve_contexts_async_generator from llama_stack.log import get_logger -from llama_stack.providers.datatypes import Api +from llama_stack.providers.datatypes import Api, ProviderContext from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig from llama_stack.providers.inline.telemetry.meta_reference.telemetry import ( TelemetryAdapter, @@ -188,30 +187,11 @@ async def sse_generator(event_gen_coroutine): ) -async def log_request_pre_validation(request: Request): - if request.method in ("POST", "PUT", "PATCH"): - try: - body_bytes = await request.body() - if body_bytes: - try: - parsed_body = json.loads(body_bytes.decode()) - log_output = rich.pretty.pretty_repr(parsed_body) - except (json.JSONDecodeError, UnicodeDecodeError): - log_output = repr(body_bytes) - logger.debug(f"Incoming raw request body for {request.method} {request.url.path}:\n{log_output}") - else: - logger.debug(f"Incoming {request.method} {request.url.path} request with empty body.") - except Exception as e: - logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}") - - def create_dynamic_typed_route(func: Any, method: str, route: str): async def endpoint(request: Request, **kwargs): # Get auth attributes from the request scope user_attributes = request.scope.get("user_attributes", {}) - await log_request_pre_validation(request) - # Use context manager with both provider data and auth attributes with request_provider_data_context(request.headers, user_attributes): is_streaming = is_streaming_request(func.__name__, request, **kwargs) @@ -442,7 +422,10 @@ def main(args: argparse.Namespace | None = None): if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) else: - setup_logger(TelemetryAdapter(TelemetryConfig(), {})) + # Resolve storage directory using the same logic as other providers + storage_dir = resolve_storage_dir(config, config.image_name) + context = ProviderContext(storage_dir=storage_dir) + setup_logger(TelemetryAdapter(context, TelemetryConfig(), {})) all_endpoints = get_all_api_endpoints() diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 3e9806f23..876e374c6 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -4,7 +4,9 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from dataclasses import dataclass from enum import Enum +from pathlib import Path from typing import Any, Protocol from urllib.parse import urlparse @@ -161,7 +163,7 @@ If a provider depends on other providers, the dependencies MUST NOT specify a co description=""" Fully-qualified name of the module to import. The module is expected to have: - - `get_provider_impl(config, deps)`: returns the local implementation + - `get_provider_impl(context, config, deps)`: returns the local implementation """, ) provider_data_validator: str | None = Field( @@ -232,3 +234,19 @@ class HealthStatus(str, Enum): HealthResponse = dict[str, Any] + + +@dataclass +class ProviderContext: + """ + Runtime context for provider instantiation. + + This object is constructed by the Llama Stack runtime and injected into every provider. + It contains environment- and deployment-specific information that should not be part of the static config file. + + Attributes: + storage_dir (Path): Directory for provider state (persistent or ephemeral), + resolved from CLI option, environment variable, or default distribution directory. + """ + + storage_dir: Path diff --git a/llama_stack/providers/inline/agents/meta_reference/__init__.py b/llama_stack/providers/inline/agents/meta_reference/__init__.py index 7503b8c90..e74ef4c29 100644 --- a/llama_stack/providers/inline/agents/meta_reference/__init__.py +++ b/llama_stack/providers/inline/agents/meta_reference/__init__.py @@ -7,14 +7,16 @@ from typing import Any from llama_stack.distribution.datatypes import Api +from llama_stack.providers.datatypes import ProviderContext from .config import MetaReferenceAgentsImplConfig -async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any]): +async def get_provider_impl(context: ProviderContext, config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any]): from .agents import MetaReferenceAgentsImpl impl = MetaReferenceAgentsImpl( + context, config, deps[Api.inference], deps[Api.vector_io], diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 86780fd61..06c079126 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -37,6 +37,7 @@ from llama_stack.apis.inference import ( from llama_stack.apis.safety import Safety from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO +from llama_stack.providers.datatypes import ProviderContext from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl from llama_stack.providers.utils.pagination import paginate_records @@ -51,6 +52,7 @@ logger = logging.getLogger() class MetaReferenceAgentsImpl(Agents): def __init__( self, + context: ProviderContext, config: MetaReferenceAgentsImplConfig, inference_api: Inference, vector_io_api: VectorIO, @@ -58,6 +60,7 @@ class MetaReferenceAgentsImpl(Agents): tool_runtime_api: ToolRuntime, tool_groups_api: ToolGroups, ): + self.context = context self.config = config self.inference_api = inference_api self.vector_io_api = vector_io_api diff --git a/llama_stack/providers/inline/datasetio/localfs/__init__.py b/llama_stack/providers/inline/datasetio/localfs/__init__.py index 58aa6ffaf..1a6e1a344 100644 --- a/llama_stack/providers/inline/datasetio/localfs/__init__.py +++ b/llama_stack/providers/inline/datasetio/localfs/__init__.py @@ -6,15 +6,18 @@ from typing import Any +from llama_stack.providers.datatypes import ProviderContext + from .config import LocalFSDatasetIOConfig async def get_provider_impl( + context: ProviderContext, config: LocalFSDatasetIOConfig, _deps: dict[str, Any], ): from .datasetio import LocalFSDatasetIOImpl - impl = LocalFSDatasetIOImpl(config) + impl = LocalFSDatasetIOImpl(context, config) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/datasetio/localfs/datasetio.py b/llama_stack/providers/inline/datasetio/localfs/datasetio.py index da71ecb17..0de86b93a 100644 --- a/llama_stack/providers/inline/datasetio/localfs/datasetio.py +++ b/llama_stack/providers/inline/datasetio/localfs/datasetio.py @@ -10,7 +10,7 @@ import pandas from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Dataset -from llama_stack.providers.datatypes import DatasetsProtocolPrivate +from llama_stack.providers.datatypes import DatasetsProtocolPrivate, ProviderContext from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.pagination import paginate_records @@ -53,7 +53,8 @@ class PandasDataframeDataset: class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): - def __init__(self, config: LocalFSDatasetIOConfig) -> None: + def __init__(self, context: ProviderContext, config: LocalFSDatasetIOConfig) -> None: + self.context = context self.config = config # local registry for keeping track of datasets within the provider self.dataset_infos = {} diff --git a/llama_stack/providers/inline/eval/meta_reference/__init__.py b/llama_stack/providers/inline/eval/meta_reference/__init__.py index 7afe7f33b..d1f6bf6d3 100644 --- a/llama_stack/providers/inline/eval/meta_reference/__init__.py +++ b/llama_stack/providers/inline/eval/meta_reference/__init__.py @@ -6,11 +6,13 @@ from typing import Any from llama_stack.distribution.datatypes import Api +from llama_stack.providers.datatypes import ProviderContext from .config import MetaReferenceEvalConfig async def get_provider_impl( + context: ProviderContext, config: MetaReferenceEvalConfig, deps: dict[Api, Any], ): diff --git a/llama_stack/providers/inline/inference/meta_reference/__init__.py b/llama_stack/providers/inline/inference/meta_reference/__init__.py index 5eb822429..98393539f 100644 --- a/llama_stack/providers/inline/inference/meta_reference/__init__.py +++ b/llama_stack/providers/inline/inference/meta_reference/__init__.py @@ -6,15 +6,14 @@ from typing import Any +from llama_stack.providers.datatypes import ProviderContext + from .config import MetaReferenceInferenceConfig -async def get_provider_impl( - config: MetaReferenceInferenceConfig, - _deps: dict[str, Any], -): +async def get_provider_impl(context: ProviderContext, config: MetaReferenceInferenceConfig, _deps: dict[str, Any]): from .inference import MetaReferenceInferenceImpl - impl = MetaReferenceInferenceImpl(config) + impl = MetaReferenceInferenceImpl(context, config) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 8dd594869..20b30ba5b 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -50,7 +50,7 @@ from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4Chat from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_types import ModelFamily -from llama_stack.providers.datatypes import ModelsProtocolPrivate +from llama_stack.providers.datatypes import ModelsProtocolPrivate, ProviderContext from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, ) @@ -89,7 +89,8 @@ class MetaReferenceInferenceImpl( Inference, ModelsProtocolPrivate, ): - def __init__(self, config: MetaReferenceInferenceConfig) -> None: + def __init__(self, context: ProviderContext, config: MetaReferenceInferenceConfig) -> None: + self.context = context self.config = config self.model_id = None self.llama_model = None diff --git a/llama_stack/providers/inline/inference/sentence_transformers/__init__.py b/llama_stack/providers/inline/inference/sentence_transformers/__init__.py index 1719cbacc..e024fd4e4 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/__init__.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/__init__.py @@ -6,12 +6,14 @@ from typing import Any +from llama_stack.providers.datatypes import ProviderContext from llama_stack.providers.inline.inference.sentence_transformers.config import ( SentenceTransformersInferenceConfig, ) async def get_provider_impl( + context: ProviderContext, config: SentenceTransformersInferenceConfig, _deps: dict[str, Any], ): diff --git a/llama_stack/providers/inline/inference/vllm/__init__.py b/llama_stack/providers/inline/inference/vllm/__init__.py index d0ec3e084..254a618ef 100644 --- a/llama_stack/providers/inline/inference/vllm/__init__.py +++ b/llama_stack/providers/inline/inference/vllm/__init__.py @@ -6,12 +6,14 @@ from typing import Any +from llama_stack.providers.datatypes import ProviderContext + from .config import VLLMConfig -async def get_provider_impl(config: VLLMConfig, _deps: dict[str, Any]): +async def get_provider_impl(context: ProviderContext, config: VLLMConfig, deps: dict[str, Any]): from .vllm import VLLMInferenceImpl - impl = VLLMInferenceImpl(config) + impl = VLLMInferenceImpl(context, config) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/post_training/torchtune/__init__.py b/llama_stack/providers/inline/post_training/torchtune/__init__.py index 7a2f9eba2..47b9e5f5c 100644 --- a/llama_stack/providers/inline/post_training/torchtune/__init__.py +++ b/llama_stack/providers/inline/post_training/torchtune/__init__.py @@ -7,6 +7,7 @@ from typing import Any from llama_stack.distribution.datatypes import Api +from llama_stack.providers.datatypes import ProviderContext from .config import TorchtunePostTrainingConfig @@ -14,12 +15,14 @@ from .config import TorchtunePostTrainingConfig async def get_provider_impl( + context: ProviderContext, config: TorchtunePostTrainingConfig, deps: dict[Api, Any], ): from .post_training import TorchtunePostTrainingImpl impl = TorchtunePostTrainingImpl( + context, config, deps[Api.datasetio], deps[Api.datasets], diff --git a/llama_stack/providers/inline/post_training/torchtune/post_training.py b/llama_stack/providers/inline/post_training/torchtune/post_training.py index c7d8d6758..4686cb37c 100644 --- a/llama_stack/providers/inline/post_training/torchtune/post_training.py +++ b/llama_stack/providers/inline/post_training/torchtune/post_training.py @@ -20,6 +20,7 @@ from llama_stack.apis.post_training import ( PostTrainingJobStatusResponse, TrainingConfig, ) +from llama_stack.providers.datatypes import ProviderContext from llama_stack.providers.inline.post_training.torchtune.config import ( TorchtunePostTrainingConfig, ) @@ -42,10 +43,12 @@ _JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune" class TorchtunePostTrainingImpl: def __init__( self, + context: ProviderContext, config: TorchtunePostTrainingConfig, datasetio_api: DatasetIO, datasets: Datasets, ) -> None: + self.context = context self.config = config self.datasetio_api = datasetio_api self.datasets_api = datasets diff --git a/llama_stack/providers/inline/safety/code_scanner/__init__.py b/llama_stack/providers/inline/safety/code_scanner/__init__.py index 68e32b747..6ad83f1c7 100644 --- a/llama_stack/providers/inline/safety/code_scanner/__init__.py +++ b/llama_stack/providers/inline/safety/code_scanner/__init__.py @@ -6,12 +6,14 @@ from typing import Any +from llama_stack.providers.datatypes import ProviderContext + from .config import CodeScannerConfig -async def get_provider_impl(config: CodeScannerConfig, deps: dict[str, Any]): +async def get_provider_impl(context: ProviderContext, config: CodeScannerConfig, deps: dict[str, Any]): from .code_scanner import MetaReferenceCodeScannerSafetyImpl - impl = MetaReferenceCodeScannerSafetyImpl(config, deps) + impl = MetaReferenceCodeScannerSafetyImpl(context, config, deps) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py index be05ee436..f0acf8ae9 100644 --- a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py +++ b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py @@ -15,6 +15,7 @@ from llama_stack.apis.safety import ( ViolationLevel, ) from llama_stack.apis.shields import Shield +from llama_stack.providers.datatypes import ProviderContext from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) @@ -30,8 +31,10 @@ ALLOWED_CODE_SCANNER_MODEL_IDS = [ class MetaReferenceCodeScannerSafetyImpl(Safety): - def __init__(self, config: CodeScannerConfig, deps) -> None: + def __init__(self, context: ProviderContext, config: CodeScannerConfig, deps) -> None: + self.context = context self.config = config + self.deps = deps async def initialize(self) -> None: pass diff --git a/llama_stack/providers/inline/safety/llama_guard/__init__.py b/llama_stack/providers/inline/safety/llama_guard/__init__.py index 8865cc344..54ad1e5bc 100644 --- a/llama_stack/providers/inline/safety/llama_guard/__init__.py +++ b/llama_stack/providers/inline/safety/llama_guard/__init__.py @@ -6,14 +6,16 @@ from typing import Any +from llama_stack.providers.datatypes import ProviderContext + from .config import LlamaGuardConfig -async def get_provider_impl(config: LlamaGuardConfig, deps: dict[str, Any]): +async def get_provider_impl(context: ProviderContext, config: LlamaGuardConfig, deps: dict[str, Any]): from .llama_guard import LlamaGuardSafetyImpl assert isinstance(config, LlamaGuardConfig), f"Unexpected config type: {type(config)}" - impl = LlamaGuardSafetyImpl(config, deps) + impl = LlamaGuardSafetyImpl(context, config, deps) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index 937301c2e..fc33a3ba2 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -24,7 +24,7 @@ from llama_stack.apis.shields import Shield from llama_stack.distribution.datatypes import Api from llama_stack.models.llama.datatypes import Role from llama_stack.models.llama.sku_types import CoreModelId -from llama_stack.providers.datatypes import ShieldsProtocolPrivate +from llama_stack.providers.datatypes import ProviderContext, ShieldsProtocolPrivate from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) @@ -130,7 +130,8 @@ PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATIO class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): - def __init__(self, config: LlamaGuardConfig, deps) -> None: + def __init__(self, context: ProviderContext, config: LlamaGuardConfig, deps) -> None: + self.context = context self.config = config self.inference_api = deps[Api.inference] diff --git a/llama_stack/providers/inline/safety/prompt_guard/__init__.py b/llama_stack/providers/inline/safety/prompt_guard/__init__.py index 1761c9138..4e30e7611 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/__init__.py +++ b/llama_stack/providers/inline/safety/prompt_guard/__init__.py @@ -6,12 +6,14 @@ from typing import Any +from llama_stack.providers.datatypes import ProviderContext + from .config import PromptGuardConfig -async def get_provider_impl(config: PromptGuardConfig, deps: dict[str, Any]): +async def get_provider_impl(context: ProviderContext, config: PromptGuardConfig, deps: dict[str, Any]): from .prompt_guard import PromptGuardSafetyImpl - impl = PromptGuardSafetyImpl(config, deps) + impl = PromptGuardSafetyImpl(context, config, deps) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index 56ce8285f..b6dba962c 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -19,7 +19,7 @@ from llama_stack.apis.safety import ( ) from llama_stack.apis.shields import Shield from llama_stack.distribution.utils.model_utils import model_local_dir -from llama_stack.providers.datatypes import ShieldsProtocolPrivate +from llama_stack.providers.datatypes import ProviderContext, ShieldsProtocolPrivate from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) @@ -32,8 +32,10 @@ PROMPT_GUARD_MODEL = "Prompt-Guard-86M" class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate): - def __init__(self, config: PromptGuardConfig, _deps) -> None: + def __init__(self, context: ProviderContext, config: PromptGuardConfig, _deps) -> None: + self.context = context self.config = config + self.deps = _deps async def initialize(self) -> None: model_dir = model_local_dir(PROMPT_GUARD_MODEL) diff --git a/llama_stack/providers/inline/scoring/basic/__init__.py b/llama_stack/providers/inline/scoring/basic/__init__.py index d9d150b1a..f4e396b18 100644 --- a/llama_stack/providers/inline/scoring/basic/__init__.py +++ b/llama_stack/providers/inline/scoring/basic/__init__.py @@ -6,17 +6,20 @@ from typing import Any from llama_stack.distribution.datatypes import Api +from llama_stack.providers.datatypes import ProviderContext from .config import BasicScoringConfig async def get_provider_impl( + context: ProviderContext, config: BasicScoringConfig, deps: dict[Api, Any], ): from .scoring import BasicScoringImpl impl = BasicScoringImpl( + context, config, deps[Api.datasetio], deps[Api.datasets], diff --git a/llama_stack/providers/inline/scoring/basic/scoring.py b/llama_stack/providers/inline/scoring/basic/scoring.py index 09f89be5e..2dc13874a 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring.py +++ b/llama_stack/providers/inline/scoring/basic/scoring.py @@ -15,7 +15,7 @@ from llama_stack.apis.scoring import ( ) from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams from llama_stack.distribution.datatypes import Api -from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate +from llama_stack.providers.datatypes import ProviderContext, ScoringFunctionsProtocolPrivate from llama_stack.providers.utils.common.data_schema_validator import ( get_valid_schemas, validate_dataset_schema, @@ -49,10 +49,12 @@ class BasicScoringImpl( ): def __init__( self, + context: ProviderContext, config: BasicScoringConfig, datasetio_api: DatasetIO, datasets_api: Datasets, ) -> None: + self.context = context self.config = config self.datasetio_api = datasetio_api self.datasets_api = datasets_api diff --git a/llama_stack/providers/inline/scoring/braintrust/__init__.py b/llama_stack/providers/inline/scoring/braintrust/__init__.py index 8ea6e9b96..ed37de611 100644 --- a/llama_stack/providers/inline/scoring/braintrust/__init__.py +++ b/llama_stack/providers/inline/scoring/braintrust/__init__.py @@ -8,6 +8,7 @@ from typing import Any from pydantic import BaseModel from llama_stack.distribution.datatypes import Api +from llama_stack.providers.datatypes import ProviderContext from .config import BraintrustScoringConfig @@ -17,6 +18,7 @@ class BraintrustProviderDataValidator(BaseModel): async def get_provider_impl( + context: ProviderContext, config: BraintrustScoringConfig, deps: dict[Api, Any], ): diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/__init__.py b/llama_stack/providers/inline/scoring/llm_as_judge/__init__.py index 88bf10737..ee0f0408b 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/__init__.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/__init__.py @@ -6,11 +6,13 @@ from typing import Any from llama_stack.distribution.datatypes import Api +from llama_stack.providers.datatypes import ProviderContext from .config import LlmAsJudgeScoringConfig async def get_provider_impl( + context: ProviderContext, config: LlmAsJudgeScoringConfig, deps: dict[Api, Any], ): diff --git a/llama_stack/providers/inline/telemetry/meta_reference/__init__.py b/llama_stack/providers/inline/telemetry/meta_reference/__init__.py index 09e97136a..36250ac26 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/__init__.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/__init__.py @@ -7,15 +7,16 @@ from typing import Any from llama_stack.distribution.datatypes import Api +from llama_stack.providers.datatypes import ProviderContext from .config import TelemetryConfig, TelemetrySink __all__ = ["TelemetryConfig", "TelemetrySink"] -async def get_provider_impl(config: TelemetryConfig, deps: dict[Api, Any]): +async def get_provider_impl(context: ProviderContext, config: TelemetryConfig, deps: dict[Api, Any]): from .telemetry import TelemetryAdapter - impl = TelemetryAdapter(config, deps) + impl = TelemetryAdapter(context, config, deps) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py index 67362dd36..ad06f6b3d 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -36,6 +36,7 @@ from llama_stack.apis.telemetry import ( UnstructuredLogEvent, ) from llama_stack.distribution.datatypes import Api +from llama_stack.providers.datatypes import ProviderContext from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import ( ConsoleSpanProcessor, ) @@ -45,7 +46,7 @@ from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor from llama_stack.providers.utils.telemetry.dataset_mixin import TelemetryDatasetMixin from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore -from .config import TelemetryConfig, TelemetrySink +from .config import TelemetrySink _GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = { "active_spans": {}, @@ -63,8 +64,10 @@ def is_tracing_enabled(tracer): class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): - def __init__(self, config: TelemetryConfig, deps: dict[Api, Any]) -> None: + def __init__(self, context: ProviderContext, config, deps): + self.context = context self.config = config + self.deps = deps self.datasetio_api = deps.get(Api.datasetio) self.meter = None diff --git a/llama_stack/providers/inline/tool_runtime/rag/__init__.py b/llama_stack/providers/inline/tool_runtime/rag/__init__.py index f9a6e5c55..9e12a3082 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/__init__.py +++ b/llama_stack/providers/inline/tool_runtime/rag/__init__.py @@ -6,12 +6,12 @@ from typing import Any -from llama_stack.providers.datatypes import Api +from llama_stack.providers.datatypes import Api, ProviderContext from .config import RagToolRuntimeConfig -async def get_provider_impl(config: RagToolRuntimeConfig, deps: dict[Api, Any]): +async def get_provider_impl(context: ProviderContext, config: RagToolRuntimeConfig, deps: dict[Api, Any]): from .memory import MemoryToolRuntimeImpl impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference]) diff --git a/llama_stack/providers/inline/vector_io/chroma/__init__.py b/llama_stack/providers/inline/vector_io/chroma/__init__.py index 2e0efb8a1..a50f5d5a3 100644 --- a/llama_stack/providers/inline/vector_io/chroma/__init__.py +++ b/llama_stack/providers/inline/vector_io/chroma/__init__.py @@ -6,16 +6,17 @@ from typing import Any -from llama_stack.providers.datatypes import Api +from llama_stack.providers.datatypes import Api, ProviderContext from .config import ChromaVectorIOConfig -async def get_provider_impl(config: ChromaVectorIOConfig, deps: dict[Api, Any]): +async def get_provider_impl(context: ProviderContext, config: ChromaVectorIOConfig, deps: dict[Api, Any]): from llama_stack.providers.remote.vector_io.chroma.chroma import ( ChromaVectorIOAdapter, ) + # Pass config directly since ChromaVectorIOAdapter doesn't accept context impl = ChromaVectorIOAdapter(config, deps[Api.inference]) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/vector_io/faiss/__init__.py b/llama_stack/providers/inline/vector_io/faiss/__init__.py index 68a1dee66..57c2d628b 100644 --- a/llama_stack/providers/inline/vector_io/faiss/__init__.py +++ b/llama_stack/providers/inline/vector_io/faiss/__init__.py @@ -6,16 +6,16 @@ from typing import Any -from llama_stack.providers.datatypes import Api +from llama_stack.providers.datatypes import Api, ProviderContext from .config import FaissVectorIOConfig -async def get_provider_impl(config: FaissVectorIOConfig, deps: dict[Api, Any]): +async def get_provider_impl(context: ProviderContext, config: FaissVectorIOConfig, deps: dict[Api, Any]): from .faiss import FaissVectorIOAdapter assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}" - impl = FaissVectorIOAdapter(config, deps[Api.inference]) + impl = FaissVectorIOAdapter(context, config, deps[Api.inference]) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index d3dc7e694..5268e5bbe 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -19,7 +19,7 @@ from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.inference.inference import Inference from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO -from llama_stack.providers.datatypes import VectorDBsProtocolPrivate +from llama_stack.providers.datatypes import ProviderContext, VectorDBsProtocolPrivate from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.vector_store import ( @@ -114,9 +114,11 @@ class FaissIndex(EmbeddingIndex): class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): - def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None: + def __init__(self, context: ProviderContext, config: FaissVectorIOConfig, inference_api: Inference) -> None: + self.context = context self.config = config self.inference_api = inference_api + self.storage_dir = context.storage_dir if context else None self.cache: dict[str, VectorDBWithIndex] = {} self.kvstore: KVStore | None = None diff --git a/llama_stack/providers/inline/vector_io/milvus/__init__.py b/llama_stack/providers/inline/vector_io/milvus/__init__.py index fe3a1f7f9..82a37e2c0 100644 --- a/llama_stack/providers/inline/vector_io/milvus/__init__.py +++ b/llama_stack/providers/inline/vector_io/milvus/__init__.py @@ -6,14 +6,15 @@ from typing import Any -from llama_stack.providers.datatypes import Api +from llama_stack.providers.datatypes import Api, ProviderContext from .config import MilvusVectorIOConfig -async def get_provider_impl(config: MilvusVectorIOConfig, deps: dict[Api, Any]): +async def get_provider_impl(context: ProviderContext, config: MilvusVectorIOConfig, deps: dict[Api, Any]): from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter + # Pass config directly since MilvusVectorIOAdapter doesn't accept context impl = MilvusVectorIOAdapter(config, deps[Api.inference]) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py b/llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py index 6db176eda..8eafe9da2 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py @@ -6,15 +6,15 @@ from typing import Any -from llama_stack.providers.datatypes import Api +from llama_stack.providers.datatypes import Api, ProviderContext from .config import SQLiteVectorIOConfig -async def get_provider_impl(config: SQLiteVectorIOConfig, deps: dict[Api, Any]): +async def get_provider_impl(context: ProviderContext, config: SQLiteVectorIOConfig, deps: dict[Api, Any]): from .sqlite_vec import SQLiteVecVectorIOAdapter assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}" - impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference]) + impl = SQLiteVecVectorIOAdapter(context, config, deps[Api.inference]) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index ab4384021..61b7faeda 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -10,6 +10,7 @@ import logging import sqlite3 import struct import uuid +from pathlib import Path from typing import Any import numpy as np @@ -19,7 +20,7 @@ from numpy.typing import NDArray from llama_stack.apis.inference.inference import Inference from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO -from llama_stack.providers.datatypes import VectorDBsProtocolPrivate +from llama_stack.providers.datatypes import ProviderContext, VectorDBsProtocolPrivate from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex logger = logging.getLogger(__name__) @@ -206,15 +207,23 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex). """ - def __init__(self, config, inference_api: Inference) -> None: + def __init__(self, context: ProviderContext, config, inference_api: Inference) -> None: self.config = config self.inference_api = inference_api self.cache: dict[str, VectorDBWithIndex] = {} + self.storage_dir = context.storage_dir + self.db_path = self._resolve_path(self.config.db_path) + + def _resolve_path(self, path: str | Path) -> Path: + path = Path(path) + if path.is_absolute(): + return path + return self.storage_dir / path async def initialize(self) -> None: def _setup_connection(): # Open a connection to the SQLite database (the file is specified in the config). - connection = _create_sqlite_connection(self.config.db_path) + connection = _create_sqlite_connection(self.db_path) cur = connection.cursor() try: # Create a table to persist vector DB registrations. @@ -237,9 +246,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): for row in rows: vector_db_data = row[0] vector_db = VectorDB.model_validate_json(vector_db_data) - index = await SQLiteVecIndex.create( - vector_db.embedding_dimension, self.config.db_path, vector_db.identifier - ) + index = await SQLiteVecIndex.create(vector_db.embedding_dimension, str(self.db_path), vector_db.identifier) self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) async def shutdown(self) -> None: @@ -248,7 +255,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): async def register_vector_db(self, vector_db: VectorDB) -> None: def _register_db(): - connection = _create_sqlite_connection(self.config.db_path) + connection = _create_sqlite_connection(self.db_path) cur = connection.cursor() try: cur.execute( @@ -261,7 +268,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): connection.close() await asyncio.to_thread(_register_db) - index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.config.db_path, vector_db.identifier) + index = await SQLiteVecIndex.create(vector_db.embedding_dimension, str(self.db_path), vector_db.identifier) self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) async def list_vector_dbs(self) -> list[VectorDB]: @@ -275,7 +282,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): del self.cache[vector_db_id] def _delete_vector_db_from_registry(): - connection = _create_sqlite_connection(self.config.db_path) + connection = _create_sqlite_connection(self.db_path) cur = connection.cursor() try: cur.execute("DELETE FROM vector_dbs WHERE id = ?", (vector_db_id,)) diff --git a/llama_stack/providers/remote/inference/vllm/__init__.py b/llama_stack/providers/remote/inference/vllm/__init__.py index e4322a6aa..3e3900113 100644 --- a/llama_stack/providers/remote/inference/vllm/__init__.py +++ b/llama_stack/providers/remote/inference/vllm/__init__.py @@ -7,7 +7,7 @@ from .config import VLLMInferenceAdapterConfig -async def get_adapter_impl(config: VLLMInferenceAdapterConfig, _deps): +async def get_adapter_impl(config: VLLMInferenceAdapterConfig, deps): from .vllm import VLLMInferenceAdapter assert isinstance(config, VLLMInferenceAdapterConfig), f"Unexpected config type: {type(config)}" diff --git a/tests/unit/providers/agent/test_meta_reference_agent.py b/tests/unit/providers/agent/test_meta_reference_agent.py index bef24e123..aec587d40 100644 --- a/tests/unit/providers/agent/test_meta_reference_agent.py +++ b/tests/unit/providers/agent/test_meta_reference_agent.py @@ -5,6 +5,7 @@ # the root directory of this source tree. from datetime import datetime +from pathlib import Path from unittest.mock import AsyncMock import pytest @@ -20,6 +21,7 @@ from llama_stack.apis.inference import Inference from llama_stack.apis.safety import Safety from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO +from llama_stack.providers.datatypes import ProviderContext from llama_stack.providers.inline.agents.meta_reference.agents import MetaReferenceAgentsImpl from llama_stack.providers.inline.agents.meta_reference.config import MetaReferenceAgentsImplConfig from llama_stack.providers.inline.agents.meta_reference.persistence import AgentInfo @@ -48,7 +50,9 @@ def config(tmp_path): @pytest_asyncio.fixture async def agents_impl(config, mock_apis): + context = ProviderContext(storage_dir=Path("/tmp")) impl = MetaReferenceAgentsImpl( + context, config, mock_apis["inference_api"], mock_apis["vector_io_api"], diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index a2e3b64c2..3d022bc00 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -62,9 +62,13 @@ class MockInferenceAdapterWithSleep: # ruff: noqa: N802 def do_POST(self): time.sleep(sleep_time) + response_json = json.dumps(response).encode("utf-8") self.send_response(code=200) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(response_json))) self.end_headers() - self.wfile.write(json.dumps(response).encode("utf-8")) + self.wfile.write(response_json) + self.wfile.flush() self.request_handler = DelayedRequestHandler diff --git a/tests/unit/test_state_dir_resolution.py b/tests/unit/test_state_dir_resolution.py new file mode 100644 index 000000000..a038ebcec --- /dev/null +++ b/tests/unit/test_state_dir_resolution.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pathlib import Path + +from llama_stack.distribution.resolver import resolve_storage_dir + + +class DummyConfig: + pass + + +def test_storage_dir_cli(monkeypatch): + config = DummyConfig() + config.storage_dir = "/cli/dir" + monkeypatch.delenv("LLAMA_STACK_STORAGE_DIR", raising=False) + result = resolve_storage_dir(config, "distro") + assert result == Path("/cli/dir") + + +def test_storage_dir_env(monkeypatch): + config = DummyConfig() + if hasattr(config, "storage_dir"): + delattr(config, "storage_dir") + monkeypatch.setenv("LLAMA_STACK_STORAGE_DIR", "/env/dir") + result = resolve_storage_dir(config, "distro") + assert result == Path("/env/dir") + + +def test_storage_dir_fallback(monkeypatch): + # Mock the DISTRIBS_BASE_DIR + monkeypatch.setattr("llama_stack.distribution.utils.config_dirs.DISTRIBS_BASE_DIR", Path("/mock/distribs")) + + config = DummyConfig() + if hasattr(config, "storage_dir"): + delattr(config, "storage_dir") + monkeypatch.delenv("LLAMA_STACK_STORAGE_DIR", raising=False) + + result = resolve_storage_dir(config, "distro") + assert result == Path("/mock/distribs/distro")