Merge branch 'main' into add-mcp-streamable-http-support

This commit is contained in:
Calum Murray 2025-07-18 14:38:54 -04:00 committed by GitHub
commit c715f30e65
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
247 changed files with 9685 additions and 5249 deletions

View file

@ -19,8 +19,10 @@ class PostTrainingMetric(BaseModel):
perplexity: float
@json_schema_type(schema={"description": "Checkpoint created during training runs"})
@json_schema_type
class Checkpoint(BaseModel):
"""Checkpoint created during training runs"""
identifier: str
created_at: datetime
epoch: int

View file

@ -7,7 +7,7 @@
from enum import StrEnum
from typing import Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, field_validator
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@ -36,13 +36,21 @@ class Model(CommonModelFields, Resource):
return self.identifier
@property
def provider_model_id(self) -> str | None:
def provider_model_id(self) -> str:
assert self.provider_resource_id is not None, "Provider resource ID must be set"
return self.provider_resource_id
model_config = ConfigDict(protected_namespaces=())
model_type: ModelType = Field(default=ModelType.llm)
@field_validator("provider_resource_id")
@classmethod
def validate_provider_resource_id(cls, v):
if v is None:
raise ValueError("provider_resource_id cannot be None")
return v
class ModelInput(CommonModelFields):
model_id: str

View file

@ -87,6 +87,20 @@ class RAGQueryGenerator(Enum):
custom = "custom"
@json_schema_type
class RAGSearchMode(Enum):
"""
Search modes for RAG query retrieval:
- VECTOR: Uses vector similarity search for semantic matching
- KEYWORD: Uses keyword-based search for exact matching
- HYBRID: Combines both vector and keyword search for better results
"""
VECTOR = "vector"
KEYWORD = "keyword"
HYBRID = "hybrid"
@json_schema_type
class DefaultRAGQueryGeneratorConfig(BaseModel):
type: Literal["default"] = "default"
@ -128,7 +142,7 @@ class RAGQueryConfig(BaseModel):
max_tokens_in_context: int = 4096
max_chunks: int = 5
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
mode: str | None = None
mode: RAGSearchMode | None = RAGSearchMode.VECTOR
ranker: Ranker | None = Field(default=None) # Only used for hybrid mode
@field_validator("chunk_template")

View file

@ -19,6 +19,7 @@ class VectorDB(Resource):
embedding_model: str
embedding_dimension: int
vector_db_name: str | None = None
@property
def vector_db_id(self) -> str:
@ -70,6 +71,7 @@ class VectorDBs(Protocol):
embedding_model: str,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
vector_db_name: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorDB:
"""Register a vector database.
@ -78,6 +80,7 @@ class VectorDBs(Protocol):
:param embedding_model: The embedding model to use.
:param embedding_dimension: The dimension of the embedding model.
:param provider_id: The identifier of the provider.
:param vector_db_name: The name of the vector database.
:param provider_vector_db_id: The identifier of the vector database in the provider.
:returns: A VectorDB.
"""

View file

@ -346,7 +346,6 @@ class VectorIO(Protocol):
embedding_model: str | None = None,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorStoreObject:
"""Creates a vector store.
@ -358,7 +357,6 @@ class VectorIO(Protocol):
:param embedding_model: The embedding model to use for this vector store.
:param embedding_dimension: The dimension of the embedding vectors (default: 384).
:param provider_id: The ID of the provider to use for this vector store.
:param provider_vector_db_id: The provider-specific vector database ID.
:returns: A VectorStoreObject representing the created vector store.
"""
...

View file

@ -93,7 +93,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
)
sys.exit(1)
elif args.providers:
providers = dict()
providers_list: dict[str, str | list[str]] = dict()
for api_provider in args.providers.split(","):
if "=" not in api_provider:
cprint(
@ -112,7 +112,15 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
)
sys.exit(1)
if provider in providers_for_api:
providers.setdefault(api, []).append(provider)
if api not in providers_list:
providers_list[api] = []
# Use type guarding to ensure we have a list
provider_value = providers_list[api]
if isinstance(provider_value, list):
provider_value.append(provider)
else:
# Convert string to list and append
providers_list[api] = [provider_value, provider]
else:
cprint(
f"{provider} is not a valid provider for the {api} API.",
@ -121,7 +129,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
)
sys.exit(1)
distribution_spec = DistributionSpec(
providers=providers,
providers=providers_list,
description=",".join(args.providers),
)
if not args.image_type:
@ -182,7 +190,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
cprint("Tip: use <TAB> to see options for the providers.\n", color="green", file=sys.stderr)
providers = dict()
providers: dict[str, str | list[str]] = dict()
for api, providers_for_api in get_provider_registry().items():
available_providers = [x for x in providers_for_api.keys() if x not in ("remote", "remote::sample")]
if not available_providers:
@ -371,10 +379,16 @@ def _run_stack_build_command_from_build_config(
if not image_name:
raise ValueError("Please specify an image name when building a venv image")
# At this point, image_name should be guaranteed to be a string
if image_name is None:
raise ValueError("image_name should not be None after validation")
if template_name:
build_dir = DISTRIBS_BASE_DIR / template_name
build_file_path = build_dir / f"{template_name}-build.yaml"
else:
if image_name is None:
raise ValueError("image_name cannot be None")
build_dir = DISTRIBS_BASE_DIR / image_name
build_file_path = build_dir / f"{image_name}-build.yaml"
@ -395,7 +409,7 @@ def _run_stack_build_command_from_build_config(
build_file_path,
image_name,
template_or_config=template_name or config_path or str(build_file_path),
run_config=run_config_file,
run_config=run_config_file.as_posix() if run_config_file else None,
)
if return_code != 0:
raise RuntimeError(f"Failed to build image {image_name}")
@ -403,15 +417,16 @@ def _run_stack_build_command_from_build_config(
if template_name:
# copy run.yaml from template to build_dir instead of generating it again
template_path = importlib.resources.files("llama_stack") / f"templates/{template_name}/run.yaml"
run_config_file = build_dir / f"{template_name}-run.yaml"
with importlib.resources.as_file(template_path) as path:
run_config_file = build_dir / f"{template_name}-run.yaml"
shutil.copy(path, run_config_file)
cprint("Build Successful!", color="green", file=sys.stderr)
cprint(f"You can find the newly-built template here: {template_path}", color="blue", file=sys.stderr)
cprint(f"You can find the newly-built template here: {run_config_file}", color="blue", file=sys.stderr)
cprint(
"You can run the new Llama Stack distro via: "
+ colored(f"llama stack run {template_path} --image-type {build_config.image_type}", "blue"),
+ colored(f"llama stack run {run_config_file} --image-type {build_config.image_type}", "blue"),
color="green",
file=sys.stderr,
)

View file

@ -47,8 +47,7 @@ class StackRun(Subcommand):
self.parser.add_argument(
"--image-name",
type=str,
default=os.environ.get("CONDA_DEFAULT_ENV"),
help="Name of the image to run. Defaults to the current environment",
help="Name of the image to run.",
)
self.parser.add_argument(
"--env",
@ -83,46 +82,57 @@ class StackRun(Subcommand):
return ImageType.CONDA.value, args.image_name
return args.image_type, args.image_name
def _resolve_config_and_template(self, args: argparse.Namespace) -> tuple[Path | None, str | None]:
"""Resolve config file path and template name from args.config"""
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
if not args.config:
return None, None
config_file = Path(args.config)
has_yaml_suffix = args.config.endswith(".yaml")
template_name = None
if not config_file.exists() and not has_yaml_suffix:
# check if this is a template
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
if config_file.exists():
template_name = args.config
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to ~/.llama dir
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
if not config_file.exists():
self.parser.error(
f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
)
if not config_file.is_file():
self.parser.error(
f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}"
)
return config_file, template_name
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
import yaml
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.utils.exec import formulate_run_args, run_command
if args.enable_ui:
self._start_ui_development_server(args.port)
image_type, image_name = self._get_image_type_and_name(args)
# Resolve config file and template name first
config_file, template_name = self._resolve_config_and_template(args)
# Check if config is required based on image type
if (image_type in [ImageType.CONDA.value, ImageType.VENV.value]) and not args.config:
if (image_type in [ImageType.CONDA.value, ImageType.VENV.value]) and not config_file:
self.parser.error("Config file is required for venv and conda environments")
if args.config:
config_file = Path(args.config)
has_yaml_suffix = args.config.endswith(".yaml")
template_name = None
if not config_file.exists() and not has_yaml_suffix:
# check if this is a template
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
if config_file.exists():
template_name = args.config
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to ~/.llama dir
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
if not config_file.exists():
self.parser.error(
f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
)
if not config_file.is_file():
self.parser.error(
f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}"
)
if config_file:
logger.info(f"Using run configuration: {config_file}")
try:
@ -138,8 +148,6 @@ class StackRun(Subcommand):
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
else:
config = None
config_file = None
template_name = None
# If neither image type nor image name is provided, assume the server should be run directly
# using the current environment packages.
@ -155,8 +163,12 @@ class StackRun(Subcommand):
# func=<bound method StackRun._run_stack_run_cmd of <llama_stack.cli.stack.run.StackRun object at 0x10484b010>>
if callable(getattr(args, arg)):
continue
if arg == "config" and template_name:
server_args.config = str(config_file)
if arg == "config":
if template_name:
server_args.template = str(template_name)
else:
# Set the config file path
server_args.config = str(config_file)
else:
setattr(server_args, arg, getattr(args, arg))

View file

@ -81,7 +81,7 @@ def is_action_allowed(
if not len(policy):
policy = default_policy()
qualified_resource_id = resource.type + "::" + resource.identifier
qualified_resource_id = f"{resource.type}::{resource.identifier}"
for rule in policy:
if rule.forbid and matches_scope(rule.forbid, action, qualified_resource_id, user.principal):
if rule.when:

View file

@ -96,7 +96,7 @@ FROM $container_base
WORKDIR /app
# We install the Python 3.12 dev headers and build tools so that any
# Cextension wheels (e.g. polyleven, faisscpu) can compile successfully.
# C-extension wheels (e.g. polyleven, faiss-cpu) can compile successfully.
RUN dnf -y update && dnf install -y iputils git net-tools wget \
vim-minimal python3.12 python3.12-pip python3.12-wheel \
@ -169,7 +169,7 @@ if [ -n "$run_config" ]; then
echo "Copying external providers directory: $external_providers_dir"
cp -r "$external_providers_dir" "$BUILD_CONTEXT_DIR/providers.d"
add_to_container << EOF
COPY --chmod=g+w providers.d /.llama/providers.d
COPY providers.d /.llama/providers.d
EOF
fi

View file

@ -17,7 +17,7 @@ from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis,
get_provider_registry,
)
from llama_stack.distribution.stack import replace_env_vars
from llama_stack.distribution.stack import cast_image_name_to_string, replace_env_vars
from llama_stack.distribution.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
@ -164,7 +164,8 @@ def upgrade_from_routing_table(
def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfig:
version = config_dict.get("version", None)
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
return StackRunConfig(**replace_env_vars(config_dict))
processed_config_dict = replace_env_vars(config_dict)
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
if "routing_table" in config_dict:
logger.info("Upgrading config...")
@ -175,4 +176,5 @@ def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfi
if not config_dict.get("external_providers_dir", None):
config_dict["external_providers_dir"] = EXTERNAL_PROVIDERS_DIR
return StackRunConfig(**replace_env_vars(config_dict))
processed_config_dict = replace_env_vars(config_dict)
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))

View file

@ -6,9 +6,9 @@
from enum import StrEnum
from pathlib import Path
from typing import Annotated, Any
from typing import Annotated, Any, Literal, Self
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field, field_validator, model_validator
from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput
from llama_stack.apis.datasetio import DatasetIO
@ -161,23 +161,113 @@ class LoggingConfig(BaseModel):
)
class OAuth2JWKSConfig(BaseModel):
# The JWKS URI for collecting public keys
uri: str
token: str | None = Field(default=None, description="token to authorise access to jwks")
key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates")
class OAuth2IntrospectionConfig(BaseModel):
url: str
client_id: str
client_secret: str
send_secret_in_body: bool = False
class AuthProviderType(StrEnum):
"""Supported authentication provider types."""
OAUTH2_TOKEN = "oauth2_token"
GITHUB_TOKEN = "github_token"
CUSTOM = "custom"
class OAuth2TokenAuthConfig(BaseModel):
"""Configuration for OAuth2 token authentication."""
type: Literal[AuthProviderType.OAUTH2_TOKEN] = AuthProviderType.OAUTH2_TOKEN
audience: str = Field(default="llama-stack")
verify_tls: bool = Field(default=True)
tls_cafile: Path | None = Field(default=None)
issuer: str | None = Field(default=None, description="The OIDC issuer URL.")
claims_mapping: dict[str, str] = Field(
default_factory=lambda: {
"sub": "roles",
"username": "roles",
"groups": "teams",
"team": "teams",
"project": "projects",
"tenant": "namespaces",
"namespace": "namespaces",
},
)
jwks: OAuth2JWKSConfig | None = Field(default=None, description="JWKS configuration")
introspection: OAuth2IntrospectionConfig | None = Field(
default=None, description="OAuth2 introspection configuration"
)
@classmethod
@field_validator("claims_mapping")
def validate_claims_mapping(cls, v):
for key, value in v.items():
if not value:
raise ValueError(f"claims_mapping value cannot be empty: {key}")
return v
@model_validator(mode="after")
def validate_mode(self) -> Self:
if not self.jwks and not self.introspection:
raise ValueError("One of jwks or introspection must be configured")
if self.jwks and self.introspection:
raise ValueError("At present only one of jwks or introspection should be configured")
return self
class CustomAuthConfig(BaseModel):
"""Configuration for custom authentication."""
type: Literal[AuthProviderType.CUSTOM] = AuthProviderType.CUSTOM
endpoint: str = Field(
...,
description="Custom authentication endpoint URL",
)
class GitHubTokenAuthConfig(BaseModel):
"""Configuration for GitHub token authentication."""
type: Literal[AuthProviderType.GITHUB_TOKEN] = AuthProviderType.GITHUB_TOKEN
github_api_base_url: str = Field(
default="https://api.github.com",
description="Base URL for GitHub API (use https://api.github.com for public GitHub)",
)
claims_mapping: dict[str, str] = Field(
default_factory=lambda: {
"login": "roles",
"organizations": "teams",
},
description="Mapping from GitHub user fields to access attributes",
)
AuthProviderConfig = Annotated[
OAuth2TokenAuthConfig | GitHubTokenAuthConfig | CustomAuthConfig,
Field(discriminator="type"),
]
class AuthenticationConfig(BaseModel):
provider_type: AuthProviderType = Field(
"""Top-level authentication configuration."""
provider_config: AuthProviderConfig = Field(
...,
description="Type of authentication provider",
description="Authentication provider configuration",
)
config: dict[str, Any] = Field(
...,
description="Provider-specific configuration",
access_policy: list[AccessRule] = Field(
default=[],
description="Rules for determining access to resources",
)
access_policy: list[AccessRule] = Field(default=[], description="Rules for determining access to resources")
class AuthenticationRequiredError(Exception):

View file

@ -200,7 +200,7 @@ def validate_and_prepare_providers(
specs = {}
for provider in providers:
if not provider.provider_id or provider.provider_id == "__disabled__":
logger.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
logger.debug(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
continue
validate_provider(provider, api, provider_registry)

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import asyncio
import uuid
from typing import Any
from llama_stack.apis.common.content_types import (
@ -81,6 +82,7 @@ class VectorIORouter(VectorIO):
embedding_model: str,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
vector_db_name: str | None = None,
provider_vector_db_id: str | None = None,
) -> None:
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
@ -89,6 +91,7 @@ class VectorIORouter(VectorIO):
embedding_model,
embedding_dimension,
provider_id,
vector_db_name,
provider_vector_db_id,
)
@ -123,7 +126,6 @@ class VectorIORouter(VectorIO):
embedding_model: str | None = None,
embedding_dimension: int | None = None,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}")
@ -135,17 +137,17 @@ class VectorIORouter(VectorIO):
embedding_model, embedding_dimension = embedding_model_info
logger.info(f"No embedding model specified, using first available: {embedding_model}")
vector_db_id = name
vector_db_id = f"vs_{uuid.uuid4()}"
registered_vector_db = await self.routing_table.register_vector_db(
vector_db_id,
embedding_model,
embedding_dimension,
provider_id,
provider_vector_db_id,
vector_db_id=vector_db_id,
embedding_model=embedding_model,
embedding_dimension=embedding_dimension,
provider_id=provider_id,
provider_vector_db_id=vector_db_id,
vector_db_name=name,
)
return await self.routing_table.get_provider_impl(registered_vector_db.identifier).openai_create_vector_store(
vector_db_id,
name=name,
file_ids=file_ids,
expires_after=expires_after,
chunking_strategy=chunking_strategy,

View file

@ -36,6 +36,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
vector_db_name: str | None = None,
) -> VectorDB:
if provider_vector_db_id is None:
provider_vector_db_id = vector_db_id
@ -62,6 +63,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
"provider_resource_id": provider_vector_db_id,
"embedding_model": embedding_model,
"embedding_dimension": model.metadata["embedding_dimension"],
"vector_db_name": vector_db_name,
}
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
await self.register_object(vector_db)

View file

@ -87,8 +87,12 @@ class AuthenticationMiddleware:
headers = dict(scope.get("headers", []))
auth_header = headers.get(b"authorization", b"").decode()
if not auth_header or not auth_header.startswith("Bearer "):
return await self._send_auth_error(send, "Missing or invalid Authorization header")
if not auth_header:
error_msg = self.auth_provider.get_auth_error_message(scope)
return await self._send_auth_error(send, error_msg)
if not auth_header.startswith("Bearer "):
return await self._send_auth_error(send, "Invalid Authorization header format")
token = auth_header.split("Bearer ", 1)[1]

View file

@ -8,15 +8,19 @@ import ssl
import time
from abc import ABC, abstractmethod
from asyncio import Lock
from pathlib import Path
from typing import Self
from urllib.parse import parse_qs
from urllib.parse import parse_qs, urlparse
import httpx
from jose import jwt
from pydantic import BaseModel, Field, field_validator, model_validator
from pydantic import BaseModel, Field
from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, User
from llama_stack.distribution.datatypes import (
AuthenticationConfig,
CustomAuthConfig,
GitHubTokenAuthConfig,
OAuth2TokenAuthConfig,
User,
)
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
@ -38,9 +42,7 @@ class AuthRequestContext(BaseModel):
headers: dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
params: dict[str, list[str]] = Field(
description="Query parameters from the original request, parsed as dictionary of lists"
)
params: dict[str, list[str]] = Field(default_factory=dict, description="Query parameters from the original request")
class AuthRequest(BaseModel):
@ -62,6 +64,10 @@ class AuthProvider(ABC):
"""Clean up any resources."""
pass
def get_auth_error_message(self, scope: dict | None = None) -> str:
"""Return provider-specific authentication error message."""
return "Authentication required"
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]:
attributes: dict[str, list[str]] = {}
@ -81,56 +87,6 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str])
return attributes
class OAuth2JWKSConfig(BaseModel):
# The JWKS URI for collecting public keys
uri: str
token: str | None = Field(default=None, description="token to authorise access to jwks")
key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates")
class OAuth2IntrospectionConfig(BaseModel):
url: str
client_id: str
client_secret: str
send_secret_in_body: bool = False
class OAuth2TokenAuthProviderConfig(BaseModel):
audience: str = "llama-stack"
verify_tls: bool = True
tls_cafile: Path | None = None
issuer: str | None = Field(default=None, description="The OIDC issuer URL.")
claims_mapping: dict[str, str] = Field(
default_factory=lambda: {
"sub": "roles",
"username": "roles",
"groups": "teams",
"team": "teams",
"project": "projects",
"tenant": "namespaces",
"namespace": "namespaces",
},
)
jwks: OAuth2JWKSConfig | None
introspection: OAuth2IntrospectionConfig | None = None
@classmethod
@field_validator("claims_mapping")
def validate_claims_mapping(cls, v):
for key, value in v.items():
if not value:
raise ValueError(f"claims_mapping value cannot be empty: {key}")
return v
@model_validator(mode="after")
def validate_mode(self) -> Self:
if not self.jwks and not self.introspection:
raise ValueError("One of jwks or introspection must be configured")
if self.jwks and self.introspection:
raise ValueError("At present only one of jwks or introspection should be configured")
return self
class OAuth2TokenAuthProvider(AuthProvider):
"""
JWT token authentication provider that validates a JWT token and extracts access attributes.
@ -138,7 +94,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
This should be the standard authentication provider for most use cases.
"""
def __init__(self, config: OAuth2TokenAuthProviderConfig):
def __init__(self, config: OAuth2TokenAuthConfig):
self.config = config
self._jwks_at: float = 0.0
self._jwks: dict[str, str] = {}
@ -170,7 +126,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
issuer=self.config.issuer,
)
except Exception as exc:
raise ValueError(f"Invalid JWT token: {token}") from exc
raise ValueError("Invalid JWT token") from exc
# There are other standard claims, the most relevant of which is `scope`.
# We should incorporate these into the access attributes.
@ -232,6 +188,17 @@ class OAuth2TokenAuthProvider(AuthProvider):
async def close(self):
pass
def get_auth_error_message(self, scope: dict | None = None) -> str:
"""Return OAuth2-specific authentication error message."""
if self.config.issuer:
return f"Authentication required. Please provide a valid OAuth2 Bearer token from {self.config.issuer}"
elif self.config.introspection:
# Extract domain from introspection URL for a cleaner message
domain = urlparse(self.config.introspection.url).netloc
return f"Authentication required. Please provide a valid OAuth2 Bearer token validated by {domain}"
else:
return "Authentication required. Please provide a valid OAuth2 Bearer token in the Authorization header"
async def _refresh_jwks(self) -> None:
"""
Refresh the JWKS cache.
@ -264,14 +231,10 @@ class OAuth2TokenAuthProvider(AuthProvider):
self._jwks_at = time.time()
class CustomAuthProviderConfig(BaseModel):
endpoint: str
class CustomAuthProvider(AuthProvider):
"""Custom authentication provider that uses an external endpoint."""
def __init__(self, config: CustomAuthProviderConfig):
def __init__(self, config: CustomAuthConfig):
self.config = config
self._client = None
@ -317,7 +280,7 @@ class CustomAuthProvider(AuthProvider):
try:
response_data = response.json()
auth_response = AuthResponse(**response_data)
return User(auth_response.principal, auth_response.attributes)
return User(principal=auth_response.principal, attributes=auth_response.attributes)
except Exception as e:
logger.exception("Error parsing authentication response")
raise ValueError("Invalid authentication response format") from e
@ -338,15 +301,88 @@ class CustomAuthProvider(AuthProvider):
await self._client.aclose()
self._client = None
def get_auth_error_message(self, scope: dict | None = None) -> str:
"""Return custom auth provider-specific authentication error message."""
domain = urlparse(self.config.endpoint).netloc
if domain:
return f"Authentication required. Please provide your API key as a Bearer token (validated by {domain})"
else:
return "Authentication required. Please provide your API key as a Bearer token in the Authorization header"
class GitHubTokenAuthProvider(AuthProvider):
"""
GitHub token authentication provider that validates GitHub access tokens directly.
This provider accepts GitHub personal access tokens or OAuth tokens and verifies
them against the GitHub API to get user information.
"""
def __init__(self, config: GitHubTokenAuthConfig):
self.config = config
async def validate_token(self, token: str, scope: dict | None = None) -> User:
"""Validate a GitHub token by calling the GitHub API.
This validates tokens issued by GitHub (personal access tokens or OAuth tokens).
"""
try:
user_info = await _get_github_user_info(token, self.config.github_api_base_url)
except httpx.HTTPStatusError as e:
logger.warning(f"GitHub token validation failed: {e}")
raise ValueError("GitHub token validation failed. Please check your token and try again.") from e
principal = user_info["user"]["login"]
github_data = {
"login": user_info["user"]["login"],
"id": str(user_info["user"]["id"]),
"organizations": user_info.get("organizations", []),
}
access_attributes = get_attributes_from_claims(github_data, self.config.claims_mapping)
return User(
principal=principal,
attributes=access_attributes,
)
async def close(self):
"""Clean up any resources."""
pass
def get_auth_error_message(self, scope: dict | None = None) -> str:
"""Return GitHub-specific authentication error message."""
return "Authentication required. Please provide a valid GitHub access token (https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens) in the Authorization header (Bearer <token>)"
async def _get_github_user_info(access_token: str, github_api_base_url: str) -> dict:
"""Fetch user info and organizations from GitHub API."""
headers = {
"Authorization": f"Bearer {access_token}",
"Accept": "application/vnd.github.v3+json",
"User-Agent": "llama-stack",
}
async with httpx.AsyncClient() as client:
user_response = await client.get(f"{github_api_base_url}/user", headers=headers, timeout=10.0)
user_response.raise_for_status()
user_data = user_response.json()
return {
"user": user_data,
}
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
"""Factory function to create the appropriate auth provider."""
provider_type = config.provider_type.lower()
provider_config = config.provider_config
if provider_type == "custom":
return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config))
elif provider_type == "oauth2_token":
return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config))
if isinstance(provider_config, CustomAuthConfig):
return CustomAuthProvider(provider_config)
elif isinstance(provider_config, OAuth2TokenAuthConfig):
return OAuth2TokenAuthProvider(provider_config)
elif isinstance(provider_config, GitHubTokenAuthConfig):
return GitHubTokenAuthProvider(provider_config)
else:
supported_providers = ", ".join([t.value for t in AuthProviderType])
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")
raise ValueError(f"Unknown authentication provider config type: {type(provider_config)}")

View file

@ -33,7 +33,11 @@ from pydantic import BaseModel, ValidationError
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.distribution.access_control.access_control import AccessDeniedError
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
from llama_stack.distribution.datatypes import (
AuthenticationRequiredError,
LoggingConfig,
StackRunConfig,
)
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
from llama_stack.distribution.resolver import InvalidProviderError
@ -43,6 +47,7 @@ from llama_stack.distribution.server.routes import (
initialize_route_impls,
)
from llama_stack.distribution.stack import (
cast_image_name_to_string,
construct_stack,
replace_env_vars,
validate_env_pair,
@ -217,7 +222,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
# Get auth attributes from the request scope
user_attributes = request.scope.get("user_attributes", {})
principal = request.scope.get("principal", "")
user = User(principal, user_attributes)
user = User(principal=principal, attributes=user_attributes)
await log_request_pre_validation(request)
@ -405,13 +410,13 @@ def main(args: argparse.Namespace | None = None):
args = parser.parse_args()
log_line = ""
if args.config:
if hasattr(args, "config") and args.config:
# if the user provided a config file, use it, even if template was specified
config_file = Path(args.config)
if not config_file.exists():
raise ValueError(f"Config file {config_file} does not exist")
log_line = f"Using config file: {config_file}"
elif args.template:
elif hasattr(args, "template") and args.template:
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
if not config_file.exists():
raise ValueError(f"Template {args.template} does not exist")
@ -435,14 +440,12 @@ def main(args: argparse.Namespace | None = None):
logger.error(f"Error: {str(e)}")
sys.exit(1)
config = replace_env_vars(config_contents)
config = StackRunConfig(**config)
config = StackRunConfig(**cast_image_name_to_string(config))
# now that the logger is initialized, print the line about which type of config we are using.
logger.info(log_line)
logger.info("Run configuration:")
safe_config = redact_sensitive_fields(config.model_dump())
logger.info(yaml.dump(safe_config, indent=2))
_log_run_config(run_config=config)
app = FastAPI(
lifespan=lifespan,
@ -450,12 +453,13 @@ def main(args: argparse.Namespace | None = None):
redoc_url="/redoc",
openapi_url="/openapi.json",
)
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
app.add_middleware(ClientVersionMiddleware)
# Add authentication middleware if configured
if config.server.auth:
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}")
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
else:
if config.server.quota:
@ -488,7 +492,13 @@ def main(args: argparse.Namespace | None = None):
)
try:
impls = asyncio.run(construct_stack(config))
# Create and set the event loop that will be used for both construction and server runtime
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Construct the stack in the persistent event loop
impls = loop.run_until_complete(construct_stack(config))
except InvalidProviderError as e:
logger.error(f"Error: {str(e)}")
sys.exit(1)
@ -586,7 +596,16 @@ def main(args: argparse.Namespace | None = None):
if ssl_config:
uvicorn_config.update(ssl_config)
uvicorn.run(**uvicorn_config)
# Run uvicorn in the existing event loop to preserve background tasks
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
def _log_run_config(run_config: StackRunConfig):
"""Logs the run config with redacted fields and disabled providers removed."""
logger.info("Run configuration:")
safe_config = redact_sensitive_fields(run_config.model_dump(mode="json"))
clean_config = remove_disabled_providers(safe_config)
logger.info(yaml.dump(clean_config, indent=2))
def extract_path_params(route: str) -> list[str]:
@ -597,5 +616,20 @@ def extract_path_params(route: str) -> list[str]:
return params
def remove_disabled_providers(obj):
if isinstance(obj, dict):
if (
obj.get("provider_id") == "__disabled__"
or obj.get("shield_id") == "__disabled__"
or obj.get("provider_model_id") == "__disabled__"
):
return None
return {k: v for k, v in ((k, remove_disabled_providers(v)) for k, v in obj.items()) if v is not None}
elif isinstance(obj, list):
return [item for item in (remove_disabled_providers(i) for i in obj) if item is not None]
else:
return obj
if __name__ == "__main__":
main()

View file

@ -98,6 +98,7 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
method = getattr(impls[api], register_method)
for obj in objects:
logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
# Do not register models on disabled providers
if hasattr(obj, "provider_id") and obj.provider_id is not None and obj.provider_id == "__disabled__":
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
@ -112,6 +113,11 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
):
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled model.")
continue
if hasattr(obj, "shield_id") and obj.shield_id is not None and obj.shield_id == "__disabled__":
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled shield.")
continue
# we want to maintain the type information in arguments to method.
# instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
# we use model_dump() to find all the attrs and then getattr to get the still typed value.
@ -166,7 +172,6 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
# Create a copy with resolved provider_id but original config
disabled_provider = v.copy()
disabled_provider["provider_id"] = resolved_provider_id
result.append(disabled_provider)
continue
except EnvVarError:
# If we can't resolve the provider_id, continue with normal processing
@ -261,6 +266,13 @@ def _convert_string_to_proper_type(value: str) -> Any:
return value
def cast_image_name_to_string(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Ensure that any value for a key 'image_name' in a config_dict is a string"""
if "image_name" in config_dict and config_dict["image_name"] is not None:
config_dict["image_name"] = str(config_dict["image_name"])
return config_dict
def validate_env_pair(env_pair: str) -> tuple[str, str]:
"""Validate and split an environment variable key-value pair."""
try:

