This commit is contained in:
Roland Huß 2025-05-15 11:50:47 +02:00 committed by GitHub
commit 8154dc7500
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
41 changed files with 242 additions and 81 deletions

View file

@ -59,6 +59,11 @@ class StackRun(Subcommand):
help="Image Type used during the build. This can be either conda or container or venv.", help="Image Type used during the build. This can be either conda or container or venv.",
choices=[e.value for e in ImageType], choices=[e.value for e in ImageType],
) )
self.parser.add_argument(
"--storage-directory",
type=str,
help="Directory to use for provider state (overrides environment variable and default).",
)
# If neither image type nor image name is provided, but at the same time # If neither image type nor image name is provided, but at the same time
# the current environment has conda breadcrumbs, then assume what the user # the current environment has conda breadcrumbs, then assume what the user
@ -118,6 +123,9 @@ class StackRun(Subcommand):
except AttributeError as e: except AttributeError as e:
self.parser.error(f"failed to parse config file '{config_file}':\n {e}") self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
# Pass the CLI storage directory option to run_config for resolver use
config.storage_dir = args.storage_directory
image_type, image_name = self._get_image_type_and_name(args) image_type, image_name = self._get_image_type_and_name(args)
# If neither image type nor image name is provided, assume the server should be run directly # If neither image type nor image name is provided, assume the server should be run directly

View file

