Merge branch 'main' into allow-dynamic-models-ollama

This commit is contained in:
Matthew Farrellee 2025-07-28 14:16:31 -04:00
commit 56476fa462
247 changed files with 9176 additions and 7177 deletions

View file

@ -4,15 +4,83 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from enum import Enum, EnumMeta
from pydantic import BaseModel
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
class DynamicApiMeta(EnumMeta):
def __new__(cls, name, bases, namespace):
# Store the original enum values
original_values = {k: v for k, v in namespace.items() if not k.startswith("_")}
# Create the enum class
cls = super().__new__(cls, name, bases, namespace)
# Store the original values for reference
cls._original_values = original_values
# Initialize _dynamic_values
cls._dynamic_values = {}
return cls
def __call__(cls, value):
try:
return super().__call__(value)
except ValueError as e:
# If this value was already dynamically added, return it
if value in cls._dynamic_values:
return cls._dynamic_values[value]
# If the value doesn't exist, create a new enum member
# Create a new member name from the value
member_name = value.lower().replace("-", "_")
# If this member name already exists in the enum, return the existing member
if member_name in cls._member_map_:
return cls._member_map_[member_name]
# Instead of creating a new member, raise ValueError to force users to use Api.add() to
# register new APIs explicitly
raise ValueError(f"API '{value}' does not exist. Use Api.add() to register new APIs.") from e
def __iter__(cls):
# Allow iteration over both static and dynamic members
yield from super().__iter__()
if hasattr(cls, "_dynamic_values"):
yield from cls._dynamic_values.values()
def add(cls, value):
"""
Add a new API to the enum.
Used to register external APIs.
"""
member_name = value.lower().replace("-", "_")
# If this member name already exists in the enum, return it
if member_name in cls._member_map_:
return cls._member_map_[member_name]
# Create a new enum member
member = object.__new__(cls)
member._name_ = member_name
member._value_ = value
# Add it to the enum class
cls._member_map_[member_name] = member
cls._member_names_.append(member_name)
cls._member_type_ = str
# Store it in our dynamic values
cls._dynamic_values[value] = member
return member
@json_schema_type
class Api(Enum):
class Api(Enum, metaclass=DynamicApiMeta):
providers = "providers"
inference = "inference"
safety = "safety"
@ -54,3 +122,12 @@ class Error(BaseModel):
title: str
detail: str
instance: str | None = None
class ExternalApiSpec(BaseModel):
"""Specification for an external API implementation."""
module: str = Field(..., description="Python module containing the API implementation")
name: str = Field(..., description="Name of the API")
pip_packages: list[str] = Field(default=[], description="List of pip packages to install the API")
protocol: str = Field(..., description="Name of the protocol class for the API")

View file

@ -464,6 +464,8 @@ register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletion
OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam]
OpenAIChatCompletionTextOnlyMessageContent = str | list[OpenAIChatCompletionContentPartTextParam]
@json_schema_type
class OpenAIUserMessageParam(BaseModel):
@ -489,7 +491,7 @@ class OpenAISystemMessageParam(BaseModel):
"""
role: Literal["system"] = "system"
content: OpenAIChatCompletionMessageContent
content: OpenAIChatCompletionTextOnlyMessageContent
name: str | None = None
@ -518,7 +520,7 @@ class OpenAIAssistantMessageParam(BaseModel):
"""
role: Literal["assistant"] = "assistant"
content: OpenAIChatCompletionMessageContent | None = None
content: OpenAIChatCompletionTextOnlyMessageContent | None = None
name: str | None = None
tool_calls: list[OpenAIChatCompletionToolCall] | None = None
@ -534,7 +536,7 @@ class OpenAIToolMessageParam(BaseModel):
role: Literal["tool"] = "tool"
tool_call_id: str
content: OpenAIChatCompletionMessageContent
content: OpenAIChatCompletionTextOnlyMessageContent
@json_schema_type
@ -547,7 +549,7 @@ class OpenAIDeveloperMessageParam(BaseModel):
"""
role: Literal["developer"] = "developer"
content: OpenAIChatCompletionMessageContent
content: OpenAIChatCompletionTextOnlyMessageContent
name: str | None = None
@ -819,12 +821,6 @@ class OpenAIEmbeddingsResponse(BaseModel):
class ModelStore(Protocol):
async def get_model(self, identifier: str) -> Model: ...
async def update_registered_llm_models(
self,
provider_id: str,
models: list[Model],
) -> None: ...
class TextTruncation(Enum):
"""Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left.

View file