View file

@ -6,12 +6,9 @@
from collections.abc import AsyncGenerator
from contextvars import ContextVar
from typing import TypeVar
T = TypeVar("T")
def preserve_contexts_async_generator(
def preserve_contexts_async_generator[T](
gen: AsyncGenerator[T, None], context_vars: list[ContextVar]
) -> AsyncGenerator[T, None]:
"""

View file

@ -8,6 +8,7 @@ import io
import json
import uuid
from dataclasses import dataclass
from typing import Any
from PIL import Image as PIL_Image
@ -184,16 +185,26 @@ class ChatFormat:
content = content[: -len("<|eom_id|>")]
stop_reason = StopReason.end_of_message
tool_name = None
tool_arguments = {}
tool_name: str | BuiltinTool | None = None
tool_arguments: dict[str, Any] = {}
custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content)
if custom_tool_info is not None:
tool_name, tool_arguments = custom_tool_info
# Type guard: ensure custom_tool_info is a tuple of correct types
if isinstance(custom_tool_info, tuple) and len(custom_tool_info) == 2:
extracted_tool_name, extracted_tool_arguments = custom_tool_info
# Handle both dict and str return types from the function
if isinstance(extracted_tool_arguments, dict):
tool_name, tool_arguments = extracted_tool_name, extracted_tool_arguments
else:
# If it's a string, treat it as a query parameter
tool_name, tool_arguments = extracted_tool_name, {"query": extracted_tool_arguments}
else:
tool_name, tool_arguments = None, {}
# Sometimes when agent has custom tools alongside builin tools
# Agent responds for builtin tool calls in the format of the custom tools
# This code tries to handle that case
if tool_name in BuiltinTool.__members__:
if tool_name is not None and tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name]
if isinstance(tool_arguments, dict):
tool_arguments = {

View file

@ -178,6 +178,7 @@ def usecases() -> list[UseCase | str]:
),
RawMessage(role="user", content="What is the 100th decimal of pi?"),
RawMessage(
role="assistant",
content="",
stop_reason=StopReason.end_of_message,
tool_calls=[

View file

@ -24,8 +24,8 @@ class ShieldRunnerMixin:
def __init__(
self,
safety_api: Safety,
input_shields: list[str] = None,
output_shields: list[str] = None,
input_shields: list[str] | None = None,
output_shields: list[str] | None = None,
):
self.safety_api = safety_api
self.input_shields = input_shields
@ -37,6 +37,7 @@ class ShieldRunnerMixin:
return await self.safety_api.run_shield(
shield_id=identifier,
messages=messages,
params={},
)
responses = await asyncio.gather(*[run_shield_with_span(identifier) for identifier in identifiers])

View file

@ -51,6 +51,9 @@ class LocalfsFilesImpl(Files):
},
)
async def shutdown(self) -> None:
pass
def _generate_file_id(self) -> str:
"""Generate a unique file ID for OpenAI API."""
return f"file-{uuid.uuid4().hex}"

View file

@ -39,7 +39,7 @@ class MetaReferenceInferenceConfig(BaseModel):
def validate_model(cls, model: str) -> str:
permitted_models = supported_inference_models()
descriptors = [m.descriptor() for m in permitted_models]
repos = [m.huggingface_repo for m in permitted_models]
repos = [m.huggingface_repo for m in permitted_models if m.huggingface_repo is not None]
if model not in (descriptors + repos):
model_list = "\n\t".join(repos)
raise ValueError(f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]")

View file

@ -98,7 +98,7 @@ class ProcessingMessageWrapper(BaseModel):
def mp_rank_0() -> bool:
return get_model_parallel_rank() == 0
return bool(get_model_parallel_rank() == 0)
def encode_msg(msg: ProcessingMessage) -> bytes:
@ -125,7 +125,7 @@ def retrieve_requests(reply_socket_url: str):
reply_socket.send_multipart([client_id, encode_msg(obj)])
while True:
tasks = [None]
tasks: list[ProcessingMessage | None] = [None]
if mp_rank_0():
client_id, maybe_task_json = maybe_get_work(reply_socket)
if maybe_task_json is not None:
@ -152,7 +152,7 @@ def retrieve_requests(reply_socket_url: str):
break
for obj in out:
updates = [None]
updates: list[ProcessingMessage | None] = [None]
if mp_rank_0():
_, update_json = maybe_get_work(reply_socket)
update = maybe_parse_message(update_json)

View file

@ -123,7 +123,8 @@ class TorchtunePostTrainingImpl:
training_config: TrainingConfig,
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
) -> PostTrainingJob: ...
) -> PostTrainingJob:
raise NotImplementedError()
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
return ListPostTrainingJobsResponse(

View file

@ -146,10 +146,9 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
pass
async def register_shield(self, shield: Shield) -> None:
if shield.provider_resource_id not in LLAMA_GUARD_MODEL_IDS:
raise ValueError(
f"Unsupported Llama Guard type: {shield.provider_resource_id}. Allowed types: {LLAMA_GUARD_MODEL_IDS}"
)
# Allow any model to be registered as a shield
# The model will be validated during runtime when making inference calls
pass
async def run_shield(
self,
@ -167,11 +166,25 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
if len(messages) > 0 and messages[0].role != Role.user.value:
messages[0] = UserMessage(content=messages[0].content)
model = LLAMA_GUARD_MODEL_IDS[shield.provider_resource_id]
# Use the inference API's model resolution instead of hardcoded mappings
# This allows the shield to work with any registered model
model_id = shield.provider_resource_id
# Determine safety categories based on the model type
# For known Llama Guard models, use specific categories
if model_id in LLAMA_GUARD_MODEL_IDS:
# Use the mapped model for categories but the original model_id for inference
mapped_model = LLAMA_GUARD_MODEL_IDS[model_id]
safety_categories = MODEL_TO_SAFETY_CATEGORIES_MAP.get(mapped_model, DEFAULT_LG_V3_SAFETY_CATEGORIES)
else:
# For unknown models, use default Llama Guard 3 8B categories
safety_categories = DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE]
impl = LlamaGuardShield(
model=model,
model=model_id,
inference_api=self.inference_api,
excluded_categories=self.config.excluded_categories,
safety_categories=safety_categories,
)
return await impl.run(messages)
@ -183,20 +196,21 @@ class LlamaGuardShield:
model: str,
inference_api: Inference,
excluded_categories: list[str] | None = None,
safety_categories: list[str] | None = None,
):
if excluded_categories is None:
excluded_categories = []
if safety_categories is None:
safety_categories = []
assert len(excluded_categories) == 0 or all(
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
if model not in MODEL_TO_SAFETY_CATEGORIES_MAP:
raise ValueError(f"Unsupported model: {model}")
self.model = model
self.inference_api = inference_api
self.excluded_categories = excluded_categories
self.safety_categories = safety_categories
def check_unsafe_response(self, response: str) -> str | None:
match = re.match(r"^unsafe\n(.*)$", response)
@ -214,7 +228,7 @@ class LlamaGuardShield:
final_categories = []
all_categories = MODEL_TO_SAFETY_CATEGORIES_MAP[self.model]
all_categories = self.safety_categories
for cat in all_categories:
cat_code = SAFETY_CATEGORIES_TO_CODE_MAP[cat]
if cat_code in excluded_categories:

View file

@ -181,8 +181,8 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
)
self.cache[vector_db.identifier] = index
# Load existing OpenAI vector stores using the mixin method
self.openai_vector_stores = await self._load_openai_vector_stores()
# Load existing OpenAI vector stores into the in-memory cache
await self.initialize_openai_vector_stores()
async def shutdown(self) -> None:
# Cleanup if needed
@ -261,42 +261,10 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
return await index.query_chunks(query, params)
# OpenAI Vector Store Mixin abstract method implementations
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
"""Load all vector store metadata from kvstore."""
assert self.kvstore is not None
start_key = OPENAI_VECTOR_STORES_PREFIX
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key)
stores = {}
for store_data in stored_openai_stores:
store_info = json.loads(store_data)
stores[store_info["id"]] = store_info
return stores
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Update vector store metadata in kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
"""Delete vector store metadata from kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.delete(key)
async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
) -> None:
"""Save vector store file metadata to kvstore."""
"""Save vector store file data to kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
await self.kvstore.set(key=key, value=json.dumps(file_info))
@ -324,7 +292,16 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
await self.kvstore.set(key=key, value=json.dumps(file_info))
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
"""Delete vector store file metadata from kvstore."""
"""Delete vector store data from kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
await self.kvstore.delete(key)
keys_to_delete = [
f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}",
f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}",
]
for key in keys_to_delete:
try:
await self.kvstore.delete(key)
except Exception as e:
logger.warning(f"Failed to delete key {key}: {e}")
continue

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import BaseModel
from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
@ -18,7 +18,8 @@ from llama_stack.schema_utils import json_schema_type
@json_schema_type
class MilvusVectorIOConfig(BaseModel):
db_path: str
kvstore: KVStoreConfig
kvstore: KVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)")
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:

View file

@ -6,14 +6,24 @@
from typing import Any
from pydantic import BaseModel
from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
class SQLiteVectorIOConfig(BaseModel):
db_path: str
db_path: str = Field(description="Path to the SQLite database file")
kvstore: KVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)")
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
return {
"db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + "sqlite_vec.db",
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="sqlite_vec_registry.db",
),
}

View file