@ -317,6 +317,11 @@ a default SQLite store will be used.""",
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.", description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
) )
storage_dir: str | None = Field(
default=None,
description="Directory to use for provider state. Can be set by CLI, environment variable and default to the distribution directory",
)
class BuildConfig(BaseModel): class BuildConfig(BaseModel):
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION version: str = LLAMA_STACK_BUILD_CONFIG_VERSION

View file

@ -24,7 +24,7 @@ class DistributionInspectConfig(BaseModel):
run_config: StackRunConfig run_config: StackRunConfig
async def get_provider_impl(config, deps): async def get_provider_impl(context, config, deps):
impl = DistributionInspectImpl(config, deps) impl = DistributionInspectImpl(config, deps)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -23,7 +23,7 @@ class ProviderImplConfig(BaseModel):
run_config: StackRunConfig run_config: StackRunConfig
async def get_provider_impl(config, deps): async def get_provider_impl(context, config, deps):
impl = ProviderImpl(config, deps) impl = ProviderImpl(config, deps)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -5,6 +5,8 @@
# the root directory of this source tree. # the root directory of this source tree.
import importlib import importlib
import inspect import inspect
import os
from pathlib import Path
from typing import Any from typing import Any
from llama_stack.apis.agents import Agents from llama_stack.apis.agents import Agents
@ -42,6 +44,7 @@ from llama_stack.providers.datatypes import (
BenchmarksProtocolPrivate, BenchmarksProtocolPrivate,
DatasetsProtocolPrivate, DatasetsProtocolPrivate,
ModelsProtocolPrivate, ModelsProtocolPrivate,
ProviderContext,
ProviderSpec, ProviderSpec,
RemoteProviderConfig, RemoteProviderConfig,
RemoteProviderSpec, RemoteProviderSpec,
@ -334,7 +337,15 @@ async def instantiate_provider(
config_type = instantiate_class_type(provider_spec.config_class) config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider.config) config = config_type(**provider.config)
args = [config, deps] # Build ProviderContext for every provider
distro_name = (
dist_registry.run_config.image_name
if dist_registry and hasattr(dist_registry, "run_config")
else provider.spec.api.value
)
storage_dir = resolve_storage_dir(config, distro_name)
context = ProviderContext(storage_dir=storage_dir)
args = [context, config, deps]
fn = getattr(module, method) fn = getattr(module, method)
impl = await fn(*args) impl = await fn(*args)
@ -413,3 +424,31 @@ async def resolve_remote_stack_impls(
) )
return impls return impls
def resolve_storage_dir(config, distro_name: str) -> Path:
"""
Resolves the storage directory for a provider in the following order of precedence:
1. CLI option (config.storage_dir)
2. Environment variable (LLAMA_STACK_STORAGE_DIR)
3. Fallback to <DISTRIBS_BASE_DIR>/<distro_name>
Args:
config: Provider configuration object
distro_name: Distribution name used for the fallback path
Returns:
Path: Resolved storage directory path
"""
# Import here to avoid circular imports
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
# 1. CLI option
storage_dir = getattr(config, "storage_dir", None)
# 2. Environment variable
if not storage_dir:
storage_dir = os.environ.get("LLAMA_STACK_STORAGE_DIR")
# 3. Fallback
if not storage_dir:
storage_dir = str(DISTRIBS_BASE_DIR / distro_name)
return Path(storage_dir)

View file

@ -18,7 +18,6 @@ from importlib.metadata import version as parse_version
from pathlib import Path from pathlib import Path
from typing import Annotated, Any from typing import Annotated, Any
import rich.pretty
import yaml import yaml
from fastapi import Body, FastAPI, HTTPException, Request from fastapi import Body, FastAPI, HTTPException, Request
from fastapi import Path as FastapiPath from fastapi import Path as FastapiPath
@ -33,7 +32,7 @@ from llama_stack.distribution.request_headers import (
PROVIDER_DATA_VAR, PROVIDER_DATA_VAR,
request_provider_data_context, request_provider_data_context,
) )
from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.resolver import InvalidProviderError, resolve_storage_dir
from llama_stack.distribution.server.endpoints import ( from llama_stack.distribution.server.endpoints import (
find_matching_endpoint, find_matching_endpoint,
initialize_endpoint_impls, initialize_endpoint_impls,
@ -46,7 +45,7 @@ from llama_stack.distribution.stack import (
from llama_stack.distribution.utils.config import redact_sensitive_fields from llama_stack.distribution.utils.config import redact_sensitive_fields
from llama_stack.distribution.utils.context import preserve_contexts_async_generator from llama_stack.distribution.utils.context import preserve_contexts_async_generator
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api, ProviderContext
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import ( from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
TelemetryAdapter, TelemetryAdapter,
@ -188,30 +187,11 @@ async def sse_generator(event_gen_coroutine):
) )
async def log_request_pre_validation(request: Request):
if request.method in ("POST", "PUT", "PATCH"):
try:
body_bytes = await request.body()
if body_bytes:
try:
parsed_body = json.loads(body_bytes.decode())
log_output = rich.pretty.pretty_repr(parsed_body)
except (json.JSONDecodeError, UnicodeDecodeError):
log_output = repr(body_bytes)
logger.debug(f"Incoming raw request body for {request.method} {request.url.path}:\n{log_output}")
else:
logger.debug(f"Incoming {request.method} {request.url.path} request with empty body.")
except Exception as e:
logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
def create_dynamic_typed_route(func: Any, method: str, route: str): def create_dynamic_typed_route(func: Any, method: str, route: str):
async def endpoint(request: Request, **kwargs): async def endpoint(request: Request, **kwargs):
# Get auth attributes from the request scope # Get auth attributes from the request scope
user_attributes = request.scope.get("user_attributes", {}) user_attributes = request.scope.get("user_attributes", {})
await log_request_pre_validation(request)
# Use context manager with both provider data and auth attributes # Use context manager with both provider data and auth attributes
with request_provider_data_context(request.headers, user_attributes): with request_provider_data_context(request.headers, user_attributes):
is_streaming = is_streaming_request(func.__name__, request, **kwargs) is_streaming = is_streaming_request(func.__name__, request, **kwargs)
@ -442,7 +422,10 @@ def main(args: argparse.Namespace | None = None):
if Api.telemetry in impls: if Api.telemetry in impls:
setup_logger(impls[Api.telemetry]) setup_logger(impls[Api.telemetry])
else: else:
setup_logger(TelemetryAdapter(TelemetryConfig(), {})) # Resolve storage directory using the same logic as other providers
storage_dir = resolve_storage_dir(config, config.image_name)
context = ProviderContext(storage_dir=storage_dir)
setup_logger(TelemetryAdapter(context, TelemetryConfig(), {}))
all_endpoints = get_all_api_endpoints() all_endpoints = get_all_api_endpoints()

View file

@ -4,7 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from dataclasses import dataclass
from enum import Enum from enum import Enum
from pathlib import Path
from typing import Any, Protocol from typing import Any, Protocol
from urllib.parse import urlparse from urllib.parse import urlparse
@ -161,7 +163,7 @@ If a provider depends on other providers, the dependencies MUST NOT specify a co
description=""" description="""
Fully-qualified name of the module to import. The module is expected to have: Fully-qualified name of the module to import. The module is expected to have:
- `get_provider_impl(config, deps)`: returns the local implementation - `get_provider_impl(context, config, deps)`: returns the local implementation
""", """,
) )
provider_data_validator: str | None = Field( provider_data_validator: str | None = Field(
@ -232,3 +234,19 @@ class HealthStatus(str, Enum):
HealthResponse = dict[str, Any] HealthResponse = dict[str, Any]
@dataclass
class ProviderContext:
"""
Runtime context for provider instantiation.
This object is constructed by the Llama Stack runtime and injected into every provider.
It contains environment- and deployment-specific information that should not be part of the static config file.
Attributes:
storage_dir (Path): Directory for provider state (persistent or ephemeral),
resolved from CLI option, environment variable, or default distribution directory.
"""
storage_dir: Path

View file

@ -7,14 +7,16 @@
from typing import Any from typing import Any
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ProviderContext
from .config import MetaReferenceAgentsImplConfig from .config import MetaReferenceAgentsImplConfig
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any]): async def get_provider_impl(context: ProviderContext, config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any]):
from .agents import MetaReferenceAgentsImpl from .agents import MetaReferenceAgentsImpl
impl = MetaReferenceAgentsImpl( impl = MetaReferenceAgentsImpl(
context,
config, config,
deps[Api.inference], deps[Api.inference],
deps[Api.vector_io], deps[Api.vector_io],

View file

@ -37,6 +37,7 @@ from llama_stack.apis.inference import (
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.providers.datatypes import ProviderContext
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from llama_stack.providers.utils.pagination import paginate_records from llama_stack.providers.utils.pagination import paginate_records
@ -51,6 +52,7 @@ logger = logging.getLogger()
class MetaReferenceAgentsImpl(Agents): class MetaReferenceAgentsImpl(Agents):
def __init__( def __init__(
self, self,
context: ProviderContext,
config: MetaReferenceAgentsImplConfig, config: MetaReferenceAgentsImplConfig,
inference_api: Inference, inference_api: Inference,
vector_io_api: VectorIO, vector_io_api: VectorIO,
@ -58,6 +60,7 @@ class MetaReferenceAgentsImpl(Agents):
tool_runtime_api: ToolRuntime, tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups, tool_groups_api: ToolGroups,
): ):
self.context = context
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
self.vector_io_api = vector_io_api self.vector_io_api = vector_io_api

View file

@ -6,15 +6,18 @@
from typing import Any from typing import Any
from llama_stack.providers.datatypes import ProviderContext
from .config import LocalFSDatasetIOConfig from .config import LocalFSDatasetIOConfig
async def get_provider_impl( async def get_provider_impl(
context: ProviderContext,
config: LocalFSDatasetIOConfig, config: LocalFSDatasetIOConfig,
_deps: dict[str, Any], _deps: dict[str, Any],
): ):
from .datasetio import LocalFSDatasetIOImpl from .datasetio import LocalFSDatasetIOImpl
impl = LocalFSDatasetIOImpl(config) impl = LocalFSDatasetIOImpl(context, config)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -10,7 +10,7 @@ import pandas
from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Dataset from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.datatypes import DatasetsProtocolPrivate, ProviderContext
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.pagination import paginate_records from llama_stack.providers.utils.pagination import paginate_records
@ -53,7 +53,8 @@ class PandasDataframeDataset:
class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
def __init__(self, config: LocalFSDatasetIOConfig) -> None: def __init__(self, context: ProviderContext, config: LocalFSDatasetIOConfig) -> None:
self.context = context
self.config = config self.config = config
# local registry for keeping track of datasets within the provider # local registry for keeping track of datasets within the provider
self.dataset_infos = {} self.dataset_infos = {}

View file

@ -6,11 +6,13 @@
from typing import Any from typing import Any
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ProviderContext
from .config import MetaReferenceEvalConfig from .config import MetaReferenceEvalConfig
async def get_provider_impl( async def get_provider_impl(
context: ProviderContext,
config: MetaReferenceEvalConfig, config: MetaReferenceEvalConfig,
deps: dict[Api, Any], deps: dict[Api, Any],
): ):

View file

@ -6,15 +6,14 @@
from typing import Any from typing import Any
from llama_stack.providers.datatypes import ProviderContext
from .config import MetaReferenceInferenceConfig from .config import MetaReferenceInferenceConfig
async def get_provider_impl( async def get_provider_impl(context: ProviderContext, config: MetaReferenceInferenceConfig, _deps: dict[str, Any]):
config: MetaReferenceInferenceConfig,
_deps: dict[str, Any],
):
from .inference import MetaReferenceInferenceImpl from .inference import MetaReferenceInferenceImpl
impl = MetaReferenceInferenceImpl(config) impl = MetaReferenceInferenceImpl(context, config)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -50,7 +50,7 @@ from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4Chat
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import ModelFamily from llama_stack.models.llama.sku_types import ModelFamily
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate, ProviderContext
from llama_stack.providers.utils.inference.embedding_mixin import ( from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin, SentenceTransformerEmbeddingMixin,
) )
@ -89,7 +89,8 @@ class MetaReferenceInferenceImpl(
Inference, Inference,
ModelsProtocolPrivate, ModelsProtocolPrivate,
): ):
def __init__(self, config: MetaReferenceInferenceConfig) -> None: def __init__(self, context: ProviderContext, config: MetaReferenceInferenceConfig) -> None:
self.context = context
self.config = config self.config = config
self.model_id = None self.model_id = None
self.llama_model = None self.llama_model = None

View file

@ -6,12 +6,14 @@
from typing import Any from typing import Any
from llama_stack.providers.datatypes import ProviderContext
from llama_stack.providers.inline.inference.sentence_transformers.config import ( from llama_stack.providers.inline.inference.sentence_transformers.config import (
SentenceTransformersInferenceConfig, SentenceTransformersInferenceConfig,
) )
async def get_provider_impl( async def get_provider_impl(
context: ProviderContext,
config: SentenceTransformersInferenceConfig, config: SentenceTransformersInferenceConfig,
_deps: dict[str, Any], _deps: dict[str, Any],
): ):

View file

@ -6,12 +6,14 @@
from typing import Any from typing import Any
from llama_stack.providers.datatypes import ProviderContext
from .config import VLLMConfig from .config import VLLMConfig
async def get_provider_impl(config: VLLMConfig, _deps: dict[str, Any]): async def get_provider_impl(context: ProviderContext, config: VLLMConfig, deps: dict[str, Any]):
from .vllm import VLLMInferenceImpl from .vllm import VLLMInferenceImpl
impl = VLLMInferenceImpl(config) impl = VLLMInferenceImpl(context, config)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -7,6 +7,7 @@
from typing import Any from typing import Any
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ProviderContext
from .config import TorchtunePostTrainingConfig from .config import TorchtunePostTrainingConfig
@ -14,12 +15,14 @@ from .config import TorchtunePostTrainingConfig
async def get_provider_impl( async def get_provider_impl(
context: ProviderContext,
config: TorchtunePostTrainingConfig, config: TorchtunePostTrainingConfig,
deps: dict[Api, Any], deps: dict[Api, Any],
): ):
from .post_training import TorchtunePostTrainingImpl from .post_training import TorchtunePostTrainingImpl
impl = TorchtunePostTrainingImpl( impl = TorchtunePostTrainingImpl(
context,
config, config,
deps[Api.datasetio], deps[Api.datasetio],
deps[Api.datasets], deps[Api.datasets],

View file

@ -20,6 +20,7 @@ from llama_stack.apis.post_training import (
PostTrainingJobStatusResponse, PostTrainingJobStatusResponse,
TrainingConfig, TrainingConfig,
) )
from llama_stack.providers.datatypes import ProviderContext
from llama_stack.providers.inline.post_training.torchtune.config import ( from llama_stack.providers.inline.post_training.torchtune.config import (
TorchtunePostTrainingConfig, TorchtunePostTrainingConfig,
) )
@ -42,10 +43,12 @@ _JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune"
class TorchtunePostTrainingImpl: class TorchtunePostTrainingImpl:
def __init__( def __init__(
self, self,
context: ProviderContext,
config: TorchtunePostTrainingConfig, config: TorchtunePostTrainingConfig,
datasetio_api: DatasetIO, datasetio_api: DatasetIO,
datasets: Datasets, datasets: Datasets,
) -> None: ) -> None:
self.context = context
self.config = config self.config = config
self.datasetio_api = datasetio_api self.datasetio_api = datasetio_api
self.datasets_api = datasets self.datasets_api = datasets

View file

@ -6,12 +6,14 @@
from typing import Any from typing import Any
from llama_stack.providers.datatypes import ProviderContext
from .config import CodeScannerConfig from .config import CodeScannerConfig
async def get_provider_impl(config: CodeScannerConfig, deps: dict[str, Any]): async def get_provider_impl(context: ProviderContext, config: CodeScannerConfig, deps: dict[str, Any]):
from .code_scanner import MetaReferenceCodeScannerSafetyImpl from .code_scanner import MetaReferenceCodeScannerSafetyImpl
impl = MetaReferenceCodeScannerSafetyImpl(config, deps) impl = MetaReferenceCodeScannerSafetyImpl(context, config, deps)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -15,6 +15,7 @@ from llama_stack.apis.safety import (
ViolationLevel, ViolationLevel,
) )
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.providers.datatypes import ProviderContext
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,
) )
@ -30,8 +31,10 @@ ALLOWED_CODE_SCANNER_MODEL_IDS = [
class MetaReferenceCodeScannerSafetyImpl(Safety): class MetaReferenceCodeScannerSafetyImpl(Safety):
def __init__(self, config: CodeScannerConfig, deps) -> None: def __init__(self, context: ProviderContext, config: CodeScannerConfig, deps) -> None:
self.context = context
self.config = config self.config = config
self.deps = deps
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass

View file

@ -6,14 +6,16 @@
from typing import Any from typing import Any
from llama_stack.providers.datatypes import ProviderContext
from .config import LlamaGuardConfig from .config import LlamaGuardConfig
async def get_provider_impl(config: LlamaGuardConfig, deps: dict[str, Any]): async def get_provider_impl(context: ProviderContext, config: LlamaGuardConfig, deps: dict[str, Any]):
from .llama_guard import LlamaGuardSafetyImpl from .llama_guard import LlamaGuardSafetyImpl
assert isinstance(config, LlamaGuardConfig), f"Unexpected config type: {type(config)}" assert isinstance(config, LlamaGuardConfig), f"Unexpected config type: {type(config)}"
impl = LlamaGuardSafetyImpl(config, deps) impl = LlamaGuardSafetyImpl(context, config, deps)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -24,7 +24,7 @@ from llama_stack.apis.shields import Shield
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.models.llama.datatypes import Role from llama_stack.models.llama.datatypes import Role
from llama_stack.models.llama.sku_types import CoreModelId from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.datatypes import ProviderContext, ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,
) )
@ -130,7 +130,8 @@ PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATIO
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
def __init__(self, config: LlamaGuardConfig, deps) -> None: def __init__(self, context: ProviderContext, config: LlamaGuardConfig, deps) -> None:
self.context = context
self.config = config self.config = config
self.inference_api = deps[Api.inference] self.inference_api = deps[Api.inference]

View file

@ -6,12 +6,14 @@
from typing import Any from typing import Any
from llama_stack.providers.datatypes import ProviderContext
from .config import PromptGuardConfig from .config import PromptGuardConfig
async def get_provider_impl(config: PromptGuardConfig, deps: dict[str, Any]): async def get_provider_impl(context: ProviderContext, config: PromptGuardConfig, deps: dict[str, Any]):
from .prompt_guard import PromptGuardSafetyImpl from .prompt_guard import PromptGuardSafetyImpl
impl = PromptGuardSafetyImpl(config, deps) impl = PromptGuardSafetyImpl(context, config, deps)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -19,7 +19,7 @@ from llama_stack.apis.safety import (
) )
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.datatypes import ProviderContext, ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,
) )
@ -32,8 +32,10 @@ PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate): class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
def __init__(self, config: PromptGuardConfig, _deps) -> None: def __init__(self, context: ProviderContext, config: PromptGuardConfig, _deps) -> None:
self.context = context
self.config = config self.config = config
self.deps = _deps
async def initialize(self) -> None: async def initialize(self) -> None:
model_dir = model_local_dir(PROMPT_GUARD_MODEL) model_dir = model_local_dir(PROMPT_GUARD_MODEL)

View file

@ -6,17 +6,20 @@
from typing import Any from typing import Any
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ProviderContext
from .config import BasicScoringConfig from .config import BasicScoringConfig
async def get_provider_impl( async def get_provider_impl(
context: ProviderContext,
config: BasicScoringConfig, config: BasicScoringConfig,
deps: dict[Api, Any], deps: dict[Api, Any],
): ):
from .scoring import BasicScoringImpl from .scoring import BasicScoringImpl
impl = BasicScoringImpl( impl = BasicScoringImpl(
context,
config, config,
deps[Api.datasetio], deps[Api.datasetio],
deps[Api.datasets], deps[Api.datasets],

View file

@ -15,7 +15,7 @@ from llama_stack.apis.scoring import (
) )
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from llama_stack.providers.datatypes import ProviderContext, ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import ( from llama_stack.providers.utils.common.data_schema_validator import (
get_valid_schemas, get_valid_schemas,
validate_dataset_schema, validate_dataset_schema,
@ -49,10 +49,12 @@ class BasicScoringImpl(
): ):
def __init__( def __init__(
self, self,
context: ProviderContext,
config: BasicScoringConfig, config: BasicScoringConfig,
datasetio_api: DatasetIO, datasetio_api: DatasetIO,
datasets_api: Datasets, datasets_api: Datasets,
) -> None: ) -> None:
self.context = context
self.config = config self.config = config
self.datasetio_api = datasetio_api self.datasetio_api = datasetio_api
self.datasets_api = datasets_api self.datasets_api = datasets_api

View file

@ -8,6 +8,7 @@ from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ProviderContext
from .config import BraintrustScoringConfig from .config import BraintrustScoringConfig
@ -17,6 +18,7 @@ class BraintrustProviderDataValidator(BaseModel):
async def get_provider_impl( async def get_provider_impl(
context: ProviderContext,
config: BraintrustScoringConfig, config: BraintrustScoringConfig,
deps: dict[Api, Any], deps: dict[Api, Any],
): ):

View file

@ -6,11 +6,13 @@
from typing import Any from typing import Any
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ProviderContext
from .config import LlmAsJudgeScoringConfig from .config import LlmAsJudgeScoringConfig
async def get_provider_impl( async def get_provider_impl(
context: ProviderContext,
config: LlmAsJudgeScoringConfig, config: LlmAsJudgeScoringConfig,
deps: dict[Api, Any], deps: dict[Api, Any],
): ):

View file

@ -7,15 +7,16 @@
from typing import Any from typing import Any
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ProviderContext
from .config import TelemetryConfig, TelemetrySink from .config import TelemetryConfig, TelemetrySink
__all__ = ["TelemetryConfig", "TelemetrySink"] __all__ = ["TelemetryConfig", "TelemetrySink"]
async def get_provider_impl(config: TelemetryConfig, deps: dict[Api, Any]): async def get_provider_impl(context: ProviderContext, config: TelemetryConfig, deps: dict[Api, Any]):
from .telemetry import TelemetryAdapter from .telemetry import TelemetryAdapter
impl = TelemetryAdapter(config, deps) impl = TelemetryAdapter(context, config, deps)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -36,6 +36,7 @@ from llama_stack.apis.telemetry import (
UnstructuredLogEvent, UnstructuredLogEvent,
) )
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ProviderContext
from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import ( from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import (
ConsoleSpanProcessor, ConsoleSpanProcessor,
) )
@ -45,7 +46,7 @@ from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor
from llama_stack.providers.utils.telemetry.dataset_mixin import TelemetryDatasetMixin from llama_stack.providers.utils.telemetry.dataset_mixin import TelemetryDatasetMixin
from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore
from .config import TelemetryConfig, TelemetrySink from .config import TelemetrySink
_GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = { _GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = {
"active_spans": {}, "active_spans": {},
@ -63,8 +64,10 @@ def is_tracing_enabled(tracer):
class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
def __init__(self, config: TelemetryConfig, deps: dict[Api, Any]) -> None: def __init__(self, context: ProviderContext, config, deps):
self.context = context
self.config = config self.config = config
self.deps = deps
self.datasetio_api = deps.get(Api.datasetio) self.datasetio_api = deps.get(Api.datasetio)
self.meter = None self.meter = None

View file

@ -6,12 +6,12 @@
from typing import Any from typing import Any
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api, ProviderContext
from .config import RagToolRuntimeConfig from .config import RagToolRuntimeConfig
async def get_provider_impl(config: RagToolRuntimeConfig, deps: dict[Api, Any]): async def get_provider_impl(context: ProviderContext, config: RagToolRuntimeConfig, deps: dict[Api, Any]):
from .memory import MemoryToolRuntimeImpl from .memory import MemoryToolRuntimeImpl
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference]) impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])

View file

@ -6,16 +6,17 @@
from typing import Any from typing import Any
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api, ProviderContext
from .config import ChromaVectorIOConfig from .config import ChromaVectorIOConfig
async def get_provider_impl(config: ChromaVectorIOConfig, deps: dict[Api, Any]): async def get_provider_impl(context: ProviderContext, config: ChromaVectorIOConfig, deps: dict[Api, Any]):
from llama_stack.providers.remote.vector_io.chroma.chroma import ( from llama_stack.providers.remote.vector_io.chroma.chroma import (
ChromaVectorIOAdapter, ChromaVectorIOAdapter,
) )
# Pass config directly since ChromaVectorIOAdapter doesn't accept context
impl = ChromaVectorIOAdapter(config, deps[Api.inference]) impl = ChromaVectorIOAdapter(config, deps[Api.inference])
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -6,16 +6,16 @@
from typing import Any from typing import Any
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api, ProviderContext
from .config import FaissVectorIOConfig from .config import FaissVectorIOConfig
async def get_provider_impl(config: FaissVectorIOConfig, deps: dict[Api, Any]): async def get_provider_impl(context: ProviderContext, config: FaissVectorIOConfig, deps: dict[Api, Any]):
from .faiss import FaissVectorIOAdapter from .faiss import FaissVectorIOAdapter
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}" assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = FaissVectorIOAdapter(config, deps[Api.inference]) impl = FaissVectorIOAdapter(context, config, deps[Api.inference])
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -19,7 +19,7 @@ from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference.inference import Inference from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate from llama_stack.providers.datatypes import ProviderContext, VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
@ -114,9 +114,11 @@ class FaissIndex(EmbeddingIndex):
class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None: def __init__(self, context: ProviderContext, config: FaissVectorIOConfig, inference_api: Inference) -> None:
self.context = context
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
self.storage_dir = context.storage_dir if context else None
self.cache: dict[str, VectorDBWithIndex] = {} self.cache: dict[str, VectorDBWithIndex] = {}
self.kvstore: KVStore | None = None self.kvstore: KVStore | None = None

View file

@ -6,14 +6,15 @@
from typing import Any from typing import Any
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api, ProviderContext
from .config import MilvusVectorIOConfig from .config import MilvusVectorIOConfig
async def get_provider_impl(config: MilvusVectorIOConfig, deps: dict[Api, Any]): async def get_provider_impl(context: ProviderContext, config: MilvusVectorIOConfig, deps: dict[Api, Any]):
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
# Pass config directly since MilvusVectorIOAdapter doesn't accept context
impl = MilvusVectorIOAdapter(config, deps[Api.inference]) impl = MilvusVectorIOAdapter(config, deps[Api.inference])
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -6,15 +6,15 @@
from typing import Any from typing import Any
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api, ProviderContext
from .config import SQLiteVectorIOConfig from .config import SQLiteVectorIOConfig
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: dict[Api, Any]): async def get_provider_impl(context: ProviderContext, config: SQLiteVectorIOConfig, deps: dict[Api, Any]):
from .sqlite_vec import SQLiteVecVectorIOAdapter from .sqlite_vec import SQLiteVecVectorIOAdapter
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}" assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference]) impl = SQLiteVecVectorIOAdapter(context, config, deps[Api.inference])
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -10,6 +10,7 @@ import logging
import sqlite3 import sqlite3
import struct import struct
import uuid import uuid
from pathlib import Path
from typing import Any from typing import Any
import numpy as np import numpy as np
@ -19,7 +20,7 @@ from numpy.typing import NDArray
from llama_stack.apis.inference.inference import Inference from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate from llama_stack.providers.datatypes import ProviderContext, VectorDBsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -206,15 +207,23 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex). and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex).
""" """
def __init__(self, config, inference_api: Inference) -> None: def __init__(self, context: ProviderContext, config, inference_api: Inference) -> None:
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
self.cache: dict[str, VectorDBWithIndex] = {} self.cache: dict[str, VectorDBWithIndex] = {}
self.storage_dir = context.storage_dir
self.db_path = self._resolve_path(self.config.db_path)
def _resolve_path(self, path: str | Path) -> Path:
path = Path(path)
if path.is_absolute():
return path
return self.storage_dir / path
async def initialize(self) -> None: async def initialize(self) -> None:
def _setup_connection(): def _setup_connection():
# Open a connection to the SQLite database (the file is specified in the config). # Open a connection to the SQLite database (the file is specified in the config).
connection = _create_sqlite_connection(self.config.db_path) connection = _create_sqlite_connection(self.db_path)
cur = connection.cursor() cur = connection.cursor()
try: try:
# Create a table to persist vector DB registrations. # Create a table to persist vector DB registrations.
@ -237,9 +246,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
for row in rows: for row in rows:
vector_db_data = row[0] vector_db_data = row[0]
vector_db = VectorDB.model_validate_json(vector_db_data) vector_db = VectorDB.model_validate_json(vector_db_data)
index = await SQLiteVecIndex.create( index = await SQLiteVecIndex.create(vector_db.embedding_dimension, str(self.db_path), vector_db.identifier)
vector_db.embedding_dimension, self.config.db_path, vector_db.identifier
)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
async def shutdown(self) -> None: async def shutdown(self) -> None:
@ -248,7 +255,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
async def register_vector_db(self, vector_db: VectorDB) -> None: async def register_vector_db(self, vector_db: VectorDB) -> None:
def _register_db(): def _register_db():
connection = _create_sqlite_connection(self.config.db_path) connection = _create_sqlite_connection(self.db_path)
cur = connection.cursor() cur = connection.cursor()
try: try:
cur.execute( cur.execute(
@ -261,7 +268,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
connection.close() connection.close()
await asyncio.to_thread(_register_db) await asyncio.to_thread(_register_db)
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.config.db_path, vector_db.identifier) index = await SQLiteVecIndex.create(vector_db.embedding_dimension, str(self.db_path), vector_db.identifier)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
async def list_vector_dbs(self) -> list[VectorDB]: async def list_vector_dbs(self) -> list[VectorDB]:
@ -275,7 +282,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
del self.cache[vector_db_id] del self.cache[vector_db_id]
def _delete_vector_db_from_registry(): def _delete_vector_db_from_registry():
connection = _create_sqlite_connection(self.config.db_path) connection = _create_sqlite_connection(self.db_path)
cur = connection.cursor() cur = connection.cursor()
try: try:
cur.execute("DELETE FROM vector_dbs WHERE id = ?", (vector_db_id,)) cur.execute("DELETE FROM vector_dbs WHERE id = ?", (vector_db_id,))

View file

@ -7,7 +7,7 @@
from .config import VLLMInferenceAdapterConfig from .config import VLLMInferenceAdapterConfig
async def get_adapter_impl(config: VLLMInferenceAdapterConfig, _deps): async def get_adapter_impl(config: VLLMInferenceAdapterConfig, deps):
from .vllm import VLLMInferenceAdapter from .vllm import VLLMInferenceAdapter
assert isinstance(config, VLLMInferenceAdapterConfig), f"Unexpected config type: {type(config)}" assert isinstance(config, VLLMInferenceAdapterConfig), f"Unexpected config type: {type(config)}"

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from datetime import datetime from datetime import datetime
from pathlib import Path
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
import pytest import pytest
@ -20,6 +21,7 @@ from llama_stack.apis.inference import Inference
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.providers.datatypes import ProviderContext
from llama_stack.providers.inline.agents.meta_reference.agents import MetaReferenceAgentsImpl from llama_stack.providers.inline.agents.meta_reference.agents import MetaReferenceAgentsImpl
from llama_stack.providers.inline.agents.meta_reference.config import MetaReferenceAgentsImplConfig from llama_stack.providers.inline.agents.meta_reference.config import MetaReferenceAgentsImplConfig
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentInfo from llama_stack.providers.inline.agents.meta_reference.persistence import AgentInfo
@ -48,7 +50,9 @@ def config(tmp_path):
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def agents_impl(config, mock_apis): async def agents_impl(config, mock_apis):
context = ProviderContext(storage_dir=Path("/tmp"))
impl = MetaReferenceAgentsImpl( impl = MetaReferenceAgentsImpl(
context,
config, config,
mock_apis["inference_api"], mock_apis["inference_api"],
mock_apis["vector_io_api"], mock_apis["vector_io_api"],

View file

@ -63,9 +63,13 @@ class MockInferenceAdapterWithSleep:
# ruff: noqa: N802 # ruff: noqa: N802
def do_POST(self): def do_POST(self):
time.sleep(sleep_time) time.sleep(sleep_time)
response_json = json.dumps(response).encode("utf-8")
self.send_response(code=200) self.send_response(code=200)
self.send_header("Content-Type", "application/json")
self.send_header("Content-Length", str(len(response_json)))
self.end_headers() self.end_headers()
self.wfile.write(json.dumps(response).encode("utf-8")) self.wfile.write(response_json)
self.wfile.flush()
self.request_handler = DelayedRequestHandler self.request_handler = DelayedRequestHandler

View file

@ -0,0 +1,43 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from llama_stack.distribution.resolver import resolve_storage_dir
class DummyConfig:
pass
def test_storage_dir_cli(monkeypatch):
config = DummyConfig()
config.storage_dir = "/cli/dir"
monkeypatch.delenv("LLAMA_STACK_STORAGE_DIR", raising=False)
result = resolve_storage_dir(config, "distro")
assert result == Path("/cli/dir")
def test_storage_dir_env(monkeypatch):
config = DummyConfig()
if hasattr(config, "storage_dir"):
delattr(config, "storage_dir")
monkeypatch.setenv("LLAMA_STACK_STORAGE_DIR", "/env/dir")
result = resolve_storage_dir(config, "distro")
assert result == Path("/env/dir")
def test_storage_dir_fallback(monkeypatch):
# Mock the DISTRIBS_BASE_DIR
monkeypatch.setattr("llama_stack.distribution.utils.config_dirs.DISTRIBS_BASE_DIR", Path("/mock/distribs"))
config = DummyConfig()
if hasattr(config, "storage_dir"):
delattr(config, "storage_dir")
monkeypatch.delenv("LLAMA_STACK_STORAGE_DIR", raising=False)
result = resolve_storage_dir(config, "distro")
assert result == Path("/mock/distribs/distro")