@ -22,6 +22,8 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
# Add this constant near the top of the file, after the imports
DEFAULT_TTL_DAYS = 7
REQUIRED_SCOPE = "telemetry.read"
@json_schema_type
class SpanStatus(Enum):
@ -259,7 +261,7 @@ class Telemetry(Protocol):
"""
...
@webmethod(route="/telemetry/traces", method="POST")
@webmethod(route="/telemetry/traces", method="POST", required_scope=REQUIRED_SCOPE)
async def query_traces(
self,
attribute_filters: list[QueryCondition] | None = None,
@ -277,7 +279,7 @@ class Telemetry(Protocol):
"""
...
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET")
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET", required_scope=REQUIRED_SCOPE)
async def get_trace(self, trace_id: str) -> Trace:
"""Get a trace by its ID.
@ -286,7 +288,9 @@ class Telemetry(Protocol):
"""
...
@webmethod(route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET")
@webmethod(
route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET", required_scope=REQUIRED_SCOPE
)
async def get_span(self, trace_id: str, span_id: str) -> Span:
"""Get a span by its ID.
@ -296,7 +300,7 @@ class Telemetry(Protocol):
"""
...
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST")
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST", required_scope=REQUIRED_SCOPE)
async def get_span_tree(
self,
span_id: str,
@ -312,7 +316,7 @@ class Telemetry(Protocol):
"""
...
@webmethod(route="/telemetry/spans", method="POST")
@webmethod(route="/telemetry/spans", method="POST", required_scope=REQUIRED_SCOPE)
async def query_spans(
self,
attribute_filters: list[QueryCondition],
@ -345,7 +349,7 @@ class Telemetry(Protocol):
"""
...
@webmethod(route="/telemetry/metrics/{metric_name}", method="POST")
@webmethod(route="/telemetry/metrics/{metric_name}", method="POST", required_scope=REQUIRED_SCOPE)
async def query_metrics(
self,
metric_name: str,

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from enum import Enum, StrEnum
from typing import Annotated, Any, Literal, Protocol
from pydantic import BaseModel, Field, field_validator
@ -88,7 +88,7 @@ class RAGQueryGenerator(Enum):
@json_schema_type
class RAGSearchMode(Enum):
class RAGSearchMode(StrEnum):
"""
Search modes for RAG query retrieval:
- VECTOR: Uses vector similarity search for semantic matching

View file

@ -34,6 +34,7 @@ class VectorDBInput(BaseModel):
vector_db_id: str
embedding_model: str
embedding_dimension: int
provider_id: str | None = None
provider_vector_db_id: str | None = None

View file

@ -338,7 +338,7 @@ class VectorIO(Protocol):
@webmethod(route="/openai/v1/vector_stores", method="POST")
async def openai_create_vector_store(
self,
name: str,
name: str | None = None,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,

View file

@ -31,11 +31,13 @@ from llama_stack.distribution.build import (
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import (
BuildConfig,
BuildProvider,
DistributionSpec,
Provider,
StackRunConfig,
)
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.external import load_external_apis
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.stack import replace_env_vars
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR, EXTERNAL_PROVIDERS_DIR
@ -93,7 +95,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
)
sys.exit(1)
elif args.providers:
providers_list: dict[str, str | list[str]] = dict()
provider_list: dict[str, list[BuildProvider]] = dict()
for api_provider in args.providers.split(","):
if "=" not in api_provider:
cprint(
@ -102,7 +104,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
file=sys.stderr,
)
sys.exit(1)
api, provider = api_provider.split("=")
api, provider_type = api_provider.split("=")
providers_for_api = get_provider_registry().get(Api(api), None)
if providers_for_api is None:
cprint(
@ -111,16 +113,12 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
file=sys.stderr,
)
sys.exit(1)
if provider in providers_for_api:
if api not in providers_list:
providers_list[api] = []
# Use type guarding to ensure we have a list
provider_value = providers_list[api]
if isinstance(provider_value, list):
provider_value.append(provider)
else:
# Convert string to list and append
providers_list[api] = [provider_value, provider]
if provider_type in providers_for_api:
provider = BuildProvider(
provider_type=provider_type,
module=None,
)
provider_list.setdefault(api, []).append(provider)
else:
cprint(
f"{provider} is not a valid provider for the {api} API.",
@ -129,7 +127,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
)
sys.exit(1)
distribution_spec = DistributionSpec(
providers=providers_list,
providers=provider_list,
description=",".join(args.providers),
)
if not args.image_type:
@ -190,7 +188,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
cprint("Tip: use <TAB> to see options for the providers.\n", color="green", file=sys.stderr)
providers: dict[str, str | list[str]] = dict()
providers: dict[str, list[BuildProvider]] = dict()
for api, providers_for_api in get_provider_registry().items():
available_providers = [x for x in providers_for_api.keys() if x not in ("remote", "remote::sample")]
if not available_providers:
@ -205,7 +203,10 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
),
)
providers[api.value] = api_provider
string_providers = api_provider.split(" ")
for provider in string_providers:
providers.setdefault(api.value, []).append(BuildProvider(provider_type=provider))
description = prompt(
"\n > (Optional) Enter a short description for your Llama Stack: ",
@ -236,11 +237,13 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
if args.print_deps_only:
print(f"# Dependencies for {args.template or args.config or image_name}")
normal_deps, special_deps = get_provider_dependencies(build_config)
normal_deps, special_deps, external_provider_dependencies = get_provider_dependencies(build_config)
normal_deps += SERVER_DEPENDENCIES
print(f"uv pip install {' '.join(normal_deps)}")
for special_dep in special_deps:
print(f"uv pip install {special_dep}")
for external_dep in external_provider_dependencies:
print(f"uv pip install {external_dep}")
return
try:
@ -276,8 +279,8 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
config = parse_and_maybe_upgrade_config(config_dict)
if config.external_providers_dir and not config.external_providers_dir.exists():
config.external_providers_dir.mkdir(exist_ok=True)
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", run_config])
run_args = formulate_run_args(args.image_type, args.image_name)
run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", str(run_config)])
run_command(run_args)
@ -303,27 +306,25 @@ def _generate_run_config(
provider_registry = get_provider_registry(build_config)
for api in apis:
run_config.providers[api] = []
provider_types = build_config.distribution_spec.providers[api]
if isinstance(provider_types, str):
provider_types = [provider_types]
providers = build_config.distribution_spec.providers[api]
for i, provider_type in enumerate(provider_types):
pid = provider_type.split("::")[-1]
for provider in providers:
pid = provider.provider_type.split("::")[-1]
p = provider_registry[Api(api)][provider_type]
p = provider_registry[Api(api)][provider.provider_type]
if p.deprecation_error:
raise InvalidProviderError(p.deprecation_error)
try:
config_type = instantiate_class_type(provider_registry[Api(api)][provider_type].config_class)
except ModuleNotFoundError:
config_type = instantiate_class_type(provider_registry[Api(api)][provider.provider_type].config_class)
except (ModuleNotFoundError, ValueError) as exc:
# HACK ALERT:
# This code executes after building is done, the import cannot work since the
# package is either available in the venv or container - not available on the host.
# TODO: use a "is_external" flag in ProviderSpec to check if the provider is
# external
cprint(
f"Failed to import provider {provider_type} for API {api} - assuming it's external, skipping",
f"Failed to import provider {provider.provider_type} for API {api} - assuming it's external, skipping: {exc}",
color="yellow",
file=sys.stderr,
)
@ -336,9 +337,10 @@ def _generate_run_config(
config = {}
p_spec = Provider(
provider_id=f"{pid}-{i}" if len(provider_types) > 1 else pid,
provider_type=provider_type,
provider_id=pid,
provider_type=provider.provider_type,
config=config,
module=provider.module,
)
run_config.providers[api].append(p_spec)
@ -401,9 +403,32 @@ def _run_stack_build_command_from_build_config(
run_config_file = _generate_run_config(build_config, build_dir, image_name)
with open(build_file_path, "w") as f:
to_write = json.loads(build_config.model_dump_json())
to_write = json.loads(build_config.model_dump_json(exclude_none=True))
f.write(yaml.dump(to_write, sort_keys=False))
# We first install the external APIs so that the build process can use them and discover the
# providers dependencies
if build_config.external_apis_dir:
cprint("Installing external APIs", color="yellow", file=sys.stderr)
external_apis = load_external_apis(build_config)
if external_apis:
# install the external APIs
packages = []
for _, api_spec in external_apis.items():
if api_spec.pip_packages:
packages.extend(api_spec.pip_packages)
cprint(
f"Installing {api_spec.name} with pip packages {api_spec.pip_packages}",
color="yellow",
file=sys.stderr,
)
return_code = run_command(["uv", "pip", "install", *packages])
if return_code != 0:
packages_str = ", ".join(packages)
raise RuntimeError(
f"Failed to install external APIs packages: {packages_str} (return code: {return_code})"
)
return_code = build_image(
build_config,
build_file_path,

View file

@ -82,39 +82,6 @@ class StackRun(Subcommand):
return ImageType.CONDA.value, args.image_name
return args.image_type, args.image_name
def _resolve_config_and_template(self, args: argparse.Namespace) -> tuple[Path | None, str | None]:
"""Resolve config file path and template name from args.config"""
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
if not args.config:
return None, None
config_file = Path(args.config)
has_yaml_suffix = args.config.endswith(".yaml")
template_name = None
if not config_file.exists() and not has_yaml_suffix:
# check if this is a template
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
if config_file.exists():
template_name = args.config
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to ~/.llama dir
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
if not config_file.exists():
self.parser.error(
f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
)
if not config_file.is_file():
self.parser.error(
f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}"
)
return config_file, template_name
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
import yaml
@ -125,8 +92,15 @@ class StackRun(Subcommand):
self._start_ui_development_server(args.port)
image_type, image_name = self._get_image_type_and_name(args)
# Resolve config file and template name first
config_file, template_name = self._resolve_config_and_template(args)
if args.config:
try:
from llama_stack.distribution.utils.config_resolution import Mode, resolve_config_or_template
config_file = resolve_config_or_template(args.config, Mode.RUN)
except ValueError as e:
self.parser.error(str(e))
else:
config_file = None
# Check if config is required based on image type
if (image_type in [ImageType.CONDA.value, ImageType.VENV.value]) and not config_file:
@ -164,18 +138,14 @@ class StackRun(Subcommand):
if callable(getattr(args, arg)):
continue
if arg == "config":
if template_name:
server_args.template = str(template_name)
else:
# Set the config file path
server_args.config = str(config_file)
server_args.config = str(config_file)
else:
setattr(server_args, arg, getattr(args, arg))
# Run the server
server_main(server_args)
else:
run_args = formulate_run_args(image_type, image_name, config, template_name)
run_args = formulate_run_args(image_type, image_name)
run_args.extend([str(args.port)])

48
llama_stack/cli/utils.py Normal file
View file

@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="cli")
def add_config_template_args(parser: argparse.ArgumentParser):
"""Add unified config/template arguments with backward compatibility."""
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"config",
nargs="?",
help="Configuration file path or template name",
)
# Backward compatibility arguments (deprecated)
group.add_argument(
"--config",
dest="config_deprecated",
help="(DEPRECATED) Use positional argument [config] instead. Configuration file path",
)
group.add_argument(
"--template",
dest="template_deprecated",
help="(DEPRECATED) Use positional argument [config] instead. Template name",
)
def get_config_from_args(args: argparse.Namespace) -> str | None:
"""Extract config value from parsed arguments, handling both new and deprecated forms."""
if args.config is not None:
return str(args.config)
elif hasattr(args, "config_deprecated") and args.config_deprecated is not None:
logger.warning("Using deprecated --config argument. Use positional argument [config] instead.")
return str(args.config_deprecated)
elif hasattr(args, "template_deprecated") and args.template_deprecated is not None:
logger.warning("Using deprecated --template argument. Use positional argument [config] instead.")
return str(args.template_deprecated)
return None

View file

@ -14,6 +14,7 @@ from termcolor import cprint
from llama_stack.distribution.datatypes import BuildConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.external import load_external_apis
from llama_stack.distribution.utils.exec import run_command
from llama_stack.distribution.utils.image_types import LlamaStackImageType
from llama_stack.providers.datatypes import Api
@ -41,7 +42,7 @@ class ApiInput(BaseModel):
def get_provider_dependencies(
config: BuildConfig | DistributionTemplate,
) -> tuple[list[str], list[str]]:
) -> tuple[list[str], list[str], list[str]]:
"""Get normal and special dependencies from provider configuration."""
if isinstance(config, DistributionTemplate):
config = config.build_config()
@ -50,6 +51,7 @@ def get_provider_dependencies(
additional_pip_packages = config.additional_pip_packages
deps = []
external_provider_deps = []
registry = get_provider_registry(config)
for api_str, provider_or_providers in providers.items():
providers_for_api = registry[Api(api_str)]
@ -64,8 +66,16 @@ def get_provider_dependencies(
raise ValueError(f"Provider `{provider}` is not available for API `{api_str}`")
provider_spec = providers_for_api[provider_type]
deps.extend(provider_spec.pip_packages)
if provider_spec.container_image:
if hasattr(provider_spec, "is_external") and provider_spec.is_external:
# this ensures we install the top level module for our external providers
if provider_spec.module:
if isinstance(provider_spec.module, str):
external_provider_deps.append(provider_spec.module)
else:
external_provider_deps.extend(provider_spec.module)
if hasattr(provider_spec, "pip_packages"):
deps.extend(provider_spec.pip_packages)
if hasattr(provider_spec, "container_image") and provider_spec.container_image:
raise ValueError("A stack's dependencies cannot have a container image")
normal_deps = []
@ -78,7 +88,7 @@ def get_provider_dependencies(
normal_deps.extend(additional_pip_packages or [])
return list(set(normal_deps)), list(set(special_deps))
return list(set(normal_deps)), list(set(special_deps)), list(set(external_provider_deps))
def print_pip_install_help(config: BuildConfig):
@ -103,41 +113,59 @@ def build_image(
):
container_base = build_config.distribution_spec.container_image or "python:3.12-slim"
normal_deps, special_deps = get_provider_dependencies(build_config)
normal_deps, special_deps, external_provider_deps = get_provider_dependencies(build_config)
normal_deps += SERVER_DEPENDENCIES
if build_config.external_apis_dir:
external_apis = load_external_apis(build_config)
if external_apis:
for _, api_spec in external_apis.items():
normal_deps.extend(api_spec.pip_packages)
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
script = str(importlib.resources.files("llama_stack") / "distribution/build_container.sh")
args = [
script,
"--template-or-config",
template_or_config,
"--image-name",
image_name,
"--container-base",
container_base,
"--normal-deps",
" ".join(normal_deps),
]
# When building from a config file (not a template), include the run config path in the
# build arguments
if run_config is not None:
args.append(run_config)
args.extend(["--run-config", run_config])
elif build_config.image_type == LlamaStackImageType.CONDA.value:
script = str(importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh")
args = [
script,
"--env-name",
str(image_name),
"--build-file-path",
str(build_file_path),
"--normal-deps",
" ".join(normal_deps),
]
elif build_config.image_type == LlamaStackImageType.VENV.value:
script = str(importlib.resources.files("llama_stack") / "distribution/build_venv.sh")
args = [
script,
"--env-name",
str(image_name),
"--normal-deps",
" ".join(normal_deps),
]
# Always pass both arguments, even if empty, to maintain consistent positional arguments
if special_deps:
args.append("#".join(special_deps))
args.extend(["--optional-deps", "#".join(special_deps)])
if external_provider_deps:
args.extend(
["--external-provider-deps", "#".join(external_provider_deps)]
) # the script will install external provider module, get its deps, and install those too.
return_code = run_command(args)

View file

@ -9,10 +9,91 @@
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
PYPI_VERSION=${PYPI_VERSION:-}
# This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694
UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
set -euo pipefail
# Define color codes
RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m' # No Color
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
source "$SCRIPT_DIR/common.sh"
# Usage function
usage() {
echo "Usage: $0 --env-name <conda_env_name> --build-file-path <build_file_path> --normal-deps <pip_dependencies> [--external-provider-deps <external_provider_deps>] [--optional-deps <special_pip_deps>]"
echo "Example: $0 --env-name my-conda-env --build-file-path ./my-stack-build.yaml --normal-deps 'numpy pandas scipy' --external-provider-deps 'foo' --optional-deps 'bar'"
exit 1
}
# Parse arguments
env_name=""
build_file_path=""
normal_deps=""
external_provider_deps=""
optional_deps=""
while [[ $# -gt 0 ]]; do
key="$1"
case "$key" in
--env-name)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --env-name requires a string value" >&2
usage
fi
env_name="$2"
shift 2
;;
--build-file-path)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --build-file-path requires a string value" >&2
usage
fi
build_file_path="$2"
shift 2
;;
--normal-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --normal-deps requires a string value" >&2
usage
fi
normal_deps="$2"
shift 2
;;
--external-provider-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --external-provider-deps requires a string value" >&2
usage
fi
external_provider_deps="$2"
shift 2
;;
--optional-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --optional-deps requires a string value" >&2
usage
fi
optional_deps="$2"
shift 2
;;
*)
echo "Unknown option: $1" >&2
usage
;;
esac
done
# Check required arguments
if [[ -z "$env_name" || -z "$build_file_path" || -z "$normal_deps" ]]; then
echo "Error: --env-name, --build-file-path, and --normal-deps are required." >&2
usage
fi
if [ -n "$LLAMA_STACK_DIR" ]; then
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
fi
@ -20,50 +101,18 @@ if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR"
fi
if [ "$#" -lt 3 ]; then
echo "Usage: $0 <distribution_type> <conda_env_name> <build_file_path> <pip_dependencies> [<special_pip_deps>]" >&2
echo "Example: $0 <distribution_type> my-conda-env ./my-stack-build.yaml 'numpy pandas scipy'" >&2
exit 1
fi
special_pip_deps="$4"
set -euo pipefail
env_name="$1"
build_file_path="$2"
pip_dependencies="$3"
# Define color codes
RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m' # No Color
# this is set if we actually create a new conda in which case we need to clean up
ENVNAME=""
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
source "$SCRIPT_DIR/common.sh"
ensure_conda_env_python310() {
local env_name="$1"
local pip_dependencies="$2"
local special_pip_deps="$3"
# Use only global variables set by flag parser
local python_version="3.12"
# Check if conda command is available
if ! is_command_available conda; then
printf "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2
exit 1
fi
# Check if the environment exists
if conda env list | grep -q "^${env_name} "; then
printf "Conda environment '${env_name}' exists. Checking Python version...\n"
# Check Python version in the environment
current_version=$(conda run -n "${env_name}" python --version 2>&1 | cut -d' ' -f2 | cut -d'.' -f1,2)
if [ "$current_version" = "$python_version" ]; then
printf "Environment '${env_name}' already has Python ${python_version}. No action needed.\n"
else
@ -73,37 +122,37 @@ ensure_conda_env_python310() {
else
printf "Conda environment '${env_name}' does not exist. Creating with Python ${python_version}...\n"
conda create -n "${env_name}" python="${python_version}" -y
ENVNAME="${env_name}"
# setup_cleanup_handlers
fi
eval "$(conda shell.bash hook)"
conda deactivate && conda activate "${env_name}"
"$CONDA_PREFIX"/bin/pip install uv
if [ -n "$TEST_PYPI_VERSION" ]; then
# these packages are damaged in test-pypi, so install them first
uv pip install fastapi libcst
uv pip install --extra-index-url https://test.pypi.org/simple/ \
llama-stack=="$TEST_PYPI_VERSION" \
"$pip_dependencies"
if [ -n "$special_pip_deps" ]; then
IFS='#' read -ra parts <<<"$special_pip_deps"
"$normal_deps"
if [ -n "$optional_deps" ]; then
IFS='#' read -ra parts <<<"$optional_deps"
for part in "${parts[@]}"; do
echo "$part"
uv pip install $part
done
fi
if [ -n "$external_provider_deps" ]; then
IFS='#' read -ra parts <<<"$external_provider_deps"
for part in "${parts[@]}"; do
echo "$part"
uv pip install "$part"
done
fi
else
# Re-installing llama-stack in the new conda environment
if [ -n "$LLAMA_STACK_DIR" ]; then
if [ ! -d "$LLAMA_STACK_DIR" ]; then
printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}\n" >&2
exit 1
fi
printf "Installing from LLAMA_STACK_DIR: $LLAMA_STACK_DIR\n"
uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR"
else
@ -115,31 +164,44 @@ ensure_conda_env_python310() {
fi
uv pip install --no-cache-dir "$SPEC_VERSION"
fi
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then
printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: $LLAMA_STACK_CLIENT_DIR${NC}\n" >&2
exit 1
fi
printf "Installing from LLAMA_STACK_CLIENT_DIR: $LLAMA_STACK_CLIENT_DIR\n"
uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR"
fi
# Install pip dependencies
printf "Installing pip dependencies\n"
uv pip install $pip_dependencies
if [ -n "$special_pip_deps" ]; then
IFS='#' read -ra parts <<<"$special_pip_deps"
uv pip install $normal_deps
if [ -n "$optional_deps" ]; then
IFS='#' read -ra parts <<<"$optional_deps"
for part in "${parts[@]}"; do
echo "$part"
uv pip install $part
done
fi
if [ -n "$external_provider_deps" ]; then
IFS='#' read -ra parts <<<"$external_provider_deps"
for part in "${parts[@]}"; do
echo "Getting provider spec for module: $part and installing dependencies"
package_name=$(echo "$part" | sed 's/[<>=!].*//')
python3 -c "
import importlib
import sys
try:
module = importlib.import_module(f'$package_name.provider')
spec = module.get_provider_spec()
if hasattr(spec, 'pip_packages') and spec.pip_packages:
print('\\n'.join(spec.pip_packages))
except Exception as e:
print(f'Error getting provider spec for $package_name: {e}', file=sys.stderr)
" | uv pip install -r -
done
fi
fi
mv "$build_file_path" "$CONDA_PREFIX"/llamastack-build.yaml
echo "Build spec configuration saved at $CONDA_PREFIX/llamastack-build.yaml"
}
ensure_conda_env_python310 "$env_name" "$pip_dependencies" "$special_pip_deps"
ensure_conda_env_python310 "$env_name" "$build_file_path" "$normal_deps" "$optional_deps" "$external_provider_deps"

View file

@ -19,57 +19,111 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
# mounting is not supported by docker buildx, so we use COPY instead
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
# Mount command for cache container .cache, can be overridden by the user if needed
MOUNT_CACHE=${MOUNT_CACHE:-"--mount=type=cache,id=llama-stack-cache,target=/root/.cache"}
# Path to the run.yaml file in the container
RUN_CONFIG_PATH=/app/run.yaml
BUILD_CONTEXT_DIR=$(pwd)
if [ "$#" -lt 4 ]; then
# This only works for templates
echo "Usage: $0 <template_or_config> <image_name> <container_base> <pip_dependencies> [<run_config>] [<special_pip_deps>]" >&2
exit 1
fi
set -euo pipefail
template_or_config="$1"
shift
image_name="$1"
shift
container_base="$1"
shift
pip_dependencies="$1"
shift
# Handle optional arguments
run_config=""
special_pip_deps=""
# Check if there are more arguments
# The logics is becoming cumbersom, we should refactor it if we can do better
if [ $# -gt 0 ]; then
# Check if the argument ends with .yaml
if [[ "$1" == *.yaml ]]; then
run_config="$1"
shift
# If there's another argument after .yaml, it must be special_pip_deps
if [ $# -gt 0 ]; then
special_pip_deps="$1"
fi
else
# If it's not .yaml, it must be special_pip_deps
special_pip_deps="$1"
fi
fi
# Define color codes
RED='\033[0;31m'
NC='\033[0m' # No Color
# Usage function
usage() {
echo "Usage: $0 --image-name <image_name> --container-base <container_base> --normal-deps <pip_dependencies> [--run-config <run_config>] [--external-provider-deps <external_provider_deps>] [--optional-deps <special_pip_deps>]"
echo "Example: $0 --image-name llama-stack-img --container-base python:3.12-slim --normal-deps 'numpy pandas' --run-config ./run.yaml --external-provider-deps 'foo' --optional-deps 'bar'"
exit 1
}
# Parse arguments
image_name=""
container_base=""
normal_deps=""
external_provider_deps=""
optional_deps=""
run_config=""
template_or_config=""
while [[ $# -gt 0 ]]; do
key="$1"
case "$key" in
--image-name)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --image-name requires a string value" >&2
usage
fi
image_name="$2"
shift 2
;;
--container-base)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --container-base requires a string value" >&2
usage
fi
container_base="$2"
shift 2
;;
--normal-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --normal-deps requires a string value" >&2
usage
fi
normal_deps="$2"
shift 2
;;
--external-provider-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --external-provider-deps requires a string value" >&2
usage
fi
external_provider_deps="$2"
shift 2
;;
--optional-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --optional-deps requires a string value" >&2
usage
fi
optional_deps="$2"
shift 2
;;
--run-config)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --run-config requires a string value" >&2
usage
fi
run_config="$2"
shift 2
;;
--template-or-config)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --template-or-config requires a string value" >&2
usage
fi
template_or_config="$2"
shift 2
;;
*)
echo "Unknown option: $1" >&2
usage
;;
esac
done
# Check required arguments
if [[ -z "$image_name" || -z "$container_base" || -z "$normal_deps" ]]; then
echo "Error: --image-name, --container-base, and --normal-deps are required." >&2
usage
fi
CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
CONTAINER_OPTS=${CONTAINER_OPTS:---progress=plain}
TEMP_DIR=$(mktemp -d)
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
source "$SCRIPT_DIR/common.sh"
@ -78,18 +132,15 @@ add_to_container() {
if [ -t 0 ]; then
printf '%s\n' "$1" >>"$output_file"
else
# If stdin is not a terminal, read from it (heredoc)
cat >>"$output_file"
fi
}
# Check if container command is available
if ! is_command_available "$CONTAINER_BINARY"; then
printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2
exit 1
fi
# Update and install UBI9 components if UBI9 base image is used
if [[ $container_base == *"registry.access.redhat.com/ubi9"* ]]; then
add_to_container << EOF
FROM $container_base
@ -125,24 +176,59 @@ RUN pip install uv
EOF
fi
# Set the link mode to copy so that uv doesn't attempt to symlink to the cache directory
add_to_container << EOF
ENV UV_LINK_MODE=copy
EOF
# Add pip dependencies first since llama-stack is what will change most often
# so we can reuse layers.
if [ -n "$pip_dependencies" ]; then
if [ -n "$normal_deps" ]; then
read -ra pip_args <<< "$normal_deps"
quoted_deps=$(printf " %q" "${pip_args[@]}")
add_to_container << EOF
RUN uv pip install --no-cache $pip_dependencies
RUN $MOUNT_CACHE uv pip install $quoted_deps
EOF
fi
if [ -n "$special_pip_deps" ]; then
IFS='#' read -ra parts <<<"$special_pip_deps"
if [ -n "$optional_deps" ]; then
IFS='#' read -ra parts <<<"$optional_deps"
for part in "${parts[@]}"; do
read -ra pip_args <<< "$part"
quoted_deps=$(printf " %q" "${pip_args[@]}")
add_to_container <<EOF
RUN uv pip install --no-cache $part
RUN $MOUNT_CACHE uv pip install $quoted_deps
EOF
done
fi
if [ -n "$external_provider_deps" ]; then
IFS='#' read -ra parts <<<"$external_provider_deps"
for part in "${parts[@]}"; do
read -ra pip_args <<< "$part"
quoted_deps=$(printf " %q" "${pip_args[@]}")
add_to_container <<EOF
RUN $MOUNT_CACHE uv pip install $quoted_deps
EOF
add_to_container <<EOF
RUN python3 - <<PYTHON | $MOUNT_CACHE uv pip install -r -
import importlib
import sys
try:
package_name = '$part'.split('==')[0].split('>=')[0].split('<=')[0].split('!=')[0].split('<')[0].split('>')[0]
module = importlib.import_module(f'{package_name}.provider')
spec = module.get_provider_spec()
if hasattr(spec, 'pip_packages') and spec.pip_packages:
if isinstance(spec.pip_packages, (list, tuple)):
print('\n'.join(spec.pip_packages))
except Exception as e:
print(f'Error getting provider spec for {package_name}: {e}', file=sys.stderr)
PYTHON
EOF
done
fi
# Function to get Python command
get_python_cmd() {
if is_command_available python; then
echo "python"
@ -207,7 +293,7 @@ COPY $dir $mount_point
EOF
fi
add_to_container << EOF
RUN uv pip install --no-cache -e $mount_point
RUN $MOUNT_CACHE uv pip install -e $mount_point
EOF
}
@ -222,10 +308,10 @@ else
if [ -n "$TEST_PYPI_VERSION" ]; then
# these packages are damaged in test-pypi, so install them first
add_to_container << EOF
RUN uv pip install fastapi libcst
RUN $MOUNT_CACHE uv pip install fastapi libcst
EOF
add_to_container << EOF
RUN uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \
RUN $MOUNT_CACHE uv pip install --extra-index-url https://test.pypi.org/simple/ \
--index-strategy unsafe-best-match \
llama-stack==$TEST_PYPI_VERSION
@ -237,7 +323,7 @@ EOF
SPEC_VERSION="llama-stack"
fi
add_to_container << EOF
RUN uv pip install --no-cache $SPEC_VERSION
RUN $MOUNT_CACHE uv pip install $SPEC_VERSION
EOF
fi
fi
@ -328,7 +414,7 @@ $CONTAINER_BINARY build \
"$BUILD_CONTEXT_DIR"
# clean up tmp/configs
rm -f "$BUILD_CONTEXT_DIR/run.yaml"
rm -rf "$BUILD_CONTEXT_DIR/run.yaml" "$TEMP_DIR"
set +x
echo "Success!"

View file

@ -18,6 +18,76 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
UV_SYSTEM_PYTHON=${UV_SYSTEM_PYTHON:-}
VIRTUAL_ENV=${VIRTUAL_ENV:-}
set -euo pipefail
# Define color codes
RED='\033[0;31m'
NC='\033[0m' # No Color
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
source "$SCRIPT_DIR/common.sh"
# Usage function
usage() {
echo "Usage: $0 --env-name <env_name> --normal-deps <pip_dependencies> [--external-provider-deps <external_provider_deps>] [--optional-deps <special_pip_deps>]"
echo "Example: $0 --env-name mybuild --normal-deps 'numpy pandas scipy' --external-provider-deps 'foo' --optional-deps 'bar'"
exit 1
}
# Parse arguments
env_name=""
normal_deps=""
external_provider_deps=""
optional_deps=""
while [[ $# -gt 0 ]]; do
key="$1"
case "$key" in
--env-name)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --env-name requires a string value" >&2
usage
fi
env_name="$2"
shift 2
;;
--normal-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --normal-deps requires a string value" >&2
usage
fi
normal_deps="$2"
shift 2
;;
--external-provider-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --external-provider-deps requires a string value" >&2
usage
fi
external_provider_deps="$2"
shift 2
;;
--optional-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --optional-deps requires a string value" >&2
usage
fi
optional_deps="$2"
shift 2
;;
*)
echo "Unknown option: $1" >&2
usage
;;
esac
done
# Check required arguments
if [[ -z "$env_name" || -z "$normal_deps" ]]; then
echo "Error: --env-name and --normal-deps are required." >&2
usage
fi
if [ -n "$LLAMA_STACK_DIR" ]; then
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
fi
@ -25,29 +95,6 @@ if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR"
fi
if [ "$#" -lt 2 ]; then
echo "Usage: $0 <env_name> <pip_dependencies> [<special_pip_deps>]" >&2
echo "Example: $0 mybuild ./my-stack-build.yaml 'numpy pandas scipy'" >&2
exit 1
fi
special_pip_deps="$3"
set -euo pipefail
env_name="$1"
pip_dependencies="$2"
# Define color codes
RED='\033[0;31m'
NC='\033[0m' # No Color
# this is set if we actually create a new conda in which case we need to clean up
ENVNAME=""
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
source "$SCRIPT_DIR/common.sh"
# pre-run checks to make sure we can proceed with the installation
pre_run_checks() {
local env_name="$1"
@ -71,49 +118,44 @@ pre_run_checks() {
}
run() {
local env_name="$1"
local pip_dependencies="$2"
local special_pip_deps="$3"
# Use only global variables set by flag parser
if [ -n "$UV_SYSTEM_PYTHON" ] || [ "$env_name" == "__system__" ]; then
echo "Installing dependencies in system Python environment"
# if env == __system__, ensure we set UV_SYSTEM_PYTHON
export UV_SYSTEM_PYTHON=1
elif [ "$VIRTUAL_ENV" == "$env_name" ]; then
echo "Virtual environment $env_name is already active"
else
echo "Using virtual environment $env_name"
uv venv "$env_name"
# shellcheck source=/dev/null
source "$env_name/bin/activate"
fi
if [ -n "$TEST_PYPI_VERSION" ]; then
# these packages are damaged in test-pypi, so install them first
uv pip install fastapi libcst
# shellcheck disable=SC2086
# we are building a command line so word splitting is expected
uv pip install --extra-index-url https://test.pypi.org/simple/ \
--index-strategy unsafe-best-match \
llama-stack=="$TEST_PYPI_VERSION" \
$pip_dependencies
if [ -n "$special_pip_deps" ]; then
IFS='#' read -ra parts <<<"$special_pip_deps"
$normal_deps
if [ -n "$optional_deps" ]; then
IFS='#' read -ra parts <<<"$optional_deps"
for part in "${parts[@]}"; do
echo "$part"
# shellcheck disable=SC2086
# we are building a command line so word splitting is expected
uv pip install $part
done
fi
if [ -n "$external_provider_deps" ]; then
IFS='#' read -ra parts <<<"$external_provider_deps"
for part in "${parts[@]}"; do
echo "$part"
uv pip install "$part"
done
fi
else
# Re-installing llama-stack in the new virtual environment
if [ -n "$LLAMA_STACK_DIR" ]; then
if [ ! -d "$LLAMA_STACK_DIR" ]; then
printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_DIR" >&2
exit 1
fi
printf "Installing from LLAMA_STACK_DIR: %s\n" "$LLAMA_STACK_DIR"
uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR"
else
@ -125,27 +167,41 @@ run() {
printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_CLIENT_DIR" >&2
exit 1
fi
printf "Installing from LLAMA_STACK_CLIENT_DIR: %s\n" "$LLAMA_STACK_CLIENT_DIR"
uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR"
fi
# Install pip dependencies
printf "Installing pip dependencies\n"
# shellcheck disable=SC2086
# we are building a command line so word splitting is expected
uv pip install $pip_dependencies
if [ -n "$special_pip_deps" ]; then
IFS='#' read -ra parts <<<"$special_pip_deps"
uv pip install $normal_deps
if [ -n "$optional_deps" ]; then
IFS='#' read -ra parts <<<"$optional_deps"
for part in "${parts[@]}"; do
echo "$part"
# shellcheck disable=SC2086
# we are building a command line so word splitting is expected
echo "Installing special provider module: $part"
uv pip install $part
done
fi
if [ -n "$external_provider_deps" ]; then
IFS='#' read -ra parts <<<"$external_provider_deps"
for part in "${parts[@]}"; do
echo "Installing external provider module: $part"
uv pip install "$part"
echo "Getting provider spec for module: $part and installing dependencies"
package_name=$(echo "$part" | sed 's/[<>=!].*//')
python3 -c "
import importlib
import sys
try:
module = importlib.import_module(f'$package_name.provider')
spec = module.get_provider_spec()
if hasattr(spec, 'pip_packages') and spec.pip_packages:
print('\\n'.join(spec.pip_packages))
except Exception as e:
print(f'Error getting provider spec for $package_name: {e}', file=sys.stderr)
" | uv pip install -r -
done
fi
fi
}
pre_run_checks "$env_name"
run "$env_name" "$pip_dependencies" "$special_pip_deps"
run

View file

@ -91,21 +91,22 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
logger.info(f"Configuring API `{api_str}`...")
updated_providers = []
for i, provider_type in enumerate(plist):
for i, provider in enumerate(plist):
if i >= 1:
others = ", ".join(plist[i:])
others = ", ".join(p.provider_type for p in plist[i:])
logger.info(
f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n"
)
break
logger.info(f"> Configuring provider `({provider_type})`")
logger.info(f"> Configuring provider `({provider.provider_type})`")
pid = provider.provider_type.split("::")[-1]
updated_providers.append(
configure_single_provider(
provider_registry[api],
Provider(
provider_id=(f"{provider_type}-{i:02d}" if len(plist) > 1 else provider_type),
provider_type=provider_type,
provider_id=(f"{pid}-{i:02d}" if len(plist) > 1 else pid),
provider_type=provider.provider_type,
config={},
),
)

View file

@ -36,6 +36,11 @@ LLAMA_STACK_RUN_CONFIG_VERSION = 2
RoutingKey = str | list[str]
class RegistryEntrySource(StrEnum):
via_register_api = "via_register_api"
listed_from_provider = "listed_from_provider"
class User(BaseModel):
principal: str
# further attributes that may be used for access control decisions
@ -50,6 +55,7 @@ class ResourceWithOwner(Resource):
resource. This can be used to constrain access to the resource."""
owner: User | None = None
source: RegistryEntrySource = RegistryEntrySource.via_register_api
# Use the extended Resource for all routable objects
@ -130,29 +136,54 @@ class RoutingTableProviderSpec(ProviderSpec):
pip_packages: list[str] = Field(default_factory=list)
class Provider(BaseModel):
# provider_id of None means that the provider is not enabled - this happens
# when the provider is enabled via a conditional environment variable
provider_id: str | None
provider_type: str
config: dict[str, Any] = {}
module: str | None = Field(
default=None,
description="""
Fully-qualified name of the external provider module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
Example: `module: ramalama_stack`
""",
)
class BuildProvider(BaseModel):
provider_type: str
module: str | None = Field(
default=None,
description="""
Fully-qualified name of the external provider module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
Example: `module: ramalama_stack`
""",
)
class DistributionSpec(BaseModel):
description: str | None = Field(
default="",
description="Description of the distribution",
)
container_image: str | None = None
providers: dict[str, str | list[str]] = Field(
providers: dict[str, list[BuildProvider]] = Field(
default_factory=dict,
description="""
Provider Types for each of the APIs provided by this distribution. If you
select multiple providers, you should provide an appropriate 'routing_map'
in the runtime configuration to help route to the correct provider.""",
Provider Types for each of the APIs provided by this distribution. If you
select multiple providers, you should provide an appropriate 'routing_map'
in the runtime configuration to help route to the correct provider.
""",
)
class Provider(BaseModel):
# provider_id of None means that the provider is not enabled - this happens
# when the provider is enabled via a conditional environment variable
provider_id: str | None
provider_type: str
config: dict[str, Any]
class LoggingConfig(BaseModel):
category_levels: dict[str, str] = Field(
default_factory=dict,
@ -381,6 +412,11 @@ a default SQLite store will be used.""",
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
)
external_apis_dir: Path | None = Field(
default=None,
description="Path to directory containing external API implementations. The APIs code and dependencies must be installed on the system.",
)
@field_validator("external_providers_dir")
@classmethod
def validate_external_providers_dir(cls, v):
@ -412,6 +448,10 @@ class BuildConfig(BaseModel):
default_factory=list,
description="Additional pip packages to install in the distribution. These packages will be installed in the distribution environment.",
)
external_apis_dir: Path | None = Field(
default=None,
description="Path to directory containing external API implementations. The APIs code and dependencies must be installed on the system.",
)
@field_validator("external_providers_dir")
@classmethod

View file

@ -12,6 +12,8 @@ from typing import Any
import yaml
from pydantic import BaseModel
from llama_stack.distribution.datatypes import BuildConfig, DistributionSpec
from llama_stack.distribution.external import load_external_apis
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
AdapterSpec,
@ -96,12 +98,10 @@ def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_nam
return spec
def get_provider_registry(
config=None,
) -> dict[Api, dict[str, ProviderSpec]]:
def get_provider_registry(config=None) -> dict[Api, dict[str, ProviderSpec]]:
"""Get the provider registry, optionally including external providers.
This function loads both built-in providers and external providers from YAML files.
This function loads both built-in providers and external providers from YAML files or from their provided modules.
External providers are loaded from a directory structure like:
providers.d/
@ -122,8 +122,13 @@ def get_provider_registry(
safety/
llama-guard.yaml
This method is overloaded in that it can be called from a variety of places: during build, during run, during stack construction.
So when building external providers from a module, there are scenarios where the pip package required to import the module might not be available yet.
There is special handling for all of the potential cases this method can be called from.
Args:
config: Optional object containing the external providers directory path
building: Optional bool delineating whether or not this is being called from a build process
Returns:
A dictionary mapping APIs to their available providers
@ -133,58 +138,140 @@ def get_provider_registry(
ValueError: If any provider spec is invalid
"""
ret: dict[Api, dict[str, ProviderSpec]] = {}
registry: dict[Api, dict[str, ProviderSpec]] = {}
for api in providable_apis():
name = api.name.lower()
logger.debug(f"Importing module {name}")
try:
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
ret[api] = {a.provider_type: a for a in module.available_providers()}
registry[api] = {a.provider_type: a for a in module.available_providers()}
except ImportError as e:
logger.warning(f"Failed to import module {name}: {e}")
# Check if config has the external_providers_dir attribute
if config and hasattr(config, "external_providers_dir") and config.external_providers_dir:
external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir))
if not os.path.exists(external_providers_dir):
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
logger.info(f"Loading external providers from {external_providers_dir}")
# Refresh providable APIs with external APIs if any
external_apis = load_external_apis(config)
for api, api_spec in external_apis.items():
name = api_spec.name.lower()
logger.info(f"Importing external API {name} module {api_spec.module}")
try:
module = importlib.import_module(api_spec.module)
registry[api] = {a.provider_type: a for a in module.available_providers()}
except (ImportError, AttributeError) as e:
# Populate the registry with an empty dict to avoid breaking the provider registry
# This assume that the in-tree provider(s) are not available for this API which means
# that users will need to use external providers for this API.
registry[api] = {}
logger.error(
f"Failed to import external API {name}: {e}. Could not populate the in-tree provider(s) registry for {api.name}. \n"
"Install the API package to load any in-tree providers for this API."
)
for api in providable_apis():
api_name = api.name.lower()
# Check if config has external providers
if config:
if hasattr(config, "external_providers_dir") and config.external_providers_dir:
registry = get_external_providers_from_dir(registry, config)
# else lets check for modules in each provider
registry = get_external_providers_from_module(
registry=registry,
config=config,
building=(isinstance(config, BuildConfig) or isinstance(config, DistributionSpec)),
)
# Process both remote and inline providers
for provider_type in ["remote", "inline"]:
api_dir = os.path.join(external_providers_dir, provider_type, api_name)
if not os.path.exists(api_dir):
logger.debug(f"No {provider_type} provider directory found for {api_name}")
continue
return registry
# Look for provider spec files in the API directory
for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")):
provider_name = os.path.splitext(os.path.basename(spec_path))[0]
logger.info(f"Loading {provider_type} provider spec from {spec_path}")
try:
with open(spec_path) as f:
spec_data = yaml.safe_load(f)
def get_external_providers_from_dir(
registry: dict[Api, dict[str, ProviderSpec]], config
) -> dict[Api, dict[str, ProviderSpec]]:
logger.warning(
"Specifying external providers via `external_providers_dir` is being deprecated. Please specify `module:` in the provider instead."
)
external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir))
if not os.path.exists(external_providers_dir):
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
logger.info(f"Loading external providers from {external_providers_dir}")
if provider_type == "remote":
spec = _load_remote_provider_spec(spec_data, api)
provider_type_key = f"remote::{provider_name}"
else:
spec = _load_inline_provider_spec(spec_data, api, provider_name)
provider_type_key = f"inline::{provider_name}"
for api in providable_apis():
api_name = api.name.lower()
logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
if provider_type_key in ret[api]:
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
ret[api][provider_type_key] = spec
logger.info(f"Successfully loaded external provider {provider_type_key}")
except yaml.YAMLError as yaml_err:
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
raise yaml_err
except Exception as e:
logger.error(f"Failed to load provider spec from {spec_path}: {e}")
raise e
return ret
# Process both remote and inline providers
for provider_type in ["remote", "inline"]:
api_dir = os.path.join(external_providers_dir, provider_type, api_name)
if not os.path.exists(api_dir):
logger.debug(f"No {provider_type} provider directory found for {api_name}")
continue
# Look for provider spec files in the API directory
for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")):
provider_name = os.path.splitext(os.path.basename(spec_path))[0]
logger.info(f"Loading {provider_type} provider spec from {spec_path}")
try:
with open(spec_path) as f:
spec_data = yaml.safe_load(f)
if provider_type == "remote":
spec = _load_remote_provider_spec(spec_data, api)
provider_type_key = f"remote::{provider_name}"
else:
spec = _load_inline_provider_spec(spec_data, api, provider_name)
provider_type_key = f"inline::{provider_name}"
logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
if provider_type_key in registry[api]:
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
registry[api][provider_type_key] = spec
logger.info(f"Successfully loaded external provider {provider_type_key}")
except yaml.YAMLError as yaml_err:
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
raise yaml_err
except Exception as e:
logger.error(f"Failed to load provider spec from {spec_path}: {e}")
raise e
return registry
def get_external_providers_from_module(
registry: dict[Api, dict[str, ProviderSpec]], config, building: bool
) -> dict[Api, dict[str, ProviderSpec]]:
provider_list = None
if isinstance(config, BuildConfig):
provider_list = config.distribution_spec.providers.items()
else:
provider_list = config.providers.items()
if provider_list is None:
logger.warning("Could not get list of providers from config")
return registry
for provider_api, providers in provider_list:
for provider in providers:
if not hasattr(provider, "module") or provider.module is None:
continue
# get provider using module
try:
if not building:
package_name = provider.module.split("==")[0]
module = importlib.import_module(f"{package_name}.provider")
# if config class is wrong you will get an error saying module could not be imported
spec = module.get_provider_spec()
else:
# pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon build and run
spec = ProviderSpec(
api=Api(provider_api),
provider_type=provider.provider_type,
is_external=True,
module=provider.module,
config_class="",
)
provider_type = provider.provider_type
# in the case we are building we CANNOT import this module of course because it has not been installed.
# return a partially filled out spec that the build script will populate.
registry[Api(provider_api)][provider_type] = spec
except ModuleNotFoundError as exc:
raise ValueError(
"get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available"
) from exc
except Exception as e:
logger.error(f"Failed to load provider spec from module {provider.module}: {e}")
raise e
return registry