@ -7,6 +7,7 @@
import asyncio
import json
import logging
import re
import sqlite3
import struct
from typing import Any
@ -24,6 +25,8 @@ from llama_stack.apis.vector_io import (
VectorIO,
)
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
RERANKER_TYPE_RRF,
@ -40,6 +43,13 @@ KEYWORD_SEARCH = "keyword"
HYBRID_SEARCH = "hybrid"
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH, HYBRID_SEARCH}
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:sqlite_vec:{VERSION}::"
VECTOR_INDEX_PREFIX = f"vector_index:sqlite_vec:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:sqlite_vec:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:sqlite_vec:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:sqlite_vec:{VERSION}::"
def serialize_vector(vector: list[float]) -> bytes:
"""Serialize a list of floats into a compact binary representation."""
@ -108,6 +118,10 @@ def _rrf_rerank(
return rrf_scores
def _make_sql_identifier(name: str) -> str:
return re.sub(r"[^a-zA-Z0-9_]", "_", name)
class SQLiteVecIndex(EmbeddingIndex):
"""
An index implementation that stores embeddings in a SQLite virtual table using sqlite-vec.
@ -117,13 +131,14 @@ class SQLiteVecIndex(EmbeddingIndex):
- An FTS5 table (fts_chunks_{bank_id}) for full-text keyword search.
"""
def __init__(self, dimension: int, db_path: str, bank_id: str):
def __init__(self, dimension: int, db_path: str, bank_id: str, kvstore: KVStore | None = None):
self.dimension = dimension
self.db_path = db_path
self.bank_id = bank_id
self.metadata_table = f"chunks_{bank_id}".replace("-", "_")
self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_")
self.fts_table = f"fts_chunks_{bank_id}".replace("-", "_")
self.metadata_table = _make_sql_identifier(f"chunks_{bank_id}")
self.vector_table = _make_sql_identifier(f"vec_chunks_{bank_id}")
self.fts_table = _make_sql_identifier(f"fts_chunks_{bank_id}")
self.kvstore = kvstore
@classmethod
async def create(cls, dimension: int, db_path: str, bank_id: str):
@ -138,14 +153,14 @@ class SQLiteVecIndex(EmbeddingIndex):
try:
# Create the table to store chunk metadata.
cur.execute(f"""
CREATE TABLE IF NOT EXISTS {self.metadata_table} (
CREATE TABLE IF NOT EXISTS [{self.metadata_table}] (
id TEXT PRIMARY KEY,
chunk TEXT
);
""")
# Create the virtual table for embeddings.
cur.execute(f"""
CREATE VIRTUAL TABLE IF NOT EXISTS {self.vector_table}
CREATE VIRTUAL TABLE IF NOT EXISTS [{self.vector_table}]
USING vec0(embedding FLOAT[{self.dimension}], id TEXT);
""")
connection.commit()
@ -153,7 +168,7 @@ class SQLiteVecIndex(EmbeddingIndex):
# based on query. Implementation of the change on client side will allow passing the search_mode option
# during initialization to make it easier to create the table that is required.
cur.execute(f"""
CREATE VIRTUAL TABLE IF NOT EXISTS {self.fts_table}
CREATE VIRTUAL TABLE IF NOT EXISTS [{self.fts_table}]
USING fts5(id, content);
""")
connection.commit()
@ -168,9 +183,9 @@ class SQLiteVecIndex(EmbeddingIndex):
connection = _create_sqlite_connection(self.db_path)
cur = connection.cursor()
try:
cur.execute(f"DROP TABLE IF EXISTS {self.metadata_table};")
cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};")
cur.execute(f"DROP TABLE IF EXISTS {self.fts_table};")
cur.execute(f"DROP TABLE IF EXISTS [{self.metadata_table}];")
cur.execute(f"DROP TABLE IF EXISTS [{self.vector_table}];")
cur.execute(f"DROP TABLE IF EXISTS [{self.fts_table}];")
connection.commit()
finally:
cur.close()
@ -202,7 +217,7 @@ class SQLiteVecIndex(EmbeddingIndex):
metadata_data = [(chunk.chunk_id, chunk.model_dump_json()) for chunk in batch_chunks]
cur.executemany(
f"""
INSERT INTO {self.metadata_table} (id, chunk)
INSERT INTO [{self.metadata_table}] (id, chunk)
VALUES (?, ?)
ON CONFLICT(id) DO UPDATE SET chunk = excluded.chunk;
""",
@ -220,7 +235,7 @@ class SQLiteVecIndex(EmbeddingIndex):
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
]
cur.executemany(
f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);",
f"INSERT INTO [{self.vector_table}] (id, embedding) VALUES (?, ?);",
embedding_data,
)
@ -228,13 +243,13 @@ class SQLiteVecIndex(EmbeddingIndex):
fts_data = [(chunk.chunk_id, chunk.content) for chunk in batch_chunks]
# DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT)
cur.executemany(
f"DELETE FROM {self.fts_table} WHERE id = ?;",
f"DELETE FROM [{self.fts_table}] WHERE id = ?;",
[(row[0],) for row in fts_data],
)
# INSERT new entries
cur.executemany(
f"INSERT INTO {self.fts_table} (id, content) VALUES (?, ?);",
f"INSERT INTO [{self.fts_table}] (id, content) VALUES (?, ?);",
fts_data,
)
@ -270,8 +285,8 @@ class SQLiteVecIndex(EmbeddingIndex):
emb_blob = serialize_vector(emb_list)
query_sql = f"""
SELECT m.id, m.chunk, v.distance
FROM {self.vector_table} AS v
JOIN {self.metadata_table} AS m ON m.id = v.id
FROM [{self.vector_table}] AS v
JOIN [{self.metadata_table}] AS m ON m.id = v.id
WHERE v.embedding MATCH ? AND k = ?
ORDER BY v.distance;
"""
@ -312,9 +327,9 @@ class SQLiteVecIndex(EmbeddingIndex):
cur = connection.cursor()
try:
query_sql = f"""
SELECT DISTINCT m.id, m.chunk, bm25({self.fts_table}) AS score
FROM {self.fts_table} AS f
JOIN {self.metadata_table} AS m ON m.id = f.id
SELECT DISTINCT m.id, m.chunk, bm25([{self.fts_table}]) AS score
FROM [{self.fts_table}] AS f
JOIN [{self.metadata_table}] AS m ON m.id = f.id
WHERE f.content MATCH ?
ORDER BY score ASC
LIMIT ?;
@ -425,27 +440,81 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
self.files_api = files_api
self.cache: dict[str, VectorDBWithIndex] = {}
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.kvstore: KVStore | None = None
async def initialize(self) -> None:
def _setup_connection():
# Open a connection to the SQLite database (the file is specified in the config).
self.kvstore = await kvstore_impl(self.config.kvstore)
start_key = VECTOR_DBS_PREFIX
end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
for db_json in stored_vector_dbs:
vector_db = VectorDB.model_validate_json(db_json)
index = await SQLiteVecIndex.create(
vector_db.embedding_dimension,
self.config.db_path,
vector_db.identifier,
)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
# Load existing OpenAI vector stores into the in-memory cache
await self.initialize_openai_vector_stores()
async def shutdown(self) -> None:
# nothing to do since we don't maintain a persistent connection
pass
async def list_vector_dbs(self) -> list[VectorDB]:
return [v.vector_db for v in self.cache.values()]
async def register_vector_db(self, vector_db: VectorDB) -> None:
index = await SQLiteVecIndex.create(
vector_db.embedding_dimension,
self.config.db_path,
vector_db.identifier,
)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
if vector_db_id in self.cache:
return self.cache[vector_db_id]
if self.vector_db_store is None:
raise ValueError(f"Vector DB {vector_db_id} not found")
vector_db = self.vector_db_store.get_vector_db(vector_db_id)
if not vector_db:
raise ValueError(f"Vector DB {vector_db_id} not found")
index = VectorDBWithIndex(
vector_db=vector_db,
index=SQLiteVecIndex(
dimension=vector_db.embedding_dimension,
db_path=self.config.db_path,
bank_id=vector_db.identifier,
kvstore=self.kvstore,
),
inference_api=self.inference_api,
)
self.cache[vector_db_id] = index
return index
async def unregister_vector_db(self, vector_db_id: str) -> None:
if vector_db_id not in self.cache:
logger.warning(f"Vector DB {vector_db_id} not found")
return
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
) -> None:
"""Save vector store file metadata to SQLite database."""
def _create_or_store():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
# Create a table to persist vector DB registrations.
cur.execute("""
CREATE TABLE IF NOT EXISTS vector_dbs (
id TEXT PRIMARY KEY,
metadata TEXT
);
""")
# Create a table to persist OpenAI vector stores.
cur.execute("""
CREATE TABLE IF NOT EXISTS openai_vector_stores (
id TEXT PRIMARY KEY,
metadata TEXT
);
""")
# Create a table to persist OpenAI vector store files.
cur.execute("""
CREATE TABLE IF NOT EXISTS openai_vector_store_files (
@ -464,168 +533,6 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
);
""")
connection.commit()
# Load any existing vector DB registrations.
cur.execute("SELECT metadata FROM vector_dbs")
vector_db_rows = cur.fetchall()
return vector_db_rows
finally:
cur.close()
connection.close()
vector_db_rows = await asyncio.to_thread(_setup_connection)
# Load existing vector DBs
for row in vector_db_rows:
vector_db_data = row[0]
vector_db = VectorDB.model_validate_json(vector_db_data)
index = await SQLiteVecIndex.create(
vector_db.embedding_dimension,
self.config.db_path,
vector_db.identifier,
)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
# Load existing OpenAI vector stores using the mixin method
self.openai_vector_stores = await self._load_openai_vector_stores()
async def shutdown(self) -> None:
# nothing to do since we don't maintain a persistent connection
pass
async def register_vector_db(self, vector_db: VectorDB) -> None:
def _register_db():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"INSERT OR REPLACE INTO vector_dbs (id, metadata) VALUES (?, ?)",
(vector_db.identifier, vector_db.model_dump_json()),
)
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_register_db)
index = await SQLiteVecIndex.create(
vector_db.embedding_dimension,
self.config.db_path,
vector_db.identifier,
)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
async def list_vector_dbs(self) -> list[VectorDB]:
return [v.vector_db for v in self.cache.values()]
async def unregister_vector_db(self, vector_db_id: str) -> None:
if vector_db_id not in self.cache:
logger.warning(f"Vector DB {vector_db_id} not found")
return
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
def _delete_vector_db_from_registry():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute("DELETE FROM vector_dbs WHERE id = ?", (vector_db_id,))
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_delete_vector_db_from_registry)
# OpenAI Vector Store Mixin abstract method implementations
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to SQLite database."""
def _store():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"INSERT OR REPLACE INTO openai_vector_stores (id, metadata) VALUES (?, ?)",
(store_id, json.dumps(store_info)),
)
connection.commit()
except Exception as e:
logger.error(f"Error saving openai vector store {store_id}: {e}")
raise
finally:
cur.close()
connection.close()
try:
await asyncio.to_thread(_store)
except Exception as e:
logger.error(f"Error saving openai vector store {store_id}: {e}")
raise
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
"""Load all vector store metadata from SQLite database."""
def _load():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute("SELECT metadata FROM openai_vector_stores")
rows = cur.fetchall()
return rows
finally:
cur.close()
connection.close()
rows = await asyncio.to_thread(_load)
stores = {}
for row in rows:
store_data = row[0]
store_info = json.loads(store_data)
stores[store_info["id"]] = store_info
return stores
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Update vector store metadata in SQLite database."""
def _update():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"UPDATE openai_vector_stores SET metadata = ? WHERE id = ?",
(json.dumps(store_info), store_id),
)
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_update)
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
"""Delete vector store metadata from SQLite database."""
def _delete():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute("DELETE FROM openai_vector_stores WHERE id = ?", (store_id,))
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_delete)
async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
) -> None:
"""Save vector store file metadata to SQLite database."""
def _store():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"INSERT OR REPLACE INTO openai_vector_store_files (store_id, file_id, metadata) VALUES (?, ?, ?)",
(store_id, file_id, json.dumps(file_info)),
@ -643,7 +550,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
connection.close()
try:
await asyncio.to_thread(_store)
await asyncio.to_thread(_create_or_store)
except Exception as e:
logger.error(f"Error saving openai vector store file {store_id} {file_id}: {e}")
raise
@ -722,6 +629,10 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
cur.execute(
"DELETE FROM openai_vector_store_files WHERE store_id = ? AND file_id = ?", (store_id, file_id)
)
cur.execute(
"DELETE FROM openai_vector_store_files_contents WHERE store_id = ? AND file_id = ?",
(store_id, file_id),
)
connection.commit()
finally:
cur.close()
@ -730,15 +641,17 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
await asyncio.to_thread(_delete)
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
if vector_db_id not in self.cache:
raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}")
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
# The VectorDBWithIndex helper is expected to compute embeddings via the inference_api
# and then call our index's add_chunks.
await self.cache[vector_db_id].insert_chunks(chunks)
await index.insert_chunks(chunks)
async def query_chunks(
self, vector_db_id: str, query: Any, params: dict[str, Any] | None = None
) -> QueryChunksResponse:
if vector_db_id not in self.cache:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
return await self.cache[vector_db_id].query_chunks(query, params)
return await index.query_chunks(query, params)

View file

@ -15,21 +15,26 @@ LLM_MODEL_IDS = [
"anthropic/claude-3-5-haiku-latest",
]
SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [
ProviderModelEntry(
provider_model_id="anthropic/voyage-3",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 1024, "context_length": 32000},
),
ProviderModelEntry(
provider_model_id="anthropic/voyage-3-lite",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 512, "context_length": 32000},
),
ProviderModelEntry(
provider_model_id="anthropic/voyage-code-3",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 1024, "context_length": 32000},
),
]
MODEL_ENTRIES = (
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
+ [
ProviderModelEntry(
provider_model_id="anthropic/voyage-3",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 1024, "context_length": 32000},
),
ProviderModelEntry(
provider_model_id="anthropic/voyage-3-lite",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 512, "context_length": 32000},
),
ProviderModelEntry(
provider_model_id="anthropic/voyage-code-3",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 1024, "context_length": 32000},
),
]
+ SAFETY_MODELS_ENTRIES
)

View file

@ -9,6 +9,10 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
SAFETY_MODELS_ENTRIES = []
# https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"meta.llama3-1-8b-instruct-v1:0",
@ -22,4 +26,4 @@ MODEL_ENTRIES = [
"meta.llama3-1-405b-instruct-v1:0",
CoreModelId.llama3_1_405b_instruct.value,
),
]
] + SAFETY_MODELS_ENTRIES

View file

@ -9,6 +9,9 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
SAFETY_MODELS_ENTRIES = []
# https://inference-docs.cerebras.ai/models
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"llama3.1-8b",
@ -18,4 +21,8 @@ MODEL_ENTRIES = [
"llama-3.3-70b",
CoreModelId.llama3_3_70b_instruct.value,
),
]
build_hf_repo_model_entry(
"llama-4-scout-17b-16e-instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
] + SAFETY_MODELS_ENTRIES

View file

@ -47,7 +47,10 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import DatabricksImplConfig
model_entries = [
SAFETY_MODELS_ENTRIES = []
# https://docs.databricks.com/aws/en/machine-learning/model-serving/foundation-model-overview
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"databricks-meta-llama-3-1-70b-instruct",
CoreModelId.llama3_1_70b_instruct.value,
@ -56,7 +59,7 @@ model_entries = [
"databricks-meta-llama-3-1-405b-instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
]
] + SAFETY_MODELS_ENTRIES
class DatabricksInferenceAdapter(
@ -66,7 +69,7 @@ class DatabricksInferenceAdapter(
OpenAICompletionToLlamaStackMixin,
):
def __init__(self, config: DatabricksImplConfig) -> None:
ModelRegistryHelper.__init__(self, model_entries=model_entries)
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
self.config = config
async def initialize(self) -> None:

View file

@ -11,6 +11,17 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
SAFETY_MODELS_ENTRIES = [
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-guard-3-8b",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-guard-3-11b-vision",
CoreModelId.llama_guard_3_11b_vision.value,
),
]
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p1-8b-instruct",
@ -40,14 +51,6 @@ MODEL_ENTRIES = [
"accounts/fireworks/models/llama-v3p3-70b-instruct",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-guard-3-8b",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-guard-3-11b-vision",
CoreModelId.llama_guard_3_11b_vision.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama4-scout-instruct-basic",
CoreModelId.llama4_scout_17b_16e_instruct.value,
@ -64,4 +67,4 @@ MODEL_ENTRIES = [
"context_length": 8192,
},
),
]
] + SAFETY_MODELS_ENTRIES

View file

@ -17,11 +17,16 @@ LLM_MODEL_IDS = [
"gemini/gemini-2.5-pro",
]
SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [
ProviderModelEntry(
provider_model_id="gemini/text-embedding-004",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 768, "context_length": 2048},
),
]
MODEL_ENTRIES = (
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
+ [
ProviderModelEntry(
provider_model_id="gemini/text-embedding-004",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 768, "context_length": 2048},
),
]
+ SAFETY_MODELS_ENTRIES
)

View file

@ -38,24 +38,18 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin):
provider_data_api_key_field="groq_api_key",
)
self.config = config
self._openai_client = None
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()
if self._openai_client:
await self._openai_client.close()
self._openai_client = None
def _get_openai_client(self) -> AsyncOpenAI:
if not self._openai_client:
self._openai_client = AsyncOpenAI(
base_url=f"{self.config.url}/openai/v1",
api_key=self.config.api_key,
)
return self._openai_client
return AsyncOpenAI(
base_url=f"{self.config.url}/openai/v1",
api_key=self.get_api_key(),
)
async def openai_chat_completion(
self,

View file

@ -10,6 +10,8 @@ from llama_stack.providers.utils.inference.model_registry import (
build_model_entry,
)
SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"groq/llama3-8b-8192",
@ -51,4 +53,4 @@ MODEL_ENTRIES = [
"groq/meta-llama/llama-4-maverick-17b-128e-instruct",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
]
] + SAFETY_MODELS_ENTRIES

View file

@ -3,16 +3,17 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from llama_stack.providers.remote.inference.llama_openai_compat.config import (
LlamaCompatConfig,
)
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
LiteLLMOpenAIMixin,
)
from llama_api_client import AsyncLlamaAPIClient, NotFoundError
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from .models import MODEL_ENTRIES
logger = logging.getLogger(__name__)
class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
_config: LlamaCompatConfig
@ -27,8 +28,32 @@ class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
)
self.config = config
async def check_model_availability(self, model: str) -> bool:
"""
Check if a specific model is available from Llama API.
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
"""
try:
llama_api_client = self._get_llama_api_client()
retrieved_model = await llama_api_client.models.retrieve(model)
logger.info(f"Model {retrieved_model.id} is available from Llama API")
return True
except NotFoundError:
logger.error(f"Model {model} is not available from Llama API")
return False
except Exception as e:
logger.error(f"Failed to check model availability from Llama API: {e}")
return False
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()
def _get_llama_api_client(self) -> AsyncLlamaAPIClient:
return AsyncLlamaAPIClient(api_key=self.get_api_key(), base_url=self.config.openai_compat_api_base)

View file

@ -11,6 +11,9 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
SAFETY_MODELS_ENTRIES = []
# https://docs.nvidia.com/nim/large-language-models/latest/supported-llm-agnostic-architectures.html
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"meta/llama3-8b-instruct",
@ -99,4 +102,4 @@ MODEL_ENTRIES = [
),
# TODO(mf): how do we handle Nemotron models?
# "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct",
]
] + SAFETY_MODELS_ENTRIES

View file

@ -7,10 +7,9 @@
import logging
import warnings
from collections.abc import AsyncIterator
from functools import lru_cache
from typing import Any
from openai import APIConnectionError, AsyncOpenAI, BadRequestError
from openai import APIConnectionError, AsyncOpenAI, BadRequestError, NotFoundError
from llama_stack.apis.common.content_types import (
InterleavedContent,
@ -41,11 +40,7 @@ from llama_stack.apis.inference import (
ToolChoice,
ToolConfig,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
from llama_stack.providers.utils.inference import (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
)
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
@ -93,41 +88,37 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
self._config = config
@lru_cache # noqa: B019
def _get_client(self, provider_model_id: str) -> AsyncOpenAI:
async def check_model_availability(self, model: str) -> bool:
"""
For hosted models, https://integrate.api.nvidia.com/v1 is the primary base_url. However,
some models are hosted on different URLs. This function returns the appropriate client
for the given provider_model_id.
Check if a specific model is available.
This relies on lru_cache and self._default_client to avoid creating a new client for each request
or for each model that is hosted on https://integrate.api.nvidia.com/v1.
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
"""
try:
await self._client.models.retrieve(model)
return True
except NotFoundError:
logger.error(f"Model {model} is not available")
except Exception as e:
logger.error(f"Failed to check model availability: {e}")
return False
@property
def _client(self) -> AsyncOpenAI:
"""
Returns an OpenAI client for the configured NVIDIA API endpoint.
:param provider_model_id: The provider model ID
:return: An OpenAI client
"""
@lru_cache # noqa: B019
def _get_client_for_base_url(base_url: str) -> AsyncOpenAI:
"""
Maintain a single OpenAI client per base_url.
"""
return AsyncOpenAI(
base_url=base_url,
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
timeout=self._config.timeout,
)
special_model_urls = {
"meta/llama-3.2-11b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-11b-vision-instruct",
"meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct",
}
base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls:
base_url = special_model_urls[provider_model_id]
return _get_client_for_base_url(base_url)
return AsyncOpenAI(
base_url=base_url,
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
timeout=self._config.timeout,
)
async def _get_provider_model_id(self, model_id: str) -> str:
if not self.model_store:
@ -169,7 +160,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
)
try:
response = await self._get_client(provider_model_id).completions.create(**request)
response = await self._client.completions.create(**request)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
@ -222,7 +213,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
extra_body["input_type"] = task_type_options[task_type]
try:
response = await self._get_client(provider_model_id).embeddings.create(
response = await self._client.embeddings.create(
model=provider_model_id,
input=input,
extra_body=extra_body,
@ -283,7 +274,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
)
try:
response = await self._get_client(provider_model_id).chat.completions.create(**request)
response = await self._client.chat.completions.create(**request)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
@ -339,7 +330,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
)
try:
return await self._get_client(provider_model_id).completions.create(**params)
return await self._client.completions.create(**params)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
@ -398,47 +389,6 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
)
try:
return await self._get_client(provider_model_id).chat.completions.create(**params)
return await self._client.chat.completions.create(**params)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
async def register_model(self, model: Model) -> Model:
"""
Allow non-llama model registration.
Non-llama model registration: API Catalogue models, post-training models, etc.
client = LlamaStackAsLibraryClient("nvidia")
client.models.register(
model_id="mistralai/mixtral-8x7b-instruct-v0.1",
model_type=ModelType.llm,
provider_id="nvidia",
provider_model_id="mistralai/mixtral-8x7b-instruct-v0.1"
)
NOTE: Only supports models endpoints compatible with AsyncOpenAI base_url format.
"""
if model.model_type == ModelType.embedding:
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
provider_resource_id = model.provider_resource_id
else:
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
if provider_resource_id:
model.provider_resource_id = provider_resource_id
else:
llama_model = model.metadata.get("llama_model")
existing_llama_model = self.get_llama_model(model.provider_resource_id)
if existing_llama_model:
if existing_llama_model != llama_model:
raise ValueError(
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
)
else:
# not llama model
if llama_model in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
)
else:
self.alias_to_provider_id_map[model.provider_model_id] = model.provider_model_id
return model

View file

@ -12,6 +12,19 @@ from llama_stack.providers.utils.inference.model_registry import (
build_model_entry,
)
SAFETY_MODELS_ENTRIES = [
# The Llama Guard models don't have their full fp16 versions
# so we are going to alias their default version to the canonical SKU
build_hf_repo_model_entry(
"llama-guard3:8b",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"llama-guard3:1b",
CoreModelId.llama_guard_3_1b.value,
),
]
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"llama3.1:8b-instruct-fp16",
@ -73,18 +86,8 @@ MODEL_ENTRIES = [
"llama3.3:70b",
CoreModelId.llama3_3_70b_instruct.value,
),
# The Llama Guard models don't have their full fp16 versions
# so we are going to alias their default version to the canonical SKU
build_hf_repo_model_entry(
"llama-guard3:8b",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"llama-guard3:1b",
CoreModelId.llama_guard_3_1b.value,
),
ProviderModelEntry(
provider_model_id="all-minilm:latest",
provider_model_id="all-minilm:l6-v2",
aliases=["all-minilm"],
model_type=ModelType.embedding,
metadata={
@ -100,4 +103,4 @@ MODEL_ENTRIES = [
"context_length": 8192,
},
),
]
] + SAFETY_MODELS_ENTRIES

View file

@ -48,16 +48,20 @@ EMBEDDING_MODEL_IDS: dict[str, EmbeddingModelInfo] = {
"text-embedding-3-small": EmbeddingModelInfo(1536, 8192),
"text-embedding-3-large": EmbeddingModelInfo(3072, 8192),
}
SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [
ProviderModelEntry(
provider_model_id=model_id,
model_type=ModelType.embedding,
metadata={
"embedding_dimension": model_info.embedding_dimension,
"context_length": model_info.context_length,
},
)
for model_id, model_info in EMBEDDING_MODEL_IDS.items()
]
MODEL_ENTRIES = (
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
+ [
ProviderModelEntry(
provider_model_id=model_id,
model_type=ModelType.embedding,
metadata={
"embedding_dimension": model_info.embedding_dimension,
"context_length": model_info.context_length,
},
)
for model_id, model_info in EMBEDDING_MODEL_IDS.items()
]
+ SAFETY_MODELS_ENTRIES
)

View file

@ -8,7 +8,7 @@ import logging
from collections.abc import AsyncIterator
from typing import Any
from openai import AsyncOpenAI
from openai import AsyncOpenAI, NotFoundError
from llama_stack.apis.inference import (
OpenAIChatCompletion,
@ -59,9 +59,27 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
# if we do not set this, users will be exposed to the
# litellm specific model names, an abstraction leak.
self.is_openai_compat = True
self._openai_client = AsyncOpenAI(
api_key=self.config.api_key,
)
async def check_model_availability(self, model: str) -> bool:
"""
Check if a specific model is available from OpenAI.
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
"""
try:
openai_client = self._get_openai_client()
retrieved_model = await openai_client.models.retrieve(model)
logger.info(f"Model {retrieved_model.id} is available from OpenAI")
return True
except NotFoundError:
logger.error(f"Model {model} is not available from OpenAI")
return False
except Exception as e:
logger.error(f"Failed to check model availability from OpenAI: {e}")
return False
async def initialize(self) -> None:
await super().initialize()
@ -69,6 +87,11 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
async def shutdown(self) -> None:
await super().shutdown()
def _get_openai_client(self) -> AsyncOpenAI:
return AsyncOpenAI(
api_key=self.get_api_key(),
)
async def openai_completion(
self,
model: str,
@ -120,7 +143,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
user=user,
suffix=suffix,
)
return await self._openai_client.completions.create(**params)
return await self._get_openai_client().completions.create(**params)
async def openai_chat_completion(
self,
@ -176,7 +199,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
top_p=top_p,
user=user,
)
return await self._openai_client.chat.completions.create(**params)
return await self._get_openai_client().chat.completions.create(**params)
async def openai_embeddings(
self,
@ -204,7 +227,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
params["user"] = user
# Call OpenAI embeddings API
response = await self._openai_client.embeddings.create(**params)
response = await self._get_openai_client().embeddings.create(**params)
data = []
for i, embedding_data in enumerate(response.data):

View file

@ -11,7 +11,7 @@ from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, build_hf_repo_model_entry
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
@ -25,6 +25,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import RunpodImplConfig
# https://docs.runpod.io/serverless/vllm/overview#compatible-models
# https://github.com/runpod-workers/worker-vllm/blob/main/README.md#compatible-model-architectures
RUNPOD_SUPPORTED_MODELS = {
"Llama3.1-8B": "meta-llama/Llama-3.1-8B",
"Llama3.1-70B": "meta-llama/Llama-3.1-70B",
@ -40,6 +42,14 @@ RUNPOD_SUPPORTED_MODELS = {
"Llama3.2-3B": "meta-llama/Llama-3.2-3B",
}
SAFETY_MODELS_ENTRIES = []
# Create MODEL_ENTRIES from RUNPOD_SUPPORTED_MODELS for compatibility with starter template
MODEL_ENTRIES = [
build_hf_repo_model_entry(provider_model_id, model_descriptor)
for provider_model_id, model_descriptor in RUNPOD_SUPPORTED_MODELS.items()
] + SAFETY_MODELS_ENTRIES
class RunpodInferenceAdapter(
ModelRegistryHelper,
@ -61,25 +71,25 @@ class RunpodInferenceAdapter(
self,
model: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> AsyncGenerator:
raise NotImplementedError()
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
@ -129,10 +139,10 @@ class RunpodInferenceAdapter(
async def embeddings(
self,
model: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -9,6 +9,14 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
SAFETY_MODELS_ENTRIES = [
build_hf_repo_model_entry(
"sambanova/Meta-Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b.value,
),
]
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"sambanova/Meta-Llama-3.1-8B-Instruct",
@ -46,8 +54,4 @@ MODEL_ENTRIES = [
"sambanova/Llama-4-Maverick-17B-128E-Instruct",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
build_hf_repo_model_entry(
"sambanova/Meta-Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b.value,
),
]
] + SAFETY_MODELS_ENTRIES

View file

@ -7,6 +7,7 @@
import json
from collections.abc import Iterable
import requests
from openai.types.chat import (
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
)
@ -56,6 +57,7 @@ from llama_stack.apis.inference import (
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.models import Model
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import BuiltinTool
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
@ -176,10 +178,11 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
def __init__(self, config: SambaNovaImplConfig):
self.config = config
self.environment_available_models = []
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
api_key_from_config=self.config.api_key,
api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None,
provider_data_api_key_field="sambanova_api_key",
)
@ -246,6 +249,22 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
**get_sampling_options(request.sampling_params),
}
async def register_model(self, model: Model) -> Model:
model_id = self.get_provider_model_id(model.provider_resource_id)
list_models_url = self.config.url + "/models"
if len(self.environment_available_models) == 0:
try:
response = requests.get(list_models_url)
response.raise_for_status()
except requests.exceptions.RequestException as e:
raise RuntimeError(f"Request to {list_models_url} failed") from e
self.environment_available_models = [model.get("id") for model in response.json().get("data", {})]
if model_id.split("sambanova/")[-1] not in self.environment_available_models:
logger.warning(f"Model {model_id} not available in {list_models_url}")
return model
async def initialize(self):
await super().initialize()

View file

@ -11,6 +11,16 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
SAFETY_MODELS_ENTRIES = [
build_hf_repo_model_entry(
"meta-llama/Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-Guard-3-11B-Vision-Turbo",
CoreModelId.llama_guard_3_11b_vision.value,
),
]
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
@ -40,14 +50,6 @@ MODEL_ENTRIES = [
"meta-llama/Llama-3.3-70B-Instruct-Turbo",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Meta-Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-Guard-3-11B-Vision-Turbo",
CoreModelId.llama_guard_3_11b_vision.value,
),
ProviderModelEntry(
provider_model_id="togethercomputer/m2-bert-80M-8k-retrieval",
model_type=ModelType.embedding,
@ -78,4 +80,4 @@ MODEL_ENTRIES = [
"together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
],
),
]
] + SAFETY_MODELS_ENTRIES

View file

@ -68,19 +68,12 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
def __init__(self, config: TogetherImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
self.config = config
self._client = None
self._openai_client = None
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
if self._client:
# Together client has no close method, so just set to None
self._client = None
if self._openai_client:
await self._openai_client.close()
self._openai_client = None
pass
async def completion(
self,
@ -108,29 +101,25 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
return await self._nonstream_completion(request)
def _get_client(self) -> AsyncTogether:
if not self._client:
together_api_key = None
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
if config_api_key:
together_api_key = config_api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
self._client = AsyncTogether(api_key=together_api_key)
return self._client
together_api_key = None
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
if config_api_key:
together_api_key = config_api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
return AsyncTogether(api_key=together_api_key)
def _get_openai_client(self) -> AsyncOpenAI:
if not self._openai_client:
together_client = self._get_client().client
self._openai_client = AsyncOpenAI(
base_url=together_client.base_url,
api_key=together_client.api_key,
)
return self._openai_client
together_client = self._get_client().client
return AsyncOpenAI(
base_url=together_client.base_url,
api_key=together_client.api_key,
)
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)

View file

@ -33,6 +33,7 @@ CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProviderData):
def __init__(self, config: SambaNovaSafetyConfig) -> None:
self.config = config
self.environment_available_models = []
async def initialize(self) -> None:
pass
@ -54,18 +55,18 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
async def register_shield(self, shield: Shield) -> None:
list_models_url = self.config.url + "/models"
try:
response = requests.get(list_models_url)
response.raise_for_status()
except requests.exceptions.RequestException as e:
raise RuntimeError(f"Request to {list_models_url} failed") from e
available_models = [model.get("id") for model in response.json().get("data", {})]
if len(self.environment_available_models) == 0:
try:
response = requests.get(list_models_url)
response.raise_for_status()
except requests.exceptions.RequestException as e:
raise RuntimeError(f"Request to {list_models_url} failed") from e
self.environment_available_models = [model.get("id") for model in response.json().get("data", {})]
if (
len(available_models) == 0
or "guard" not in shield.provider_resource_id.lower()
or shield.provider_resource_id.split("sambanova/")[-1] not in available_models
"guard" not in shield.provider_resource_id.lower()
or shield.provider_resource_id.split("sambanova/")[-1] not in self.environment_available_models
):
raise ValueError(f"Shield {shield.provider_resource_id} not found in SambaNova")
logger.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}")
async def run_shield(
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None

View file

@ -217,7 +217,6 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
embedding_model: str | None = None,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")

View file

@ -14,6 +14,6 @@ async def get_adapter_impl(config: MilvusVectorIOConfig, deps: dict[Api, Provide
assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = MilvusVectorIOAdapter(config, deps[Api.inference])
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
await impl.initialize()
return impl

View file

@ -8,6 +8,7 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.schema_utils import json_schema_type
@ -16,6 +17,7 @@ class MilvusVectorIOConfig(BaseModel):
uri: str = Field(description="The URI of the Milvus server")
token: str | None = Field(description="The token of the Milvus server")
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
kvstore: KVStoreConfig = Field(description="Config for KV store backend")
# This configuration allows additional fields to be passed through to the underlying Milvus client.
# See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general.
@ -23,4 +25,11 @@ class MilvusVectorIOConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {"uri": "${env.MILVUS_ENDPOINT}", "token": "${env.MILVUS_TOKEN}"}
return {
"uri": "${env.MILVUS_ENDPOINT}",
"token": "${env.MILVUS_TOKEN}",
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="milvus_remote_registry.db",
),
}

View file

@ -12,7 +12,7 @@ import re
from typing import Any
from numpy.typing import NDArray
from pymilvus import DataType, MilvusClient
from pymilvus import DataType, Function, FunctionType, MilvusClient
from llama_stack.apis.files.files import Files
from llama_stack.apis.inference import Inference, InterleavedContent
@ -61,6 +61,11 @@ class MilvusIndex(EmbeddingIndex):
self.consistency_level = consistency_level
self.kvstore = kvstore
async def initialize(self):
# MilvusIndex does not require explicit initialization
# TODO: could move collection creation into initialization but it is not really necessary
pass
async def delete(self):
if await asyncio.to_thread(self.client.has_collection, self.collection_name):
await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name)
@ -69,12 +74,66 @@ class MilvusIndex(EmbeddingIndex):
assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
)
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
# Create schema for vector search
schema = self.client.create_schema()
schema.add_field(
field_name="chunk_id",
datatype=DataType.VARCHAR,
is_primary=True,
max_length=100,
)
schema.add_field(
field_name="content",
datatype=DataType.VARCHAR,
max_length=65535,
enable_analyzer=True, # Enable text analysis for BM25
)
schema.add_field(
field_name="vector",
datatype=DataType.FLOAT_VECTOR,
dim=len(embeddings[0]),
)
schema.add_field(
field_name="chunk_content",
datatype=DataType.JSON,
)
# Add sparse vector field for BM25 (required by the function)
schema.add_field(
field_name="sparse",
datatype=DataType.SPARSE_FLOAT_VECTOR,
)
# Create indexes
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="vector",
index_type="FLAT",
metric_type="COSINE",
)
# Add index for sparse field (required by BM25 function)
index_params.add_index(
field_name="sparse",
index_type="SPARSE_INVERTED_INDEX",
metric_type="BM25",
)
# Add BM25 function for full-text search
bm25_function = Function(
name="text_bm25_emb",
input_field_names=["content"],
output_field_names=["sparse"],
function_type=FunctionType.BM25,
)
schema.add_function(bm25_function)
await asyncio.to_thread(
self.client.create_collection,
self.collection_name,
dimension=len(embeddings[0]),
auto_id=True,
schema=schema,
index_params=index_params,
consistency_level=self.consistency_level,
)
@ -83,8 +142,10 @@ class MilvusIndex(EmbeddingIndex):
data.append(
{
"chunk_id": chunk.chunk_id,
"content": chunk.content,
"vector": embedding,
"chunk_content": chunk.model_dump(),
# sparse field will be handled by BM25 function automatically
}
)
try:
@ -102,6 +163,7 @@ class MilvusIndex(EmbeddingIndex):
self.client.search,
collection_name=self.collection_name,
data=[embedding],
anns_field="vector",
limit=k,
output_fields=["*"],
search_params={"params": {"radius": score_threshold}},
@ -116,7 +178,64 @@ class MilvusIndex(EmbeddingIndex):
k: int,
score_threshold: float,
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Milvus")
"""
Perform BM25-based keyword search using Milvus's built-in full-text search.
"""
try:
# Use Milvus's built-in BM25 search
search_res = await asyncio.to_thread(
self.client.search,
collection_name=self.collection_name,
data=[query_string], # Raw text query
anns_field="sparse", # Use sparse field for BM25
output_fields=["chunk_content"], # Output the chunk content
limit=k,
search_params={
"params": {
"drop_ratio_search": 0.2, # Ignore low-importance terms
}
},
)
chunks = []
scores = []
for res in search_res[0]:
chunk = Chunk(**res["entity"]["chunk_content"])
chunks.append(chunk)
scores.append(res["distance"]) # BM25 score from Milvus
# Filter by score threshold
filtered_chunks = [chunk for chunk, score in zip(chunks, scores, strict=False) if score >= score_threshold]
filtered_scores = [score for score in scores if score >= score_threshold]
return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores)
except Exception as e:
logger.error(f"Error performing BM25 search: {e}")
# Fallback to simple text search
return await self._fallback_keyword_search(query_string, k, score_threshold)
async def _fallback_keyword_search(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
"""
Fallback to simple text search when BM25 search is not available.
"""
# Simple text search using content field
search_res = await asyncio.to_thread(
self.client.query,
collection_name=self.collection_name,
filter='content like "%{content}%"',
filter_params={"content": query_string},
output_fields=["*"],
limit=k,
)
chunks = [Chunk(**res["chunk_content"]) for res in search_res]
scores = [1.0] * len(chunks) # Simple binary score for text search
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_hybrid(
self,
@ -154,10 +273,10 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
for vector_db_data in stored_vector_dbs:
vector_db = VectorDB.mdel_validate_json(vector_db_data)
vector_db = VectorDB.model_validate_json(vector_db_data)
index = VectorDBWithIndex(
vector_db,
index=await MilvusIndex(
index=MilvusIndex(
client=self.client,
collection_name=vector_db.identifier,
consistency_level=self.config.consistency_level,
@ -174,7 +293,8 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
uri = os.path.expanduser(self.config.db_path)
self.client = MilvusClient(uri=uri)
self.openai_vector_stores = await self._load_openai_vector_stores()
# Load existing OpenAI vector stores into the in-memory cache
await self.initialize_openai_vector_stores()
async def shutdown(self) -> None:
self.client.close()
@ -199,6 +319,9 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
if vector_db_id in self.cache:
return self.cache[vector_db_id]
if self.vector_db_store is None:
raise ValueError(f"Vector DB {vector_db_id} not found")
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
if not vector_db:
raise ValueError(f"Vector DB {vector_db_id} not found")
@ -238,36 +361,16 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
if params and params.get("mode") == "keyword":
# Check if this is inline Milvus (Milvus-Lite)
if hasattr(self.config, "db_path"):
raise NotImplementedError(
"Keyword search is not supported in Milvus-Lite. "
"Please use a remote Milvus server for keyword search functionality."
)
return await index.query_chunks(query, params)
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to persistent storage."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
self.openai_vector_stores[store_id] = store_info
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Update vector store metadata in persistent storage."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
self.openai_vector_stores[store_id] = store_info
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
"""Delete vector store metadata from persistent storage."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.delete(key)
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
"""Load all vector store metadata from persistent storage."""
assert self.kvstore is not None
start_key = OPENAI_VECTOR_STORES_PREFIX
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
stored = await self.kvstore.values_in_range(start_key, end_key)
return {json.loads(s)["id"]: json.loads(s) for s in stored}
async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
) -> None:
@ -377,6 +480,29 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
logger.error(f"Error loading openai vector store file {file_id} for store {store_id}: {e}")
return {}
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
"""Update vector store file metadata in Milvus database."""
try:
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"):
return
file_data = [
{
"store_file_id": f"{store_id}_{file_id}",
"store_id": store_id,
"file_id": file_id,
"file_info": json.dumps(file_info),
}
]
await asyncio.to_thread(
self.client.upsert,
collection_name="openai_vector_store_files",
data=file_data,
)
except Exception as e:
logger.error(f"Error updating openai vector store file {file_id} for store {store_id}: {e}")
raise
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
"""Load vector store file contents from Milvus database."""
try:
@ -405,29 +531,6 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
logger.error(f"Error loading openai vector store file contents for {file_id} in store {store_id}: {e}")
return []
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
"""Update vector store file metadata in Milvus database."""
try:
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"):
return
file_data = [
{
"store_file_id": f"{store_id}_{file_id}",
"store_id": store_id,
"file_id": file_id,
"file_info": json.dumps(file_info),
}
]
await asyncio.to_thread(
self.client.upsert,
collection_name="openai_vector_store_files",
data=file_data,
)
except Exception as e:
logger.error(f"Error updating openai vector store file {file_id} for store {store_id}: {e}")
raise
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
"""Delete vector store file metadata from Milvus database."""
try:

