mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
Merge e6c9aebe47
into 8e7ab146f8
This commit is contained in:
commit
8154dc7500
41 changed files with 242 additions and 81 deletions
|
@ -59,6 +59,11 @@ class StackRun(Subcommand):
|
||||||
help="Image Type used during the build. This can be either conda or container or venv.",
|
help="Image Type used during the build. This can be either conda or container or venv.",
|
||||||
choices=[e.value for e in ImageType],
|
choices=[e.value for e in ImageType],
|
||||||
)
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--storage-directory",
|
||||||
|
type=str,
|
||||||
|
help="Directory to use for provider state (overrides environment variable and default).",
|
||||||
|
)
|
||||||
|
|
||||||
# If neither image type nor image name is provided, but at the same time
|
# If neither image type nor image name is provided, but at the same time
|
||||||
# the current environment has conda breadcrumbs, then assume what the user
|
# the current environment has conda breadcrumbs, then assume what the user
|
||||||
|
@ -118,6 +123,9 @@ class StackRun(Subcommand):
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
|
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
|
||||||
|
|
||||||
|
# Pass the CLI storage directory option to run_config for resolver use
|
||||||
|
config.storage_dir = args.storage_directory
|
||||||
|
|
||||||
image_type, image_name = self._get_image_type_and_name(args)
|
image_type, image_name = self._get_image_type_and_name(args)
|
||||||
|
|
||||||
# If neither image type nor image name is provided, assume the server should be run directly
|
# If neither image type nor image name is provided, assume the server should be run directly
|
||||||
|
|
|
@ -317,6 +317,11 @@ a default SQLite store will be used.""",
|
||||||
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
|
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
storage_dir: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Directory to use for provider state. Can be set by CLI, environment variable and default to the distribution directory",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BuildConfig(BaseModel):
|
class BuildConfig(BaseModel):
|
||||||
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
|
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
|
||||||
|
|
|
@ -24,7 +24,7 @@ class DistributionInspectConfig(BaseModel):
|
||||||
run_config: StackRunConfig
|
run_config: StackRunConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config, deps):
|
async def get_provider_impl(context, config, deps):
|
||||||
impl = DistributionInspectImpl(config, deps)
|
impl = DistributionInspectImpl(config, deps)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -23,7 +23,7 @@ class ProviderImplConfig(BaseModel):
|
||||||
run_config: StackRunConfig
|
run_config: StackRunConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config, deps):
|
async def get_provider_impl(context, config, deps):
|
||||||
impl = ProviderImpl(config, deps)
|
impl = ProviderImpl(config, deps)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -5,6 +5,8 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.agents import Agents
|
from llama_stack.apis.agents import Agents
|
||||||
|
@ -42,6 +44,7 @@ from llama_stack.providers.datatypes import (
|
||||||
BenchmarksProtocolPrivate,
|
BenchmarksProtocolPrivate,
|
||||||
DatasetsProtocolPrivate,
|
DatasetsProtocolPrivate,
|
||||||
ModelsProtocolPrivate,
|
ModelsProtocolPrivate,
|
||||||
|
ProviderContext,
|
||||||
ProviderSpec,
|
ProviderSpec,
|
||||||
RemoteProviderConfig,
|
RemoteProviderConfig,
|
||||||
RemoteProviderSpec,
|
RemoteProviderSpec,
|
||||||
|
@ -334,7 +337,15 @@ async def instantiate_provider(
|
||||||
|
|
||||||
config_type = instantiate_class_type(provider_spec.config_class)
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
config = config_type(**provider.config)
|
config = config_type(**provider.config)
|
||||||
args = [config, deps]
|
# Build ProviderContext for every provider
|
||||||
|
distro_name = (
|
||||||
|
dist_registry.run_config.image_name
|
||||||
|
if dist_registry and hasattr(dist_registry, "run_config")
|
||||||
|
else provider.spec.api.value
|
||||||
|
)
|
||||||
|
storage_dir = resolve_storage_dir(config, distro_name)
|
||||||
|
context = ProviderContext(storage_dir=storage_dir)
|
||||||
|
args = [context, config, deps]
|
||||||
|
|
||||||
fn = getattr(module, method)
|
fn = getattr(module, method)
|
||||||
impl = await fn(*args)
|
impl = await fn(*args)
|
||||||
|
@ -413,3 +424,31 @@ async def resolve_remote_stack_impls(
|
||||||
)
|
)
|
||||||
|
|
||||||
return impls
|
return impls
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_storage_dir(config, distro_name: str) -> Path:
|
||||||
|
"""
|
||||||
|
Resolves the storage directory for a provider in the following order of precedence:
|
||||||
|
1. CLI option (config.storage_dir)
|
||||||
|
2. Environment variable (LLAMA_STACK_STORAGE_DIR)
|
||||||
|
3. Fallback to <DISTRIBS_BASE_DIR>/<distro_name>
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Provider configuration object
|
||||||
|
distro_name: Distribution name used for the fallback path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: Resolved storage directory path
|
||||||
|
"""
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||||
|
|
||||||
|
# 1. CLI option
|
||||||
|
storage_dir = getattr(config, "storage_dir", None)
|
||||||
|
# 2. Environment variable
|
||||||
|
if not storage_dir:
|
||||||
|
storage_dir = os.environ.get("LLAMA_STACK_STORAGE_DIR")
|
||||||
|
# 3. Fallback
|
||||||
|
if not storage_dir:
|
||||||
|
storage_dir = str(DISTRIBS_BASE_DIR / distro_name)
|
||||||
|
return Path(storage_dir)
|
||||||
|
|
|
@ -18,7 +18,6 @@ from importlib.metadata import version as parse_version
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated, Any
|
from typing import Annotated, Any
|
||||||
|
|
||||||
import rich.pretty
|
|
||||||
import yaml
|
import yaml
|
||||||
from fastapi import Body, FastAPI, HTTPException, Request
|
from fastapi import Body, FastAPI, HTTPException, Request
|
||||||
from fastapi import Path as FastapiPath
|
from fastapi import Path as FastapiPath
|
||||||
|
@ -33,7 +32,7 @@ from llama_stack.distribution.request_headers import (
|
||||||
PROVIDER_DATA_VAR,
|
PROVIDER_DATA_VAR,
|
||||||
request_provider_data_context,
|
request_provider_data_context,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.resolver import InvalidProviderError
|
from llama_stack.distribution.resolver import InvalidProviderError, resolve_storage_dir
|
||||||
from llama_stack.distribution.server.endpoints import (
|
from llama_stack.distribution.server.endpoints import (
|
||||||
find_matching_endpoint,
|
find_matching_endpoint,
|
||||||
initialize_endpoint_impls,
|
initialize_endpoint_impls,
|
||||||
|
@ -46,7 +45,7 @@ from llama_stack.distribution.stack import (
|
||||||
from llama_stack.distribution.utils.config import redact_sensitive_fields
|
from llama_stack.distribution.utils.config import redact_sensitive_fields
|
||||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api, ProviderContext
|
||||||
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
|
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
|
||||||
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
|
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
|
||||||
TelemetryAdapter,
|
TelemetryAdapter,
|
||||||
|
@ -188,30 +187,11 @@ async def sse_generator(event_gen_coroutine):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def log_request_pre_validation(request: Request):
|
|
||||||
if request.method in ("POST", "PUT", "PATCH"):
|
|
||||||
try:
|
|
||||||
body_bytes = await request.body()
|
|
||||||
if body_bytes:
|
|
||||||
try:
|
|
||||||
parsed_body = json.loads(body_bytes.decode())
|
|
||||||
log_output = rich.pretty.pretty_repr(parsed_body)
|
|
||||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
|
||||||
log_output = repr(body_bytes)
|
|
||||||
logger.debug(f"Incoming raw request body for {request.method} {request.url.path}:\n{log_output}")
|
|
||||||
else:
|
|
||||||
logger.debug(f"Incoming {request.method} {request.url.path} request with empty body.")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
def create_dynamic_typed_route(func: Any, method: str, route: str):
|
def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||||
async def endpoint(request: Request, **kwargs):
|
async def endpoint(request: Request, **kwargs):
|
||||||
# Get auth attributes from the request scope
|
# Get auth attributes from the request scope
|
||||||
user_attributes = request.scope.get("user_attributes", {})
|
user_attributes = request.scope.get("user_attributes", {})
|
||||||
|
|
||||||
await log_request_pre_validation(request)
|
|
||||||
|
|
||||||
# Use context manager with both provider data and auth attributes
|
# Use context manager with both provider data and auth attributes
|
||||||
with request_provider_data_context(request.headers, user_attributes):
|
with request_provider_data_context(request.headers, user_attributes):
|
||||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||||
|
@ -442,7 +422,10 @@ def main(args: argparse.Namespace | None = None):
|
||||||
if Api.telemetry in impls:
|
if Api.telemetry in impls:
|
||||||
setup_logger(impls[Api.telemetry])
|
setup_logger(impls[Api.telemetry])
|
||||||
else:
|
else:
|
||||||
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
|
# Resolve storage directory using the same logic as other providers
|
||||||
|
storage_dir = resolve_storage_dir(config, config.image_name)
|
||||||
|
context = ProviderContext(storage_dir=storage_dir)
|
||||||
|
setup_logger(TelemetryAdapter(context, TelemetryConfig(), {}))
|
||||||
|
|
||||||
all_endpoints = get_all_api_endpoints()
|
all_endpoints = get_all_api_endpoints()
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,9 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Protocol
|
from typing import Any, Protocol
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
@ -161,7 +163,7 @@ If a provider depends on other providers, the dependencies MUST NOT specify a co
|
||||||
description="""
|
description="""
|
||||||
Fully-qualified name of the module to import. The module is expected to have:
|
Fully-qualified name of the module to import. The module is expected to have:
|
||||||
|
|
||||||
- `get_provider_impl(config, deps)`: returns the local implementation
|
- `get_provider_impl(context, config, deps)`: returns the local implementation
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
provider_data_validator: str | None = Field(
|
provider_data_validator: str | None = Field(
|
||||||
|
@ -232,3 +234,19 @@ class HealthStatus(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
HealthResponse = dict[str, Any]
|
HealthResponse = dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProviderContext:
|
||||||
|
"""
|
||||||
|
Runtime context for provider instantiation.
|
||||||
|
|
||||||
|
This object is constructed by the Llama Stack runtime and injected into every provider.
|
||||||
|
It contains environment- and deployment-specific information that should not be part of the static config file.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
storage_dir (Path): Directory for provider state (persistent or ephemeral),
|
||||||
|
resolved from CLI option, environment variable, or default distribution directory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
storage_dir: Path
|
||||||
|
|
|
@ -7,14 +7,16 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
from llama_stack.providers.datatypes import ProviderContext
|
||||||
|
|
||||||
from .config import MetaReferenceAgentsImplConfig
|
from .config import MetaReferenceAgentsImplConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any]):
|
async def get_provider_impl(context: ProviderContext, config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any]):
|
||||||
from .agents import MetaReferenceAgentsImpl
|
from .agents import MetaReferenceAgentsImpl
|
||||||
|
|
||||||
impl = MetaReferenceAgentsImpl(
|
impl = MetaReferenceAgentsImpl(
|
||||||
|
context,
|
||||||
config,
|
config,
|
||||||
deps[Api.inference],
|
deps[Api.inference],
|
||||||
deps[Api.vector_io],
|
deps[Api.vector_io],
|
||||||
|
|
|
@ -37,6 +37,7 @@ from llama_stack.apis.inference import (
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
|
from llama_stack.providers.datatypes import ProviderContext
|
||||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||||
from llama_stack.providers.utils.pagination import paginate_records
|
from llama_stack.providers.utils.pagination import paginate_records
|
||||||
|
|
||||||
|
@ -51,6 +52,7 @@ logger = logging.getLogger()
|
||||||
class MetaReferenceAgentsImpl(Agents):
|
class MetaReferenceAgentsImpl(Agents):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
context: ProviderContext,
|
||||||
config: MetaReferenceAgentsImplConfig,
|
config: MetaReferenceAgentsImplConfig,
|
||||||
inference_api: Inference,
|
inference_api: Inference,
|
||||||
vector_io_api: VectorIO,
|
vector_io_api: VectorIO,
|
||||||
|
@ -58,6 +60,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
tool_runtime_api: ToolRuntime,
|
tool_runtime_api: ToolRuntime,
|
||||||
tool_groups_api: ToolGroups,
|
tool_groups_api: ToolGroups,
|
||||||
):
|
):
|
||||||
|
self.context = context
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.vector_io_api = vector_io_api
|
self.vector_io_api = vector_io_api
|
||||||
|
|
|
@ -6,15 +6,18 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import ProviderContext
|
||||||
|
|
||||||
from .config import LocalFSDatasetIOConfig
|
from .config import LocalFSDatasetIOConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
async def get_provider_impl(
|
||||||
|
context: ProviderContext,
|
||||||
config: LocalFSDatasetIOConfig,
|
config: LocalFSDatasetIOConfig,
|
||||||
_deps: dict[str, Any],
|
_deps: dict[str, Any],
|
||||||
):
|
):
|
||||||
from .datasetio import LocalFSDatasetIOImpl
|
from .datasetio import LocalFSDatasetIOImpl
|
||||||
|
|
||||||
impl = LocalFSDatasetIOImpl(config)
|
impl = LocalFSDatasetIOImpl(context, config)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -10,7 +10,7 @@ import pandas
|
||||||
from llama_stack.apis.common.responses import PaginatedResponse
|
from llama_stack.apis.common.responses import PaginatedResponse
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Dataset
|
from llama_stack.apis.datasets import Dataset
|
||||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
from llama_stack.providers.datatypes import DatasetsProtocolPrivate, ProviderContext
|
||||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri
|
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
from llama_stack.providers.utils.pagination import paginate_records
|
from llama_stack.providers.utils.pagination import paginate_records
|
||||||
|
@ -53,7 +53,8 @@ class PandasDataframeDataset:
|
||||||
|
|
||||||
|
|
||||||
class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
def __init__(self, config: LocalFSDatasetIOConfig) -> None:
|
def __init__(self, context: ProviderContext, config: LocalFSDatasetIOConfig) -> None:
|
||||||
|
self.context = context
|
||||||
self.config = config
|
self.config = config
|
||||||
# local registry for keeping track of datasets within the provider
|
# local registry for keeping track of datasets within the provider
|
||||||
self.dataset_infos = {}
|
self.dataset_infos = {}
|
||||||
|
|
|
@ -6,11 +6,13 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
from llama_stack.providers.datatypes import ProviderContext
|
||||||
|
|
||||||
from .config import MetaReferenceEvalConfig
|
from .config import MetaReferenceEvalConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
async def get_provider_impl(
|
||||||
|
context: ProviderContext,
|
||||||
config: MetaReferenceEvalConfig,
|
config: MetaReferenceEvalConfig,
|
||||||
deps: dict[Api, Any],
|
deps: dict[Api, Any],
|
||||||
):
|
):
|
||||||
|
|
|
@ -6,15 +6,14 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import ProviderContext
|
||||||
|
|
||||||
from .config import MetaReferenceInferenceConfig
|
from .config import MetaReferenceInferenceConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
async def get_provider_impl(context: ProviderContext, config: MetaReferenceInferenceConfig, _deps: dict[str, Any]):
|
||||||
config: MetaReferenceInferenceConfig,
|
|
||||||
_deps: dict[str, Any],
|
|
||||||
):
|
|
||||||
from .inference import MetaReferenceInferenceImpl
|
from .inference import MetaReferenceInferenceImpl
|
||||||
|
|
||||||
impl = MetaReferenceInferenceImpl(config)
|
impl = MetaReferenceInferenceImpl(context, config)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -50,7 +50,7 @@ from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4Chat
|
||||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
from llama_stack.models.llama.sku_types import ModelFamily
|
from llama_stack.models.llama.sku_types import ModelFamily
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate, ProviderContext
|
||||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||||
SentenceTransformerEmbeddingMixin,
|
SentenceTransformerEmbeddingMixin,
|
||||||
)
|
)
|
||||||
|
@ -89,7 +89,8 @@ class MetaReferenceInferenceImpl(
|
||||||
Inference,
|
Inference,
|
||||||
ModelsProtocolPrivate,
|
ModelsProtocolPrivate,
|
||||||
):
|
):
|
||||||
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
def __init__(self, context: ProviderContext, config: MetaReferenceInferenceConfig) -> None:
|
||||||
|
self.context = context
|
||||||
self.config = config
|
self.config = config
|
||||||
self.model_id = None
|
self.model_id = None
|
||||||
self.llama_model = None
|
self.llama_model = None
|
||||||
|
|
|
@ -6,12 +6,14 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import ProviderContext
|
||||||
from llama_stack.providers.inline.inference.sentence_transformers.config import (
|
from llama_stack.providers.inline.inference.sentence_transformers.config import (
|
||||||
SentenceTransformersInferenceConfig,
|
SentenceTransformersInferenceConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
async def get_provider_impl(
|
||||||
|
context: ProviderContext,
|
||||||
config: SentenceTransformersInferenceConfig,
|
config: SentenceTransformersInferenceConfig,
|
||||||
_deps: dict[str, Any],
|
_deps: dict[str, Any],
|
||||||
):
|
):
|
||||||
|
|
|
@ -6,12 +6,14 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import ProviderContext
|
||||||
|
|
||||||
from .config import VLLMConfig
|
from .config import VLLMConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: VLLMConfig, _deps: dict[str, Any]):
|
async def get_provider_impl(context: ProviderContext, config: VLLMConfig, deps: dict[str, Any]):
|
||||||
from .vllm import VLLMInferenceImpl
|
from .vllm import VLLMInferenceImpl
|
||||||
|
|
||||||
impl = VLLMInferenceImpl(config)
|
impl = VLLMInferenceImpl(context, config)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
from llama_stack.providers.datatypes import ProviderContext
|
||||||
|
|
||||||
from .config import TorchtunePostTrainingConfig
|
from .config import TorchtunePostTrainingConfig
|
||||||
|
|
||||||
|
@ -14,12 +15,14 @@ from .config import TorchtunePostTrainingConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
async def get_provider_impl(
|
||||||
|
context: ProviderContext,
|
||||||
config: TorchtunePostTrainingConfig,
|
config: TorchtunePostTrainingConfig,
|
||||||
deps: dict[Api, Any],
|
deps: dict[Api, Any],
|
||||||
):
|
):
|
||||||
from .post_training import TorchtunePostTrainingImpl
|
from .post_training import TorchtunePostTrainingImpl
|
||||||
|
|
||||||
impl = TorchtunePostTrainingImpl(
|
impl = TorchtunePostTrainingImpl(
|
||||||
|
context,
|
||||||
config,
|
config,
|
||||||
deps[Api.datasetio],
|
deps[Api.datasetio],
|
||||||
deps[Api.datasets],
|
deps[Api.datasets],
|
||||||
|
|
|
@ -20,6 +20,7 @@ from llama_stack.apis.post_training import (
|
||||||
PostTrainingJobStatusResponse,
|
PostTrainingJobStatusResponse,
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.datatypes import ProviderContext
|
||||||
from llama_stack.providers.inline.post_training.torchtune.config import (
|
from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||||
TorchtunePostTrainingConfig,
|
TorchtunePostTrainingConfig,
|
||||||
)
|
)
|
||||||
|
@ -42,10 +43,12 @@ _JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune"
|
||||||
class TorchtunePostTrainingImpl:
|
class TorchtunePostTrainingImpl:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
context: ProviderContext,
|
||||||
config: TorchtunePostTrainingConfig,
|
config: TorchtunePostTrainingConfig,
|
||||||
datasetio_api: DatasetIO,
|
datasetio_api: DatasetIO,
|
||||||
datasets: Datasets,
|
datasets: Datasets,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
self.context = context
|
||||||
self.config = config
|
self.config = config
|
||||||
self.datasetio_api = datasetio_api
|
self.datasetio_api = datasetio_api
|
||||||
self.datasets_api = datasets
|
self.datasets_api = datasets
|
||||||
|
|
|
@ -6,12 +6,14 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import ProviderContext
|
||||||
|
|
||||||
from .config import CodeScannerConfig
|
from .config import CodeScannerConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: CodeScannerConfig, deps: dict[str, Any]):
|
async def get_provider_impl(context: ProviderContext, config: CodeScannerConfig, deps: dict[str, Any]):
|
||||||
from .code_scanner import MetaReferenceCodeScannerSafetyImpl
|
from .code_scanner import MetaReferenceCodeScannerSafetyImpl
|
||||||
|
|
||||||
impl = MetaReferenceCodeScannerSafetyImpl(config, deps)
|
impl = MetaReferenceCodeScannerSafetyImpl(context, config, deps)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -15,6 +15,7 @@ from llama_stack.apis.safety import (
|
||||||
ViolationLevel,
|
ViolationLevel,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
|
from llama_stack.providers.datatypes import ProviderContext
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
)
|
)
|
||||||
|
@ -30,8 +31,10 @@ ALLOWED_CODE_SCANNER_MODEL_IDS = [
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceCodeScannerSafetyImpl(Safety):
|
class MetaReferenceCodeScannerSafetyImpl(Safety):
|
||||||
def __init__(self, config: CodeScannerConfig, deps) -> None:
|
def __init__(self, context: ProviderContext, config: CodeScannerConfig, deps) -> None:
|
||||||
|
self.context = context
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.deps = deps
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -6,14 +6,16 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import ProviderContext
|
||||||
|
|
||||||
from .config import LlamaGuardConfig
|
from .config import LlamaGuardConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: LlamaGuardConfig, deps: dict[str, Any]):
|
async def get_provider_impl(context: ProviderContext, config: LlamaGuardConfig, deps: dict[str, Any]):
|
||||||
from .llama_guard import LlamaGuardSafetyImpl
|
from .llama_guard import LlamaGuardSafetyImpl
|
||||||
|
|
||||||
assert isinstance(config, LlamaGuardConfig), f"Unexpected config type: {type(config)}"
|
assert isinstance(config, LlamaGuardConfig), f"Unexpected config type: {type(config)}"
|
||||||
|
|
||||||
impl = LlamaGuardSafetyImpl(config, deps)
|
impl = LlamaGuardSafetyImpl(context, config, deps)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -24,7 +24,7 @@ from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api
|
||||||
from llama_stack.models.llama.datatypes import Role
|
from llama_stack.models.llama.datatypes import Role
|
||||||
from llama_stack.models.llama.sku_types import CoreModelId
|
from llama_stack.models.llama.sku_types import CoreModelId
|
||||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
from llama_stack.providers.datatypes import ProviderContext, ShieldsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
)
|
)
|
||||||
|
@ -130,7 +130,8 @@ PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATIO
|
||||||
|
|
||||||
|
|
||||||
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||||
def __init__(self, config: LlamaGuardConfig, deps) -> None:
|
def __init__(self, context: ProviderContext, config: LlamaGuardConfig, deps) -> None:
|
||||||
|
self.context = context
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = deps[Api.inference]
|
self.inference_api = deps[Api.inference]
|
||||||
|
|
||||||
|
|
|
@ -6,12 +6,14 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import ProviderContext
|
||||||
|
|
||||||
from .config import PromptGuardConfig
|
from .config import PromptGuardConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: PromptGuardConfig, deps: dict[str, Any]):
|
async def get_provider_impl(context: ProviderContext, config: PromptGuardConfig, deps: dict[str, Any]):
|
||||||
from .prompt_guard import PromptGuardSafetyImpl
|
from .prompt_guard import PromptGuardSafetyImpl
|
||||||
|
|
||||||
impl = PromptGuardSafetyImpl(config, deps)
|
impl = PromptGuardSafetyImpl(context, config, deps)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -19,7 +19,7 @@ from llama_stack.apis.safety import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
from llama_stack.providers.datatypes import ProviderContext, ShieldsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
)
|
)
|
||||||
|
@ -32,8 +32,10 @@ PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
||||||
|
|
||||||
|
|
||||||
class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||||
def __init__(self, config: PromptGuardConfig, _deps) -> None:
|
def __init__(self, context: ProviderContext, config: PromptGuardConfig, _deps) -> None:
|
||||||
|
self.context = context
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.deps = _deps
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
|
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
|
||||||
|
|
|
@ -6,17 +6,20 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
from llama_stack.providers.datatypes import ProviderContext
|
||||||
|
|
||||||
from .config import BasicScoringConfig
|
from .config import BasicScoringConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
async def get_provider_impl(
|
||||||
|
context: ProviderContext,
|
||||||
config: BasicScoringConfig,
|
config: BasicScoringConfig,
|
||||||
deps: dict[Api, Any],
|
deps: dict[Api, Any],
|
||||||
):
|
):
|
||||||
from .scoring import BasicScoringImpl
|
from .scoring import BasicScoringImpl
|
||||||
|
|
||||||
impl = BasicScoringImpl(
|
impl = BasicScoringImpl(
|
||||||
|
context,
|
||||||
config,
|
config,
|
||||||
deps[Api.datasetio],
|
deps[Api.datasetio],
|
||||||
deps[Api.datasets],
|
deps[Api.datasets],
|
||||||
|
|
|
@ -15,7 +15,7 @@ from llama_stack.apis.scoring import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api
|
||||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
from llama_stack.providers.datatypes import ProviderContext, ScoringFunctionsProtocolPrivate
|
||||||
from llama_stack.providers.utils.common.data_schema_validator import (
|
from llama_stack.providers.utils.common.data_schema_validator import (
|
||||||
get_valid_schemas,
|
get_valid_schemas,
|
||||||
validate_dataset_schema,
|
validate_dataset_schema,
|
||||||
|
@ -49,10 +49,12 @@ class BasicScoringImpl(
|
||||||
):
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
context: ProviderContext,
|
||||||
config: BasicScoringConfig,
|
config: BasicScoringConfig,
|
||||||
datasetio_api: DatasetIO,
|
datasetio_api: DatasetIO,
|
||||||
datasets_api: Datasets,
|
datasets_api: Datasets,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
self.context = context
|
||||||
self.config = config
|
self.config = config
|
||||||
self.datasetio_api = datasetio_api
|
self.datasetio_api = datasetio_api
|
||||||
self.datasets_api = datasets_api
|
self.datasets_api = datasets_api
|
||||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
from llama_stack.providers.datatypes import ProviderContext
|
||||||
|
|
||||||
from .config import BraintrustScoringConfig
|
from .config import BraintrustScoringConfig
|
||||||
|
|
||||||
|
@ -17,6 +18,7 @@ class BraintrustProviderDataValidator(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
async def get_provider_impl(
|
||||||
|
context: ProviderContext,
|
||||||
config: BraintrustScoringConfig,
|
config: BraintrustScoringConfig,
|
||||||
deps: dict[Api, Any],
|
deps: dict[Api, Any],
|
||||||
):
|
):
|
||||||
|
|
|
@ -6,11 +6,13 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
from llama_stack.providers.datatypes import ProviderContext
|
||||||
|
|
||||||
from .config import LlmAsJudgeScoringConfig
|
from .config import LlmAsJudgeScoringConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
async def get_provider_impl(
|
||||||
|
context: ProviderContext,
|
||||||
config: LlmAsJudgeScoringConfig,
|
config: LlmAsJudgeScoringConfig,
|
||||||
deps: dict[Api, Any],
|
deps: dict[Api, Any],
|
||||||
):
|
):
|
||||||
|
|
|
@ -7,15 +7,16 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
from llama_stack.providers.datatypes import ProviderContext
|
||||||
|
|
||||||
from .config import TelemetryConfig, TelemetrySink
|
from .config import TelemetryConfig, TelemetrySink
|
||||||
|
|
||||||
__all__ = ["TelemetryConfig", "TelemetrySink"]
|
__all__ = ["TelemetryConfig", "TelemetrySink"]
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: TelemetryConfig, deps: dict[Api, Any]):
|
async def get_provider_impl(context: ProviderContext, config: TelemetryConfig, deps: dict[Api, Any]):
|
||||||
from .telemetry import TelemetryAdapter
|
from .telemetry import TelemetryAdapter
|
||||||
|
|
||||||
impl = TelemetryAdapter(config, deps)
|
impl = TelemetryAdapter(context, config, deps)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -36,6 +36,7 @@ from llama_stack.apis.telemetry import (
|
||||||
UnstructuredLogEvent,
|
UnstructuredLogEvent,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
from llama_stack.providers.datatypes import ProviderContext
|
||||||
from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import (
|
from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import (
|
||||||
ConsoleSpanProcessor,
|
ConsoleSpanProcessor,
|
||||||
)
|
)
|
||||||
|
@ -45,7 +46,7 @@ from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor
|
||||||
from llama_stack.providers.utils.telemetry.dataset_mixin import TelemetryDatasetMixin
|
from llama_stack.providers.utils.telemetry.dataset_mixin import TelemetryDatasetMixin
|
||||||
from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore
|
from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore
|
||||||
|
|
||||||
from .config import TelemetryConfig, TelemetrySink
|
from .config import TelemetrySink
|
||||||
|
|
||||||
_GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = {
|
_GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = {
|
||||||
"active_spans": {},
|
"active_spans": {},
|
||||||
|
@ -63,8 +64,10 @@ def is_tracing_enabled(tracer):
|
||||||
|
|
||||||
|
|
||||||
class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
||||||
def __init__(self, config: TelemetryConfig, deps: dict[Api, Any]) -> None:
|
def __init__(self, context: ProviderContext, config, deps):
|
||||||
|
self.context = context
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.deps = deps
|
||||||
self.datasetio_api = deps.get(Api.datasetio)
|
self.datasetio_api = deps.get(Api.datasetio)
|
||||||
self.meter = None
|
self.meter = None
|
||||||
|
|
||||||
|
|
|
@ -6,12 +6,12 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api, ProviderContext
|
||||||
|
|
||||||
from .config import RagToolRuntimeConfig
|
from .config import RagToolRuntimeConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: RagToolRuntimeConfig, deps: dict[Api, Any]):
|
async def get_provider_impl(context: ProviderContext, config: RagToolRuntimeConfig, deps: dict[Api, Any]):
|
||||||
from .memory import MemoryToolRuntimeImpl
|
from .memory import MemoryToolRuntimeImpl
|
||||||
|
|
||||||
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])
|
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])
|
||||||
|
|
|
@ -6,16 +6,17 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api, ProviderContext
|
||||||
|
|
||||||
from .config import ChromaVectorIOConfig
|
from .config import ChromaVectorIOConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: ChromaVectorIOConfig, deps: dict[Api, Any]):
|
async def get_provider_impl(context: ProviderContext, config: ChromaVectorIOConfig, deps: dict[Api, Any]):
|
||||||
from llama_stack.providers.remote.vector_io.chroma.chroma import (
|
from llama_stack.providers.remote.vector_io.chroma.chroma import (
|
||||||
ChromaVectorIOAdapter,
|
ChromaVectorIOAdapter,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Pass config directly since ChromaVectorIOAdapter doesn't accept context
|
||||||
impl = ChromaVectorIOAdapter(config, deps[Api.inference])
|
impl = ChromaVectorIOAdapter(config, deps[Api.inference])
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -6,16 +6,16 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api, ProviderContext
|
||||||
|
|
||||||
from .config import FaissVectorIOConfig
|
from .config import FaissVectorIOConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: FaissVectorIOConfig, deps: dict[Api, Any]):
|
async def get_provider_impl(context: ProviderContext, config: FaissVectorIOConfig, deps: dict[Api, Any]):
|
||||||
from .faiss import FaissVectorIOAdapter
|
from .faiss import FaissVectorIOAdapter
|
||||||
|
|
||||||
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
|
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||||
|
|
||||||
impl = FaissVectorIOAdapter(config, deps[Api.inference])
|
impl = FaissVectorIOAdapter(context, config, deps[Api.inference])
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -19,7 +19,7 @@ from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
from llama_stack.apis.inference.inference import Inference
|
from llama_stack.apis.inference.inference import Inference
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import ProviderContext, VectorDBsProtocolPrivate
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
@ -114,9 +114,11 @@ class FaissIndex(EmbeddingIndex):
|
||||||
|
|
||||||
|
|
||||||
class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None:
|
def __init__(self, context: ProviderContext, config: FaissVectorIOConfig, inference_api: Inference) -> None:
|
||||||
|
self.context = context
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
|
self.storage_dir = context.storage_dir if context else None
|
||||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||||
self.kvstore: KVStore | None = None
|
self.kvstore: KVStore | None = None
|
||||||
|
|
||||||
|
|
|
@ -6,14 +6,15 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api, ProviderContext
|
||||||
|
|
||||||
from .config import MilvusVectorIOConfig
|
from .config import MilvusVectorIOConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: MilvusVectorIOConfig, deps: dict[Api, Any]):
|
async def get_provider_impl(context: ProviderContext, config: MilvusVectorIOConfig, deps: dict[Api, Any]):
|
||||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
|
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
|
||||||
|
|
||||||
|
# Pass config directly since MilvusVectorIOAdapter doesn't accept context
|
||||||
impl = MilvusVectorIOAdapter(config, deps[Api.inference])
|
impl = MilvusVectorIOAdapter(config, deps[Api.inference])
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -6,15 +6,15 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api, ProviderContext
|
||||||
|
|
||||||
from .config import SQLiteVectorIOConfig
|
from .config import SQLiteVectorIOConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: dict[Api, Any]):
|
async def get_provider_impl(context: ProviderContext, config: SQLiteVectorIOConfig, deps: dict[Api, Any]):
|
||||||
from .sqlite_vec import SQLiteVecVectorIOAdapter
|
from .sqlite_vec import SQLiteVecVectorIOAdapter
|
||||||
|
|
||||||
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
|
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||||
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference])
|
impl = SQLiteVecVectorIOAdapter(context, config, deps[Api.inference])
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -10,6 +10,7 @@ import logging
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import struct
|
import struct
|
||||||
import uuid
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -19,7 +20,7 @@ from numpy.typing import NDArray
|
||||||
from llama_stack.apis.inference.inference import Inference
|
from llama_stack.apis.inference.inference import Inference
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import ProviderContext, VectorDBsProtocolPrivate
|
||||||
from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex
|
from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -206,15 +207,23 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex).
|
and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config, inference_api: Inference) -> None:
|
def __init__(self, context: ProviderContext, config, inference_api: Inference) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||||
|
self.storage_dir = context.storage_dir
|
||||||
|
self.db_path = self._resolve_path(self.config.db_path)
|
||||||
|
|
||||||
|
def _resolve_path(self, path: str | Path) -> Path:
|
||||||
|
path = Path(path)
|
||||||
|
if path.is_absolute():
|
||||||
|
return path
|
||||||
|
return self.storage_dir / path
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
def _setup_connection():
|
def _setup_connection():
|
||||||
# Open a connection to the SQLite database (the file is specified in the config).
|
# Open a connection to the SQLite database (the file is specified in the config).
|
||||||
connection = _create_sqlite_connection(self.config.db_path)
|
connection = _create_sqlite_connection(self.db_path)
|
||||||
cur = connection.cursor()
|
cur = connection.cursor()
|
||||||
try:
|
try:
|
||||||
# Create a table to persist vector DB registrations.
|
# Create a table to persist vector DB registrations.
|
||||||
|
@ -237,9 +246,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
for row in rows:
|
for row in rows:
|
||||||
vector_db_data = row[0]
|
vector_db_data = row[0]
|
||||||
vector_db = VectorDB.model_validate_json(vector_db_data)
|
vector_db = VectorDB.model_validate_json(vector_db_data)
|
||||||
index = await SQLiteVecIndex.create(
|
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, str(self.db_path), vector_db.identifier)
|
||||||
vector_db.embedding_dimension, self.config.db_path, vector_db.identifier
|
|
||||||
)
|
|
||||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
|
@ -248,7 +255,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
|
|
||||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||||
def _register_db():
|
def _register_db():
|
||||||
connection = _create_sqlite_connection(self.config.db_path)
|
connection = _create_sqlite_connection(self.db_path)
|
||||||
cur = connection.cursor()
|
cur = connection.cursor()
|
||||||
try:
|
try:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
|
@ -261,7 +268,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
connection.close()
|
connection.close()
|
||||||
|
|
||||||
await asyncio.to_thread(_register_db)
|
await asyncio.to_thread(_register_db)
|
||||||
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.config.db_path, vector_db.identifier)
|
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, str(self.db_path), vector_db.identifier)
|
||||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||||
|
|
||||||
async def list_vector_dbs(self) -> list[VectorDB]:
|
async def list_vector_dbs(self) -> list[VectorDB]:
|
||||||
|
@ -275,7 +282,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
del self.cache[vector_db_id]
|
del self.cache[vector_db_id]
|
||||||
|
|
||||||
def _delete_vector_db_from_registry():
|
def _delete_vector_db_from_registry():
|
||||||
connection = _create_sqlite_connection(self.config.db_path)
|
connection = _create_sqlite_connection(self.db_path)
|
||||||
cur = connection.cursor()
|
cur = connection.cursor()
|
||||||
try:
|
try:
|
||||||
cur.execute("DELETE FROM vector_dbs WHERE id = ?", (vector_db_id,))
|
cur.execute("DELETE FROM vector_dbs WHERE id = ?", (vector_db_id,))
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
from .config import VLLMInferenceAdapterConfig
|
from .config import VLLMInferenceAdapterConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: VLLMInferenceAdapterConfig, _deps):
|
async def get_adapter_impl(config: VLLMInferenceAdapterConfig, deps):
|
||||||
from .vllm import VLLMInferenceAdapter
|
from .vllm import VLLMInferenceAdapter
|
||||||
|
|
||||||
assert isinstance(config, VLLMInferenceAdapterConfig), f"Unexpected config type: {type(config)}"
|
assert isinstance(config, VLLMInferenceAdapterConfig), f"Unexpected config type: {type(config)}"
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -20,6 +21,7 @@ from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
|
from llama_stack.providers.datatypes import ProviderContext
|
||||||
from llama_stack.providers.inline.agents.meta_reference.agents import MetaReferenceAgentsImpl
|
from llama_stack.providers.inline.agents.meta_reference.agents import MetaReferenceAgentsImpl
|
||||||
from llama_stack.providers.inline.agents.meta_reference.config import MetaReferenceAgentsImplConfig
|
from llama_stack.providers.inline.agents.meta_reference.config import MetaReferenceAgentsImplConfig
|
||||||
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentInfo
|
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentInfo
|
||||||
|
@ -48,7 +50,9 @@ def config(tmp_path):
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest_asyncio.fixture
|
||||||
async def agents_impl(config, mock_apis):
|
async def agents_impl(config, mock_apis):
|
||||||
|
context = ProviderContext(storage_dir=Path("/tmp"))
|
||||||
impl = MetaReferenceAgentsImpl(
|
impl = MetaReferenceAgentsImpl(
|
||||||
|
context,
|
||||||
config,
|
config,
|
||||||
mock_apis["inference_api"],
|
mock_apis["inference_api"],
|
||||||
mock_apis["vector_io_api"],
|
mock_apis["vector_io_api"],
|
||||||
|
|
|
@ -63,9 +63,13 @@ class MockInferenceAdapterWithSleep:
|
||||||
# ruff: noqa: N802
|
# ruff: noqa: N802
|
||||||
def do_POST(self):
|
def do_POST(self):
|
||||||
time.sleep(sleep_time)
|
time.sleep(sleep_time)
|
||||||
|
response_json = json.dumps(response).encode("utf-8")
|
||||||
self.send_response(code=200)
|
self.send_response(code=200)
|
||||||
|
self.send_header("Content-Type", "application/json")
|
||||||
|
self.send_header("Content-Length", str(len(response_json)))
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
self.wfile.write(json.dumps(response).encode("utf-8"))
|
self.wfile.write(response_json)
|
||||||
|
self.wfile.flush()
|
||||||
|
|
||||||
self.request_handler = DelayedRequestHandler
|
self.request_handler = DelayedRequestHandler
|
||||||
|
|
||||||
|
|
43
tests/unit/test_state_dir_resolution.py
Normal file
43
tests/unit/test_state_dir_resolution.py
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from llama_stack.distribution.resolver import resolve_storage_dir
|
||||||
|
|
||||||
|
|
||||||
|
class DummyConfig:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_storage_dir_cli(monkeypatch):
|
||||||
|
config = DummyConfig()
|
||||||
|
config.storage_dir = "/cli/dir"
|
||||||
|
monkeypatch.delenv("LLAMA_STACK_STORAGE_DIR", raising=False)
|
||||||
|
result = resolve_storage_dir(config, "distro")
|
||||||
|
assert result == Path("/cli/dir")
|
||||||
|
|
||||||
|
|
||||||
|
def test_storage_dir_env(monkeypatch):
|
||||||
|
config = DummyConfig()
|
||||||
|
if hasattr(config, "storage_dir"):
|
||||||
|
delattr(config, "storage_dir")
|
||||||
|
monkeypatch.setenv("LLAMA_STACK_STORAGE_DIR", "/env/dir")
|
||||||
|
result = resolve_storage_dir(config, "distro")
|
||||||
|
assert result == Path("/env/dir")
|
||||||
|
|
||||||
|
|
||||||
|
def test_storage_dir_fallback(monkeypatch):
|
||||||
|
# Mock the DISTRIBS_BASE_DIR
|
||||||
|
monkeypatch.setattr("llama_stack.distribution.utils.config_dirs.DISTRIBS_BASE_DIR", Path("/mock/distribs"))
|
||||||
|
|
||||||
|
config = DummyConfig()
|
||||||
|
if hasattr(config, "storage_dir"):
|
||||||
|
delattr(config, "storage_dir")
|
||||||
|
monkeypatch.delenv("LLAMA_STACK_STORAGE_DIR", raising=False)
|
||||||
|
|
||||||
|
result = resolve_storage_dir(config, "distro")
|
||||||
|
assert result == Path("/mock/distribs/distro")
|
Loading…
Add table
Add a link
Reference in a new issue