View file

@ -0,0 +1,54 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import yaml
from llama_stack.apis.datatypes import Api, ExternalApiSpec
from llama_stack.distribution.datatypes import BuildConfig, StackRunConfig
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="core")
def load_external_apis(config: StackRunConfig | BuildConfig | None) -> dict[Api, ExternalApiSpec]:
"""Load external API specifications from the configured directory.
Args:
config: StackRunConfig or BuildConfig containing the external APIs directory path
Returns:
A dictionary mapping API names to their specifications
"""
if not config or not config.external_apis_dir:
return {}
external_apis_dir = config.external_apis_dir.expanduser().resolve()
if not external_apis_dir.is_dir():
logger.error(f"External APIs directory is not a directory: {external_apis_dir}")
return {}
logger.info(f"Loading external APIs from {external_apis_dir}")
external_apis: dict[Api, ExternalApiSpec] = {}
# Look for YAML files in the external APIs directory
for yaml_path in external_apis_dir.glob("*.yaml"):
try:
with open(yaml_path) as f:
spec_data = yaml.safe_load(f)
spec = ExternalApiSpec(**spec_data)
api = Api.add(spec.name)
logger.info(f"Loaded external API spec for {spec.name} from {yaml_path}")
external_apis[api] = spec
except yaml.YAMLError as yaml_err:
logger.error(f"Failed to parse YAML file {yaml_path}: {yaml_err}")
raise
except Exception:
logger.exception(f"Failed to load external API spec from {yaml_path}")
raise
return external_apis

View file

@ -16,6 +16,7 @@ from llama_stack.apis.inspect import (
VersionInfo,
)
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.external import load_external_apis
from llama_stack.distribution.server.routes import get_all_api_routes
from llama_stack.providers.datatypes import HealthStatus
@ -42,7 +43,8 @@ class DistributionInspectImpl(Inspect):
run_config: StackRunConfig = self.config.run_config
ret = []
all_endpoints = get_all_api_routes()
external_apis = load_external_apis(run_config)
all_endpoints = get_all_api_routes(external_apis)
for api, endpoints in all_endpoints.items():
# Always include provider and inspect APIs, filter others based on run config
if api.value in ["providers", "inspect"]:
@ -53,7 +55,8 @@ class DistributionInspectImpl(Inspect):
method=next(iter([m for m in e.methods if m != "HEAD"])),
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
)
for e in endpoints
for e, _ in endpoints
if e.methods is not None
]
)
else:
@ -66,7 +69,8 @@ class DistributionInspectImpl(Inspect):
method=next(iter([m for m in e.methods if m != "HEAD"])),
provider_types=[p.provider_type for p in providers],
)
for e in endpoints
for e, _ in endpoints
if e.methods is not None
]
)

View file