View file

@ -8,6 +8,10 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from llama_stack.schema_utils import json_schema_type
@ -18,10 +22,12 @@ class PGVectorVectorIOConfig(BaseModel):
db: str | None = Field(default="postgres")
user: str | None = Field(default="postgres")
password: str | None = Field(default="mysecretpassword")
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
@classmethod
def sample_run_config(
cls,
__distro_dir__: str,
host: str = "${env.PGVECTOR_HOST:=localhost}",
port: int = "${env.PGVECTOR_PORT:=5432}",
db: str = "${env.PGVECTOR_DB}",
@ -29,4 +35,14 @@ class PGVectorVectorIOConfig(BaseModel):
password: str = "${env.PGVECTOR_PASSWORD}",
**kwargs: Any,
) -> dict[str, Any]:
return {"host": host, "port": port, "db": db, "user": user, "password": password}
return {
"host": host,
"port": port,
"db": db,
"user": user,
"password": password,
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="pgvector_registry.db",
),
}

View file

@ -13,24 +13,18 @@ from psycopg2 import sql
from psycopg2.extras import Json, execute_values
from pydantic import BaseModel, TypeAdapter
from llama_stack.apis.files.files import Files
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
SearchRankingOptions,
VectorIO,
VectorStoreChunkingStrategy,
VectorStoreDeleteResponse,
VectorStoreFileContentsResponse,
VectorStoreFileObject,
VectorStoreFileStatus,
VectorStoreListFilesResponse,
VectorStoreListResponse,
VectorStoreObject,
VectorStoreSearchResponsePage,
)
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex,
VectorDBWithIndex,
@ -40,6 +34,13 @@ from .config import PGVectorVectorIOConfig
log = logging.getLogger(__name__)
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::"
VECTOR_INDEX_PREFIX = f"vector_index:pgvector:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:pgvector:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:pgvector:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:pgvector:{VERSION}::"
def check_extension_version(cur):
cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'")
@ -69,7 +70,7 @@ def load_models(cur, cls):
class PGVectorIndex(EmbeddingIndex):
def __init__(self, vector_db: VectorDB, dimension: int, conn):
def __init__(self, vector_db: VectorDB, dimension: int, conn, kvstore: KVStore | None = None):
self.conn = conn
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
# Sanitize the table name by replacing hyphens with underscores
@ -77,6 +78,7 @@ class PGVectorIndex(EmbeddingIndex):
# when created with patterns like "test-vector-db-{uuid4()}"
sanitized_identifier = vector_db.identifier.replace("-", "_")
self.table_name = f"vector_store_{sanitized_identifier}"
self.kvstore = kvstore
cur.execute(
f"""
@ -158,15 +160,28 @@ class PGVectorIndex(EmbeddingIndex):
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
class PGVectorVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: PGVectorVectorIOConfig, inference_api: Api.inference) -> None:
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
def __init__(
self,
config: PGVectorVectorIOConfig,
inference_api: Api.inference,
files_api: Files | None = None,
) -> None:
self.config = config
self.inference_api = inference_api
self.conn = None
self.cache = {}
self.files_api = files_api
self.kvstore: KVStore | None = None
self.vector_db_store = None
self.openai_vector_store: dict[str, dict[str, Any]] = {}
self.metadatadata_collection_name = "openai_vector_stores_metadata"
async def initialize(self) -> None:
log.info(f"Initializing PGVector memory adapter with config: {self.config}")
self.kvstore = await kvstore_impl(self.config.kvstore)
await self.initialize_openai_vector_stores()
try:
self.conn = psycopg2.connect(
host=self.config.host,
@ -201,14 +216,28 @@ class PGVectorVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
log.info("Connection to PGVector database server closed")
async def register_vector_db(self, vector_db: VectorDB) -> None:
# Persist vector DB metadata in the KV store
assert self.kvstore is not None
# Upsert model metadata in Postgres
upsert_models(self.conn, [(vector_db.identifier, vector_db)])
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
# Create and cache the PGVector index table for the vector DB
index = VectorDBWithIndex(
vector_db,
index=PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn, kvstore=self.kvstore),
inference_api=self.inference_api,
)
self.cache[vector_db.identifier] = index
async def unregister_vector_db(self, vector_db_id: str) -> None:
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
# Remove provider index and cache
if vector_db_id in self.cache:
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
# Delete vector DB metadata from KV store
assert self.kvstore is not None
await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_db_id}")
async def insert_chunks(
self,
@ -237,107 +266,124 @@ class PGVectorVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api)
return self.cache[vector_db_id]
async def openai_create_vector_store(
self,
name: str,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
embedding_model: str | None = None,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
# OpenAI Vector Stores File operations are not supported in PGVector
async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
) -> None:
"""Save vector store file metadata to Postgres database."""
if self.conn is None:
raise RuntimeError("PostgreSQL connection is not initialized")
try:
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(
"""
CREATE TABLE IF NOT EXISTS openai_vector_store_files (
store_id TEXT,
file_id TEXT,
metadata JSONB,
PRIMARY KEY (store_id, file_id)
)
"""
)
cur.execute(
"""
CREATE TABLE IF NOT EXISTS openai_vector_store_files_contents (
store_id TEXT,
file_id TEXT,
contents JSONB,
PRIMARY KEY (store_id, file_id)
)
"""
)
# Insert file metadata
files_query = sql.SQL(
"""
INSERT INTO openai_vector_store_files (store_id, file_id, metadata)
VALUES %s
ON CONFLICT (store_id, file_id) DO UPDATE SET metadata = EXCLUDED.metadata
"""
)
files_values = [(store_id, file_id, Json(file_info))]
execute_values(cur, files_query, files_values, template="(%s, %s, %s)")
# Insert file contents
contents_query = sql.SQL(
"""
INSERT INTO openai_vector_store_files_contents (store_id, file_id, contents)
VALUES %s
ON CONFLICT (store_id, file_id) DO UPDATE SET contents = EXCLUDED.contents
"""
)
contents_values = [(store_id, file_id, Json(file_contents))]
execute_values(cur, contents_query, contents_values, template="(%s, %s, %s)")
except Exception as e:
log.error(f"Error saving openai vector store file {file_id} for store {store_id}: {e}")
raise
async def openai_list_vector_stores(
self,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
"""Load vector store file metadata from Postgres database."""
if self.conn is None:
raise RuntimeError("PostgreSQL connection is not initialized")
try:
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(
"SELECT metadata FROM openai_vector_store_files WHERE store_id = %s AND file_id = %s",
(store_id, file_id),
)
row = cur.fetchone()
return row[0] if row and row[0] is not None else {}
except Exception as e:
log.error(f"Error loading openai vector store file {file_id} for store {store_id}: {e}")
return {}
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
"""Load vector store file contents from Postgres database."""
if self.conn is None:
raise RuntimeError("PostgreSQL connection is not initialized")
try:
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(
"SELECT contents FROM openai_vector_store_files_contents WHERE store_id = %s AND file_id = %s",
(store_id, file_id),
)
row = cur.fetchone()
return row[0] if row and row[0] is not None else []
except Exception as e:
log.error(f"Error loading openai vector store file contents for {file_id} in store {store_id}: {e}")
return []
async def openai_update_vector_store(
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
"""Update vector store file metadata in Postgres database."""
if self.conn is None:
raise RuntimeError("PostgreSQL connection is not initialized")
try:
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
query = sql.SQL(
"""
INSERT INTO openai_vector_store_files (store_id, file_id, metadata)
VALUES %s
ON CONFLICT (store_id, file_id) DO UPDATE SET metadata = EXCLUDED.metadata
"""
)
values = [(store_id, file_id, Json(file_info))]
execute_values(cur, query, values, template="(%s, %s, %s)")
except Exception as e:
log.error(f"Error updating openai vector store file {file_id} for store {store_id}: {e}")
raise
async def openai_delete_vector_store(
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage:
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
async def openai_attach_file_to_vector_store(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
async def openai_list_files_in_vector_store(
self,
vector_store_id: str,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
filter: VectorStoreFileStatus | None = None,
) -> VectorStoreListFilesResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
async def openai_retrieve_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
async def openai_retrieve_vector_store_file_contents(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
async def openai_update_vector_store_file(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any] | None = None,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
async def openai_delete_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
"""Delete vector store file metadata from Postgres database."""
if self.conn is None:
raise RuntimeError("PostgreSQL connection is not initialized")
try:
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(
"DELETE FROM openai_vector_store_files WHERE store_id = %s AND file_id = %s",
(store_id, file_id),
)
cur.execute(
"DELETE FROM openai_vector_store_files_contents WHERE store_id = %s AND file_id = %s",
(store_id, file_id),
)
except Exception as e:
log.error(f"Error deleting openai vector store file {file_id} for store {store_id}: {e}")
raise

View file

@ -214,7 +214,6 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
embedding_model: str | None = None,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")

View file

@ -6,15 +6,26 @@
from typing import Any
from pydantic import BaseModel
from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
class WeaviateRequestProviderData(BaseModel):
weaviate_api_key: str
weaviate_cluster_url: str
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
class WeaviateVectorIOConfig(BaseModel):
@classmethod
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
return {}
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="weaviate_registry.db",
),
}

View file

@ -14,10 +14,13 @@ from weaviate.classes.init import Auth
from weaviate.classes.query import Filter
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.files.files import Files
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex,
VectorDBWithIndex,
@ -27,11 +30,19 @@ from .config import WeaviateRequestProviderData, WeaviateVectorIOConfig
log = logging.getLogger(__name__)
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"
VECTOR_INDEX_PREFIX = f"vector_index:weaviate:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:weaviate:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:weaviate:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:weaviate:{VERSION}::"
class WeaviateIndex(EmbeddingIndex):
def __init__(self, client: weaviate.Client, collection_name: str):
def __init__(self, client: weaviate.Client, collection_name: str, kvstore: KVStore | None = None):
self.client = client
self.collection_name = collection_name
self.kvstore = kvstore
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), (
@ -109,11 +120,21 @@ class WeaviateVectorIOAdapter(
NeedsRequestProviderData,
VectorDBsProtocolPrivate,
):
def __init__(self, config: WeaviateVectorIOConfig, inference_api: Api.inference) -> None:
def __init__(
self,
config: WeaviateVectorIOConfig,
inference_api: Api.inference,
files_api: Files | None,
) -> None:
self.config = config
self.inference_api = inference_api
self.client_cache = {}
self.cache = {}
self.files_api = files_api
self.kvstore: KVStore | None = None
self.vector_db_store = None
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.metadata_collection_name = "openai_vector_stores_metadata"
def _get_client(self) -> weaviate.Client:
provider_data = self.get_request_provider_data()
@ -132,7 +153,26 @@ class WeaviateVectorIOAdapter(
return client
async def initialize(self) -> None:
pass
"""Set up KV store and load existing vector DBs and OpenAI vector stores."""
# Initialize KV store for metadata
self.kvstore = await kvstore_impl(self.config.kvstore)
# Load existing vector DB definitions
start_key = VECTOR_DBS_PREFIX
end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored = await self.kvstore.values_in_range(start_key, end_key)
for raw in stored:
vector_db = VectorDB.model_validate_json(raw)
client = self._get_client()
idx = WeaviateIndex(client=client, collection_name=vector_db.identifier, kvstore=self.kvstore)
self.cache[vector_db.identifier] = VectorDBWithIndex(
vector_db=vector_db,
index=idx,
inference_api=self.inference_api,
)
# Load OpenAI vector stores metadata into cache
await self.initialize_openai_vector_stores()
async def shutdown(self) -> None:
for client in self.client_cache.values():
@ -206,3 +246,21 @@ class WeaviateVectorIOAdapter(
raise ValueError(f"Vector DB {vector_db_id} not found")
return await index.query_chunks(query, params)
# OpenAI Vector Stores File operations are not supported in Weaviate
async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
) -> None:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")

View file

@ -13,7 +13,6 @@ from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.common.errors import UnsupportedModelError
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
@ -39,7 +38,6 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
@ -90,12 +88,6 @@ class LiteLLMOpenAIMixin(
async def shutdown(self):
pass
async def register_model(self, model: Model) -> Model:
model_id = self.get_provider_model_id(model.provider_resource_id)
if model_id is None:
raise UnsupportedModelError(model.provider_resource_id, self.alias_to_provider_id_map.keys())
return model
def get_litellm_model_name(self, model_id: str) -> str:
# users may be using openai/ prefix in their model names. the openai/models.py did this by default.
# model_id.startswith("openai/") is for backwards compatibility.

View file

@ -44,6 +44,7 @@ def build_hf_repo_model_entry(
]
if additional_aliases:
aliases.extend(additional_aliases)
aliases = [alias for alias in aliases if alias is not None]
return ProviderModelEntry(
provider_model_id=provider_model_id,
aliases=aliases,
@ -82,15 +83,43 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
def get_llama_model(self, provider_model_id: str) -> str | None:
return self.provider_id_to_llama_model_map.get(provider_model_id, None)
async def check_model_availability(self, model: str) -> bool:
"""
Check if a specific model is available from the provider (non-static check).
This is for subclassing purposes, so providers can check if a specific
model is currently available for use through dynamic means (e.g., API calls).
This method should NOT check statically configured model entries in
`self.alias_to_provider_id_map` - that is handled separately in register_model.
Default implementation returns False (no dynamic models available).
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
"""
return False
async def register_model(self, model: Model) -> Model:
if not (supported_model_id := self.get_provider_model_id(model.provider_resource_id)):
raise UnsupportedModelError(model.provider_resource_id, self.alias_to_provider_id_map.keys())
# Check if model is supported in static configuration
supported_model_id = self.get_provider_model_id(model.provider_resource_id)
# If not found in static config, check if it's available dynamically from provider
if not supported_model_id:
if await self.check_model_availability(model.provider_resource_id):
supported_model_id = model.provider_resource_id
else:
# note: we cannot provide a complete list of supported models without
# getting a complete list from the provider, so we return "..."
all_supported_models = [*self.alias_to_provider_id_map.keys(), "..."]
raise UnsupportedModelError(model.provider_resource_id, all_supported_models)
provider_resource_id = self.get_provider_model_id(model.model_id)
if model.model_type == ModelType.embedding:
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
provider_resource_id = model.provider_resource_id
if provider_resource_id:
if provider_resource_id != supported_model_id: # be idemopotent, only reject differences
if provider_resource_id != supported_model_id: # be idempotent, only reject differences
raise ValueError(
f"Model id '{model.model_id}' is already registered. Please use a different id or unregister it first."
)
@ -113,6 +142,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
)
# Register the model alias, ensuring it maps to the correct provider model id
self.alias_to_provider_id_map[model.model_id] = supported_model_id
return model

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import asyncio
import json
import logging
import mimetypes
import time
@ -35,6 +36,7 @@ from llama_stack.apis.vector_io import (
VectorStoreSearchResponse,
VectorStoreSearchResponsePage,
)
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, make_overlapped_chunks
logger = logging.getLogger(__name__)
@ -59,26 +61,45 @@ class OpenAIVectorStoreMixin(ABC):
# These should be provided by the implementing class
openai_vector_stores: dict[str, dict[str, Any]]
files_api: Files | None
# KV store for persisting OpenAI vector store metadata
kvstore: KVStore | None
@abstractmethod
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to persistent storage."""
pass
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
# update in-memory cache
self.openai_vector_stores[store_id] = store_info
@abstractmethod
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
"""Load all vector store metadata from persistent storage."""
pass
assert self.kvstore is not None
start_key = OPENAI_VECTOR_STORES_PREFIX
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
stored_data = await self.kvstore.values_in_range(start_key, end_key)
stores: dict[str, dict[str, Any]] = {}
for item in stored_data:
info = json.loads(item)
stores[info["id"]] = info
return stores
@abstractmethod
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Update vector store metadata in persistent storage."""
pass
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
# update in-memory cache
self.openai_vector_stores[store_id] = store_info
@abstractmethod
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
"""Delete vector store metadata from persistent storage."""
pass
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.delete(key)
# remove from in-memory cache
self.openai_vector_stores.pop(store_id, None)
@abstractmethod
async def _save_openai_vector_store_file(
@ -117,6 +138,10 @@ class OpenAIVectorStoreMixin(ABC):
"""Unregister a vector database (provider-specific implementation)."""
pass
async def initialize_openai_vector_stores(self) -> None:
"""Load existing OpenAI vector stores into the in-memory cache."""
self.openai_vector_stores = await self._load_openai_vector_stores()
@abstractmethod
async def insert_chunks(
self,
@ -147,8 +172,9 @@ class OpenAIVectorStoreMixin(ABC):
provider_vector_db_id: str | None = None,
) -> VectorStoreObject:
"""Creates a vector store."""
store_id = name or str(uuid.uuid4())
created_at = int(time.time())
# Derive the canonical vector_db_id (allow override, else generate)
vector_db_id = provider_vector_db_id or f"vs_{uuid.uuid4()}"
if provider_id is None:
raise ValueError("Provider ID is required")
@ -156,19 +182,19 @@ class OpenAIVectorStoreMixin(ABC):
if embedding_model is None:
raise ValueError("Embedding model is required")
# Use provided embedding dimension or default to 384
# Embedding dimension is required (defaulted to 384 if not provided)
if embedding_dimension is None:
raise ValueError("Embedding dimension is required")
provider_vector_db_id = provider_vector_db_id or store_id
# Register the VectorDB backing this vector store
vector_db = VectorDB(
identifier=store_id,
identifier=vector_db_id,
embedding_dimension=embedding_dimension,
embedding_model=embedding_model,
provider_id=provider_id,
provider_resource_id=provider_vector_db_id,
provider_resource_id=vector_db_id,
vector_db_name=name,
)
# Register the vector DB
await self.register_vector_db(vector_db)
# Create OpenAI vector store metadata
@ -182,11 +208,11 @@ class OpenAIVectorStoreMixin(ABC):
in_progress=0,
total=0,
)
store_info = {
"id": store_id,
store_info: dict[str, Any] = {
"id": vector_db_id,
"object": "vector_store",
"created_at": created_at,
"name": store_id,
"name": name,
"usage_bytes": 0,
"file_counts": file_counts.model_dump(),
"status": status,
@ -206,18 +232,18 @@ class OpenAIVectorStoreMixin(ABC):
store_info["metadata"] = metadata
# Save to persistent storage (provider-specific)
await self._save_openai_vector_store(store_id, store_info)
await self._save_openai_vector_store(vector_db_id, store_info)
# Store in memory cache
self.openai_vector_stores[store_id] = store_info
self.openai_vector_stores[vector_db_id] = store_info
# Now that our vector store is created, attach any files that were provided
file_ids = file_ids or []
tasks = [self.openai_attach_file_to_vector_store(store_id, file_id) for file_id in file_ids]
tasks = [self.openai_attach_file_to_vector_store(vector_db_id, file_id) for file_id in file_ids]
await asyncio.gather(*tasks)
# Get the updated store info and return it
store_info = self.openai_vector_stores[store_id]
store_info = self.openai_vector_stores[vector_db_id]
return VectorStoreObject.model_validate(store_info)
async def openai_list_vector_stores(

View file

@ -15,6 +15,7 @@ from llama_stack.distribution.request_headers import get_authenticated_user
from llama_stack.log import get_logger
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
from .sqlstore import SqlStoreType
logger = get_logger(name=__name__, category="authorized_sqlstore")
@ -38,22 +39,10 @@ SQL_OPTIMIZED_POLICY = [
class SqlRecord(ProtectedResource):
"""Simple ProtectedResource implementation for SQL records."""
def __init__(self, record_id: str, table_name: str, access_attributes: dict[str, list[str]] | None = None):
def __init__(self, record_id: str, table_name: str, owner: User):
self.type = f"sql_record::{table_name}"
self.identifier = record_id
if access_attributes:
self.owner = User(
principal="system",
attributes=access_attributes,
)
else:
self.owner = User(
principal="system_public",
attributes=None,
)
self.owner = owner
class AuthorizedSqlStore:
@ -71,9 +60,18 @@ class AuthorizedSqlStore:
:param sql_store: Base SqlStore implementation to wrap
"""
self.sql_store = sql_store
self._detect_database_type()
self._validate_sql_optimized_policy()
def _detect_database_type(self) -> None:
"""Detect the database type from the underlying SQL store."""
if not hasattr(self.sql_store, "config"):
raise ValueError("SqlStore must have a config attribute to be used with AuthorizedSqlStore")
self.database_type = self.sql_store.config.type
if self.database_type not in [SqlStoreType.postgres, SqlStoreType.sqlite]:
raise ValueError(f"Unsupported database type: {self.database_type}")
def _validate_sql_optimized_policy(self) -> None:
"""Validate that SQL_OPTIMIZED_POLICY matches the actual default_policy().
@ -91,22 +89,27 @@ class AuthorizedSqlStore:
async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None:
"""Create a table with built-in access control support."""
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
enhanced_schema = dict(schema)
if "access_attributes" not in enhanced_schema:
enhanced_schema["access_attributes"] = ColumnType.JSON
if "owner_principal" not in enhanced_schema:
enhanced_schema["owner_principal"] = ColumnType.STRING
await self.sql_store.create_table(table, enhanced_schema)
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
await self.sql_store.add_column_if_not_exists(table, "owner_principal", ColumnType.STRING)
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
"""Insert a row with automatic access control attribute capture."""
enhanced_data = dict(data)
current_user = get_authenticated_user()
if current_user and current_user.attributes:
if current_user:
enhanced_data["owner_principal"] = current_user.principal
enhanced_data["access_attributes"] = current_user.attributes
else:
enhanced_data["owner_principal"] = None
enhanced_data["access_attributes"] = None
await self.sql_store.insert(table, enhanced_data)
@ -136,9 +139,12 @@ class AuthorizedSqlStore:
for row in rows.data:
stored_access_attrs = row.get("access_attributes")
stored_owner_principal = row.get("owner_principal") or ""
record_id = row.get("id", "unknown")
sql_record = SqlRecord(str(record_id), table, stored_access_attrs)
sql_record = SqlRecord(
str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs)
)
if is_action_allowed(policy, Action.READ, sql_record, current_user):
filtered_rows.append(row)
@ -176,43 +182,90 @@ class AuthorizedSqlStore:
Only applies SQL filtering for the default policy to ensure correctness.
For custom policies, uses conservative filtering to avoid blocking legitimate access.
"""
current_user = get_authenticated_user()
if not policy or policy == SQL_OPTIMIZED_POLICY:
return self._build_default_policy_where_clause()
return self._build_default_policy_where_clause(current_user)
else:
return self._build_conservative_where_clause()
def _build_default_policy_where_clause(self) -> str:
def _json_extract(self, column: str, path: str) -> str:
"""Extract JSON value (keeping JSON type).
Args:
column: The JSON column name
path: The JSON path (e.g., 'roles', 'teams')
Returns:
SQL expression to extract JSON value
"""
if self.database_type == SqlStoreType.postgres:
return f"{column}->'{path}'"
elif self.database_type == SqlStoreType.sqlite:
return f"JSON_EXTRACT({column}, '$.{path}')"
else:
raise ValueError(f"Unsupported database type: {self.database_type}")
def _json_extract_text(self, column: str, path: str) -> str:
"""Extract JSON value as text.
Args:
column: The JSON column name
path: The JSON path (e.g., 'roles', 'teams')
Returns:
SQL expression to extract JSON value as text
"""
if self.database_type == SqlStoreType.postgres:
return f"{column}->>'{path}'"
elif self.database_type == SqlStoreType.sqlite:
return f"JSON_EXTRACT({column}, '$.{path}')"
else:
raise ValueError(f"Unsupported database type: {self.database_type}")
def _get_public_access_conditions(self) -> list[str]:
"""Get the SQL conditions for public access."""
# Public records are records that have no owner_principal or access_attributes
conditions = ["owner_principal = ''"]
if self.database_type == SqlStoreType.postgres:
# Postgres stores JSON null as 'null'
conditions.append("access_attributes::text = 'null'")
elif self.database_type == SqlStoreType.sqlite:
conditions.append("access_attributes = 'null'")
else:
raise ValueError(f"Unsupported database type: {self.database_type}")
return conditions
def _build_default_policy_where_clause(self, current_user: User | None) -> str:
"""Build SQL WHERE clause for the default policy.
Default policy: permit all actions when user in owners [roles, teams, projects, namespaces]
This means user must match ALL attribute categories that exist in the resource.
"""
current_user = get_authenticated_user()
if not current_user or not current_user.attributes:
return "(access_attributes IS NULL OR access_attributes = 'null' OR access_attributes = '{}')"
else:
base_conditions = ["access_attributes IS NULL", "access_attributes = 'null'", "access_attributes = '{}'"]
user_attr_conditions = []
base_conditions = self._get_public_access_conditions()
user_attr_conditions = []
if current_user and current_user.attributes:
for attr_key, user_values in current_user.attributes.items():
if user_values:
value_conditions = []
for value in user_values:
value_conditions.append(f"JSON_EXTRACT(access_attributes, '$.{attr_key}') LIKE '%\"{value}\"%'")
# Check if JSON array contains the value
escaped_value = value.replace("'", "''")
json_text = self._json_extract_text("access_attributes", attr_key)
value_conditions.append(f"({json_text} LIKE '%\"{escaped_value}\"%')")
if value_conditions:
category_missing = f"JSON_EXTRACT(access_attributes, '$.{attr_key}') IS NULL"
# Check if the category is missing (NULL)
category_missing = f"{self._json_extract('access_attributes', attr_key)} IS NULL"
user_matches_category = f"({' OR '.join(value_conditions)})"
user_attr_conditions.append(f"({category_missing} OR {user_matches_category})")
if user_attr_conditions:
all_requirements_met = f"({' AND '.join(user_attr_conditions)})"
base_conditions.append(all_requirements_met)
return f"({' OR '.join(base_conditions)})"
else:
return f"({' OR '.join(base_conditions)})"
return f"({' OR '.join(base_conditions)})"
def _build_conservative_where_clause(self) -> str:
"""Conservative SQL filtering for custom policies.
@ -222,5 +275,8 @@ class AuthorizedSqlStore:
current_user = get_authenticated_user()
if not current_user:
return "(access_attributes IS NULL OR access_attributes = 'null' OR access_attributes = '{}')"
# Only allow public records
base_conditions = self._get_public_access_conditions()
return f"({' OR '.join(base_conditions)})"
return "1=1"

View file

@ -244,35 +244,41 @@ class SqlAlchemySqlStoreImpl(SqlStore):
engine = create_async_engine(self.config.engine_str)
try:
inspector = inspect(engine)
table_names = inspector.get_table_names()
if table not in table_names:
return
existing_columns = inspector.get_columns(table)
column_names = [col["name"] for col in existing_columns]
if column_name in column_names:
return
sqlalchemy_type = TYPE_MAPPING.get(column_type)
if not sqlalchemy_type:
raise ValueError(f"Unsupported column type '{column_type}' for column '{column_name}'.")
# Create the ALTER TABLE statement
# Note: We need to get the dialect-specific type name
dialect = engine.dialect
type_impl = sqlalchemy_type()
compiled_type = type_impl.compile(dialect=dialect)
nullable_clause = "" if nullable else " NOT NULL"
add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}")
async with engine.begin() as conn:
def check_column_exists(sync_conn):
inspector = inspect(sync_conn)
table_names = inspector.get_table_names()
if table not in table_names:
return False, False # table doesn't exist, column doesn't exist
existing_columns = inspector.get_columns(table)
column_names = [col["name"] for col in existing_columns]
return True, column_name in column_names # table exists, column exists or not
table_exists, column_exists = await conn.run_sync(check_column_exists)
if not table_exists or column_exists:
return
sqlalchemy_type = TYPE_MAPPING.get(column_type)
if not sqlalchemy_type:
raise ValueError(f"Unsupported column type '{column_type}' for column '{column_name}'.")
# Create the ALTER TABLE statement
# Note: We need to get the dialect-specific type name
dialect = engine.dialect
type_impl = sqlalchemy_type()
compiled_type = type_impl.compile(dialect=dialect)
nullable_clause = "" if nullable else " NOT NULL"
add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}")
await conn.execute(add_column_sql)
except Exception:
except Exception as e:
# If any error occurs during migration, log it but don't fail
# The table creation will handle adding the column
logger.error(f"Error adding column {column_name} to table {table}: {e}")
pass

View file

@ -4,9 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import abstractmethod
from enum import Enum
from enum import StrEnum
from pathlib import Path
from typing import Annotated, Literal
@ -19,7 +18,7 @@ from .api import SqlStore
sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"]
class SqlStoreType(Enum):
class SqlStoreType(StrEnum):
sqlite = "sqlite"
postgres = "postgres"
@ -36,7 +35,7 @@ class SqlAlchemySqlStoreConfig(BaseModel):
class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
type: Literal["sqlite"] = SqlStoreType.sqlite.value
type: Literal[SqlStoreType.sqlite] = SqlStoreType.sqlite
db_path: str = Field(
default=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
description="Database path, e.g. ~/.llama/distributions/ollama/sqlstore.db",
@ -59,7 +58,7 @@ class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
type: Literal["postgres"] = SqlStoreType.postgres.value
type: Literal[SqlStoreType.postgres] = SqlStoreType.postgres
host: str = "localhost"
port: int = 5432
db: str = "llamastack"
@ -107,7 +106,7 @@ def get_pip_packages(store_config: dict | SqlStoreConfig) -> list[str]:
def sqlstore_impl(config: SqlStoreConfig) -> SqlStore:
if config.type in [SqlStoreType.sqlite.value, SqlStoreType.postgres.value]:
if config.type in [SqlStoreType.sqlite, SqlStoreType.postgres]:
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
impl = SqlAlchemySqlStoreImpl(config)

View file

@ -9,14 +9,12 @@ import inspect
import json
from collections.abc import AsyncGenerator, Callable
from functools import wraps
from typing import Any, TypeVar
from typing import Any
from pydantic import BaseModel
from llama_stack.models.llama.datatypes import Primitive
T = TypeVar("T")
def serialize_value(value: Any) -> Primitive:
return str(_prepare_for_json(value))
@ -44,7 +42,7 @@ def _prepare_for_json(value: Any) -> str:
return str(value)
def trace_protocol(cls: type[T]) -> type[T]:
def trace_protocol[T](cls: type[T]) -> type[T]:
"""
A class decorator that automatically traces all methods in a protocol/base class
and its inheriting classes.

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .nvidia import get_distribution_template # noqa: F401

View file

@ -0,0 +1,29 @@
version: 2
distribution_spec:
description: Use NVIDIA NIM for running LLM inference, evaluation and safety
providers:
inference:
- remote::nvidia
vector_io:
- inline::faiss
safety:
- remote::nvidia
agents:
- inline::meta-reference
telemetry:
- inline::meta-reference
eval:
- remote::nvidia
post_training:
- remote::nvidia
datasetio:
- inline::localfs
- remote::nvidia
scoring:
- inline::basic
tool_runtime:
- inline::rag-runtime
image_type: conda
additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio]