@ -33,7 +33,7 @@ from termcolor import cprint
from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import Api, BuildConfig, DistributionSpec
from llama_stack.distribution.datatypes import Api, BuildConfig, BuildProvider, DistributionSpec
from llama_stack.distribution.request_headers import (
PROVIDER_DATA_VAR,
request_provider_data_context,
@ -161,7 +161,13 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
if not self.skip_logger_removal:
self._remove_root_logger_handlers()
return self.loop.run_until_complete(self.async_client.initialize())
# use a new event loop to avoid interfering with the main event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(self.async_client.initialize())
finally:
asyncio.set_event_loop(None)
def _remove_root_logger_handlers(self):
"""
@ -243,15 +249,16 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
file=sys.stderr,
)
if self.config_path_or_template_name.endswith(".yaml"):
# Convert Provider objects to their types
provider_types: dict[str, str | list[str]] = {}
for api, providers in self.config.providers.items():
types = [p.provider_type for p in providers]
# Convert single-item lists to strings
provider_types[api] = types[0] if len(types) == 1 else types
providers: dict[str, list[BuildProvider]] = {}
for api, run_providers in self.config.providers.items():
for provider in run_providers:
providers.setdefault(api, []).append(
BuildProvider(provider_type=provider.provider_type, module=provider.module)
)
providers = dict(providers)
build_config = BuildConfig(
distribution_spec=DistributionSpec(
providers=provider_types,
providers=providers,
),
external_providers_dir=self.config.external_providers_dir,
)
@ -353,13 +360,15 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
body = options.params or {}
body |= options.json_data or {}
matched_func, path_params, route = find_matching_route(options.method, path, self.route_impls)
matched_func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
body |= path_params
body, field_names = self._handle_file_uploads(options, body)
body = self._convert_body(path, options.method, body, exclude_params=set(field_names))
await start_trace(route, {"__location__": "library_client"})
trace_path = webmethod.descriptive_name or route_path
await start_trace(trace_path, {"__location__": "library_client"})
try:
result = await matched_func(**body)
finally:
@ -409,12 +418,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
path = options.url
body = options.params or {}
body |= options.json_data or {}
func, path_params, route = find_matching_route(options.method, path, self.route_impls)
func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
body |= path_params
body = self._convert_body(path, options.method, body)
await start_trace(route, {"__location__": "library_client"})
trace_path = webmethod.descriptive_name or route_path
await start_trace(trace_path, {"__location__": "library_client"})
async def gen():
try:
@ -445,8 +455,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
# we use asynchronous impl always internally and channel all requests to AsyncLlamaStackClient
# however, the top-level caller may be a SyncAPIClient -- so its stream_cls might be a Stream (SyncStream)
# so we need to convert it to AsyncStream
# mypy can't track runtime variables inside the [...] of a generic, so ignore that check
args = get_args(stream_cls)
stream_cls = AsyncStream[args[0]]
stream_cls = AsyncStream[args[0]] # type: ignore[valid-type]
response = AsyncAPIResponse(
raw=mock_response,
client=self,
@ -468,7 +479,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
exclude_params = exclude_params or set()
func, _, _ = find_matching_route(method, path, self.route_impls)
func, _, _, _ = find_matching_route(method, path, self.route_impls)
sig = inspect.signature(func)
# Strip NOT_GIVENs to use the defaults in signature

View file

@ -101,3 +101,15 @@ def get_authenticated_user() -> User | None:
if not provider_data:
return None
return provider_data.get("__authenticated_user")
def user_from_scope(scope: dict) -> User | None:
"""Create a User object from ASGI scope data (set by authentication middleware)"""
user_attributes = scope.get("user_attributes", {})
principal = scope.get("principal", "")
# auth not enabled
if not principal and not user_attributes:
return None
return User(principal=principal, attributes=user_attributes)

View file

@ -11,6 +11,7 @@ from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.datatypes import ExternalApiSpec
from llama_stack.apis.eval import Eval
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference, InferenceProvider
@ -35,6 +36,7 @@ from llama_stack.distribution.datatypes import (
StackRunConfig,
)
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.external import load_external_apis
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
@ -59,8 +61,16 @@ class InvalidProviderError(Exception):
pass
def api_protocol_map() -> dict[Api, Any]:
return {
def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) -> dict[Api, Any]:
"""Get a mapping of API types to their protocol classes.
Args:
external_apis: Optional dictionary of external API specifications
Returns:
Dictionary mapping API types to their protocol classes
"""
protocols = {
Api.providers: ProvidersAPI,
Api.agents: Agents,
Api.inference: Inference,
@ -83,10 +93,23 @@ def api_protocol_map() -> dict[Api, Any]:
Api.files: Files,
}
if external_apis:
for api, api_spec in external_apis.items():
try:
module = importlib.import_module(api_spec.module)
api_class = getattr(module, api_spec.protocol)
def api_protocol_map_for_compliance_check() -> dict[Api, Any]:
protocols[api] = api_class
except (ImportError, AttributeError):
logger.exception(f"Failed to load external API {api_spec.name}")
return protocols
def api_protocol_map_for_compliance_check(config: Any) -> dict[Api, Any]:
external_apis = load_external_apis(config)
return {
**api_protocol_map(),
**api_protocol_map(external_apis),
Api.inference: InferenceProvider,
}
@ -250,7 +273,7 @@ async def instantiate_providers(
dist_registry: DistributionRegistry,
run_config: StackRunConfig,
policy: list[AccessRule],
) -> dict:
) -> dict[Api, Any]:
"""Instantiates providers asynchronously while managing dependencies."""
impls: dict[Api, Any] = {}
inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
@ -322,7 +345,7 @@ async def instantiate_provider(
policy: list[AccessRule],
):
provider_spec = provider.spec
if not hasattr(provider_spec, "module"):
if not hasattr(provider_spec, "module") or provider_spec.module is None:
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
logger.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}")
@ -360,7 +383,7 @@ async def instantiate_provider(
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
protocols = api_protocol_map_for_compliance_check()
protocols = api_protocol_map_for_compliance_check(run_config)
additional_protocols = additional_protocols_map()
# TODO: check compliance for special tool groups
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol

View file

@ -57,7 +57,8 @@ class DatasetIORouter(DatasetIO):
logger.debug(
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
)
return await self.routing_table.get_provider_impl(dataset_id).iterrows(
provider = await self.routing_table.get_provider_impl(dataset_id)
return await provider.iterrows(
dataset_id=dataset_id,
start_index=start_index,
limit=limit,
@ -65,7 +66,8 @@ class DatasetIORouter(DatasetIO):
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
provider = await self.routing_table.get_provider_impl(dataset_id)
return await provider.append_rows(
dataset_id=dataset_id,
rows=rows,
)

View file

@ -44,7 +44,8 @@ class ScoringRouter(Scoring):
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
res = {}
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
provider = await self.routing_table.get_provider_impl(fn_identifier)
score_response = await provider.score_batch(
dataset_id=dataset_id,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
@ -66,7 +67,8 @@ class ScoringRouter(Scoring):
res = {}
# look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score(
provider = await self.routing_table.get_provider_impl(fn_identifier)
score_response = await provider.score(
input_rows=input_rows,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
@ -97,7 +99,8 @@ class EvalRouter(Eval):
benchmark_config: BenchmarkConfig,
) -> Job:
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.run_eval(
benchmark_id=benchmark_id,
benchmark_config=benchmark_config,
)
@ -110,7 +113,8 @@ class EvalRouter(Eval):
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.evaluate_rows(
benchmark_id=benchmark_id,
input_rows=input_rows,
scoring_functions=scoring_functions,
@ -123,7 +127,8 @@ class EvalRouter(Eval):
job_id: str,
) -> Job:
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.job_status(benchmark_id, job_id)
async def job_cancel(
self,
@ -131,7 +136,8 @@ class EvalRouter(Eval):
job_id: str,
) -> None:
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
provider = await self.routing_table.get_provider_impl(benchmark_id)
await provider.job_cancel(
benchmark_id,
job_id,
)
@ -142,7 +148,8 @@ class EvalRouter(Eval):
job_id: str,
) -> EvaluateResponse:
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.job_result(
benchmark_id,
job_id,
)

View file

@ -231,7 +231,7 @@ class InferenceRouter(Inference):
logprobs=logprobs,
tool_config=tool_config,
)
provider = self.routing_table.get_provider_impl(model_id)
provider = await self.routing_table.get_provider_impl(model_id)
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
if stream:
@ -292,7 +292,7 @@ class InferenceRouter(Inference):
logger.debug(
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
)
provider = self.routing_table.get_provider_impl(model_id)
provider = await self.routing_table.get_provider_impl(model_id)
return await provider.batch_chat_completion(
model_id=model_id,
messages_batch=messages_batch,
@ -322,7 +322,7 @@ class InferenceRouter(Inference):
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding:
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
provider = self.routing_table.get_provider_impl(model_id)
provider = await self.routing_table.get_provider_impl(model_id)
params = dict(
model_id=model_id,
content=content,
@ -378,7 +378,7 @@ class InferenceRouter(Inference):
logger.debug(
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
)
provider = self.routing_table.get_provider_impl(model_id)
provider = await self.routing_table.get_provider_impl(model_id)
return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs)
async def embeddings(
@ -395,7 +395,8 @@ class InferenceRouter(Inference):
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.llm:
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
return await self.routing_table.get_provider_impl(model_id).embeddings(
provider = await self.routing_table.get_provider_impl(model_id)
return await provider.embeddings(
model_id=model_id,
contents=contents,
text_truncation=text_truncation,
@ -458,7 +459,7 @@ class InferenceRouter(Inference):
suffix=suffix,
)
provider = self.routing_table.get_provider_impl(model_obj.identifier)
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.openai_completion(**params)
async def openai_chat_completion(
@ -538,7 +539,7 @@ class InferenceRouter(Inference):
user=user,
)
provider = self.routing_table.get_provider_impl(model_obj.identifier)
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
if stream:
response_stream = await provider.openai_chat_completion(**params)
if self.store:
@ -575,7 +576,7 @@ class InferenceRouter(Inference):
user=user,
)
provider = self.routing_table.get_provider_impl(model_obj.identifier)
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.openai_embeddings(**params)
async def list_chat_completions(

View file

@ -50,7 +50,8 @@ class SafetyRouter(Safety):
params: dict[str, Any] = None,
) -> RunShieldResponse:
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
return await self.routing_table.get_provider_impl(shield_id).run_shield(
provider = await self.routing_table.get_provider_impl(shield_id)
return await provider.run_shield(
shield_id=shield_id,
messages=messages,
params=params,

View file

@ -41,9 +41,8 @@ class ToolRuntimeRouter(ToolRuntime):
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
return await self.routing_table.get_provider_impl("knowledge_search").query(
content, vector_db_ids, query_config
)
provider = await self.routing_table.get_provider_impl("knowledge_search")
return await provider.query(content, vector_db_ids, query_config)
async def insert(
self,
@ -54,9 +53,8 @@ class ToolRuntimeRouter(ToolRuntime):
logger.debug(
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
)
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
documents, vector_db_id, chunk_size_in_tokens
)
provider = await self.routing_table.get_provider_impl("insert_into_memory")
return await provider.insert(documents, vector_db_id, chunk_size_in_tokens)
def __init__(
self,
@ -80,7 +78,8 @@ class ToolRuntimeRouter(ToolRuntime):
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
provider = await self.routing_table.get_provider_impl(tool_name)
return await provider.invoke_tool(
tool_name=tool_name,
kwargs=kwargs,
)

View file

@ -104,7 +104,8 @@ class VectorIORouter(VectorIO):
logger.debug(
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
)
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
provider = await self.routing_table.get_provider_impl(vector_db_id)
return await provider.insert_chunks(vector_db_id, chunks, ttl_seconds)
async def query_chunks(
self,
@ -113,7 +114,8 @@ class VectorIORouter(VectorIO):
params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
provider = await self.routing_table.get_provider_impl(vector_db_id)
return await provider.query_chunks(vector_db_id, query, params)
# OpenAI Vector Stores API endpoints
async def openai_create_vector_store(
@ -146,7 +148,8 @@ class VectorIORouter(VectorIO):
provider_vector_db_id=vector_db_id,
vector_db_name=name,
)
return await self.routing_table.get_provider_impl(registered_vector_db.identifier).openai_create_vector_store(
provider = await self.routing_table.get_provider_impl(registered_vector_db.identifier)
return await provider.openai_create_vector_store(
name=name,
file_ids=file_ids,
expires_after=expires_after,
@ -172,9 +175,8 @@ class VectorIORouter(VectorIO):
all_stores = []
for vector_db in vector_dbs:
try:
vector_store = await self.routing_table.get_provider_impl(
vector_db.identifier
).openai_retrieve_vector_store(vector_db.identifier)
provider = await self.routing_table.get_provider_impl(vector_db.identifier)
vector_store = await provider.openai_retrieve_vector_store(vector_db.identifier)
all_stores.append(vector_store)
except Exception as e:
logger.error(f"Error retrieving vector store {vector_db.identifier}: {e}")
@ -214,9 +216,7 @@ class VectorIORouter(VectorIO):
vector_store_id: str,
) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store(vector_store_id)
return await self.routing_table.openai_retrieve_vector_store(vector_store_id)
async def openai_update_vector_store(
self,
@ -226,9 +226,7 @@ class VectorIORouter(VectorIO):
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store(
return await self.routing_table.openai_update_vector_store(
vector_store_id=vector_store_id,
name=name,
expires_after=expires_after,
@ -240,12 +238,7 @@ class VectorIORouter(VectorIO):
vector_store_id: str,
) -> VectorStoreDeleteResponse:
logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)
result = await provider.openai_delete_vector_store(vector_store_id)
# drop from registry
await self.routing_table.unregister_vector_db(vector_store_id)
return result
return await self.routing_table.openai_delete_vector_store(vector_store_id)
async def openai_search_vector_store(
self,
@ -258,9 +251,7 @@ class VectorIORouter(VectorIO):
search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage:
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_search_vector_store(
return await self.routing_table.openai_search_vector_store(
vector_store_id=vector_store_id,
query=query,
filters=filters,
@ -278,9 +269,7 @@ class VectorIORouter(VectorIO):
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_attach_file_to_vector_store(
return await self.routing_table.openai_attach_file_to_vector_store(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
@ -297,9 +286,7 @@ class VectorIORouter(VectorIO):
filter: VectorStoreFileStatus | None = None,
) -> list[VectorStoreFileObject]:
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_list_files_in_vector_store(
return await self.routing_table.openai_list_files_in_vector_store(
vector_store_id=vector_store_id,
limit=limit,
order=order,
@ -314,9 +301,7 @@ class VectorIORouter(VectorIO):
file_id: str,
) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file(
return await self.routing_table.openai_retrieve_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
)
@ -327,9 +312,7 @@ class VectorIORouter(VectorIO):
file_id: str,
) -> VectorStoreFileContentsResponse:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_contents(
return await self.routing_table.openai_retrieve_vector_store_file_contents(
vector_store_id=vector_store_id,
file_id=file_id,
)
@ -341,9 +324,7 @@ class VectorIORouter(VectorIO):
attributes: dict[str, Any],
) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store_file(
return await self.routing_table.openai_update_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
@ -355,9 +336,7 @@ class VectorIORouter(VectorIO):
file_id: str,
) -> VectorStoreFileDeleteResponse:
logger.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_delete_vector_store_file(
return await self.routing_table.openai_delete_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
)

View file

@ -6,9 +6,11 @@
from typing import Any
from llama_stack.apis.models import Model
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.distribution.access_control.datatypes import Action
from llama_stack.distribution.datatypes import (
AccessRule,
RoutableObject,
@ -115,7 +117,10 @@ class CommonRoutingTableImpl(RoutingTable):
for p in self.impls_by_provider_id.values():
await p.shutdown()
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
async def refresh(self) -> None:
pass
async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
from .benchmarks import BenchmarksRoutingTable
from .datasets import DatasetsRoutingTable
from .models import ModelsRoutingTable
@ -204,11 +209,24 @@ class CommonRoutingTableImpl(RoutingTable):
if obj.type == ResourceType.model.value:
await self.dist_registry.register(registered_obj)
return registered_obj
else:
await self.dist_registry.register(obj)
return obj
async def assert_action_allowed(
self,
action: Action,
type: str,
identifier: str,
) -> None:
"""Fetch a registered object by type/identifier and enforce the given action permission."""
obj = await self.get_object_by_identifier(type, identifier)
if obj is None:
raise ValueError(f"{type.capitalize()} '{identifier}' not found")
user = get_authenticated_user()
if not is_action_allowed(self.policy, action, obj, user):
raise AccessDeniedError(action, obj, user)
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
objs = await self.dist_registry.get_all()
filtered_objs = [obj for obj in objs if obj.type == type]
@ -220,3 +238,28 @@ class CommonRoutingTableImpl(RoutingTable):
]
return filtered_objs
async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> Model:
# first try to get the model by identifier
# this works if model_id is an alias or is of the form provider_id/provider_model_id
model = await routing_table.get_object_by_identifier("model", model_id)
if model is not None:
return model
logger.warning(
f"WARNING: model identifier '{model_id}' not found in routing table. Falling back to "
"searching in all providers. This is only for backwards compatibility and will stop working "
"soon. Migrate your calls to use fully scoped `provider_id/model_id` names."
)
# if not found, this means model_id is an unscoped provider_model_id, we need
# to iterate (given a lack of an efficient index on the KVStore)
models = await routing_table.get_all_with_type("model")
matching_models = [m for m in models if m.provider_resource_id == model_id]
if len(matching_models) == 0:
raise ValueError(f"Model '{model_id}' not found")
if len(matching_models) > 1:
raise ValueError(f"Multiple providers found for '{model_id}': {[m.provider_id for m in matching_models]}")
return matching_models[0]

View file

@ -10,15 +10,37 @@ from typing import Any
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
from llama_stack.distribution.datatypes import (
ModelWithOwner,
RegistryEntrySource,
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core")
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
listed_providers: set[str] = set()
async def refresh(self) -> None:
for provider_id, provider in self.impls_by_provider_id.items():
refresh = await provider.should_refresh_models()
refresh = refresh or provider_id not in self.listed_providers
if not refresh:
continue
try:
models = await provider.list_models()
except Exception as e:
logger.exception(f"Model refresh failed for provider {provider_id}: {e}")
continue
self.listed_providers.add(provider_id)
if models is None:
continue
await self.update_registered_models(provider_id, models)
async def list_models(self) -> ListModelsResponse:
return ListModelsResponse(data=await self.get_all_with_type("model"))
@ -36,10 +58,11 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
return OpenAIListModelsResponse(data=openai_models)
async def get_model(self, model_id: str) -> Model:
model = await self.get_object_by_identifier("model", model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
return model
return await lookup_model(self, model_id)
async def get_provider_impl(self, model_id: str) -> Any:
model = await lookup_model(self, model_id)
return self.impls_by_provider_id[model.provider_id]
async def register_model(
self,
@ -49,28 +72,38 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None,
) -> Model:
if provider_model_id is None:
provider_model_id = model_id
if provider_id is None:
# If provider_id not specified, use the only provider if it supports this model
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
f"Please specify a provider_id for model {model_id} since multiple providers are available: {self.impls_by_provider_id.keys()}.\n\n"
"Use the provider_id as a prefix to disambiguate, e.g. 'provider_id/model_id'."
)
if metadata is None:
metadata = {}
if model_type is None:
model_type = ModelType.llm
provider_model_id = provider_model_id or model_id
metadata = metadata or {}
model_type = model_type or ModelType.llm
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
raise ValueError("Embedding model must have an embedding dimension in its metadata")
# an identifier different than provider_model_id implies it is an alias, so that
# becomes the globally unique identifier. otherwise provider_model_ids can conflict,
# so as a general rule we must use the provider_id to disambiguate.
if model_id != provider_model_id:
identifier = model_id
else:
identifier = f"{provider_id}/{provider_model_id}"
model = ModelWithOwner(
identifier=model_id,
identifier=identifier,
provider_resource_id=provider_model_id,
provider_id=provider_id,
metadata=metadata,
model_type=model_type,
source=RegistryEntrySource.via_register_api,
)
registered_model = await self.register_object(model)
return registered_model
@ -81,7 +114,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
raise ValueError(f"Model {model_id} not found")
await self.unregister_object(existing_model)
async def update_registered_llm_models(
async def update_registered_models(
self,
provider_id: str,
models: list[Model],
@ -92,18 +125,22 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
# from run.yaml) that we need to keep track of
model_ids = {}
for model in existing_models:
# we leave embeddings models alone because often we don't get metadata
# (embedding dimension, etc.) from the provider
if model.provider_id == provider_id and model.model_type == ModelType.llm:
if model.provider_id != provider_id:
continue
if model.source == RegistryEntrySource.via_register_api:
model_ids[model.provider_resource_id] = model.identifier
logger.debug(f"unregistering model {model.identifier}")
await self.unregister_object(model)
continue
logger.debug(f"unregistering model {model.identifier}")
await self.unregister_object(model)
for model in models:
if model.model_type != ModelType.llm:
continue
if model.provider_resource_id in model_ids:
model.identifier = model_ids[model.provider_resource_id]
# avoid overwriting a non-provider-registered model entry
continue
if model.identifier == model.provider_resource_id:
model.identifier = f"{provider_id}/{model.provider_resource_id}"
logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
await self.register_object(
@ -113,5 +150,6 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
provider_id=provider_id,
metadata=model.metadata,
model_type=model.model_type,
source=RegistryEntrySource.listed_from_provider,
)
)

View file

@ -30,7 +30,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
tool_to_toolgroup: dict[str, str] = {}
# overridden
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
# we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id
# TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while?
@ -40,7 +40,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
if routing_key in self.tool_to_toolgroup:
routing_key = self.tool_to_toolgroup[routing_key]
return super().get_provider_impl(routing_key, provider_id)
return await super().get_provider_impl(routing_key, provider_id)
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
if toolgroup_id:
@ -59,7 +59,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
return ListToolsResponse(data=all_tools)
async def _index_tools(self, toolgroup: ToolGroup):
provider_impl = super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
provider_impl = await super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint)
# TODO: kill this Tool vs ToolDef distinction

View file

@ -4,17 +4,30 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import TypeAdapter
from llama_stack.apis.models import ModelType
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
from llama_stack.apis.vector_io.vector_io import (
SearchRankingOptions,
VectorStoreChunkingStrategy,
VectorStoreDeleteResponse,
VectorStoreFileContentsResponse,
VectorStoreFileDeleteResponse,
VectorStoreFileObject,
VectorStoreFileStatus,
VectorStoreObject,
VectorStoreSearchResponsePage,
)
from llama_stack.distribution.datatypes import (
VectorDBWithOwner,
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core")
@ -38,8 +51,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
provider_vector_db_id: str | None = None,
vector_db_name: str | None = None,
) -> VectorDB:
if provider_vector_db_id is None:
provider_vector_db_id = vector_db_id
provider_vector_db_id = provider_vector_db_id or vector_db_id
if provider_id is None:
if len(self.impls_by_provider_id) > 0:
provider_id = list(self.impls_by_provider_id.keys())[0]
@ -49,7 +61,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
)
else:
raise ValueError("No provider available. Please configure a vector_io provider.")
model = await self.get_object_by_identifier("model", embedding_model)
model = await lookup_model(self, embedding_model)
if model is None:
raise ValueError(f"Model {embedding_model} not found")
if model.model_type != ModelType.embedding:
@ -74,3 +86,145 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
if existing_vector_db is None:
raise ValueError(f"Vector DB {vector_db_id} not found")
await self.unregister_object(existing_vector_db)
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store(vector_store_id)
async def openai_update_vector_store(
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store(
vector_store_id=vector_store_id,
name=name,
expires_after=expires_after,
metadata=metadata,
)
async def openai_delete_vector_store(
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
result = await provider.openai_delete_vector_store(vector_store_id)
await self.unregister_vector_db(vector_store_id)
return result
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_search_vector_store(
vector_store_id=vector_store_id,
query=query,
filters=filters,
max_num_results=max_num_results,
ranking_options=ranking_options,
rewrite_query=rewrite_query,
search_mode=search_mode,
)
async def openai_attach_file_to_vector_store(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_attach_file_to_vector_store(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
chunking_strategy=chunking_strategy,
)
async def openai_list_files_in_vector_store(
self,
vector_store_id: str,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
filter: VectorStoreFileStatus | None = None,
) -> list[VectorStoreFileObject]:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_list_files_in_vector_store(
vector_store_id=vector_store_id,
limit=limit,
order=order,
after=after,
before=before,
filter=filter,
)
async def openai_retrieve_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
)
async def openai_retrieve_vector_store_file_contents(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_contents(
vector_store_id=vector_store_id,
file_id=file_id,
)
async def openai_update_vector_store_file(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any],
) -> VectorStoreFileObject:
await self.assert_action_allowed("update", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
)
async def openai_delete_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileDeleteResponse:
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_delete_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
)

View file

@ -7,9 +7,12 @@
import json
import httpx
from aiohttp import hdrs
from llama_stack.distribution.datatypes import AuthenticationConfig
from llama_stack.distribution.datatypes import AuthenticationConfig, User
from llama_stack.distribution.request_headers import user_from_scope
from llama_stack.distribution.server.auth_providers import create_auth_provider
from llama_stack.distribution.server.routes import find_matching_route, initialize_route_impls
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
@ -78,12 +81,14 @@ class AuthenticationMiddleware:
access resources that don't have access_attributes defined.
"""
def __init__(self, app, auth_config: AuthenticationConfig):
def __init__(self, app, auth_config: AuthenticationConfig, impls):
self.app = app
self.impls = impls
self.auth_provider = create_auth_provider(auth_config)
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
# First, handle authentication
headers = dict(scope.get("headers", []))
auth_header = headers.get(b"authorization", b"").decode()
@ -121,15 +126,50 @@ class AuthenticationMiddleware:
f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes"
)
# Scope-based API access control
path = scope.get("path", "")
method = scope.get("method", hdrs.METH_GET)
if not hasattr(self, "route_impls"):
self.route_impls = initialize_route_impls(self.impls)
try:
_, _, _, webmethod = find_matching_route(method, path, self.route_impls)
except ValueError:
# If no matching endpoint is found, pass through to FastAPI
return await self.app(scope, receive, send)
if webmethod.required_scope:
user = user_from_scope(scope)
if not _has_required_scope(webmethod.required_scope, user):
return await self._send_auth_error(
send,
f"Access denied: user does not have required scope: {webmethod.required_scope}",
status=403,
)
return await self.app(scope, receive, send)
async def _send_auth_error(self, send, message):
async def _send_auth_error(self, send, message, status=401):
await send(
{
"type": "http.response.start",
"status": 401,
"status": status,
"headers": [[b"content-type", b"application/json"]],
}
)
error_msg = json.dumps({"error": {"message": message}}).encode()
error_key = "message" if status == 401 else "detail"
error_msg = json.dumps({"error": {error_key: message}}).encode()
await send({"type": "http.response.body", "body": error_msg})
def _has_required_scope(required_scope: str, user: User | None) -> bool:
# if no user, assume auth is not enabled
if not user:
return True
if not user.attributes:
return False
user_scopes = user.attributes.get("scopes", [])
return required_scope in user_scopes

View file

@ -12,17 +12,18 @@ from typing import Any
from aiohttp import hdrs
from starlette.routing import Route
from llama_stack.apis.datatypes import Api, ExternalApiSpec
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
from llama_stack.distribution.resolver import api_protocol_map
from llama_stack.providers.datatypes import Api
from llama_stack.schema_utils import WebMethod
EndpointFunc = Callable[..., Any]
PathParams = dict[str, str]
RouteInfo = tuple[EndpointFunc, str]
RouteInfo = tuple[EndpointFunc, str, WebMethod]
PathImpl = dict[str, RouteInfo]
RouteImpls = dict[str, PathImpl]
RouteMatch = tuple[EndpointFunc, PathParams, str]
RouteMatch = tuple[EndpointFunc, PathParams, str, WebMethod]
def toolgroup_protocol_map():
@ -31,10 +32,12 @@ def toolgroup_protocol_map():
}
def get_all_api_routes() -> dict[Api, list[Route]]:
def get_all_api_routes(
external_apis: dict[Api, ExternalApiSpec] | None = None,
) -> dict[Api, list[tuple[Route, WebMethod]]]:
apis = {}
protocols = api_protocol_map()
protocols = api_protocol_map(external_apis)
toolgroup_protocols = toolgroup_protocol_map()
for api, protocol in protocols.items():
routes = []
@ -65,7 +68,7 @@ def get_all_api_routes() -> dict[Api, list[Route]]:
else:
http_method = hdrs.METH_POST
routes.append(
Route(path=path, methods=[http_method], name=name, endpoint=None)
(Route(path=path, methods=[http_method], name=name, endpoint=None), webmethod)
) # setting endpoint to None since don't use a Router object
apis[api] = routes
@ -73,8 +76,8 @@ def get_all_api_routes() -> dict[Api, list[Route]]:
return apis
def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
routes = get_all_api_routes()
def initialize_route_impls(impls, external_apis: dict[Api, ExternalApiSpec] | None = None) -> RouteImpls:
api_to_routes = get_all_api_routes(external_apis)
route_impls: RouteImpls = {}
def _convert_path_to_regex(path: str) -> str:
@ -88,10 +91,10 @@ def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
return f"^{pattern}$"
for api, api_routes in routes.items():
for api, api_routes in api_to_routes.items():
if api not in impls:
continue
for route in api_routes:
for route, webmethod in api_routes:
impl = impls[api]
func = getattr(impl, route.name)
# Get the first (and typically only) method from the set, filtering out HEAD
@ -104,6 +107,7 @@ def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
route_impls[method][_convert_path_to_regex(route.path)] = (
func,
route.path,
webmethod,
)
return route_impls
@ -118,7 +122,7 @@ def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> Rout
route_impls: A dictionary of endpoint implementations
Returns:
A tuple of (endpoint_function, path_params, descriptive_name)
A tuple of (endpoint_function, path_params, route_path, webmethod_metadata)
Raises:
ValueError: If no matching endpoint is found
@ -127,11 +131,11 @@ def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> Rout
if not impls:
raise ValueError(f"No endpoint found for {path}")
for regex, (func, descriptive_name) in impls.items():
for regex, (func, route_path, webmethod) in impls.items():
match = re.match(regex, path)
if match:
# Extract named groups from the regex match
path_params = match.groupdict()
return func, path_params, descriptive_name
return func, path_params, route_path, webmethod
raise ValueError(f"No endpoint found for {path}")

View file