View file

@ -0,0 +1,149 @@
# NVIDIA Distribution
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations.
{{ providers_table }}
{% if run_config_env_vars %}
### Environment Variables
The following environment variables can be configured:
{% for var, (default_value, description) in run_config_env_vars.items() %}
- `{{ var }}`: {{ description }} (default: `{{ default_value }}`)
{% endfor %}
{% endif %}
{% if default_models %}
### Models
The following models are available by default:
{% for model in default_models %}
- `{{ model.model_id }} {{ model.doc_string }}`
{% endfor %}
{% endif %}
## Prerequisites
### NVIDIA API Keys
Make sure you have access to a NVIDIA API Key. You can get one by visiting [https://build.nvidia.com/](https://build.nvidia.com/). Use this key for the `NVIDIA_API_KEY` environment variable.
### Deploy NeMo Microservices Platform
The NVIDIA NeMo microservices platform supports end-to-end microservice deployment of a complete AI flywheel on your Kubernetes cluster through the NeMo Microservices Helm Chart. Please reference the [NVIDIA NeMo Microservices documentation](https://docs.nvidia.com/nemo/microservices/latest/about/index.html) for platform prerequisites and instructions to install and deploy the platform.
## Supported Services
Each Llama Stack API corresponds to a specific NeMo microservice. The core microservices (Customizer, Evaluator, Guardrails) are exposed by the same endpoint. The platform components (Data Store) are each exposed by separate endpoints.
### Inference: NVIDIA NIM
NVIDIA NIM is used for running inference with registered models. There are two ways to access NVIDIA NIMs:
1. Hosted (default): Preview APIs hosted at https://integrate.api.nvidia.com (Requires an API key)
2. Self-hosted: NVIDIA NIMs that run on your own infrastructure.
The deployed platform includes the NIM Proxy microservice, which is the service that provides to access your NIMs (for example, to run inference on a model). Set the `NVIDIA_BASE_URL` environment variable to use your NVIDIA NIM Proxy deployment.
### Datasetio API: NeMo Data Store
The NeMo Data Store microservice serves as the default file storage solution for the NeMo microservices platform. It exposts APIs compatible with the Hugging Face Hub client (`HfApi`), so you can use the client to interact with Data Store. The `NVIDIA_DATASETS_URL` environment variable should point to your NeMo Data Store endpoint.
See the {repopath}`NVIDIA Datasetio docs::llama_stack/providers/remote/datasetio/nvidia/README.md` for supported features and example usage.
### Eval API: NeMo Evaluator
The NeMo Evaluator microservice supports evaluation of LLMs. Launching an Evaluation job with NeMo Evaluator requires an Evaluation Config (an object that contains metadata needed by the job). A Llama Stack Benchmark maps to an Evaluation Config, so registering a Benchmark creates an Evaluation Config in NeMo Evaluator. The `NVIDIA_EVALUATOR_URL` environment variable should point to your NeMo Microservices endpoint.
See the {repopath}`NVIDIA Eval docs::llama_stack/providers/remote/eval/nvidia/README.md` for supported features and example usage.
### Post-Training API: NeMo Customizer
The NeMo Customizer microservice supports fine-tuning models. You can reference {repopath}`this list of supported models::llama_stack/providers/remote/post_training/nvidia/models.py` that can be fine-tuned using Llama Stack. The `NVIDIA_CUSTOMIZER_URL` environment variable should point to your NeMo Microservices endpoint.
See the {repopath}`NVIDIA Post-Training docs::llama_stack/providers/remote/post_training/nvidia/README.md` for supported features and example usage.
### Safety API: NeMo Guardrails
The NeMo Guardrails microservice sits between your application and the LLM, and adds checks and content moderation to a model. The `GUARDRAILS_SERVICE_URL` environment variable should point to your NeMo Microservices endpoint.
See the {repopath}`NVIDIA Safety docs::llama_stack/providers/remote/safety/nvidia/README.md` for supported features and example usage.
## Deploying models
In order to use a registered model with the Llama Stack APIs, ensure the corresponding NIM is deployed to your environment. For example, you can use the NIM Proxy microservice to deploy `meta/llama-3.2-1b-instruct`.
Note: For improved inference speeds, we need to use NIM with `fast_outlines` guided decoding system (specified in the request body). This is the default if you deployed the platform with the NeMo Microservices Helm Chart.
```sh
# URL to NeMo NIM Proxy service
export NEMO_URL="http://nemo.test"
curl --location "$NEMO_URL/v1/deployment/model-deployments" \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"name": "llama-3.2-1b-instruct",
"namespace": "meta",
"config": {
"model": "meta/llama-3.2-1b-instruct",
"nim_deployment": {
"image_name": "nvcr.io/nim/meta/llama-3.2-1b-instruct",
"image_tag": "1.8.3",
"pvc_size": "25Gi",
"gpu": 1,
"additional_envs": {
"NIM_GUIDED_DECODING_BACKEND": "fast_outlines"
}
}
}
}'
```
This NIM deployment should take approximately 10 minutes to go live. [See the docs](https://docs.nvidia.com/nemo/microservices/latest/get-started/tutorials/deploy-nims.html) for more information on how to deploy a NIM and verify it's available for inference.
You can also remove a deployed NIM to free up GPU resources, if needed.
```sh
export NEMO_URL="http://nemo.test"
curl -X DELETE "$NEMO_URL/v1/deployment/model-deployments/meta/llama-3.1-8b-instruct"
```
## Running Llama Stack with NVIDIA
You can do this via Conda or venv (build code), or Docker which has a pre-built image.
### Via Docker
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=8321
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-{{ name }} \
--config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
```
### Via Conda
```bash
INFERENCE_MODEL=meta-llama/Llama-3.1-8b-Instruct
llama stack build --template nvidia --image-type conda
llama stack run ./run.yaml \
--port 8321 \
--env NVIDIA_API_KEY=$NVIDIA_API_KEY \
--env INFERENCE_MODEL=$INFERENCE_MODEL
```
### Via venv
If you've set up your local development environment, you can also build the image using your local virtual environment.
```bash
INFERENCE_MODEL=meta-llama/Llama-3.1-8b-Instruct
llama stack build --template nvidia --image-type venv
llama stack run ./run.yaml \
--port 8321 \
--env NVIDIA_API_KEY=$NVIDIA_API_KEY \
--env INFERENCE_MODEL=$INFERENCE_MODEL
```
## Example Notebooks
For examples of how to use the NVIDIA Distribution to run inference, fine-tune, evaluate, and run safety checks on your LLMs, you can reference the example notebooks in {repopath}`docs/notebooks/nvidia`.

View file

@ -0,0 +1,150 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput
from llama_stack.providers.remote.datasetio.nvidia import NvidiaDatasetIOConfig
from llama_stack.providers.remote.eval.nvidia import NVIDIAEvalConfig
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["remote::nvidia"],
"vector_io": ["inline::faiss"],
"safety": ["remote::nvidia"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
"eval": ["remote::nvidia"],
"post_training": ["remote::nvidia"],
"datasetio": ["inline::localfs", "remote::nvidia"],
"scoring": ["inline::basic"],
"tool_runtime": ["inline::rag-runtime"],
}
inference_provider = Provider(
provider_id="nvidia",
provider_type="remote::nvidia",
config=NVIDIAConfig.sample_run_config(),
)
safety_provider = Provider(
provider_id="nvidia",
provider_type="remote::nvidia",
config=NVIDIASafetyConfig.sample_run_config(),
)
datasetio_provider = Provider(
provider_id="nvidia",
provider_type="remote::nvidia",
config=NvidiaDatasetIOConfig.sample_run_config(),
)
eval_provider = Provider(
provider_id="nvidia",
provider_type="remote::nvidia",
config=NVIDIAEvalConfig.sample_run_config(),
)
inference_model = ModelInput(
model_id="${env.INFERENCE_MODEL}",
provider_id="nvidia",
)
safety_model = ModelInput(
model_id="${env.SAFETY_MODEL}",
provider_id="nvidia",
)
available_models = {
"nvidia": MODEL_ENTRIES,
}
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
]
default_models, _ = get_model_registry(available_models)
return DistributionTemplate(
name="nvidia",
distro_type="self_hosted",
description="Use NVIDIA NIM for running LLM inference, evaluation and safety",
container_image=None,
template_path=Path(__file__).parent / "doc_template.md",
providers=providers,
available_models_by_provider=available_models,
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider],
"datasetio": [datasetio_provider],
"eval": [eval_provider],
},
default_models=default_models,
default_tool_groups=default_tool_groups,
),
"run-with-safety.yaml": RunConfigSettings(
provider_overrides={
"inference": [
inference_provider,
safety_provider,
],
"eval": [eval_provider],
},
default_models=[inference_model, safety_model],
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}", provider_id="nvidia")],
default_tool_groups=default_tool_groups,
),
},
run_config_env_vars={
"NVIDIA_API_KEY": (
"",
"NVIDIA API Key",
),
"NVIDIA_APPEND_API_VERSION": (
"True",
"Whether to append the API version to the base_url",
),
## Nemo Customizer related variables
"NVIDIA_DATASET_NAMESPACE": (
"default",
"NVIDIA Dataset Namespace",
),
"NVIDIA_PROJECT_ID": (
"test-project",
"NVIDIA Project ID",
),
"NVIDIA_CUSTOMIZER_URL": (
"https://customizer.api.nvidia.com",
"NVIDIA Customizer URL",
),
"NVIDIA_OUTPUT_MODEL_DIR": (
"test-example-model@v1",
"NVIDIA Output Model Directory",
),
"GUARDRAILS_SERVICE_URL": (
"http://0.0.0.0:7331",
"URL for the NeMo Guardrails Service",
),
"NVIDIA_GUARDRAILS_CONFIG_ID": (
"self-check",
"NVIDIA Guardrail Configuration ID",
),
"NVIDIA_EVALUATOR_URL": (
"http://0.0.0.0:7331",
"URL for the NeMo Evaluator Service",
),
"INFERENCE_MODEL": (
"Llama3.1-8B-Instruct",
"Inference model",
),
"SAFETY_MODEL": (
"meta/llama-3.1-8b-instruct",
"Name of the model to use for safety",
),
},
)

View file

@ -0,0 +1,119 @@
version: 2
image_name: nvidia
apis:
- agents
- datasetio
- eval
- inference
- post_training
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: nvidia
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: nvidia
provider_type: remote::nvidia
config:
guardrails_service_url: ${env.GUARDRAILS_SERVICE_URL:=http://localhost:7331}
config_id: ${env.NVIDIA_GUARDRAILS_CONFIG_ID:=self-check}
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/faiss_store.db
safety:
- provider_id: nvidia
provider_type: remote::nvidia
config:
guardrails_service_url: ${env.GUARDRAILS_SERVICE_URL:=http://localhost:7331}
config_id: ${env.NVIDIA_GUARDRAILS_CONFIG_ID:=self-check}
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/agents_store.db
responses_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/responses_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
eval:
- provider_id: nvidia
provider_type: remote::nvidia
config:
evaluator_url: ${env.NVIDIA_EVALUATOR_URL:=http://localhost:7331}
post_training:
- provider_id: nvidia
provider_type: remote::nvidia
config:
api_key: ${env.NVIDIA_API_KEY:=}
dataset_namespace: ${env.NVIDIA_DATASET_NAMESPACE:=default}
project_id: ${env.NVIDIA_PROJECT_ID:=test-project}
customizer_url: ${env.NVIDIA_CUSTOMIZER_URL:=http://nemo.test}
datasetio:
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/localfs_datasetio.db
- provider_id: nvidia
provider_type: remote::nvidia
config:
api_key: ${env.NVIDIA_API_KEY:=}
dataset_namespace: ${env.NVIDIA_DATASET_NAMESPACE:=default}
project_id: ${env.NVIDIA_PROJECT_ID:=test-project}
datasets_url: ${env.NVIDIA_DATASETS_URL:=http://nemo.test}
scoring:
- provider_id: basic
provider_type: inline::basic
config: {}
tool_runtime:
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/registry.db
inference_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/inference_store.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: nvidia
model_type: llm
- metadata: {}
model_id: ${env.SAFETY_MODEL}
provider_id: nvidia
model_type: llm
shields:
- shield_id: ${env.SAFETY_MODEL}
provider_id: nvidia
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321

View file

@ -0,0 +1,226 @@
version: 2
image_name: nvidia
apis:
- agents
- datasetio
- eval
- inference
- post_training
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: nvidia
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/faiss_store.db
safety:
- provider_id: nvidia
provider_type: remote::nvidia
config:
guardrails_service_url: ${env.GUARDRAILS_SERVICE_URL:=http://localhost:7331}
config_id: ${env.NVIDIA_GUARDRAILS_CONFIG_ID:=self-check}
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/agents_store.db
responses_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/responses_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
eval:
- provider_id: nvidia
provider_type: remote::nvidia
config:
evaluator_url: ${env.NVIDIA_EVALUATOR_URL:=http://localhost:7331}
post_training:
- provider_id: nvidia
provider_type: remote::nvidia
config:
api_key: ${env.NVIDIA_API_KEY:=}
dataset_namespace: ${env.NVIDIA_DATASET_NAMESPACE:=default}
project_id: ${env.NVIDIA_PROJECT_ID:=test-project}
customizer_url: ${env.NVIDIA_CUSTOMIZER_URL:=http://nemo.test}
datasetio:
- provider_id: nvidia
provider_type: remote::nvidia
config:
api_key: ${env.NVIDIA_API_KEY:=}
dataset_namespace: ${env.NVIDIA_DATASET_NAMESPACE:=default}
project_id: ${env.NVIDIA_PROJECT_ID:=test-project}
datasets_url: ${env.NVIDIA_DATASETS_URL:=http://nemo.test}
scoring:
- provider_id: basic
provider_type: inline::basic
config: {}
tool_runtime:
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/registry.db
inference_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/inference_store.db
models:
- metadata: {}
model_id: meta/llama3-8b-instruct
provider_id: nvidia
provider_model_id: meta/llama3-8b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3-8B-Instruct
provider_id: nvidia
provider_model_id: meta/llama3-8b-instruct
model_type: llm
- metadata: {}
model_id: meta/llama3-70b-instruct
provider_id: nvidia
provider_model_id: meta/llama3-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3-70B-Instruct
provider_id: nvidia
provider_model_id: meta/llama3-70b-instruct
model_type: llm
- metadata: {}
model_id: meta/llama-3.1-8b-instruct
provider_id: nvidia
provider_model_id: meta/llama-3.1-8b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.1-8b-instruct
model_type: llm
- metadata: {}
model_id: meta/llama-3.1-70b-instruct
provider_id: nvidia
provider_model_id: meta/llama-3.1-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-70B-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.1-70b-instruct
model_type: llm
- metadata: {}
model_id: meta/llama-3.1-405b-instruct
provider_id: nvidia
provider_model_id: meta/llama-3.1-405b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-405B-Instruct-FP8
provider_id: nvidia
provider_model_id: meta/llama-3.1-405b-instruct
model_type: llm
- metadata: {}
model_id: meta/llama-3.2-1b-instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-1b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-1B-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-1b-instruct
model_type: llm
- metadata: {}
model_id: meta/llama-3.2-3b-instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-3b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-3B-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-3b-instruct
model_type: llm
- metadata: {}
model_id: meta/llama-3.2-11b-vision-instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta/llama-3.2-90b-vision-instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-90b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-90b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta/llama-3.3-70b-instruct
provider_id: nvidia
provider_model_id: meta/llama-3.3-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.3-70B-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.3-70b-instruct
model_type: llm
- metadata:
embedding_dimension: 2048
context_length: 8192
model_id: nvidia/llama-3.2-nv-embedqa-1b-v2
provider_id: nvidia
provider_model_id: nvidia/llama-3.2-nv-embedqa-1b-v2
model_type: embedding
- metadata:
embedding_dimension: 1024
context_length: 512
model_id: nvidia/nv-embedqa-e5-v5
provider_id: nvidia
provider_model_id: nvidia/nv-embedqa-e5-v5
model_type: embedding
- metadata:
embedding_dimension: 4096
context_length: 512
model_id: nvidia/nv-embedqa-mistral-7b-v2
provider_id: nvidia
provider_model_id: nvidia/nv-embedqa-mistral-7b-v2
model_type: embedding
- metadata:
embedding_dimension: 1024
context_length: 512
model_id: snowflake/arctic-embed-l
provider_id: nvidia
provider_model_id: snowflake/arctic-embed-l
model_type: embedding
shields: []
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321

View file

@ -128,6 +128,7 @@ def get_distribution_template() -> DistributionTemplate:
provider_id="${env.ENABLE_PGVECTOR:+pgvector}",
provider_type="remote::pgvector",
config=PGVectorVectorIOConfig.sample_run_config(
f"~/.llama/distributions/{name}",
db="${env.PGVECTOR_DB:=}",
user="${env.PGVECTOR_USER:=}",
password="${env.PGVECTOR_PASSWORD:=}",
@ -146,7 +147,8 @@ def get_distribution_template() -> DistributionTemplate:
),
]
default_models = get_model_registry(available_models) + [
models, _ = get_model_registry(available_models)
default_models = models + [
ModelInput(
model_id="meta-llama/Llama-3.3-70B-Instruct",
provider_id="groq",

View file

@ -39,6 +39,9 @@ providers:
provider_type: inline::sqlite-vec
config:
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/sqlite_vec.db
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/sqlite_vec_registry.db
- provider_id: ${env.ENABLE_CHROMADB:+chromadb}
provider_type: remote::chromadb
config:
@ -51,6 +54,9 @@ providers:
db: ${env.PGVECTOR_DB:=}
user: ${env.PGVECTOR_USER:=}
password: ${env.PGVECTOR_PASSWORD:=}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/pgvector_registry.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard

View file

@ -144,6 +144,9 @@ providers:
provider_type: inline::sqlite-vec
config:
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db
- provider_id: ${env.ENABLE_MILVUS:=__disabled__}
provider_type: inline::milvus
config:
@ -163,6 +166,9 @@ providers:
db: ${env.PGVECTOR_DB:=}
user: ${env.PGVECTOR_USER:=}
password: ${env.PGVECTOR_PASSWORD:=}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
@ -256,11 +262,51 @@ inference_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/inference_store.db
models:
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers}
model_type: embedding
- metadata: {}
model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/llama3.1-8b
provider_id: ${env.ENABLE_CEREBRAS:=__disabled__}
provider_model_id: llama3.1-8b
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/meta-llama/Llama-3.1-8B-Instruct
provider_id: ${env.ENABLE_CEREBRAS:=__disabled__}
provider_model_id: llama3.1-8b
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/llama-3.3-70b
provider_id: ${env.ENABLE_CEREBRAS:=__disabled__}
provider_model_id: llama-3.3-70b
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/meta-llama/Llama-3.3-70B-Instruct
provider_id: ${env.ENABLE_CEREBRAS:=__disabled__}
provider_model_id: llama-3.3-70b
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/llama-4-scout-17b-16e-instruct
provider_id: ${env.ENABLE_CEREBRAS:=__disabled__}
provider_model_id: llama-4-scout-17b-16e-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/meta-llama/Llama-4-Scout-17B-16E-Instruct
provider_id: ${env.ENABLE_CEREBRAS:=__disabled__}
provider_model_id: llama-4-scout-17b-16e-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.OLLAMA_INFERENCE_MODEL:=__disabled__}
provider_id: ${env.ENABLE_OLLAMA:=__disabled__}
provider_model_id: ${env.OLLAMA_INFERENCE_MODEL:=__disabled__}
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=__disabled__}
provider_id: ${env.ENABLE_OLLAMA:=__disabled__}
provider_model_id: ${env.SAFETY_MODEL:=__disabled__}
model_type: llm
- metadata:
embedding_dimension: ${env.OLLAMA_EMBEDDING_DIMENSION:=384}
model_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.OLLAMA_EMBEDDING_MODEL:=__disabled__}
@ -342,26 +388,6 @@ models:
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama-guard-3-8b
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-guard-3-8b
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/meta-llama/Llama-Guard-3-8B
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-guard-3-8b
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama-guard-3-11b-vision
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/meta-llama/Llama-Guard-3-11B-Vision
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama4-scout-instruct-basic
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
@ -389,6 +415,26 @@ models:
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: nomic-ai/nomic-embed-text-v1.5
model_type: embedding
- metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama-guard-3-8b
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-guard-3-8b
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/meta-llama/Llama-Guard-3-8B
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-guard-3-8b
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama-guard-3-11b-vision
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/meta-llama/Llama-Guard-3-11B-Vision
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
@ -459,26 +505,6 @@ models:
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Llama-3.3-70B-Instruct-Turbo
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Meta-Llama-Guard-3-8B
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Meta-Llama-Guard-3-8B
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-8B
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Meta-Llama-Guard-3-8B
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-11B-Vision-Turbo
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Llama-Guard-3-11B-Vision-Turbo
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-11B-Vision
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Llama-Guard-3-11B-Vision-Turbo
model_type: llm
- metadata:
embedding_dimension: 768
context_length: 8192
@ -523,6 +549,264 @@ models:
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-8B
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Llama-Guard-3-8B
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-8B
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Llama-Guard-3-8B
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-11B-Vision-Turbo
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Llama-Guard-3-11B-Vision-Turbo
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-11B-Vision
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Llama-Guard-3-11B-Vision-Turbo
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta.llama3-1-8b-instruct-v1:0
provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
provider_model_id: meta.llama3-1-8b-instruct-v1:0
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta-llama/Llama-3.1-8B-Instruct
provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
provider_model_id: meta.llama3-1-8b-instruct-v1:0
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta.llama3-1-70b-instruct-v1:0
provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
provider_model_id: meta.llama3-1-70b-instruct-v1:0
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta-llama/Llama-3.1-70B-Instruct
provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
provider_model_id: meta.llama3-1-70b-instruct-v1:0
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta.llama3-1-405b-instruct-v1:0
provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
provider_model_id: meta.llama3-1-405b-instruct-v1:0
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta-llama/Llama-3.1-405B-Instruct-FP8
provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
provider_model_id: meta.llama3-1-405b-instruct-v1:0
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_DATABRICKS:=__disabled__}/databricks-meta-llama-3-1-70b-instruct
provider_id: ${env.ENABLE_DATABRICKS:=__disabled__}
provider_model_id: databricks-meta-llama-3-1-70b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_DATABRICKS:=__disabled__}/meta-llama/Llama-3.1-70B-Instruct
provider_id: ${env.ENABLE_DATABRICKS:=__disabled__}
provider_model_id: databricks-meta-llama-3-1-70b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_DATABRICKS:=__disabled__}/databricks-meta-llama-3-1-405b-instruct
provider_id: ${env.ENABLE_DATABRICKS:=__disabled__}
provider_model_id: databricks-meta-llama-3-1-405b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_DATABRICKS:=__disabled__}/meta-llama/Llama-3.1-405B-Instruct-FP8
provider_id: ${env.ENABLE_DATABRICKS:=__disabled__}
provider_model_id: databricks-meta-llama-3-1-405b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama3-8b-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama3-8b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3-8B-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama3-8b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama3-70b-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama3-70b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3-70B-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama3-70b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.1-8b-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.1-8b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.1-8B-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.1-8b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.1-70b-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.1-70b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.1-70B-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.1-70b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.1-405b-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.1-405b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.1-405B-Instruct-FP8
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.1-405b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.2-1b-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.2-1b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.2-1B-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.2-1b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.2-3b-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.2-3b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.2-3B-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.2-3b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.2-11b-vision-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.2-11B-Vision-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.2-90b-vision-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.2-90b-vision-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.2-90B-Vision-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.2-90b-vision-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.3-70b-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.3-70b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.3-70B-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.3-70b-instruct
model_type: llm
- metadata:
embedding_dimension: 2048
context_length: 8192
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/nvidia/llama-3.2-nv-embedqa-1b-v2
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: nvidia/llama-3.2-nv-embedqa-1b-v2
model_type: embedding
- metadata:
embedding_dimension: 1024
context_length: 512
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/nvidia/nv-embedqa-e5-v5
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: nvidia/nv-embedqa-e5-v5
model_type: embedding
- metadata:
embedding_dimension: 4096
context_length: 512
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/nvidia/nv-embedqa-mistral-7b-v2
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: nvidia/nv-embedqa-mistral-7b-v2
model_type: embedding
- metadata:
embedding_dimension: 1024
context_length: 512
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/snowflake/arctic-embed-l
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: snowflake/arctic-embed-l
model_type: embedding
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-8B
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-8B
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-70B
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-70B
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B:bf16-mp8
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-405B:bf16-mp8
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-405B
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B:bf16-mp16
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-405B:bf16-mp16
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-8B-Instruct
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-8B-Instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-70B-Instruct
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-70B-Instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B-Instruct:bf16-mp8
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-405B-Instruct:bf16-mp8
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B-Instruct
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-405B-Instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B-Instruct:bf16-mp16
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-405B-Instruct:bf16-mp16
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.2-1B
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.2-1B
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.2-3B
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.2-3B
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_OPENAI:=__disabled__}/openai/gpt-4o
provider_id: ${env.ENABLE_OPENAI:=__disabled__}
@ -889,12 +1173,9 @@ models:
provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__}
provider_model_id: sambanova/Meta-Llama-Guard-3-8B
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers}
model_type: embedding
shields: []
shields:
- shield_id: ${env.SAFETY_MODEL:=__disabled__}
provider_shield_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=__disabled__}
vector_dbs: []
datasets: []
scoring_fns: []

View file

@ -31,6 +31,15 @@ from llama_stack.providers.registry.inference import available_providers
from llama_stack.providers.remote.inference.anthropic.models import (
MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.bedrock.models import (
MODEL_ENTRIES as BEDROCK_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.cerebras.models import (
MODEL_ENTRIES as CEREBRAS_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.databricks.databricks import (
MODEL_ENTRIES as DATABRICKS_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.fireworks.models import (
MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES,
)
@ -40,9 +49,15 @@ from llama_stack.providers.remote.inference.gemini.models import (
from llama_stack.providers.remote.inference.groq.models import (
MODEL_ENTRIES as GROQ_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.nvidia.models import (
MODEL_ENTRIES as NVIDIA_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.openai.models import (
MODEL_ENTRIES as OPENAI_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.runpod.runpod import (
MODEL_ENTRIES as RUNPOD_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.sambanova.models import (
MODEL_ENTRIES as SAMBANOVA_MODEL_ENTRIES,
)
@ -59,6 +74,7 @@ from llama_stack.templates.template import (
DistributionTemplate,
RunConfigSettings,
get_model_registry,
get_shield_registry,
)
@ -72,6 +88,11 @@ def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEnt
"gemini": GEMINI_MODEL_ENTRIES,
"groq": GROQ_MODEL_ENTRIES,
"sambanova": SAMBANOVA_MODEL_ENTRIES,
"cerebras": CEREBRAS_MODEL_ENTRIES,
"bedrock": BEDROCK_MODEL_ENTRIES,
"databricks": DATABRICKS_MODEL_ENTRIES,
"nvidia": NVIDIA_MODEL_ENTRIES,
"runpod": RUNPOD_MODEL_ENTRIES,
}
# Special handling for providers with dynamic model entries
@ -81,6 +102,10 @@ def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEnt
provider_model_id="${env.OLLAMA_INFERENCE_MODEL:=__disabled__}",
model_type=ModelType.llm,
),
ProviderModelEntry(
provider_model_id="${env.SAFETY_MODEL:=__disabled__}",
model_type=ModelType.llm,
),
ProviderModelEntry(
provider_model_id="${env.OLLAMA_EMBEDDING_MODEL:=__disabled__}",
model_type=ModelType.embedding,
@ -100,6 +125,20 @@ def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEnt
return model_entries_map.get(provider_type, [])
def _get_model_safety_entries_for_provider(provider_type: str) -> list[ProviderModelEntry]:
"""Get model entries for a specific provider type."""
safety_model_entries_map = {
"ollama": [
ProviderModelEntry(
provider_model_id="${env.SAFETY_MODEL:=__disabled__}",
model_type=ModelType.llm,
),
],
}
return safety_model_entries_map.get(provider_type, [])
def _get_config_for_provider(provider_spec: ProviderSpec) -> dict[str, Any]:
"""Get configuration for a provider using its adapter's config class."""
config_class = instantiate_class_type(provider_spec.config_class)
@ -155,6 +194,23 @@ def get_remote_inference_providers() -> tuple[list[Provider], dict[str, list[Pro
return inference_providers, available_models
# build a list of shields for all possible providers
def get_safety_models_for_providers(providers: list[Provider]) -> dict[str, list[ProviderModelEntry]]:
available_models = {}
for provider in providers:
provider_type = provider.provider_type.split("::")[1]
safety_model_entries = _get_model_safety_entries_for_provider(provider_type)
if len(safety_model_entries) == 0:
continue
env_var = f"ENABLE_{provider_type.upper().replace('-', '_').replace('::', '_')}"
provider_id = f"${{env.{env_var}:=__disabled__}}"
available_models[provider_id] = safety_model_entries
return available_models
def get_distribution_template() -> DistributionTemplate:
remote_inference_providers, available_models = get_remote_inference_providers()
@ -185,6 +241,7 @@ def get_distribution_template() -> DistributionTemplate:
provider_id="${env.ENABLE_PGVECTOR:=__disabled__}",
provider_type="remote::pgvector",
config=PGVectorVectorIOConfig.sample_run_config(
f"~/.llama/distributions/{name}",
db="${env.PGVECTOR_DB:=}",
user="${env.PGVECTOR_USER:=}",
password="${env.PGVECTOR_PASSWORD:=}",
@ -244,7 +301,10 @@ def get_distribution_template() -> DistributionTemplate:
},
)
default_models = get_model_registry(available_models)
default_models, ids_conflict_in_models = get_model_registry(available_models)
available_safety_models = get_safety_models_for_providers(remote_inference_providers)
shields = get_shield_registry(available_safety_models, ids_conflict_in_models)
return DistributionTemplate(
name=name,
@ -263,12 +323,10 @@ def get_distribution_template() -> DistributionTemplate:
"files": [files_provider],
"post_training": [post_training_provider],
},
default_models=default_models + [embedding_model],
default_models=[embedding_model] + default_models,
default_tool_groups=default_tool_groups,
# TODO: add a way to enable/disable shields on the fly
# default_shields=[
# ShieldInput(provider_id="llama-guard", shield_id="${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-8B}")
# ],
default_shields=shields,
),
},
run_config_env_vars={

View file

@ -37,7 +37,7 @@ from llama_stack.providers.utils.sqlstore.sqlstore import get_pip_packages as ge
def get_model_registry(
available_models: dict[str, list[ProviderModelEntry]],
) -> list[ModelInput]:
) -> tuple[list[ModelInput], bool]:
models = []
# check for conflicts in model ids
@ -74,7 +74,50 @@ def get_model_registry(
metadata=entry.metadata,
)
)
return models
return models, ids_conflict
def get_shield_registry(
available_safety_models: dict[str, list[ProviderModelEntry]],
ids_conflict_in_models: bool,
) -> list[ShieldInput]:
shields = []
# check for conflicts in shield ids
all_ids = set()
ids_conflict = False
for _, entries in available_safety_models.items():
for entry in entries:
ids = [entry.provider_model_id] + entry.aliases
for model_id in ids:
if model_id in all_ids:
ids_conflict = True
rich.print(
f"[yellow]Shield id {model_id} conflicts; all shield ids will be prefixed with provider id[/yellow]"
)
break
all_ids.update(ids)
if ids_conflict:
break
if ids_conflict:
break
for provider_id, entries in available_safety_models.items():
for entry in entries:
ids = [entry.provider_model_id] + entry.aliases
for model_id in ids:
identifier = f"{provider_id}/{model_id}" if ids_conflict and provider_id not in model_id else model_id
shields.append(
ShieldInput(
shield_id=identifier,
provider_shield_id=f"{provider_id}/{entry.provider_model_id}"
if ids_conflict_in_models
else entry.provider_model_id,
)
)
return shields
class DefaultModel(BaseModel):

View file

@ -69,7 +69,7 @@ def get_distribution_template() -> DistributionTemplate:
},
)
default_models = get_model_registry(available_models)
default_models, _ = get_model_registry(available_models)
return DistributionTemplate(
name="watsonx",
distro_type="remote_hosted",

View file

@ -0,0 +1,6 @@
import NextAuth from "next-auth";
import { authOptions } from "@/lib/auth";
const handler = NextAuth(authOptions);
export { handler as GET, handler as POST };

View file

@ -0,0 +1,118 @@
"use client";
import { signIn, signOut, useSession } from "next-auth/react";
import { Button } from "@/components/ui/button";
import {
Card,
CardContent,
CardDescription,
CardHeader,
CardTitle,
} from "@/components/ui/card";
import { Copy, Check, Home, Github } from "lucide-react";
import { useState } from "react";
import { useRouter } from "next/navigation";
export default function SignInPage() {
const { data: session, status } = useSession();
const [copied, setCopied] = useState(false);
const router = useRouter();
const handleCopyToken = async () => {
if (session?.accessToken) {
await navigator.clipboard.writeText(session.accessToken);
setCopied(true);
setTimeout(() => setCopied(false), 2000);
}
};
if (status === "loading") {
return (
<div className="flex items-center justify-center min-h-screen">
<div className="text-muted-foreground">Loading...</div>
</div>
);
}
return (
<div className="flex items-center justify-center min-h-screen">
<Card className="w-[400px]">
<CardHeader>
<CardTitle>Authentication</CardTitle>
<CardDescription>
{session
? "You are successfully authenticated!"
: "Sign in with GitHub to use your access token as an API key"}
</CardDescription>
</CardHeader>
<CardContent className="space-y-4">
{!session ? (
<Button
onClick={() => {
console.log("Signing in with GitHub...");
signIn("github", { callbackUrl: "/auth/signin" }).catch(
(error) => {
console.error("Sign in error:", error);
},
);
}}
className="w-full"
variant="default"
>
<Github className="mr-2 h-4 w-4" />
Sign in with GitHub
</Button>
) : (
<div className="space-y-4">
<div className="text-sm text-muted-foreground">
Signed in as {session.user?.email}
</div>
{session.accessToken && (
<div className="space-y-2">
<div className="text-sm font-medium">
GitHub Access Token:
</div>
<div className="flex gap-2">
<code className="flex-1 p-2 bg-muted rounded text-xs break-all">
{session.accessToken}
</code>
<Button
size="sm"
variant="outline"
onClick={handleCopyToken}
>
{copied ? (
<Check className="h-4 w-4" />
) : (
<Copy className="h-4 w-4" />
)}
</Button>
</div>
<div className="text-xs text-muted-foreground">
This GitHub token will be used as your API key for
authenticated Llama Stack requests.
</div>
</div>
)}
<div className="flex gap-2">
<Button onClick={() => router.push("/")} className="flex-1">
<Home className="mr-2 h-4 w-4" />
Go to Dashboard
</Button>
<Button
onClick={() => signOut()}
variant="outline"
className="flex-1"
>
Sign out
</Button>
</div>
</div>
)}
</CardContent>
</Card>
</div>
);
}

View file

@ -1,5 +1,6 @@
import type { Metadata } from "next";
import { ThemeProvider } from "@/components/ui/theme-provider";
import { SessionProvider } from "@/components/providers/session-provider";
import { Geist, Geist_Mono } from "next/font/google";
import { ModeToggle } from "@/components/ui/mode-toggle";
import "./globals.css";
@ -21,34 +22,38 @@ export const metadata: Metadata = {
import { SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar";
import { AppSidebar } from "@/components/layout/app-sidebar";
import { SignInButton } from "@/components/ui/sign-in-button";
export default function Layout({ children }: { children: React.ReactNode }) {
return (
<html lang="en" suppressHydrationWarning>
<body className={`${geistSans.variable} ${geistMono.variable} font-sans`}>
<ThemeProvider
attribute="class"
defaultTheme="system"
enableSystem
disableTransitionOnChange
>
<SidebarProvider>
<AppSidebar />
<main className="flex flex-col flex-1">
{/* Header with aligned elements */}
<div className="flex items-center p-4 border-b">
<div className="flex-none">
<SidebarTrigger />
<SessionProvider>
<ThemeProvider
attribute="class"
defaultTheme="system"
enableSystem
disableTransitionOnChange
>
<SidebarProvider>
<AppSidebar />
<main className="flex flex-col flex-1">
{/* Header with aligned elements */}
<div className="flex items-center p-4 border-b">
<div className="flex-none">
<SidebarTrigger />
</div>
<div className="flex-1 text-center"></div>
<div className="flex-none flex items-center gap-2">
<SignInButton />
<ModeToggle />
</div>
</div>
<div className="flex-1 text-center"></div>
<div className="flex-none">
<ModeToggle />
</div>
</div>
<div className="flex flex-col flex-1 p-4">{children}</div>
</main>
</SidebarProvider>
</ThemeProvider>
<div className="flex flex-col flex-1 p-4">{children}</div>
</main>
</SidebarProvider>
</ThemeProvider>
</SessionProvider>
</body>
</html>
);

View file

@ -4,11 +4,12 @@ import { useEffect, useState } from "react";
import { useParams } from "next/navigation";
import { ChatCompletion } from "@/lib/types";
import { ChatCompletionDetailView } from "@/components/chat-completions/chat-completion-detail";
import { client } from "@/lib/client";
import { useAuthClient } from "@/hooks/use-auth-client";
export default function ChatCompletionDetailPage() {
const params = useParams();
const id = params.id as string;
const client = useAuthClient();
const [completionDetail, setCompletionDetail] =
useState<ChatCompletion | null>(null);
@ -45,7 +46,7 @@ export default function ChatCompletionDetailPage() {
};
fetchCompletionDetail();
}, [id]);
}, [id, client]);
return (
<ChatCompletionDetailView

View file

@ -5,11 +5,12 @@ import { useParams } from "next/navigation";
import type { ResponseObject } from "llama-stack-client/resources/responses/responses";
import { OpenAIResponse, InputItemListResponse } from "@/lib/types";
import { ResponseDetailView } from "@/components/responses/responses-detail";
import { client } from "@/lib/client";
import { useAuthClient } from "@/hooks/use-auth-client";
export default function ResponseDetailPage() {
const params = useParams();
const id = params.id as string;
const client = useAuthClient();
const [responseDetail, setResponseDetail] = useState<OpenAIResponse | null>(
null,
@ -109,7 +110,7 @@ export default function ResponseDetailPage() {
};
fetchResponseDetail();
}, [id]);
}, [id, client]);
return (
<ResponseDetailView

View file

@ -0,0 +1,82 @@
"use client";
import { useEffect, useState } from "react";
import { useParams, useRouter } from "next/navigation";
import { useAuthClient } from "@/hooks/use-auth-client";
import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores";
import type { VectorStoreFile } from "llama-stack-client/resources/vector-stores/files";
import { VectorStoreDetailView } from "@/components/vector-stores/vector-store-detail";
export default function VectorStoreDetailPage() {
const params = useParams();
const id = params.id as string;
const client = useAuthClient();
const router = useRouter();
const [store, setStore] = useState<VectorStore | null>(null);
const [files, setFiles] = useState<VectorStoreFile[]>([]);
const [isLoadingStore, setIsLoadingStore] = useState(true);
const [isLoadingFiles, setIsLoadingFiles] = useState(true);
const [errorStore, setErrorStore] = useState<Error | null>(null);
const [errorFiles, setErrorFiles] = useState<Error | null>(null);
useEffect(() => {
if (!id) {
setErrorStore(new Error("Vector Store ID is missing."));
setIsLoadingStore(false);
return;
}
const fetchStore = async () => {
setIsLoadingStore(true);
setErrorStore(null);
try {
const response = await client.vectorStores.retrieve(id);
setStore(response as VectorStore);
} catch (err) {
setErrorStore(
err instanceof Error
? err
: new Error("Failed to load vector store."),
);
} finally {
setIsLoadingStore(false);
}
};
fetchStore();
}, [id, client]);
useEffect(() => {
if (!id) {
setErrorFiles(new Error("Vector Store ID is missing."));
setIsLoadingFiles(false);
return;
}
const fetchFiles = async () => {
setIsLoadingFiles(true);
setErrorFiles(null);
try {
const result = await client.vectorStores.files.list(id as any);
setFiles((result as any).data);
} catch (err) {
setErrorFiles(
err instanceof Error ? err : new Error("Failed to load files."),
);
} finally {
setIsLoadingFiles(false);
}
};
fetchFiles();
}, [id]);
return (
<VectorStoreDetailView
store={store}
files={files}
isLoadingStore={isLoadingStore}
isLoadingFiles={isLoadingFiles}
errorStore={errorStore}
errorFiles={errorFiles}
id={id}
/>
);
}

View file

@ -0,0 +1,16 @@
"use client";
import React from "react";
import LogsLayout from "@/components/layout/logs-layout";
export default function VectorStoresLayout({
children,
}: {
children: React.ReactNode;
}) {
return (
<LogsLayout sectionLabel="Vector Stores" basePath="/logs/vector-stores">
{children}
</LogsLayout>
);
}

View file

@ -0,0 +1,121 @@
"use client";
import React from "react";
import { useAuthClient } from "@/hooks/use-auth-client";
import type {
ListVectorStoresResponse,
VectorStore,
} from "llama-stack-client/resources/vector-stores/vector-stores";
import { useRouter } from "next/navigation";
import { usePagination } from "@/hooks/use-pagination";
import {
Table,
TableBody,
TableCaption,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/ui/table";
import { Skeleton } from "@/components/ui/skeleton";
export default function VectorStoresPage() {
const client = useAuthClient();
const router = useRouter();
const {
data: stores,
status,
hasMore,
error,
loadMore,
} = usePagination<VectorStore>({
limit: 20,
order: "desc",
fetchFunction: async (client, params) => {
const response = await client.vectorStores.list({
after: params.after,
limit: params.limit,
order: params.order,
} as any);
return response as ListVectorStoresResponse;
},
errorMessagePrefix: "vector stores",
});
// Auto-load all pages for infinite scroll behavior (like Responses)
React.useEffect(() => {
if (status === "idle" && hasMore) {
loadMore();
}
}, [status, hasMore, loadMore]);
if (status === "loading") {
return (
<div className="space-y-2">
<Skeleton className="h-8 w-full" />
<Skeleton className="h-4 w-full" />
<Skeleton className="h-4 w-full" />
</div>
);
}
if (status === "error") {
return <div className="text-destructive">Error: {error?.message}</div>;
}
if (!stores || stores.length === 0) {
return <p>No vector stores found.</p>;
}
return (
<div className="overflow-auto flex-1 min-h-0">
<Table>
<TableHeader>
<TableRow>
<TableHead>ID</TableHead>
<TableHead>Name</TableHead>
<TableHead>Created</TableHead>
<TableHead>Completed</TableHead>
<TableHead>Cancelled</TableHead>
<TableHead>Failed</TableHead>
<TableHead>In Progress</TableHead>
<TableHead>Total</TableHead>
<TableHead>Usage Bytes</TableHead>
<TableHead>Provider ID</TableHead>
<TableHead>Provider Vector DB ID</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{stores.map((store) => {
const fileCounts = store.file_counts;
const metadata = store.metadata || {};
const providerId = metadata.provider_id ?? "";
const providerDbId = metadata.provider_vector_db_id ?? "";
return (
<TableRow
key={store.id}
onClick={() => router.push(`/logs/vector-stores/${store.id}`)}
className="cursor-pointer hover:bg-muted/50"
>
<TableCell>{store.id}</TableCell>
<TableCell>{store.name}</TableCell>
<TableCell>
{new Date(store.created_at * 1000).toLocaleString()}
</TableCell>
<TableCell>{fileCounts.completed}</TableCell>
<TableCell>{fileCounts.cancelled}</TableCell>
<TableCell>{fileCounts.failed}</TableCell>
<TableCell>{fileCounts.in_progress}</TableCell>
<TableCell>{fileCounts.total}</TableCell>
<TableCell>{store.usage_bytes}</TableCell>
<TableCell>{providerId}</TableCell>
<TableCell>{providerDbId}</TableCell>
</TableRow>
);
})}
</TableBody>
</Table>
</div>
);
}

View file

@ -12,24 +12,34 @@ jest.mock("next/navigation", () => ({
}),
}));
// Mock next-auth
jest.mock("next-auth/react", () => ({
useSession: () => ({
status: "authenticated",
data: { accessToken: "mock-token" },
}),
}));
// Mock helper functions
jest.mock("@/lib/truncate-text");
jest.mock("@/lib/format-message-content");
// Mock the client
jest.mock("@/lib/client", () => ({
client: {
chat: {
completions: {
list: jest.fn(),
},
// Mock the auth client hook
const mockClient = {
chat: {
completions: {
list: jest.fn(),
},
},
};
jest.mock("@/hooks/use-auth-client", () => ({
useAuthClient: () => mockClient,
}));
// Mock the usePagination hook
const mockLoadMore = jest.fn();
jest.mock("@/hooks/usePagination", () => ({
jest.mock("@/hooks/use-pagination", () => ({
usePagination: jest.fn(() => ({
data: [],
status: "idle",
@ -47,7 +57,7 @@ import {
} from "@/lib/format-message-content";
// Import the mocked hook
import { usePagination } from "@/hooks/usePagination";
import { usePagination } from "@/hooks/use-pagination";
const mockedUsePagination = usePagination as jest.MockedFunction<
typeof usePagination
>;

View file

@ -10,8 +10,7 @@ import {
extractTextFromContentPart,
extractDisplayableText,
} from "@/lib/format-message-content";
import { usePagination } from "@/hooks/usePagination";
import { client } from "@/lib/client";
import { usePagination } from "@/hooks/use-pagination";
interface ChatCompletionsTableProps {
/** Optional pagination configuration */
@ -32,12 +31,15 @@ function formatChatCompletionToRow(completion: ChatCompletion): LogTableRow {
export function ChatCompletionsTable({
paginationOptions,
}: ChatCompletionsTableProps) {
const fetchFunction = async (params: {
after?: string;
limit: number;
model?: string;
order?: string;
}) => {
const fetchFunction = async (
client: ReturnType<typeof import("@/hooks/use-auth-client").useAuthClient>,
params: {
after?: string;
limit: number;
model?: string;
order?: string;
},
) => {
const response = await client.chat.completions.list({
after: params.after,
limit: params.limit,

View file

@ -1,6 +1,11 @@
"use client";
import { MessageSquareText, MessagesSquare, MoveUpRight } from "lucide-react";
import {
MessageSquareText,
MessagesSquare,
MoveUpRight,
Database,
} from "lucide-react";
import Link from "next/link";
import { usePathname } from "next/navigation";
import { cn } from "@/lib/utils";
@ -28,6 +33,11 @@ const logItems = [
url: "/logs/responses",
icon: MessagesSquare,
},
{
title: "Vector Stores",
url: "/logs/vector-stores",
icon: Database,
},
{
title: "Documentation",
url: "https://llama-stack.readthedocs.io/en/latest/references/api_reference/index.html",
@ -57,13 +67,13 @@ export function AppSidebar() {
className={cn(
"justify-start",
isActive &&
"bg-gray-200 hover:bg-gray-200 text-primary hover:text-primary",
"bg-gray-200 dark:bg-gray-700 hover:bg-gray-200 dark:hover:bg-gray-700 text-gray-900 dark:text-gray-100",
)}
>
<Link href={item.url}>
<item.icon
className={cn(
isActive && "text-primary",
isActive && "text-gray-900 dark:text-gray-100",
"mr-2 h-4 w-4",
)}
/>

View file

@ -93,7 +93,9 @@ export function PropertyItem({
>
<strong>{label}:</strong>{" "}
{typeof value === "string" || typeof value === "number" ? (
<span className="text-gray-900 font-medium">{value}</span>
<span className="text-gray-900 dark:text-gray-100 font-medium">
{value}
</span>
) : (
value
)}
@ -112,7 +114,9 @@ export function PropertiesCard({ children }: PropertiesCardProps) {
<CardTitle>Properties</CardTitle>
</CardHeader>
<CardContent>
<ul className="space-y-2 text-sm text-gray-600">{children}</ul>
<ul className="space-y-2 text-sm text-gray-600 dark:text-gray-400">
{children}
</ul>
</CardContent>
</Card>
);

View file

@ -12,7 +12,7 @@ jest.mock("next/navigation", () => ({
}));
// Mock the useInfiniteScroll hook
jest.mock("@/hooks/useInfiniteScroll", () => ({
jest.mock("@/hooks/use-infinite-scroll", () => ({
useInfiniteScroll: jest.fn((onLoadMore, options) => {
const ref = React.useRef(null);

View file

@ -4,7 +4,7 @@ import { useRouter } from "next/navigation";
import { useRef } from "react";
import { truncateText } from "@/lib/truncate-text";
import { PaginationStatus } from "@/lib/types";
import { useInfiniteScroll } from "@/hooks/useInfiniteScroll";
import { useInfiniteScroll } from "@/hooks/use-infinite-scroll";
import {
Table,
TableBody,

View file

@ -0,0 +1,7 @@
"use client";
import { SessionProvider as NextAuthSessionProvider } from "next-auth/react";
export function SessionProvider({ children }: { children: React.ReactNode }) {
return <NextAuthSessionProvider>{children}</NextAuthSessionProvider>;
}

View file

@ -12,21 +12,31 @@ jest.mock("next/navigation", () => ({
}),
}));
// Mock next-auth
jest.mock("next-auth/react", () => ({
useSession: () => ({
status: "authenticated",
data: { accessToken: "mock-token" },
}),
}));
// Mock helper functions
jest.mock("@/lib/truncate-text");
// Mock the client
jest.mock("@/lib/client", () => ({
client: {
responses: {
list: jest.fn(),
},
// Mock the auth client hook
const mockClient = {
responses: {
list: jest.fn(),
},
};
jest.mock("@/hooks/use-auth-client", () => ({
useAuthClient: () => mockClient,
}));
// Mock the usePagination hook
const mockLoadMore = jest.fn();
jest.mock("@/hooks/usePagination", () => ({
jest.mock("@/hooks/use-pagination", () => ({
usePagination: jest.fn(() => ({
data: [],
status: "idle",
@ -40,7 +50,7 @@ jest.mock("@/hooks/usePagination", () => ({
import { truncateText as originalTruncateText } from "@/lib/truncate-text";
// Import the mocked hook
import { usePagination } from "@/hooks/usePagination";
import { usePagination } from "@/hooks/use-pagination";
const mockedUsePagination = usePagination as jest.MockedFunction<
typeof usePagination
>;

View file

@ -6,8 +6,7 @@ import {
UsePaginationOptions,
} from "@/lib/types";
import { LogsTable, LogTableRow } from "@/components/logs/logs-table";
import { usePagination } from "@/hooks/usePagination";
import { client } from "@/lib/client";
import { usePagination } from "@/hooks/use-pagination";
import type { ResponseListResponse } from "llama-stack-client/resources/responses/responses";
import {
isMessageInput,
@ -125,12 +124,15 @@ function formatResponseToRow(response: OpenAIResponse): LogTableRow {
}
export function ResponsesTable({ paginationOptions }: ResponsesTableProps) {
const fetchFunction = async (params: {
after?: string;
limit: number;
model?: string;
order?: string;
}) => {
const fetchFunction = async (
client: ReturnType<typeof import("@/hooks/use-auth-client").useAuthClient>,
params: {
after?: string;
limit: number;
model?: string;
order?: string;
},
) => {
const response = await client.responses.list({
after: params.after,
limit: params.limit,

View file

@ -17,10 +17,10 @@ export const MessageBlock: React.FC<MessageBlockProps> = ({
}) => {
return (
<div className={`mb-4 ${className}`}>
<p className="py-1 font-semibold text-gray-800 mb-1">
<p className="py-1 font-semibold text-muted-foreground mb-1">
{label}
{labelDetail && (
<span className="text-xs text-gray-500 font-normal ml-1">
<span className="text-xs text-muted-foreground font-normal ml-1">
{labelDetail}
</span>
)}

View file

@ -0,0 +1,25 @@
"use client";
import { User } from "lucide-react";
import Link from "next/link";
import { useSession } from "next-auth/react";
import { Button } from "./button";
export function SignInButton() {
const { data: session, status } = useSession();
return (
<Button variant="ghost" size="sm" asChild>
<Link href="/auth/signin" className="flex items-center">
<User className="mr-2 h-4 w-4" />
<span>
{status === "loading"
? "Loading..."
: session
? session.user?.email || "Signed In"
: "Sign In"}
</span>
</Link>
</Button>
);
}

View file

@ -0,0 +1,128 @@
"use client";
import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores";
import type { VectorStoreFile } from "llama-stack-client/resources/vector-stores/files";
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
import { Skeleton } from "@/components/ui/skeleton";
import {
DetailLoadingView,
DetailErrorView,
DetailNotFoundView,
DetailLayout,
PropertiesCard,
PropertyItem,
} from "@/components/layout/detail-layout";
import {
Table,
TableBody,
TableCaption,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/ui/table";
interface VectorStoreDetailViewProps {
store: VectorStore | null;
files: VectorStoreFile[];
isLoadingStore: boolean;
isLoadingFiles: boolean;
errorStore: Error | null;
errorFiles: Error | null;
id: string;
}
export function VectorStoreDetailView({
store,
files,
isLoadingStore,
isLoadingFiles,
errorStore,
errorFiles,
id,
}: VectorStoreDetailViewProps) {
const title = "Vector Store Details";
if (errorStore) {
return <DetailErrorView title={title} id={id} error={errorStore} />;
}
if (isLoadingStore) {
return <DetailLoadingView title={title} />;
}
if (!store) {
return <DetailNotFoundView title={title} id={id} />;
}
const mainContent = (
<>
<Card>
<CardHeader>
<CardTitle>Files</CardTitle>
</CardHeader>
<CardContent>
{isLoadingFiles ? (
<Skeleton className="h-4 w-full" />
) : errorFiles ? (
<div className="text-destructive text-sm">
Error loading files: {errorFiles.message}
</div>
) : files.length > 0 ? (
<Table>
<TableCaption>Files in this vector store</TableCaption>
<TableHeader>
<TableRow>
<TableHead>ID</TableHead>
<TableHead>Status</TableHead>
<TableHead>Created</TableHead>
<TableHead>Usage Bytes</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{files.map((file) => (
<TableRow key={file.id}>
<TableCell>{file.id}</TableCell>
<TableCell>{file.status}</TableCell>
<TableCell>
{new Date(file.created_at * 1000).toLocaleString()}
</TableCell>
<TableCell>{file.usage_bytes}</TableCell>
</TableRow>
))}
</TableBody>
</Table>
) : (
<p className="text-gray-500 italic text-sm">
No files in this vector store.
</p>
)}
</CardContent>
</Card>
</>
);
const sidebar = (
<PropertiesCard>
<PropertyItem label="ID" value={store.id} />
<PropertyItem label="Name" value={store.name || ""} />
<PropertyItem
label="Created"
value={new Date(store.created_at * 1000).toLocaleString()}
/>
<PropertyItem label="Status" value={store.status} />
<PropertyItem label="Total Files" value={store.file_counts.total} />
<PropertyItem label="Usage Bytes" value={store.usage_bytes} />
<PropertyItem
label="Provider ID"
value={(store.metadata.provider_id as string) || ""}
/>
<PropertyItem
label="Provider DB ID"
value={(store.metadata.provider_vector_db_id as string) || ""}
/>
</PropertiesCard>
);
return (
<DetailLayout title={title} mainContent={mainContent} sidebar={sidebar} />
);
}

View file

@ -0,0 +1,24 @@
import { useSession } from "next-auth/react";
import { useMemo } from "react";
import LlamaStackClient from "llama-stack-client";
export function useAuthClient() {
const { data: session } = useSession();
const client = useMemo(() => {
const clientHostname =
typeof window !== "undefined" ? window.location.origin : "";
const options: any = {
baseURL: `${clientHostname}/api`,
};
if (session?.accessToken) {
options.apiKey = session.accessToken;
}
return new LlamaStackClient(options);
}, [session?.accessToken]);
return client;
}

Some files were not shown because too many files have changed in this diff Show more