@ -32,6 +32,7 @@ from openai import BadRequestError
from pydantic import BaseModel, ValidationError
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.cli.utils import add_config_template_args, get_config_from_args
from llama_stack.distribution.access_control.access_control import AccessDeniedError
from llama_stack.distribution.datatypes import (
AuthenticationRequiredError,
@ -39,7 +40,12 @@ from llama_stack.distribution.datatypes import (
StackRunConfig,
)
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
from llama_stack.distribution.external import ExternalApiSpec, load_external_apis
from llama_stack.distribution.request_headers import (
PROVIDER_DATA_VAR,
request_provider_data_context,
user_from_scope,
)
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.server.routes import (
find_matching_route,
@ -50,9 +56,11 @@ from llama_stack.distribution.stack import (
cast_image_name_to_string,
construct_stack,
replace_env_vars,
shutdown_stack,
validate_env_pair,
)
from llama_stack.distribution.utils.config import redact_sensitive_fields
from llama_stack.distribution.utils.config_resolution import Mode, resolve_config_or_template
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
@ -144,18 +152,7 @@ async def shutdown(app):
Handled by the lifespan context manager. The shutdown process involves
shutting down all implementations registered in the application.
"""
for impl in app.__llama_stack_impls__.values():
impl_name = impl.__class__.__name__
logger.info("Shutting down %s", impl_name)
try:
if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logger.warning("No shutdown method for %s", impl_name)
except TimeoutError:
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
except (Exception, asyncio.CancelledError) as e:
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
await shutdown_stack(app.__llama_stack_impls__)
@asynccontextmanager
@ -220,9 +217,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
@functools.wraps(func)
async def route_handler(request: Request, **kwargs):
# Get auth attributes from the request scope
user_attributes = request.scope.get("user_attributes", {})
principal = request.scope.get("principal", "")
user = User(principal=principal, attributes=user_attributes)
user = user_from_scope(request.scope)
await log_request_pre_validation(request)
@ -280,9 +275,10 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
class TracingMiddleware:
def __init__(self, app, impls):
def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]):
self.app = app
self.impls = impls
self.external_apis = external_apis
# FastAPI built-in paths that should bypass custom routing
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
@ -299,10 +295,12 @@ class TracingMiddleware:
return await self.app(scope, receive, send)
if not hasattr(self, "route_impls"):
self.route_impls = initialize_route_impls(self.impls)
self.route_impls = initialize_route_impls(self.impls, self.external_apis)
try:
_, _, trace_path = find_matching_route(scope.get("method", hdrs.METH_GET), path, self.route_impls)
_, _, route_path, webmethod = find_matching_route(
scope.get("method", hdrs.METH_GET), path, self.route_impls
)
except ValueError:
# If no matching endpoint is found, pass through to FastAPI
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
@ -319,6 +317,7 @@ class TracingMiddleware:
if tracestate:
trace_attributes["tracestate"] = tracestate
trace_path = webmethod.descriptive_name or route_path
trace_context = await start_trace(trace_path, trace_attributes)
async def send_with_trace_id(message):
@ -377,20 +376,8 @@ class ClientVersionMiddleware:
def main(args: argparse.Namespace | None = None):
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
parser.add_argument(
"--yaml-config",
dest="config",
help="(Deprecated) Path to YAML configuration file - use --config instead",
)
parser.add_argument(
"--config",
dest="config",
help="Path to YAML configuration file",
)
parser.add_argument(
"--template",
help="One of the template names in llama_stack/templates (e.g., tgi, fireworks, remote-vllm, etc.)",
)
add_config_template_args(parser)
parser.add_argument(
"--port",
type=int,
@ -409,20 +396,8 @@ def main(args: argparse.Namespace | None = None):
if args is None:
args = parser.parse_args()
log_line = ""
if hasattr(args, "config") and args.config:
# if the user provided a config file, use it, even if template was specified
config_file = Path(args.config)
if not config_file.exists():
raise ValueError(f"Config file {config_file} does not exist")
log_line = f"Using config file: {config_file}"
elif hasattr(args, "template") and args.template:
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
if not config_file.exists():
raise ValueError(f"Template {args.template} does not exist")
log_line = f"Using template {args.template} config file: {config_file}"
else:
raise ValueError("Either --config or --template must be provided")
config_or_template = get_config_from_args(args)
config_file = resolve_config_or_template(config_or_template, Mode.RUN)
logger_config = None
with open(config_file) as fp:
@ -442,9 +417,6 @@ def main(args: argparse.Namespace | None = None):
config = replace_env_vars(config_contents)
config = StackRunConfig(**cast_image_name_to_string(config))
# now that the logger is initialized, print the line about which type of config we are using.
logger.info(log_line)
_log_run_config(run_config=config)
app = FastAPI(
@ -457,10 +429,21 @@ def main(args: argparse.Namespace | None = None):
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
app.add_middleware(ClientVersionMiddleware)
# Add authentication middleware if configured
try:
# Create and set the event loop that will be used for both construction and server runtime
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Construct the stack in the persistent event loop
impls = loop.run_until_complete(construct_stack(config))
except InvalidProviderError as e:
logger.error(f"Error: {str(e)}")
sys.exit(1)
if config.server.auth:
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth, impls=impls)
else:
if config.server.quota:
quota = config.server.quota
@ -491,24 +474,14 @@ def main(args: argparse.Namespace | None = None):
window_seconds=window_seconds,
)
try:
# Create and set the event loop that will be used for both construction and server runtime
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Construct the stack in the persistent event loop
impls = loop.run_until_complete(construct_stack(config))
except InvalidProviderError as e:
logger.error(f"Error: {str(e)}")
sys.exit(1)
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
else:
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
all_routes = get_all_api_routes()
# Load external APIs if configured
external_apis = load_external_apis(config)
all_routes = get_all_api_routes(external_apis)
if config.apis:
apis_to_serve = set(config.apis)
@ -527,9 +500,12 @@ def main(args: argparse.Namespace | None = None):
api = Api(api_str)
routes = all_routes[api]
impl = impls[api]
try:
impl = impls[api]
except KeyError as e:
raise ValueError(f"Could not find provider implementation for {api} API") from e
for route in routes:
for route, _ in routes:
if not hasattr(impl, route.name):
# ideally this should be a typing violation already
raise ValueError(f"Could not find method {route.name} on {impl}!")
@ -558,7 +534,7 @@ def main(args: argparse.Namespace | None = None):
app.exception_handler(Exception)(global_exception_handler)
app.__llama_stack_impls__ = impls
app.add_middleware(TracingMiddleware, impls=impls)
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
import uvicorn
@ -592,12 +568,29 @@ def main(args: argparse.Namespace | None = None):
"port": port,
"lifespan": "on",
"log_level": logger.getEffectiveLevel(),
"log_config": logger_config,
}
if ssl_config:
uvicorn_config.update(ssl_config)
# Run uvicorn in the existing event loop to preserve background tasks
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
# We need to catch KeyboardInterrupt because uvicorn's signal handling
# re-raises SIGINT signals using signal.raise_signal(), which Python
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
# stack trace when using Ctrl+C or kill -2 (SIGINT).
# SIGTERM (kill -15) works fine without this because Python doesn't
# have a default handler for it.
#
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
# signal handling but this is quite intrusive and not worth the effort.
try:
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
except (KeyboardInterrupt, SystemExit):
logger.info("Received interrupt signal, shutting down gracefully...")
finally:
if not loop.is_closed():
logger.debug("Closing event loop")
loop.close()
def _log_run_config(run_config: StackRunConfig):
@ -618,11 +611,8 @@ def extract_path_params(route: str) -> list[str]:
def remove_disabled_providers(obj):
if isinstance(obj, dict):
if (
obj.get("provider_id") == "__disabled__"
or obj.get("shield_id") == "__disabled__"
or obj.get("provider_model_id") == "__disabled__"
):
keys = ["provider_id", "shield_id", "provider_model_id", "model_id"]
if any(k in obj and obj[k] in ("__disabled__", "", None) for k in keys):
return None
return {k: v for k, v in ((k, remove_disabled_providers(v)) for k, v in obj.items()) if v is not None}
elif isinstance(obj, list):

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import importlib.resources
import os
import re
@ -38,6 +39,7 @@ from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl
from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
from llama_stack.distribution.routing_tables.common import CommonRoutingTableImpl
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
@ -90,6 +92,10 @@ RESOURCES = [
]
REGISTRY_REFRESH_INTERVAL_SECONDS = 300
REGISTRY_REFRESH_TASK = None
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
for rsrc, api, register_method, list_method in RESOURCES:
objects = getattr(run_config, rsrc)
@ -99,23 +105,10 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
method = getattr(impls[api], register_method)
for obj in objects:
logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
# Do not register models on disabled providers
if hasattr(obj, "provider_id") and obj.provider_id is not None and obj.provider_id == "__disabled__":
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
continue
# In complex templates, like our starter template, we may have dynamic model ids
# given by environment variables. This allows those environment variables to have
# a default value of __disabled__ to skip registration of the model if not set.
if (
hasattr(obj, "provider_model_id")
and obj.provider_model_id is not None
and "__disabled__" in obj.provider_model_id
):
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled model.")
continue
if hasattr(obj, "shield_id") and obj.shield_id is not None and obj.shield_id == "__disabled__":
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled shield.")
# Do not register models on disabled providers
if hasattr(obj, "provider_id") and (not obj.provider_id or obj.provider_id == "__disabled__"):
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
continue
# we want to maintain the type information in arguments to method.
@ -324,9 +317,61 @@ async def construct_stack(
add_internal_implementations(impls, run_config)
await register_resources(run_config, impls)
await refresh_registry_once(impls)
global REGISTRY_REFRESH_TASK
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(impls))
def cb(task):
import traceback
if task.cancelled():
logger.error("Model refresh task cancelled")
elif task.exception():
logger.error(f"Model refresh task failed: {task.exception()}")
traceback.print_exception(task.exception())
else:
logger.debug("Model refresh task completed")
REGISTRY_REFRESH_TASK.add_done_callback(cb)
return impls
async def shutdown_stack(impls: dict[Api, Any]):
for impl in impls.values():
impl_name = impl.__class__.__name__
logger.info(f"Shutting down {impl_name}")
try:
if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logger.warning(f"No shutdown method for {impl_name}")
except TimeoutError:
logger.exception(f"Shutdown timeout for {impl_name}")
except (Exception, asyncio.CancelledError) as e:
logger.exception(f"Failed to shutdown {impl_name}: {e}")
global REGISTRY_REFRESH_TASK
if REGISTRY_REFRESH_TASK:
REGISTRY_REFRESH_TASK.cancel()
async def refresh_registry_once(impls: dict[Api, Any]):
logger.debug("refreshing registry")
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
for routing_table in routing_tables:
await routing_table.refresh()
async def refresh_registry_task(impls: dict[Api, Any]):
logger.info("starting registry refresh task")
while True:
await refresh_registry_once(impls)
await asyncio.sleep(REGISTRY_REFRESH_INTERVAL_SECONDS)
def get_stack_run_config_from_template(template: str) -> StackRunConfig:
template_path = importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml"

View file

@ -117,7 +117,7 @@ if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
set -x
if [ -n "$yaml_config" ]; then
yaml_config_arg="--config $yaml_config"
yaml_config_arg="$yaml_config"
else
yaml_config_arg=""
fi

View file

@ -0,0 +1,125 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import StrEnum
from pathlib import Path
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="config_resolution")
TEMPLATE_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "templates"
class Mode(StrEnum):
RUN = "run"
BUILD = "build"
def resolve_config_or_template(
config_or_template: str,
mode: Mode = Mode.RUN,
) -> Path:
"""
Resolve a config/template argument to a concrete config file path.
Args:
config_or_template: User input (file path, template name, or built distribution)
mode: Mode resolving for ("run", "build", "server")
Returns:
Path to the resolved config file
Raises:
ValueError: If resolution fails
"""
# Strategy 1: Try as file path first
config_path = Path(config_or_template)
if config_path.exists() and config_path.is_file():
logger.info(f"Using file path: {config_path}")
return config_path.resolve()
# Strategy 2: Try as template name (if no .yaml extension)
if not config_or_template.endswith(".yaml"):
template_config = _get_template_config_path(config_or_template, mode)
if template_config.exists():
logger.info(f"Using template: {template_config}")
return template_config
# Strategy 3: Try as built distribution name
distrib_config = DISTRIBS_BASE_DIR / f"llamastack-{config_or_template}" / f"{config_or_template}-{mode}.yaml"
if distrib_config.exists():
logger.info(f"Using built distribution: {distrib_config}")
return distrib_config
distrib_config = DISTRIBS_BASE_DIR / f"{config_or_template}" / f"{config_or_template}-{mode}.yaml"
if distrib_config.exists():
logger.info(f"Using built distribution: {distrib_config}")
return distrib_config
# Strategy 4: Failed - provide helpful error
raise ValueError(_format_resolution_error(config_or_template, mode))
def _get_template_config_path(template_name: str, mode: Mode) -> Path:
"""Get the config file path for a template."""
return TEMPLATE_DIR / template_name / f"{mode}.yaml"
def _format_resolution_error(config_or_template: str, mode: Mode) -> str:
"""Format a helpful error message for resolution failures."""
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
template_path = _get_template_config_path(config_or_template, mode)
distrib_path = DISTRIBS_BASE_DIR / f"llamastack-{config_or_template}" / f"{config_or_template}-{mode}.yaml"
distrib_path2 = DISTRIBS_BASE_DIR / f"{config_or_template}" / f"{config_or_template}-{mode}.yaml"
available_templates = _get_available_templates()
templates_str = ", ".join(available_templates) if available_templates else "none found"
return f"""Could not resolve config or template '{config_or_template}'.
Tried the following locations:
1. As file path: {Path(config_or_template).resolve()}
2. As template: {template_path}
3. As built distribution: ({distrib_path}, {distrib_path2})
Available templates: {templates_str}
Did you mean one of these templates?
{_format_template_suggestions(available_templates, config_or_template)}
"""
def _get_available_templates() -> list[str]:
"""Get list of available template names."""
if not TEMPLATE_DIR.exists() and not DISTRIBS_BASE_DIR.exists():
return []
return list(
set(
[d.name for d in TEMPLATE_DIR.iterdir() if d.is_dir() and not d.name.startswith(".")]
+ [d.name for d in DISTRIBS_BASE_DIR.iterdir() if d.is_dir() and not d.name.startswith(".")]
)
)
def _format_template_suggestions(templates: list[str], user_input: str) -> str:
"""Format template suggestions for error messages, showing closest matches first."""
if not templates:
return " (no templates found)"
import difflib
# Get up to 3 closest matches with similarity threshold of 0.3 (lower = more permissive)
close_matches = difflib.get_close_matches(user_input, templates, n=3, cutoff=0.3)
display_templates = close_matches if close_matches else templates[:3]
suggestions = [f" - {t}" for t in display_templates]
return "\n".join(suggestions)

View file

@ -21,7 +21,7 @@ from pathlib import Path
from llama_stack.distribution.utils.image_types import LlamaStackImageType
def formulate_run_args(image_type, image_name, config, template_name) -> list:
def formulate_run_args(image_type: str, image_name: str) -> list[str]:
env_name = ""
if image_type == LlamaStackImageType.CONDA.value:

View file

@ -6,6 +6,7 @@
import logging
import os
import re
import sys
from logging.config import dictConfig
@ -30,6 +31,7 @@ CATEGORIES = [
"eval",
"tools",
"client",
"telemetry",
]
# Initialize category levels with default level
@ -113,6 +115,11 @@ def parse_environment_config(env_config: str) -> dict[str, int]:
return category_levels
def strip_rich_markup(text):
"""Remove Rich markup tags like [dim], [bold magenta], etc."""
return re.sub(r"\[/?[a-zA-Z0-9 _#=,]+\]", "", text)
class CustomRichHandler(RichHandler):
def __init__(self, *args, **kwargs):
kwargs["console"] = Console(width=150)
@ -131,6 +138,19 @@ class CustomRichHandler(RichHandler):
self.markup = original_markup
class CustomFileHandler(logging.FileHandler):
def __init__(self, filename, mode="a", encoding=None, delay=False):
super().__init__(filename, mode, encoding, delay)
# Default formatter to match console output
self.default_formatter = logging.Formatter("%(asctime)s %(name)s:%(lineno)d %(category)s: %(message)s")
self.setFormatter(self.default_formatter)
def emit(self, record):
if hasattr(record, "msg"):
record.msg = strip_rich_markup(str(record.msg))
super().emit(record)
def setup_logging(category_levels: dict[str, int], log_file: str | None) -> None:
"""
Configure logging based on the provided category log levels and an optional log file.
@ -167,8 +187,7 @@ def setup_logging(category_levels: dict[str, int], log_file: str | None) -> None
# Add a file handler if log_file is set
if log_file:
handlers["file"] = {
"class": "logging.FileHandler",
"formatter": "rich",
"()": CustomFileHandler,
"filename": log_file,
"mode": "a",
"encoding": "utf-8",

View file

@ -43,10 +43,24 @@ class ModelsProtocolPrivate(Protocol):
-> Provider uses provider-model-id for inference
"""
# this should be called `on_model_register` or something like that.
# the provider should _not_ be able to change the object in this
# callback
async def register_model(self, model: Model) -> Model: ...
async def unregister_model(self, model_id: str) -> None: ...
# the Stack router will query each provider for their list of models
# if a `refresh_interval_seconds` is provided, this method will be called
# periodically to refresh the list of models
#
# NOTE: each model returned will be registered with the model registry. this means
# a callback to the `register_model()` method will be made. this is duplicative and
# may be removed in the future.
async def list_models(self) -> list[Model] | None: ...
async def should_refresh_models(self) -> bool: ...
class ShieldsProtocolPrivate(Protocol):
async def register_shield(self, shield: Shield) -> None: ...
@ -104,6 +118,19 @@ class ProviderSpec(BaseModel):
description="If this provider is deprecated and does NOT work, specify the error message here",
)
module: str | None = Field(
default=None,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
Example: `module: ramalama_stack`
""",
)
is_external: bool = Field(default=False, description="Notes whether this provider is an external provider.")
# used internally by the resolver; this is a hack for now
deps__: list[str] = Field(default_factory=list)
@ -113,7 +140,7 @@ class ProviderSpec(BaseModel):
class RoutingTable(Protocol):
def get_provider_impl(self, routing_key: str) -> Any: ...
async def get_provider_impl(self, routing_key: str) -> Any: ...
# TODO: this can now be inlined into RemoteProviderSpec
@ -124,7 +151,7 @@ class AdapterSpec(BaseModel):
description="Unique identifier for this adapter",
)
module: str = Field(
...,
default_factory=str,
description="""
Fully-qualified name of the module to import. The module is expected to have:
@ -162,14 +189,7 @@ The container image to use for this implementation. If one is provided, pip_pack
If a provider depends on other providers, the dependencies MUST NOT specify a container image.
""",
)
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_provider_impl(config, deps)`: returns the local implementation
""",
)
# module field is inherited from ProviderSpec
provider_data_validator: str | None = Field(
default=None,
)
@ -212,9 +232,7 @@ API responses, specify the adapter here.
def container_image(self) -> str | None:
return None
@property
def module(self) -> str:
return self.adapter.module
# module field is inherited from ProviderSpec
@property
def pip_packages(self) -> list[str]:
@ -226,14 +244,19 @@ API responses, specify the adapter here.
def remote_provider_spec(
api: Api, adapter: AdapterSpec, api_dependencies: list[Api] | None = None
api: Api,
adapter: AdapterSpec,
api_dependencies: list[Api] | None = None,
optional_api_dependencies: list[Api] | None = None,
) -> RemoteProviderSpec:
return RemoteProviderSpec(
api=api,
provider_type=f"remote::{adapter.adapter_type}",
config_class=adapter.config_class,
module=adapter.module,
adapter=adapter,
api_dependencies=api_dependencies or [],
optional_api_dependencies=optional_api_dependencies or [],
)

View file

@ -10,6 +10,7 @@ import re
import secrets
import string
import uuid
import warnings
from collections.abc import AsyncGenerator
from datetime import UTC, datetime
@ -911,8 +912,16 @@ async def load_data_from_url(url: str) -> str:
async def get_raw_document_text(document: Document) -> str:
if not document.mime_type.startswith("text/"):
# Handle deprecated text/yaml mime type with warning
if document.mime_type == "text/yaml":
warnings.warn(
"The 'text/yaml' MIME type is deprecated. Please use 'application/yaml' instead.",
DeprecationWarning,
stacklevel=2,
)
elif not (document.mime_type.startswith("text/") or document.mime_type == "application/yaml"):
raise ValueError(f"Unexpected document mime type: {document.mime_type}")
if isinstance(document.content, URL):
return await load_data_from_url(document.content.uri)
elif isinstance(document.content, str):

View file

@ -128,6 +128,11 @@ class AgentPersistence:
except Exception as e:
log.error(f"Error parsing turn: {e}")
continue
# The kvstore does not guarantee order, so we sort by started_at
# to ensure consistent ordering of turns.
turns.sort(key=lambda t: t.started_at)
return turns
async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None:

View file

@ -102,6 +102,12 @@ class MetaReferenceInferenceImpl(
if self.config.create_distributed_process_group:
self.generator.stop()
async def should_refresh_models(self) -> bool:
return False
async def list_models(self) -> list[Model] | None:
return None
async def unregister_model(self, model_id: str) -> None:
pass

View file

@ -20,6 +20,7 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import ModelType
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
@ -41,6 +42,8 @@ class SentenceTransformersInferenceImpl(
InferenceProvider,
ModelsProtocolPrivate,
):
__provider_id__: str
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:
self.config = config
@ -50,6 +53,22 @@ class SentenceTransformersInferenceImpl(
async def shutdown(self) -> None:
pass
async def should_refresh_models(self) -> bool:
return False
async def list_models(self) -> list[Model] | None:
return [
Model(
identifier="all-MiniLM-L6-v2",
provider_resource_id="all-MiniLM-L6-v2",
provider_id=self.__provider_id__,
metadata={
"embedding_dimension": 384,
},
model_type=ModelType.embedding,
),
]
async def register_model(self, model: Model) -> Model:
return model

View file

@ -146,9 +146,9 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
pass
async def register_shield(self, shield: Shield) -> None:
# Allow any model to be registered as a shield
# The model will be validated during runtime when making inference calls
pass
model_id = shield.provider_resource_id
if not model_id:
raise ValueError("Llama Guard shield must have a model id")
async def run_shield(
self,

View file

@ -11,19 +11,9 @@ from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.trace.export import SpanProcessor
from opentelemetry.trace.status import StatusCode
# Colors for console output
COLORS = {
"reset": "\033[0m",
"bold": "\033[1m",
"dim": "\033[2m",
"red": "\033[31m",
"green": "\033[32m",
"yellow": "\033[33m",
"blue": "\033[34m",
"magenta": "\033[35m",
"cyan": "\033[36m",
"white": "\033[37m",
}
from llama_stack.log import get_logger
logger = get_logger(name="console_span_processor", category="telemetry")
class ConsoleSpanProcessor(SpanProcessor):
@ -35,34 +25,21 @@ class ConsoleSpanProcessor(SpanProcessor):
return
timestamp = datetime.fromtimestamp(span.start_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
print(
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
f"{COLORS['magenta']}[START]{COLORS['reset']} "
f"{COLORS['dim']}{span.name}{COLORS['reset']}"
)
logger.info(f"[dim]{timestamp}[/dim] [bold magenta][START][/bold magenta] [dim]{span.name}[/dim]")
def on_end(self, span: ReadableSpan) -> None:
if span.attributes and span.attributes.get("__autotraced__"):
return
timestamp = datetime.fromtimestamp(span.end_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
span_context = (
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
f"{COLORS['magenta']}[END]{COLORS['reset']} "
f"{COLORS['dim']}{span.name}{COLORS['reset']}"
)
span_context = f"[dim]{timestamp}[/dim] [bold magenta][END][/bold magenta] [dim]{span.name}[/dim]"
if span.status.status_code == StatusCode.ERROR:
span_context += f"{COLORS['reset']} {COLORS['red']}[ERROR]{COLORS['reset']}"
span_context += " [bold red][ERROR][/bold red]"
elif span.status.status_code != StatusCode.UNSET:
span_context += f"{COLORS['reset']} [{span.status.status_code}]"
span_context += f" [{span.status.status_code}]"
duration_ms = (span.end_time - span.start_time) / 1e6
span_context += f"{COLORS['reset']} ({duration_ms:.2f}ms)"
print(span_context)
span_context += f" ({duration_ms:.2f}ms)"
logger.info(span_context)
if self.print_attributes and span.attributes:
for key, value in span.attributes.items():
@ -71,31 +48,26 @@ class ConsoleSpanProcessor(SpanProcessor):
str_value = str(value)
if len(str_value) > 1000:
str_value = str_value[:997] + "..."
print(f" {COLORS['dim']}{key}: {str_value}{COLORS['reset']}")
logger.info(f" [dim]{key}[/dim]: {str_value}")
for event in span.events:
event_time = datetime.fromtimestamp(event.timestamp / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
severity = event.attributes.get("severity", "info")
message = event.attributes.get("message", event.name)
if isinstance(message, dict | list):
if isinstance(message, dict) or isinstance(message, list):
message = json.dumps(message, indent=2)
severity_colors = {
"error": f"{COLORS['bold']}{COLORS['red']}",
"warn": f"{COLORS['bold']}{COLORS['yellow']}",
"info": COLORS["white"],
"debug": COLORS["dim"],
}
msg_color = severity_colors.get(severity, COLORS["white"])
print(f" {event_time} {msg_color}[{severity.upper()}] {message}{COLORS['reset']}")
severity_color = {
"error": "red",
"warn": "yellow",
"info": "white",
"debug": "dim",
}.get(severity, "white")
logger.info(f" {event_time} [bold {severity_color}][{severity.upper()}][/bold {severity_color}] {message}")
if event.attributes:
for key, value in event.attributes.items():
if key.startswith("__") or key in ["message", "severity"]:
continue
print(f" {COLORS['dim']}{key}: {value}{COLORS['reset']}")
logger.info(f"/r[dim]{key}[/dim]: {value}")
def shutdown(self) -> None:
"""Shutdown the processor."""

View file

@ -16,6 +16,6 @@ async def get_provider_impl(config: ChromaVectorIOConfig, deps: dict[Api, Any]):
ChromaVectorIOAdapter,
)
impl = ChromaVectorIOAdapter(config, deps[Api.inference])
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
await impl.initialize()
return impl

View file

@ -6,12 +6,25 @@
from typing import Any
from pydantic import BaseModel
from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class ChromaVectorIOConfig(BaseModel):
db_path: str
kvstore: KVStoreConfig = Field(description="Config for KV store backend")
@classmethod
def sample_run_config(cls, db_path: str = "${env.CHROMADB_PATH}", **kwargs: Any) -> dict[str, Any]:
return {"db_path": db_path}
def sample_run_config(
cls, __distro_dir__: str, db_path: str = "${env.CHROMADB_PATH}", **kwargs: Any
) -> dict[str, Any]:
return {
"db_path": db_path,
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="chroma_inline_registry.db",
),
}

View file

@ -55,6 +55,11 @@ class FaissIndex(EmbeddingIndex):
self.kvstore = kvstore
self.bank_id = bank_id
# A list of chunk id's in the same order as they are in the index,
# must be updated when chunks are added or removed
self.chunk_id_lock = asyncio.Lock()
self.chunk_ids: list[Any] = []
@classmethod
async def create(cls, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
instance = cls(dimension, kvstore, bank_id)
@ -75,6 +80,7 @@ class FaissIndex(EmbeddingIndex):
buffer = io.BytesIO(base64.b64decode(data["faiss_index"]))
try:
self.index = faiss.deserialize_index(np.load(buffer, allow_pickle=False))
self.chunk_ids = [chunk.chunk_id for chunk in self.chunk_by_index.values()]
except Exception as e:
logger.debug(e, exc_info=True)
raise ValueError(
@ -114,11 +120,33 @@ class FaissIndex(EmbeddingIndex):
for i, chunk in enumerate(chunks):
self.chunk_by_index[indexlen + i] = chunk
self.index.add(np.array(embeddings).astype(np.float32))
async with self.chunk_id_lock:
self.index.add(np.array(embeddings).astype(np.float32))
self.chunk_ids.extend([chunk.chunk_id for chunk in chunks])
# Save updated index
await self._save_index()
async def delete_chunk(self, chunk_id: str) -> None:
if chunk_id not in self.chunk_ids:
return
async with self.chunk_id_lock:
index = self.chunk_ids.index(chunk_id)
self.index.remove_ids(np.array([index]))
new_chunk_by_index = {}
for idx, chunk in self.chunk_by_index.items():
# Shift all chunks after the removed chunk to the left
if idx > index:
new_chunk_by_index[idx - 1] = chunk
else:
new_chunk_by_index[idx] = chunk
self.chunk_by_index = new_chunk_by_index
self.chunk_ids.pop(index)
await self._save_index()
async def query_vector(
self,
embedding: NDArray,
@ -261,47 +289,8 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
return await index.query_chunks(query, params)
async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
) -> None:
"""Save vector store file data to kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
await self.kvstore.set(key=key, value=json.dumps(file_info))
content_key = f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}"
await self.kvstore.set(key=content_key, value=json.dumps(file_contents))
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
"""Load vector store file metadata from kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
stored_data = await self.kvstore.get(key)
return json.loads(stored_data) if stored_data else {}
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
"""Load vector store file contents from kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}"
stored_data = await self.kvstore.get(key)
return json.loads(stored_data) if stored_data else []
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
"""Update vector store file metadata in kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
await self.kvstore.set(key=key, value=json.dumps(file_info))
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
"""Delete vector store data from kvstore."""
assert self.kvstore is not None
keys_to_delete = [
f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}",
f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}",
]
for key in keys_to_delete:
try:
await self.kvstore.delete(key)
except Exception as e:
logger.warning(f"Failed to delete key {key}: {e}")
continue
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
"""Delete a chunk from a faiss index"""
faiss_index = self.cache[store_id].index
for chunk_id in chunk_ids:
await faiss_index.delete_chunk(chunk_id)

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import asyncio
import json
import logging
import re
import sqlite3
@ -426,6 +425,35 @@ class SQLiteVecIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores)
async def delete_chunk(self, chunk_id: str) -> None:
"""Remove a chunk from the SQLite vector store."""
def _delete_chunk():
connection = _create_sqlite_connection(self.db_path)
cur = connection.cursor()
try:
cur.execute("BEGIN TRANSACTION")
# Delete from metadata table
cur.execute(f"DELETE FROM {self.metadata_table} WHERE id = ?", (chunk_id,))
# Delete from vector table
cur.execute(f"DELETE FROM {self.vector_table} WHERE id = ?", (chunk_id,))
# Delete from FTS table
cur.execute(f"DELETE FROM {self.fts_table} WHERE id = ?", (chunk_id,))
connection.commit()
except Exception as e:
connection.rollback()
logger.error(f"Error deleting chunk {chunk_id}: {e}")
raise
finally:
cur.close()
connection.close()
await asyncio.to_thread(_delete_chunk)
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
"""
@ -506,140 +534,6 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
) -> None:
"""Save vector store file metadata to SQLite database."""
def _create_or_store():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
# Create a table to persist OpenAI vector store files.
cur.execute("""
CREATE TABLE IF NOT EXISTS openai_vector_store_files (
store_id TEXT,
file_id TEXT,
metadata TEXT,
PRIMARY KEY (store_id, file_id)
);
""")
cur.execute("""
CREATE TABLE IF NOT EXISTS openai_vector_store_files_contents (
store_id TEXT,
file_id TEXT,
contents TEXT,
PRIMARY KEY (store_id, file_id)
);
""")
connection.commit()
cur.execute(
"INSERT OR REPLACE INTO openai_vector_store_files (store_id, file_id, metadata) VALUES (?, ?, ?)",
(store_id, file_id, json.dumps(file_info)),
)
cur.execute(
"INSERT OR REPLACE INTO openai_vector_store_files_contents (store_id, file_id, contents) VALUES (?, ?, ?)",
(store_id, file_id, json.dumps(file_contents)),
)
connection.commit()
except Exception as e:
logger.error(f"Error saving openai vector store file {store_id} {file_id}: {e}")
raise
finally:
cur.close()
connection.close()
try:
await asyncio.to_thread(_create_or_store)
except Exception as e:
logger.error(f"Error saving openai vector store file {store_id} {file_id}: {e}")
raise
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
"""Load vector store file metadata from SQLite database."""
def _load():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"SELECT metadata FROM openai_vector_store_files WHERE store_id = ? AND file_id = ?",
(store_id, file_id),
)
row = cur.fetchone()
if row is None:
return None
(metadata,) = row
return metadata
finally:
cur.close()
connection.close()
stored_data = await asyncio.to_thread(_load)
return json.loads(stored_data) if stored_data else {}
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
"""Load vector store file contents from SQLite database."""
def _load():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"SELECT contents FROM openai_vector_store_files_contents WHERE store_id = ? AND file_id = ?",
(store_id, file_id),
)
row = cur.fetchone()
if row is None:
return None
(contents,) = row
return contents
finally:
cur.close()
connection.close()
stored_contents = await asyncio.to_thread(_load)
return json.loads(stored_contents) if stored_contents else []
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
"""Update vector store file metadata in SQLite database."""
def _update():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"UPDATE openai_vector_store_files SET metadata = ? WHERE store_id = ? AND file_id = ?",
(json.dumps(file_info), store_id, file_id),
)
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_update)
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
"""Delete vector store file metadata from SQLite database."""
def _delete():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"DELETE FROM openai_vector_store_files WHERE store_id = ? AND file_id = ?", (store_id, file_id)
)
cur.execute(
"DELETE FROM openai_vector_store_files_contents WHERE store_id = ? AND file_id = ?",
(store_id, file_id),
)
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_delete)
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
@ -655,3 +549,13 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
return await index.query_chunks(query, params)
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
"""Delete a chunk from a sqlite_vec index."""
index = await self._get_and_cache_vector_db_index(store_id)
if not index:
raise ValueError(f"Vector DB {store_id} not found")
for chunk_id in chunk_ids:
# Use the index's delete_chunk method
await index.index.delete_chunk(chunk_id)

View file

@ -224,17 +224,6 @@ def available_providers() -> list[ProviderSpec]:
description="Groq inference provider for ultra-fast inference using Groq's LPU technology.",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="fireworks-openai-compat",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.fireworks_openai_compat",
config_class="llama_stack.providers.remote.inference.fireworks_openai_compat.config.FireworksCompatConfig",
provider_data_validator="llama_stack.providers.remote.inference.fireworks_openai_compat.config.FireworksProviderDataValidator",
description="Fireworks AI OpenAI-compatible provider for using Fireworks models with OpenAI API format.",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
@ -246,50 +235,6 @@ def available_providers() -> list[ProviderSpec]:
description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="together-openai-compat",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.together_openai_compat",
config_class="llama_stack.providers.remote.inference.together_openai_compat.config.TogetherCompatConfig",
provider_data_validator="llama_stack.providers.remote.inference.together_openai_compat.config.TogetherProviderDataValidator",
description="Together AI OpenAI-compatible provider for using Together models with OpenAI API format.",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="groq-openai-compat",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.groq_openai_compat",
config_class="llama_stack.providers.remote.inference.groq_openai_compat.config.GroqCompatConfig",
provider_data_validator="llama_stack.providers.remote.inference.groq_openai_compat.config.GroqProviderDataValidator",
description="Groq OpenAI-compatible provider for using Groq models with OpenAI API format.",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="sambanova-openai-compat",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.sambanova_openai_compat",
config_class="llama_stack.providers.remote.inference.sambanova_openai_compat.config.SambaNovaCompatConfig",
provider_data_validator="llama_stack.providers.remote.inference.sambanova_openai_compat.config.SambaNovaProviderDataValidator",
description="SambaNova OpenAI-compatible provider for using SambaNova models with OpenAI API format.",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="cerebras-openai-compat",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.cerebras_openai_compat",
config_class="llama_stack.providers.remote.inference.cerebras_openai_compat.config.CerebrasCompatConfig",
provider_data_validator="llama_stack.providers.remote.inference.cerebras_openai_compat.config.CerebrasProviderDataValidator",
description="Cerebras OpenAI-compatible provider for using Cerebras models with OpenAI API format.",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(

View file

@ -395,7 +395,7 @@ That means you'll get fast and efficient vector retrieval.
To use PGVector in your Llama Stack project, follow these steps:
1. Install the necessary dependencies.
2. Configure your Llama Stack project to use Faiss.
2. Configure your Llama Stack project to use pgvector. (e.g. remote::pgvector).
3. Start storing and querying vectors.
## Installation
@ -410,6 +410,7 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de
""",
),
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
),
remote_provider_spec(
Api.vector_io,

View file

@ -15,6 +15,7 @@ class AnthropicInferenceAdapter(LiteLLMOpenAIMixin):
LiteLLMOpenAIMixin.__init__(
self,
MODEL_ENTRIES,
litellm_provider_name="anthropic",
api_key_from_config=config.api_key,
provider_data_api_key_field="anthropic_api_key",
)

View file

@ -26,7 +26,7 @@ class AnthropicConfig(BaseModel):
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY}", **kwargs) -> dict[str, Any]:
def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {
"api_key": api_key,
}

View file

@ -10,9 +10,9 @@ from llama_stack.providers.utils.inference.model_registry import (
)
LLM_MODEL_IDS = [
"anthropic/claude-3-5-sonnet-latest",
"anthropic/claude-3-7-sonnet-latest",
"anthropic/claude-3-5-haiku-latest",
"claude-3-5-sonnet-latest",
"claude-3-7-sonnet-latest",
"claude-3-5-haiku-latest",
]
SAFETY_MODELS_ENTRIES = []
@ -21,17 +21,17 @@ MODEL_ENTRIES = (
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
+ [
ProviderModelEntry(
provider_model_id="anthropic/voyage-3",
provider_model_id="voyage-3",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 1024, "context_length": 32000},
),
ProviderModelEntry(
provider_model_id="anthropic/voyage-3-lite",
provider_model_id="voyage-3-lite",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 512, "context_length": 32000},
),
ProviderModelEntry(
provider_model_id="anthropic/voyage-code-3",
provider_model_id="voyage-code-3",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 1024, "context_length": 32000},
),

View file

@ -63,18 +63,20 @@ class BedrockInferenceAdapter(
def __init__(self, config: BedrockConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
self._config = config
self._client = create_bedrock_client(config)
self._client = None
@property
def client(self) -> BaseClient:
if self._client is None:
self._client = create_bedrock_client(self._config)
return self._client
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
self.client.close()
if self._client is not None:
self._client.close()
async def completion(
self,

View file

@ -65,6 +65,7 @@ class CerebrasInferenceAdapter(
)
self.config = config
# TODO: make this use provider data, etc. like other providers
self.client = AsyncCerebras(
base_url=self.config.base_url,
api_key=self.config.api_key.get_secret_value(),

View file

@ -26,7 +26,7 @@ class CerebrasImplConfig(BaseModel):
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY}", **kwargs) -> dict[str, Any]:
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {
"base_url": DEFAULT_BASE_URL,
"api_key": api_key,

View file

@ -1,17 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import InferenceProvider
from .config import CerebrasCompatConfig
async def get_adapter_impl(config: CerebrasCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .cerebras import CerebrasCompatInferenceAdapter
adapter = CerebrasCompatInferenceAdapter(config)
return adapter

View file

@ -1,30 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.remote.inference.cerebras_openai_compat.config import CerebrasCompatConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from ..cerebras.models import MODEL_ENTRIES
class CerebrasCompatInferenceAdapter(LiteLLMOpenAIMixin):
_config: CerebrasCompatConfig
def __init__(self, config: CerebrasCompatConfig):
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
api_key_from_config=config.api_key,
provider_data_api_key_field="cerebras_api_key",
openai_compat_api_base=config.openai_compat_api_base,
)
self.config = config
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()

View file

@ -1,38 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
class CerebrasProviderDataValidator(BaseModel):
cerebras_api_key: str | None = Field(
default=None,
description="API key for Cerebras models",
)
@json_schema_type
class CerebrasCompatConfig(BaseModel):
api_key: str | None = Field(
default=None,
description="The Cerebras API key",
)
openai_compat_api_base: str = Field(
default="https://api.cerebras.ai/v1",
description="The URL for the Cerebras API server",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY}", **kwargs) -> dict[str, Any]:
return {
"openai_compat_api_base": "https://api.cerebras.ai/v1",
"api_key": api_key,
}

View file

@ -25,8 +25,8 @@ class DatabricksImplConfig(BaseModel):
@classmethod
def sample_run_config(
cls,
url: str = "${env.DATABRICKS_URL}",
api_token: str = "${env.DATABRICKS_API_TOKEN}",
url: str = "${env.DATABRICKS_URL:=}",
api_token: str = "${env.DATABRICKS_API_TOKEN:=}",
**kwargs: Any,
) -> dict[str, Any]:
return {

View file

@ -6,13 +6,14 @@
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class FireworksImplConfig(BaseModel):
class FireworksImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
default="https://api.fireworks.ai/inference/v1",
description="The URL for the Fireworks server",
@ -23,7 +24,7 @@ class FireworksImplConfig(BaseModel):
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY}", **kwargs) -> dict[str, Any]:
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {
"url": "https://api.fireworks.ai/inference/v1",
"api_key": api_key,

View file

@ -70,7 +70,7 @@ logger = get_logger(name=__name__, category="inference")
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: FireworksImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
self.config = config
async def initialize(self) -> None:

View file

@ -1,17 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import InferenceProvider
from .config import FireworksCompatConfig
async def get_adapter_impl(config: FireworksCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .fireworks import FireworksCompatInferenceAdapter
adapter = FireworksCompatInferenceAdapter(config)
return adapter

View file

@ -1,38 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
class FireworksProviderDataValidator(BaseModel):
fireworks_api_key: str | None = Field(
default=None,
description="API key for Fireworks models",
)
@json_schema_type
class FireworksCompatConfig(BaseModel):
api_key: str | None = Field(
default=None,
description="The Fireworks API key",
)
openai_compat_api_base: str = Field(
default="https://api.fireworks.ai/inference/v1",
description="The URL for the Fireworks API server",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY}", **kwargs) -> dict[str, Any]:
return {
"openai_compat_api_base": "https://api.fireworks.ai/inference/v1",
"api_key": api_key,
}

View file

@ -1,30 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.remote.inference.fireworks_openai_compat.config import FireworksCompatConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from ..fireworks.models import MODEL_ENTRIES
class FireworksCompatInferenceAdapter(LiteLLMOpenAIMixin):
_config: FireworksCompatConfig
def __init__(self, config: FireworksCompatConfig):
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
api_key_from_config=config.api_key,
provider_data_api_key_field="fireworks_api_key",
openai_compat_api_base=config.openai_compat_api_base,
)
self.config = config
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()

View file

@ -26,7 +26,7 @@ class GeminiConfig(BaseModel):
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY}", **kwargs) -> dict[str, Any]:
def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {
"api_key": api_key,
}

View file

@ -15,6 +15,7 @@ class GeminiInferenceAdapter(LiteLLMOpenAIMixin):
LiteLLMOpenAIMixin.__init__(
self,
MODEL_ENTRIES,
litellm_provider_name="gemini",
api_key_from_config=config.api_key,
provider_data_api_key_field="gemini_api_key",
)

View file

@ -10,11 +10,11 @@ from llama_stack.providers.utils.inference.model_registry import (
)
LLM_MODEL_IDS = [
"gemini/gemini-1.5-flash",
"gemini/gemini-1.5-pro",
"gemini/gemini-2.0-flash",
"gemini/gemini-2.5-flash",
"gemini/gemini-2.5-pro",
"gemini-1.5-flash",
"gemini-1.5-pro",
"gemini-2.0-flash",
"gemini-2.5-flash",
"gemini-2.5-pro",
]
SAFETY_MODELS_ENTRIES = []
@ -23,7 +23,7 @@ MODEL_ENTRIES = (
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
+ [
ProviderModelEntry(
provider_model_id="gemini/text-embedding-004",
provider_model_id="text-embedding-004",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 768, "context_length": 2048},
),

View file

@ -32,7 +32,7 @@ class GroqConfig(BaseModel):
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY}", **kwargs) -> dict[str, Any]:
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {
"url": "https://api.groq.com",
"api_key": api_key,

View file

@ -34,6 +34,7 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin):
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
litellm_provider_name="groq",
api_key_from_config=config.api_key,
provider_data_api_key_field="groq_api_key",
)
@ -96,7 +97,7 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin):
tool_choice = "required"
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id.replace("groq/", ""),
model=model_obj.provider_resource_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,

View file

@ -14,19 +14,19 @@ SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"groq/llama3-8b-8192",
"llama3-8b-8192",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_entry(
"groq/llama-3.1-8b-instant",
"llama-3.1-8b-instant",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"groq/llama3-70b-8192",
"llama3-70b-8192",
CoreModelId.llama3_70b_instruct.value,
),
build_hf_repo_model_entry(
"groq/llama-3.3-70b-versatile",
"llama-3.3-70b-versatile",
CoreModelId.llama3_3_70b_instruct.value,
),
# Groq only contains a preview version for llama-3.2-3b
@ -34,23 +34,15 @@ MODEL_ENTRIES = [
# to pass the test fixture
# TODO(aidand): Replace this with a stable model once Groq supports it
build_hf_repo_model_entry(
"groq/llama-3.2-3b-preview",
"llama-3.2-3b-preview",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"groq/llama-4-scout-17b-16e-instruct",
"meta-llama/llama-4-scout-17b-16e-instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
build_hf_repo_model_entry(
"groq/meta-llama/llama-4-scout-17b-16e-instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
build_hf_repo_model_entry(
"groq/llama-4-maverick-17b-128e-instruct",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
build_hf_repo_model_entry(
"groq/meta-llama/llama-4-maverick-17b-128e-instruct",
"meta-llama/llama-4-maverick-17b-128e-instruct",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
] + SAFETY_MODELS_ENTRIES

View file

@ -1,17 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import InferenceProvider
from .config import GroqCompatConfig
async def get_adapter_impl(config: GroqCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .groq import GroqCompatInferenceAdapter
adapter = GroqCompatInferenceAdapter(config)
return adapter

View file

@ -1,38 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
class GroqProviderDataValidator(BaseModel):
groq_api_key: str | None = Field(
default=None,
description="API key for Groq models",
)
@json_schema_type
class GroqCompatConfig(BaseModel):
api_key: str | None = Field(
default=None,
description="The Groq API key",
)
openai_compat_api_base: str = Field(
default="https://api.groq.com/openai/v1",
description="The URL for the Groq API server",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY}", **kwargs) -> dict[str, Any]:
return {
"openai_compat_api_base": "https://api.groq.com/openai/v1",
"api_key": api_key,
}

View file

@ -1,30 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.remote.inference.groq_openai_compat.config import GroqCompatConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from ..groq.models import MODEL_ENTRIES
class GroqCompatInferenceAdapter(LiteLLMOpenAIMixin):
_config: GroqCompatConfig
def __init__(self, config: GroqCompatConfig):
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
api_key_from_config=config.api_key,
provider_data_api_key_field="groq_api_key",
openai_compat_api_base=config.openai_compat_api_base,
)
self.config = config
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()

View file

@ -5,55 +5,53 @@
# the root directory of this source tree.
import logging
from llama_api_client import AsyncLlamaAPIClient, NotFoundError
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .models import MODEL_ENTRIES
logger = logging.getLogger(__name__)
class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
"""
Llama API Inference Adapter for Llama Stack.
Note: The inheritance order is important here. OpenAIMixin must come before
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
is used instead of ModelRegistryHelper.check_model_availability().
- OpenAIMixin.check_model_availability() queries the Llama API to check if a model exists
- ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning
"""
_config: LlamaCompatConfig
def __init__(self, config: LlamaCompatConfig):
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
litellm_provider_name="meta_llama",
api_key_from_config=config.api_key,
provider_data_api_key_field="llama_api_key",
openai_compat_api_base=config.openai_compat_api_base,
)
self.config = config
async def check_model_availability(self, model: str) -> bool:
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_base_url(self) -> str:
"""
Check if a specific model is available from Llama API.
Get the base URL for OpenAI mixin.
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
:return: The Llama API base URL
"""
try:
llama_api_client = self._get_llama_api_client()
retrieved_model = await llama_api_client.models.retrieve(model)
logger.info(f"Model {retrieved_model.id} is available from Llama API")
return True
except NotFoundError:
logger.error(f"Model {model} is not available from Llama API")
return False
except Exception as e:
logger.error(f"Failed to check model availability from Llama API: {e}")
return False
return self.config.openai_compat_api_base
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()
def _get_llama_api_client(self) -> AsyncLlamaAPIClient:
return AsyncLlamaAPIClient(api_key=self.get_api_key(), base_url=self.config.openai_compat_api_base)

View file

@ -7,9 +7,8 @@
import logging
import warnings
from collections.abc import AsyncIterator
from typing import Any
from openai import APIConnectionError, AsyncOpenAI, BadRequestError, NotFoundError
from openai import APIConnectionError, BadRequestError
from llama_stack.apis.common.content_types import (
InterleavedContent,
@ -28,12 +27,6 @@ from llama_stack.apis.inference import (
Inference,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
SamplingParams,
TextTruncation,
@ -47,8 +40,8 @@ from llama_stack.providers.utils.inference.model_registry import (
from llama_stack.providers.utils.inference.openai_compat import (
convert_openai_chat_completion_choice,
convert_openai_chat_completion_stream,
prepare_openai_completion_params,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
from . import NVIDIAConfig
@ -64,7 +57,20 @@ from .utils import _is_nvidia_hosted
logger = logging.getLogger(__name__)
class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
"""
NVIDIA Inference Adapter for Llama Stack.
Note: The inheritance order is important here. OpenAIMixin must come before
ModelRegistryHelper to ensure that OpenAIMixin.check_model_availability()
is used instead of ModelRegistryHelper.check_model_availability(). It also
must come before Inference to ensure that OpenAIMixin methods are available
in the Inference interface.
- OpenAIMixin.check_model_availability() queries the NVIDIA API to check if a model exists
- ModelRegistryHelper.check_model_availability() just returns False and shows a warning
"""
def __init__(self, config: NVIDIAConfig) -> None:
# TODO(mf): filter by available models
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
@ -88,45 +94,21 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
self._config = config
async def check_model_availability(self, model: str) -> bool:
def get_api_key(self) -> str:
"""
Check if a specific model is available.
Get the API key for OpenAI mixin.
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
:return: The NVIDIA API key
"""
try:
await self._client.models.retrieve(model)
return True
except NotFoundError:
logger.error(f"Model {model} is not available")
except Exception as e:
logger.error(f"Failed to check model availability: {e}")
return False
return self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"
@property
def _client(self) -> AsyncOpenAI:
def get_base_url(self) -> str:
"""
Returns an OpenAI client for the configured NVIDIA API endpoint.
Get the base URL for OpenAI mixin.
:return: An OpenAI client
:return: The NVIDIA API base URL
"""
base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
return AsyncOpenAI(
base_url=base_url,
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
timeout=self._config.timeout,
)
async def _get_provider_model_id(self, model_id: str) -> str:
if not self.model_store:
raise RuntimeError("Model store is not set")
model = await self.model_store.get_model(model_id)
if model is None:
raise ValueError(f"Model {model_id} is unknown")
return model.provider_model_id
return f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
async def completion(
self,
@ -160,7 +142,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
)
try:
response = await self._client.completions.create(**request)
response = await self.client.completions.create(**request)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
@ -213,7 +195,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
extra_body["input_type"] = task_type_options[task_type]
try:
response = await self._client.embeddings.create(
response = await self.client.embeddings.create(
model=provider_model_id,
input=input,
extra_body=extra_body,
@ -228,16 +210,6 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
#
return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data])
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def chat_completion(
self,
model_id: str,
@ -274,7 +246,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
)
try:
response = await self._client.chat.completions.create(**request)
response = await self.client.chat.completions.create(**request)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
@ -283,112 +255,3 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
else:
# we pass n=1 to get only one completion
return convert_openai_chat_completion_choice(response.choices[0])
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
provider_model_id = await self._get_provider_model_id(model)
params = await prepare_openai_completion_params(
model=provider_model_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
)
try:
return await self._client.completions.create(**params)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
provider_model_id = await self._get_provider_model_id(model)
params = await prepare_openai_completion_params(
model=provider_model_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
try:
return await self._client.chat.completions.create(**params)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e

View file

@ -13,8 +13,10 @@ DEFAULT_OLLAMA_URL = "http://localhost:11434"
class OllamaImplConfig(BaseModel):
url: str = DEFAULT_OLLAMA_URL
refresh_models: bool = Field(default=False, description="refresh and re-register models periodically")
refresh_models_interval: int = Field(default=300, description="interval in seconds to refresh models")
refresh_models: bool = Field(
default=False,
description="Whether to refresh models periodically",
)
@classmethod
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:

View file

@ -96,14 +96,16 @@ class OllamaInferenceAdapter(
def __init__(self, config: OllamaImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
self.config = config
self._client = None
self._clients: dict[asyncio.AbstractEventLoop, AsyncClient] = {}
self._openai_client = None
@property
def client(self) -> AsyncClient:
if self._client is None:
self._client = AsyncClient(host=self.config.url)
return self._client
# ollama client attaches itself to the current event loop (sadly?)
loop = asyncio.get_running_loop()
if loop not in self._clients:
self._clients[loop] = AsyncClient(host=self.config.url)
return self._clients[loop]
@property
def openai_client(self) -> AsyncOpenAI:
@ -119,59 +121,61 @@ class OllamaInferenceAdapter(
"Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal"
)
if self.config.refresh_models:
logger.debug("ollama starting background model refresh task")
self._refresh_task = asyncio.create_task(self._refresh_models())
def cb(task):
if task.cancelled():
import traceback
logger.error(f"ollama background refresh task canceled:\n{''.join(traceback.format_stack())}")
elif task.exception():
logger.error(f"ollama background refresh task died: {task.exception()}")
else:
logger.error("ollama background refresh task completed unexpectedly")
self._refresh_task.add_done_callback(cb)
async def _refresh_models(self) -> None:
# Wait for model store to be available (with timeout)
waited_time = 0
while not self.model_store and waited_time < 60:
await asyncio.sleep(1)
waited_time += 1
if not self.model_store:
raise ValueError("Model store not set after waiting 60 seconds")
async def should_refresh_models(self) -> bool:
return self.config.refresh_models
async def list_models(self) -> list[Model] | None:
provider_id = self.__provider_id__
while True:
try:
response = await self.client.list()
except Exception as e:
logger.warning(f"Failed to list models: {str(e)}")
await asyncio.sleep(self.config.refresh_models_interval)
response = await self.client.list()
# always add the two embedding models which can be pulled on demand
models = [
Model(
identifier="all-minilm:l6-v2",
provider_resource_id="all-minilm:l6-v2",
provider_id=provider_id,
metadata={
"embedding_dimension": 384,
"context_length": 512,
},
model_type=ModelType.embedding,
),
# add all-minilm alias
Model(
identifier="all-minilm",
provider_resource_id="all-minilm:l6-v2",
provider_id=provider_id,
metadata={
"embedding_dimension": 384,
"context_length": 512,
},
model_type=ModelType.embedding,
),
Model(
identifier="nomic-embed-text",
provider_resource_id="nomic-embed-text",
provider_id=provider_id,
metadata={
"embedding_dimension": 768,
"context_length": 8192,
},
model_type=ModelType.embedding,
),
]
for m in response.models:
# kill embedding models since we don't know dimensions for them
if "bert" in m.details.family:
continue
models = []
for m in response.models:
model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm
if model_type == ModelType.embedding:
continue
models.append(
Model(
identifier=m.model,
provider_resource_id=m.model,
provider_id=provider_id,
metadata={},
model_type=model_type,
)
models.append(
Model(
identifier=m.model,
provider_resource_id=m.model,
provider_id=provider_id,
metadata={},
model_type=ModelType.llm,
)
await self.model_store.update_registered_llm_models(provider_id, models)
logger.debug(f"ollama refreshed model list ({len(models)} models)")
await asyncio.sleep(self.config.refresh_models_interval)
)
return models
async def health(self) -> HealthResponse:
"""
@ -223,12 +227,7 @@ class OllamaInferenceAdapter(
return available_models
async def shutdown(self) -> None:
if hasattr(self, "_refresh_task") and not self._refresh_task.done():
logger.debug("ollama cancelling background refresh task")
self._refresh_task.cancel()
self._client = None
self._openai_client = None
self._clients.clear()
async def unregister_model(self, model_id: str) -> None:
pass

View file

@ -24,9 +24,19 @@ class OpenAIConfig(BaseModel):
default=None,
description="API key for OpenAI models",
)
base_url: str = Field(
default="https://api.openai.com/v1",
description="Base URL for OpenAI API",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.OPENAI_API_KEY}", **kwargs) -> dict[str, Any]:
def sample_run_config(
cls,
api_key: str = "${env.OPENAI_API_KEY:=}",
base_url: str = "${env.OPENAI_BASE_URL:=https://api.openai.com/v1}",
**kwargs,
) -> dict[str, Any]:
return {
"api_key": api_key,
"base_url": base_url,
}

View file

@ -12,11 +12,6 @@ from llama_stack.providers.utils.inference.model_registry import (
)
LLM_MODEL_IDS = [
# the models w/ "openai/" prefix are the litellm specific model names.
# they should be deprecated in favor of the canonical openai model names.
"openai/gpt-4o",
"openai/gpt-4o-mini",
"openai/chatgpt-4o-latest",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo",
"gpt-3.5-turbo-instruct",
@ -43,8 +38,6 @@ class EmbeddingModelInfo:
EMBEDDING_MODEL_IDS: dict[str, EmbeddingModelInfo] = {
"openai/text-embedding-3-small": EmbeddingModelInfo(1536, 8192),
"openai/text-embedding-3-large": EmbeddingModelInfo(3072, 8192),
"text-embedding-3-small": EmbeddingModelInfo(1536, 8192),
"text-embedding-3-large": EmbeddingModelInfo(3072, 8192),
}

View file

@ -5,23 +5,9 @@
# the root directory of this source tree.
import logging
from collections.abc import AsyncIterator
from typing import Any
from openai import AsyncOpenAI, NotFoundError
from llama_stack.apis.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import OpenAIConfig
from .models import MODEL_ENTRIES
@ -30,7 +16,7 @@ logger = logging.getLogger(__name__)
#
# This OpenAI adapter implements Inference methods using two clients -
# This OpenAI adapter implements Inference methods using two mixins -
#
# | Inference Method | Implementation Source |
# |----------------------------|--------------------------|
@ -39,15 +25,27 @@ logger = logging.getLogger(__name__)
# | embedding | LiteLLMOpenAIMixin |
# | batch_completion | LiteLLMOpenAIMixin |
# | batch_chat_completion | LiteLLMOpenAIMixin |
# | openai_completion | AsyncOpenAI |
# | openai_chat_completion | AsyncOpenAI |
# | openai_embeddings | AsyncOpenAI |
# | openai_completion | OpenAIMixin |
# | openai_chat_completion | OpenAIMixin |
# | openai_embeddings | OpenAIMixin |
#
class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
"""
OpenAI Inference Adapter for Llama Stack.
Note: The inheritance order is important here. OpenAIMixin must come before
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
is used instead of ModelRegistryHelper.check_model_availability().
- OpenAIMixin.check_model_availability() queries the OpenAI API to check if a model exists
- ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning
"""
def __init__(self, config: OpenAIConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
MODEL_ENTRIES,
litellm_provider_name="openai",
api_key_from_config=config.api_key,
provider_data_api_key_field="openai_api_key",
)
@ -60,191 +58,19 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
# litellm specific model names, an abstraction leak.
self.is_openai_compat = True
async def check_model_availability(self, model: str) -> bool:
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_base_url(self) -> str:
"""
Check if a specific model is available from OpenAI.
Get the OpenAI API base URL.
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
Returns the OpenAI API base URL from the configuration.
"""
try:
openai_client = self._get_openai_client()
retrieved_model = await openai_client.models.retrieve(model)
logger.info(f"Model {retrieved_model.id} is available from OpenAI")
return True
except NotFoundError:
logger.error(f"Model {model} is not available from OpenAI")
return False
except Exception as e:
logger.error(f"Failed to check model availability from OpenAI: {e}")
return False
return self.config.base_url
async def initialize(self) -> None:
await super().initialize()
async def shutdown(self) -> None:
await super().shutdown()
def _get_openai_client(self) -> AsyncOpenAI:
return AsyncOpenAI(
api_key=self.get_api_key(),
)
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
if guided_choice is not None:
logging.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
if prompt_logprobs is not None:
logging.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")
model_id = (await self.model_store.get_model(model)).provider_resource_id
if model_id.startswith("openai/"):
model_id = model_id[len("openai/") :]
params = await prepare_openai_completion_params(
model=model_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
suffix=suffix,
)
return await self._get_openai_client().completions.create(**params)
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_id = (await self.model_store.get_model(model)).provider_resource_id
if model_id.startswith("openai/"):
model_id = model_id[len("openai/") :]
params = await prepare_openai_completion_params(
model=model_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
return await self._get_openai_client().chat.completions.create(**params)
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
model_id = (await self.model_store.get_model(model)).provider_resource_id
if model_id.startswith("openai/"):
model_id = model_id[len("openai/") :]
# Prepare parameters for OpenAI embeddings API
params = {
"model": model_id,
"input": input,
}
if encoding_format is not None:
params["encoding_format"] = encoding_format
if dimensions is not None:
params["dimensions"] = dimensions
if user is not None:
params["user"] = user
# Call OpenAI embeddings API
response = await self._get_openai_client().embeddings.create(**params)
data = []
for i, embedding_data in enumerate(response.data):
data.append(
OpenAIEmbeddingData(
embedding=embedding_data.embedding,
index=i,
)
)
usage = OpenAIEmbeddingUsage(
prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens,
)
return OpenAIEmbeddingsResponse(
data=data,
model=response.model,
usage=usage,
)

View file

@ -30,7 +30,7 @@ class SambaNovaImplConfig(BaseModel):
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]:
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {
"url": "https://api.sambanova.ai/v1",
"api_key": api_key,

View file

@ -9,49 +9,20 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
SAFETY_MODELS_ENTRIES = [
build_hf_repo_model_entry(
"sambanova/Meta-Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b.value,
),
]
SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"sambanova/Meta-Llama-3.1-8B-Instruct",
"Meta-Llama-3.1-8B-Instruct",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"sambanova/Meta-Llama-3.1-405B-Instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
build_hf_repo_model_entry(
"sambanova/Meta-Llama-3.2-1B-Instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
build_hf_repo_model_entry(
"sambanova/Meta-Llama-3.2-3B-Instruct",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"sambanova/Meta-Llama-3.3-70B-Instruct",
"Meta-Llama-3.3-70B-Instruct",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"sambanova/Llama-3.2-11B-Vision-Instruct",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"sambanova/Llama-3.2-90B-Vision-Instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"sambanova/Llama-4-Scout-17B-16E-Instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
build_hf_repo_model_entry(
"sambanova/Llama-4-Maverick-17B-128E-Instruct",
"Llama-4-Maverick-17B-128E-Instruct",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
] + SAFETY_MODELS_ENTRIES

View file

@ -182,6 +182,7 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
litellm_provider_name="sambanova",
api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None,
provider_data_api_key_field="sambanova_api_key",
)

View file

@ -1,17 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import InferenceProvider
from .config import SambaNovaCompatConfig
async def get_adapter_impl(config: SambaNovaCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .sambanova import SambaNovaCompatInferenceAdapter
adapter = SambaNovaCompatInferenceAdapter(config)
return adapter

View file

@ -1,38 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
class SambaNovaProviderDataValidator(BaseModel):
sambanova_api_key: str | None = Field(
default=None,
description="API key for SambaNova models",
)
@json_schema_type
class SambaNovaCompatConfig(BaseModel):
api_key: str | None = Field(
default=None,
description="The SambaNova API key",
)
openai_compat_api_base: str = Field(
default="https://api.sambanova.ai/v1",
description="The URL for the SambaNova API server",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]:
return {
"openai_compat_api_base": "https://api.sambanova.ai/v1",
"api_key": api_key,
}

View file

@ -1,30 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.remote.inference.sambanova_openai_compat.config import SambaNovaCompatConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from ..sambanova.models import MODEL_ENTRIES
class SambaNovaCompatInferenceAdapter(LiteLLMOpenAIMixin):
_config: SambaNovaCompatConfig
def __init__(self, config: SambaNovaCompatConfig):
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
api_key_from_config=config.api_key,
provider_data_api_key_field="sambanova_api_key",
openai_compat_api_base=config.openai_compat_api_base,
)
self.config = config
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()

View file

@ -19,7 +19,7 @@ class TGIImplConfig(BaseModel):
@classmethod
def sample_run_config(
cls,
url: str = "${env.TGI_URL}",
url: str = "${env.TGI_URL:=}",
**kwargs,
):
return {

View file

@ -305,6 +305,8 @@ class _HfAdapter(
class TGIAdapter(_HfAdapter):
async def initialize(self, config: TGIImplConfig) -> None:
if not config.url:
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
log.info(f"Initializing TGI client with url={config.url}")
self.client = AsyncInferenceClient(
model=config.url,

View file

@ -6,13 +6,14 @@
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class TogetherImplConfig(BaseModel):
class TogetherImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
default="https://api.together.xyz/v1",
description="The URL for the Together AI server",
@ -26,5 +27,5 @@ class TogetherImplConfig(BaseModel):
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"url": "https://api.together.xyz/v1",
"api_key": "${env.TOGETHER_API_KEY}",
"api_key": "${env.TOGETHER_API_KEY:=}",
}

View file

@ -69,15 +69,9 @@ MODEL_ENTRIES = [
build_hf_repo_model_entry(
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
additional_aliases=[
"together/meta-llama/Llama-4-Scout-17B-16E-Instruct",
],
),
build_hf_repo_model_entry(
"meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
additional_aliases=[
"together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
],
),
] + SAFETY_MODELS_ENTRIES

View file

@ -66,7 +66,7 @@ logger = get_logger(name=__name__, category="inference")
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: TogetherImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
self.config = config
async def initialize(self) -> None:

View file

@ -1,17 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import InferenceProvider
from .config import TogetherCompatConfig
async def get_adapter_impl(config: TogetherCompatConfig, _deps) -> InferenceProvider:
# import dynamically so the import is used only when it is needed
from .together import TogetherCompatInferenceAdapter
adapter = TogetherCompatInferenceAdapter(config)
return adapter

View file

@ -1,38 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
class TogetherProviderDataValidator(BaseModel):
together_api_key: str | None = Field(
default=None,
description="API key for Together models",
)
@json_schema_type
class TogetherCompatConfig(BaseModel):
api_key: str | None = Field(
default=None,
description="The Together API key",
)
openai_compat_api_base: str = Field(
default="https://api.together.xyz/v1",
description="The URL for the Together API server",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.TOGETHER_API_KEY}", **kwargs) -> dict[str, Any]:
return {
"openai_compat_api_base": "https://api.together.xyz/v1",
"api_key": api_key,
}

View file

@ -1,30 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.remote.inference.together_openai_compat.config import TogetherCompatConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from ..together.models import MODEL_ENTRIES
class TogetherCompatInferenceAdapter(LiteLLMOpenAIMixin):
_config: TogetherCompatConfig
def __init__(self, config: TogetherCompatConfig):
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
api_key_from_config=config.api_key,
provider_data_api_key_field="together_api_key",
openai_compat_api_base=config.openai_compat_api_base,
)
self.config = config
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()

View file

@ -33,10 +33,6 @@ class VLLMInferenceAdapterConfig(BaseModel):
default=False,
description="Whether to refresh models periodically",
)
refresh_models_interval: int = Field(
default=300,
description="Interval in seconds to refresh models",
)
@field_validator("tls_verify")
@classmethod

View file

@ -3,7 +3,6 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
@ -293,7 +292,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
# automatically set by the resolver when instantiating the provider
__provider_id__: str
model_store: ModelStore | None = None
_refresh_task: asyncio.Task | None = None
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
@ -302,64 +300,32 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def initialize(self) -> None:
if not self.config.url:
# intentionally don't raise an error here, we want to allow the provider to be "dormant"
# or available in distributions like "starter" without causing a ruckus
return
raise ValueError(
"You must provide a URL in run.yaml (or via the VLLM_URL environment variable) to use vLLM."
)
if self.config.refresh_models:
self._refresh_task = asyncio.create_task(self._refresh_models())
def cb(task):
import traceback
if task.cancelled():
log.error(f"vLLM background refresh task canceled:\n{''.join(traceback.format_stack())}")
elif task.exception():
# print the stack trace for the exception
exc = task.exception()
log.error(f"vLLM background refresh task died: {exc}")
traceback.print_exception(exc)
else:
log.error("vLLM background refresh task completed unexpectedly")
self._refresh_task.add_done_callback(cb)
async def _refresh_models(self) -> None:
provider_id = self.__provider_id__
waited_time = 0
while not self.model_store and waited_time < 60:
await asyncio.sleep(1)
waited_time += 1
if not self.model_store:
raise ValueError("Model store not set after waiting 60 seconds")
async def should_refresh_models(self) -> bool:
return self.config.refresh_models
async def list_models(self) -> list[Model] | None:
self._lazy_initialize_client()
assert self.client is not None # mypy
while True:
try:
models = []
async for m in self.client.models.list():
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
models.append(
Model(
identifier=m.id,
provider_resource_id=m.id,
provider_id=provider_id,
metadata={},
model_type=model_type,
)
)
await self.model_store.update_registered_llm_models(provider_id, models)
log.debug(f"vLLM refreshed model list ({len(models)} models)")
except Exception as e:
log.error(f"vLLM background refresh task failed: {e}")
await asyncio.sleep(self.config.refresh_models_interval)
models = []
async for m in self.client.models.list():
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
models.append(
Model(
identifier=m.id,
provider_resource_id=m.id,
provider_id=self.__provider_id__,
metadata={},
model_type=model_type,
)
)
return models
async def shutdown(self) -> None:
if self._refresh_task:
self._refresh_task.cancel()
self._refresh_task = None
pass
async def unregister_model(self, model_id: str) -> None:
pass
@ -374,9 +340,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
HealthResponse: A dictionary containing the health status.
"""
try:
if not self.config.url:
return HealthResponse(status=HealthStatus.ERROR, message="vLLM URL is not set")
client = self._create_client() if self.client is None else self.client
_ = [m async for m in client.models.list()] # Ensure the client is initialized
return HealthResponse(status=HealthStatus.OK)
@ -392,11 +355,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
if self.client is not None:
return
if not self.config.url:
raise ValueError(
"You must provide a vLLM URL in the run.yaml file (or set the VLLM_URL environment variable)"
)
log.info(f"Initializing vLLM client with base_url={self.config.url}")
self.client = self._create_client()

View file

@ -30,7 +30,7 @@ class SambaNovaSafetyConfig(BaseModel):
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]:
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {
"url": "https://api.sambanova.ai/v1",
"api_key": api_key,

Some files were not shown because too many files have changed in this diff Show more