Merge remote-tracking branch 'upstream/main' into add-file-processor-skeleton

This commit is contained in:
Alina Ryan 2025-11-25 14:37:56 -05:00
commit 479e627b0d
645 changed files with 96136 additions and 35829 deletions

View file

@ -1,9 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .batches import Batches, BatchObject, ListBatchesResponse
__all__ = ["Batches", "BatchObject", "ListBatchesResponse"]

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .benchmarks import *

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,27 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .conversations import (
Conversation,
ConversationDeletedResource,
ConversationItem,
ConversationItemCreateRequest,
ConversationItemDeletedResource,
ConversationItemList,
Conversations,
Metadata,
)
__all__ = [
"Conversation",
"ConversationDeletedResource",
"ConversationItem",
"ConversationItemCreateRequest",
"ConversationItemDeletedResource",
"ConversationItemList",
"Conversations",
"Metadata",
]

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .datasets import *

View file

@ -1,159 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum, EnumMeta
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
class DynamicApiMeta(EnumMeta):
def __new__(cls, name, bases, namespace):
# Store the original enum values
original_values = {k: v for k, v in namespace.items() if not k.startswith("_")}
# Create the enum class
cls = super().__new__(cls, name, bases, namespace)
# Store the original values for reference
cls._original_values = original_values
# Initialize _dynamic_values
cls._dynamic_values = {}
return cls
def __call__(cls, value):
try:
return super().__call__(value)
except ValueError as e:
# If this value was already dynamically added, return it
if value in cls._dynamic_values:
return cls._dynamic_values[value]
# If the value doesn't exist, create a new enum member
# Create a new member name from the value
member_name = value.lower().replace("-", "_")
# If this member name already exists in the enum, return the existing member
if member_name in cls._member_map_:
return cls._member_map_[member_name]
# Instead of creating a new member, raise ValueError to force users to use Api.add() to
# register new APIs explicitly
raise ValueError(f"API '{value}' does not exist. Use Api.add() to register new APIs.") from e
def __iter__(cls):
# Allow iteration over both static and dynamic members
yield from super().__iter__()
if hasattr(cls, "_dynamic_values"):
yield from cls._dynamic_values.values()
def add(cls, value):
"""
Add a new API to the enum.
Used to register external APIs.
"""
member_name = value.lower().replace("-", "_")
# If this member name already exists in the enum, return it
if member_name in cls._member_map_:
return cls._member_map_[member_name]
# Create a new enum member
member = object.__new__(cls)
member._name_ = member_name
member._value_ = value
# Add it to the enum class
cls._member_map_[member_name] = member
cls._member_names_.append(member_name)
cls._member_type_ = str
# Store it in our dynamic values
cls._dynamic_values[value] = member
return member
@json_schema_type
class Api(Enum, metaclass=DynamicApiMeta):
"""Enumeration of all available APIs in the Llama Stack system.
:cvar providers: Provider management and configuration
:cvar inference: Text generation, chat completions, and embeddings
:cvar safety: Content moderation and safety shields
:cvar agents: Agent orchestration and execution
:cvar batches: Batch processing for asynchronous API requests
:cvar vector_io: Vector database operations and queries
:cvar datasetio: Dataset input/output operations
:cvar scoring: Model output evaluation and scoring
:cvar eval: Model evaluation and benchmarking framework
:cvar post_training: Fine-tuning and model training
:cvar tool_runtime: Tool execution and management
:cvar telemetry: Observability and system monitoring
:cvar models: Model metadata and management
:cvar shields: Safety shield implementations
:cvar datasets: Dataset creation and management
:cvar scoring_functions: Scoring function definitions
:cvar benchmarks: Benchmark suite management
:cvar tool_groups: Tool group organization
:cvar files: File storage and management
:cvar prompts: Prompt versions and management
:cvar inspect: Built-in system inspection and introspection
"""
providers = "providers"
inference = "inference"
safety = "safety"
agents = "agents"
batches = "batches"
vector_io = "vector_io"
datasetio = "datasetio"
scoring = "scoring"
eval = "eval"
post_training = "post_training"
tool_runtime = "tool_runtime"
models = "models"
shields = "shields"
vector_stores = "vector_stores" # only used for routing table
datasets = "datasets"
scoring_functions = "scoring_functions"
benchmarks = "benchmarks"
tool_groups = "tool_groups"
files = "files"
prompts = "prompts"
conversations = "conversations"
file_processor = "file_processor"
# built-in API
inspect = "inspect"
@json_schema_type
class Error(BaseModel):
"""
Error response from the API. Roughly follows RFC 7807.
:param status: HTTP status code
:param title: Error title, a short summary of the error which is invariant for an error type
:param detail: Error detail, a longer human-readable description of the error
:param instance: (Optional) A URL which can be used to retrieve more information about the specific occurrence of the error
"""
status: int
title: str
detail: str
instance: str | None = None
class ExternalApiSpec(BaseModel):
"""Specification for an external API implementation."""
module: str = Field(..., description="Python module containing the API implementation")
name: str = Field(..., description="Name of the API")
pip_packages: list[str] = Field(default=[], description="List of pip packages to install the API")
protocol: str = Field(..., description="Name of the protocol class for the API")

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .eval import *

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .file_processor import *

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .files import *

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .inference import *

View file

@ -1,43 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from termcolor import cprint
from llama_stack.apis.inference import (
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
)
class LogEvent:
def __init__(
self,
content: str = "",
end: str = "\n",
color="white",
):
self.content = content
self.color = color
self.end = "\n" if end is None else end
def print(self, flush=True):
cprint(f"{self.content}", color=self.color, end=self.end, flush=flush)
class EventLogger:
async def log(self, event_generator):
async for chunk in event_generator:
if isinstance(chunk, ChatCompletionResponseStreamChunk):
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.start:
yield LogEvent("Assistant> ", color="cyan", end="")
elif event.event_type == ChatCompletionResponseEventType.progress:
yield LogEvent(event.delta, color="yellow", end="")
elif event.event_type == ChatCompletionResponseEventType.complete:
yield LogEvent("")
else:
yield LogEvent("Assistant> ", color="cyan", end="")
yield LogEvent(chunk.completion_message.content, color="yellow")

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .inspect import *

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .models import *

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .post_training import *

View file

@ -1,9 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .prompts import ListPromptsResponse, Prompt, Prompts
__all__ = ["Prompt", "Prompts", "ListPromptsResponse"]

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .providers import *

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .safety import *

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .scoring import *

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .scoring_functions import *

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .shields import *

View file

@ -1,8 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .rag_tool import *
from .tools import *

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .vector_io import *

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .vector_stores import *

View file

@ -21,7 +21,7 @@ from llama_stack.core.datatypes import (
from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.stack import replace_env_vars
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
from llama_stack_api import Api
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"

View file

@ -32,7 +32,7 @@ from llama_stack.core.storage.datatypes import (
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR, EXTERNAL_PROVIDERS_DIR
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.core.utils.image_types import LlamaStackImageType
from llama_stack.providers.datatypes import Api
from llama_stack_api import Api
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "distributions"

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import importlib.resources
import sys
from pydantic import BaseModel
@ -12,12 +11,9 @@ from termcolor import cprint
from llama_stack.core.datatypes import BuildConfig
from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.external import load_external_apis
from llama_stack.core.utils.exec import run_command
from llama_stack.core.utils.image_types import LlamaStackImageType
from llama_stack.distributions.template import DistributionTemplate
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
from llama_stack_api import Api
log = get_logger(name=__name__, category="core")
@ -101,64 +97,3 @@ def print_pip_install_help(config: BuildConfig):
for special_dep in special_deps:
cprint(f"uv pip install {special_dep}", color="yellow", file=sys.stderr)
print()
def build_image(
build_config: BuildConfig,
image_name: str,
distro_or_config: str,
run_config: str | None = None,
):
container_base = build_config.distribution_spec.container_image or "python:3.12-slim"
normal_deps, special_deps, external_provider_deps = get_provider_dependencies(build_config)
normal_deps += SERVER_DEPENDENCIES
if build_config.external_apis_dir:
external_apis = load_external_apis(build_config)
if external_apis:
for _, api_spec in external_apis.items():
normal_deps.extend(api_spec.pip_packages)
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
script = str(importlib.resources.files("llama_stack") / "core/build_container.sh")
args = [
script,
"--distro-or-config",
distro_or_config,
"--image-name",
image_name,
"--container-base",
container_base,
"--normal-deps",
" ".join(normal_deps),
]
# When building from a config file (not a template), include the run config path in the
# build arguments
if run_config is not None:
args.extend(["--run-config", run_config])
else:
script = str(importlib.resources.files("llama_stack") / "core/build_venv.sh")
args = [
script,
"--env-name",
str(image_name),
"--normal-deps",
" ".join(normal_deps),
]
# Always pass both arguments, even if empty, to maintain consistent positional arguments
if special_deps:
args.extend(["--optional-deps", "#".join(special_deps)])
if external_provider_deps:
args.extend(
["--external-provider-deps", "#".join(external_provider_deps)]
) # the script will install external provider module, get its deps, and install those too.
return_code = run_command(args)
if return_code != 0:
log.error(
f"Failed to build target {image_name} with return code {return_code}",
)
return return_code

View file

@ -15,7 +15,7 @@ import httpx
from pydantic import BaseModel, parse_obj_as
from termcolor import cprint
from llama_stack.providers.datatypes import RemoteProviderConfig
from llama_stack_api import RemoteProviderConfig
_CLIENT_CLASSES = {}

View file

@ -20,7 +20,7 @@ from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.core.utils.prompt_for_config import prompt_for_config
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack_api import Api, ProviderSpec
logger = get_logger(name=__name__, category="core")

View file

@ -10,7 +10,11 @@ from typing import Any, Literal
from pydantic import BaseModel, TypeAdapter
from llama_stack.apis.conversations.conversations import (
from llama_stack.core.datatypes import AccessRule, StackRunConfig
from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.core.storage.sqlstore.sqlstore import sqlstore_impl
from llama_stack.log import get_logger
from llama_stack_api import (
Conversation,
ConversationDeletedResource,
ConversationItem,
@ -20,11 +24,7 @@ from llama_stack.apis.conversations.conversations import (
Conversations,
Metadata,
)
from llama_stack.core.datatypes import AccessRule, StackRunConfig
from llama_stack.log import get_logger
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType
logger = get_logger(name=__name__, category="openai_conversations")
@ -203,16 +203,11 @@ class ConversationServiceImpl(Conversations):
"item_data": item_dict,
}
# TODO: Add support for upsert in sql_store, this will fail first if ID exists and then update
try:
await self.sql_store.insert(table="conversation_items", data=item_record)
except Exception:
# If insert fails due to ID conflict, update existing record
await self.sql_store.update(
table="conversation_items",
data={"created_at": created_at, "item_data": item_dict},
where={"id": item_id},
)
await self.sql_store.upsert(
table="conversation_items",
data=item_record,
conflict_columns=["id"],
)
created_items.append(item_dict)

View file

@ -11,20 +11,6 @@ from urllib.parse import urlparse
from pydantic import BaseModel, Field, field_validator, model_validator
from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Dataset, DatasetInput
from llama_stack.apis.eval import Eval
from llama_stack.apis.inference import Inference
from llama_stack.apis.models import Model, ModelInput
from llama_stack.apis.resource import Resource
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
from llama_stack.apis.shields import Shield, ShieldInput
from llama_stack.apis.tools import ToolGroup, ToolGroupInput, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.apis.vector_stores import VectorStore, VectorStoreInput
from llama_stack.core.access_control.datatypes import AccessRule
from llama_stack.core.storage.datatypes import (
KVStoreReference,
@ -32,7 +18,32 @@ from llama_stack.core.storage.datatypes import (
StorageConfig,
)
from llama_stack.log import LoggingConfig
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack_api import (
Api,
Benchmark,
BenchmarkInput,
Dataset,
DatasetInput,
DatasetIO,
Eval,
Inference,
Model,
ModelInput,
ProviderSpec,
Resource,
Safety,
Scoring,
ScoringFn,
ScoringFnInput,
Shield,
ShieldInput,
ToolGroup,
ToolGroupInput,
ToolRuntime,
VectorIO,
VectorStore,
VectorStoreInput,
)
LLAMA_STACK_BUILD_CONFIG_VERSION = 2
LLAMA_STACK_RUN_CONFIG_VERSION = 2

View file

@ -15,7 +15,7 @@ from pydantic import BaseModel
from llama_stack.core.datatypes import BuildConfig, DistributionSpec
from llama_stack.core.external import load_external_apis
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
from llama_stack_api import (
Api,
InlineProviderSpec,
ProviderSpec,

View file

@ -7,9 +7,9 @@
import yaml
from llama_stack.apis.datatypes import Api, ExternalApiSpec
from llama_stack.core.datatypes import BuildConfig, StackRunConfig
from llama_stack.log import get_logger
from llama_stack_api import Api, ExternalApiSpec
logger = get_logger(name=__name__, category="core")

View file

@ -8,18 +8,17 @@ from importlib.metadata import version
from pydantic import BaseModel
from llama_stack.apis.inspect import (
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.core.external import load_external_apis
from llama_stack.core.server.routes import get_all_api_routes
from llama_stack_api import (
HealthInfo,
HealthStatus,
Inspect,
ListRoutesResponse,
RouteInfo,
VersionInfo,
)
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.core.external import load_external_apis
from llama_stack.core.server.routes import get_all_api_routes
from llama_stack.providers.datatypes import HealthStatus
class DistributionInspectConfig(BaseModel):
@ -46,8 +45,8 @@ class DistributionInspectImpl(Inspect):
# Helper function to determine if a route should be included based on api_filter
def should_include_route(webmethod) -> bool:
if api_filter is None:
# Default: only non-deprecated v1 APIs
return not webmethod.deprecated and webmethod.level == LLAMA_STACK_API_V1
# Default: only non-deprecated APIs
return not webmethod.deprecated
elif api_filter == "deprecated":
# Special filter: show deprecated routes regardless of their actual level
return bool(webmethod.deprecated)

View file

@ -19,6 +19,8 @@ import httpx
import yaml
from fastapi import Response as FastAPIResponse
from llama_stack.core.utils.type_inspection import is_unwrapped_body_param
try:
from llama_stack_client import (
NOT_GIVEN,
@ -40,24 +42,16 @@ from termcolor import cprint
from llama_stack.core.build import print_pip_install_help
from llama_stack.core.configure import parse_and_maybe_upgrade_config
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
from llama_stack.core.request_headers import (
PROVIDER_DATA_VAR,
request_provider_data_context,
)
from llama_stack.core.request_headers import PROVIDER_DATA_VAR, request_provider_data_context
from llama_stack.core.resolver import ProviderRegistry
from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls
from llama_stack.core.stack import (
Stack,
get_stack_run_config_from_distro,
replace_env_vars,
)
from llama_stack.core.stack import Stack, get_stack_run_config_from_distro, replace_env_vars
from llama_stack.core.telemetry import Telemetry
from llama_stack.core.telemetry.tracing import CURRENT_TRACE_CONTEXT, end_trace, setup_logger, start_trace
from llama_stack.core.utils.config import redact_sensitive_fields
from llama_stack.core.utils.context import preserve_contexts_async_generator
from llama_stack.core.utils.exec import in_notebook
from llama_stack.log import get_logger, setup_logging
from llama_stack.strong_typing.inspection import is_unwrapped_body_param
logger = get_logger(name=__name__, category="core")
@ -389,6 +383,12 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
matched_func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
body |= path_params
# Pass through params that aren't already handled as path params
if options.params:
extra_query_params = {k: v for k, v in options.params.items() if k not in path_params}
if extra_query_params:
body["extra_query"] = extra_query_params
body, field_names = self._handle_file_uploads(options, body)
body = self._convert_body(matched_func, body, exclude_params=set(field_names))

View file

@ -9,9 +9,9 @@ from typing import Any
from pydantic import BaseModel
from llama_stack.apis.prompts import ListPromptsResponse, Prompt, Prompts
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
from llama_stack.core.storage.kvstore import KVStore, kvstore_impl
from llama_stack_api import ListPromptsResponse, Prompt, Prompts
class PromptServiceConfig(BaseModel):

View file

@ -9,9 +9,8 @@ from typing import Any
from pydantic import BaseModel
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus
from llama_stack_api import HealthResponse, HealthStatus, ListProvidersResponse, ProviderInfo, Providers
from .datatypes import StackRunConfig
from .utils.config import redact_sensitive_fields

View file

@ -8,30 +8,6 @@ import importlib.metadata
import inspect
from typing import Any
from llama_stack.apis.agents import Agents
from llama_stack.apis.batches import Batches
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.datatypes import ExternalApiSpec
from llama_stack.apis.eval import Eval
from llama_stack.apis.file_processor import FileProcessor
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference, InferenceProvider
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.prompts import Prompts
from llama_stack.apis.providers import Providers as ProvidersAPI
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
from llama_stack.core.client import get_client_impl
from llama_stack.core.datatypes import (
AccessRule,
@ -45,17 +21,45 @@ from llama_stack.core.external import load_external_apis
from llama_stack.core.store import DistributionRegistry
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
from llama_stack_api import (
LLAMA_STACK_API_V1ALPHA,
Agents,
Api,
Batches,
Benchmarks,
BenchmarksProtocolPrivate,
Conversations,
DatasetIO,
Datasets,
DatasetsProtocolPrivate,
Eval,
ExternalApiSpec,
FileProcessor,
Files,
Inference,
InferenceProvider,
Inspect,
Models,
ModelsProtocolPrivate,
PostTraining,
Prompts,
ProviderSpec,
RemoteProviderConfig,
RemoteProviderSpec,
Safety,
Scoring,
ScoringFunctions,
ScoringFunctionsProtocolPrivate,
Shields,
ShieldsProtocolPrivate,
ToolGroups,
ToolGroupsProtocolPrivate,
ToolRuntime,
VectorIO,
VectorStore,
)
from llama_stack_api import (
Providers as ProvidersAPI,
)
logger = get_logger(name=__name__, category="core")

View file

@ -12,8 +12,8 @@ from llama_stack.core.datatypes import (
)
from llama_stack.core.stack import StackRunConfig
from llama_stack.core.store import DistributionRegistry
from llama_stack.providers.datatypes import Api, RoutingTable
from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack_api import Api, RoutingTable
async def get_routing_table_impl(

View file

@ -6,11 +6,8 @@
from typing import Any
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
from llama_stack_api import DatasetIO, DatasetPurpose, DataSource, PaginatedResponse, RoutingTable
logger = get_logger(name=__name__, category="core::routers")

View file

@ -6,15 +6,18 @@
from typing import Any
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
from llama_stack.apis.scoring import (
from llama_stack.log import get_logger
from llama_stack_api import (
BenchmarkConfig,
Eval,
EvaluateResponse,
Job,
RoutingTable,
ScoreBatchResponse,
ScoreResponse,
Scoring,
ScoringFnParams,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core::routers")

View file

@ -15,13 +15,25 @@ from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatC
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
from pydantic import TypeAdapter
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
from llama_stack.apis.inference import (
from llama_stack.core.telemetry.telemetry import MetricEvent
from llama_stack.core.telemetry.tracing import enqueue_event, get_current_span
from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack_api import (
HealthResponse,
HealthStatus,
Inference,
ListOpenAIChatCompletionResponse,
ModelNotFoundError,
ModelType,
ModelTypeError,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
@ -35,19 +47,8 @@ from llama_stack.apis.inference import (
OpenAIMessageParam,
Order,
RerankResponse,
RoutingTable,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
)
from llama_stack.apis.models import ModelType
from llama_stack.core.telemetry.telemetry import MetricEvent
from llama_stack.core.telemetry.tracing import enqueue_event, get_current_span
from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
from llama_stack.providers.utils.inference.inference_store import InferenceStore
logger = get_logger(name=__name__, category="core::routers")
@ -416,7 +417,7 @@ class InferenceRouter(Inference):
prompt_tokens=chunk.usage.prompt_tokens,
completion_tokens=chunk.usage.completion_tokens,
total_tokens=chunk.usage.total_tokens,
model_id=fully_qualified_model_id,
fully_qualified_model_id=fully_qualified_model_id,
provider_id=provider_id,
)
for metric in metrics:

View file

@ -6,13 +6,9 @@
from typing import Any
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.safety.safety import ModerationObject
from llama_stack.apis.shields import Shield
from llama_stack.core.datatypes import SafetyConfig
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
from llama_stack_api import ModerationObject, OpenAIMessageParam, RoutingTable, RunShieldResponse, Safety, Shield
logger = get_logger(name=__name__, category="core::routers")
@ -52,7 +48,7 @@ class SafetyRouter(Safety):
async def run_shield(
self,
shield_id: str,
messages: list[Message],
messages: list[OpenAIMessageParam],
params: dict[str, Any] = None,
) -> RunShieldResponse:
logger.debug(f"SafetyRouter.run_shield: {shield_id}")

View file

@ -6,14 +6,12 @@
from typing import Any
from llama_stack.apis.common.content_types import (
from llama_stack.log import get_logger
from llama_stack_api import (
URL,
)
from llama_stack.apis.tools import (
ListToolDefsResponse,
ToolRuntime,
)
from llama_stack.log import get_logger
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
@ -36,16 +34,16 @@ class ToolRuntimeRouter(ToolRuntime):
logger.debug("ToolRuntimeRouter.shutdown")
pass
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None) -> Any:
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
provider = await self.routing_table.get_provider_impl(tool_name)
return await provider.invoke_tool(
tool_name=tool_name,
kwargs=kwargs,
authorization=authorization,
)
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None, authorization: str | None = None
) -> ListToolDefsResponse:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.list_tools(tool_group_id)
return await self.routing_table.list_tools(tool_group_id, authorization=authorization)

View file

@ -10,13 +10,20 @@ from typing import Annotated, Any
from fastapi import Body
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.models import ModelType
from llama_stack.apis.vector_io import (
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.log import get_logger
from llama_stack_api import (
Chunk,
HealthResponse,
HealthStatus,
InterleavedContent,
ModelNotFoundError,
ModelType,
ModelTypeError,
OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
OpenAICreateVectorStoreRequestWithExtraBody,
QueryChunksResponse,
RoutingTable,
SearchRankingOptions,
VectorIO,
VectorStoreChunkingStrategy,
@ -24,7 +31,7 @@ from llama_stack.apis.vector_io import (
VectorStoreChunkingStrategyStaticConfig,
VectorStoreDeleteResponse,
VectorStoreFileBatchObject,
VectorStoreFileContentsResponse,
VectorStoreFileContentResponse,
VectorStoreFileDeleteResponse,
VectorStoreFileObject,
VectorStoreFilesListInBatchResponse,
@ -33,9 +40,6 @@ from llama_stack.apis.vector_io import (
VectorStoreObject,
VectorStoreSearchResponsePage,
)
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
logger = get_logger(name=__name__, category="core::routers")
@ -122,6 +126,14 @@ class VectorIORouter(VectorIO):
if embedding_model is not None and embedding_dimension is None:
embedding_dimension = await self._get_embedding_model_dimension(embedding_model)
# Validate that embedding model exists and is of the correct type
if embedding_model is not None:
model = await self.routing_table.get_object_by_identifier("model", embedding_model)
if model is None:
raise ModelNotFoundError(embedding_model)
if model.model_type != ModelType.embedding:
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
# Auto-select provider if not specified
if provider_id is None:
num_providers = len(self.routing_table.impls_by_provider_id)
@ -247,6 +259,13 @@ class VectorIORouter(VectorIO):
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}")
# Check if provider_id is being changed (not supported)
if metadata and "provider_id" in metadata:
current_store = await self.routing_table.get_object_by_identifier("vector_store", vector_store_id)
if current_store and current_store.provider_id != metadata["provider_id"]:
raise ValueError("provider_id cannot be changed after vector store creation")
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store(
vector_store_id=vector_store_id,
@ -338,12 +357,19 @@ class VectorIORouter(VectorIO):
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}")
provider = await self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_contents(
include_embeddings: bool | None = False,
include_metadata: bool | None = False,
) -> VectorStoreFileContentResponse:
logger.debug(
f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}, "
f"include_embeddings={include_embeddings}, include_metadata={include_metadata}"
)
return await self.routing_table.openai_retrieve_vector_store_file_contents(
vector_store_id=vector_store_id,
file_id=file_id,
include_embeddings=include_embeddings,
include_metadata=include_metadata,
)
async def openai_update_vector_store_file(

View file

@ -6,11 +6,11 @@
from typing import Any
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
from llama_stack.core.datatypes import (
BenchmarkWithOwner,
)
from llama_stack.log import get_logger
from llama_stack_api import Benchmark, Benchmarks, ListBenchmarksResponse
from .common import CommonRoutingTableImpl

View file

@ -6,9 +6,6 @@
from typing import Any
from llama_stack.apis.common.errors import ModelNotFoundError
from llama_stack.apis.models import Model
from llama_stack.apis.resource import ResourceType
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.core.access_control.datatypes import Action
from llama_stack.core.datatypes import (
@ -21,7 +18,7 @@ from llama_stack.core.datatypes import (
from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.core.store import DistributionRegistry
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, RoutingTable
from llama_stack_api import Api, Model, ModelNotFoundError, ResourceType, RoutingTable
logger = get_logger(name=__name__, category="core::routing_tables")

View file

@ -7,22 +7,22 @@
import uuid
from typing import Any
from llama_stack.apis.common.errors import DatasetNotFoundError
from llama_stack.apis.datasets import (
from llama_stack.core.datatypes import (
DatasetWithOwner,
)
from llama_stack.log import get_logger
from llama_stack_api import (
Dataset,
DatasetNotFoundError,
DatasetPurpose,
Datasets,
DatasetType,
DataSource,
ListDatasetsResponse,
ResourceType,
RowsDataSource,
URIDataSource,
)
from llama_stack.apis.resource import ResourceType
from llama_stack.core.datatypes import (
DatasetWithOwner,
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl

View file

@ -7,8 +7,6 @@
import time
from typing import Any
from llama_stack.apis.common.errors import ModelNotFoundError
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
from llama_stack.core.datatypes import (
ModelWithOwner,
RegistryEntrySource,
@ -16,6 +14,15 @@ from llama_stack.core.datatypes import (
from llama_stack.core.request_headers import PROVIDER_DATA_VAR, NeedsRequestProviderData
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
from llama_stack_api import (
ListModelsResponse,
Model,
ModelNotFoundError,
Models,
ModelType,
OpenAIListModelsResponse,
OpenAIModel,
)
from .common import CommonRoutingTableImpl, lookup_model

View file

@ -4,18 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import (
ListScoringFunctionsResponse,
ScoringFn,
ScoringFnParams,
ScoringFunctions,
)
from llama_stack.core.datatypes import (
ScoringFnWithOwner,
)
from llama_stack.log import get_logger
from llama_stack_api import (
ListScoringFunctionsResponse,
ParamType,
ResourceType,
ScoringFn,
ScoringFnParams,
ScoringFunctions,
)
from .common import CommonRoutingTableImpl

View file

@ -6,12 +6,11 @@
from typing import Any
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
from llama_stack.core.datatypes import (
ShieldWithOwner,
)
from llama_stack.log import get_logger
from llama_stack_api import ListShieldsResponse, ResourceType, Shield, Shields
from .common import CommonRoutingTableImpl

View file

@ -6,11 +6,17 @@
from typing import Any
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.errors import ToolGroupNotFoundError
from llama_stack.apis.tools import ListToolDefsResponse, ListToolGroupsResponse, ToolDef, ToolGroup, ToolGroups
from llama_stack.core.datatypes import AuthenticationRequiredError, ToolGroupWithOwner
from llama_stack.log import get_logger
from llama_stack_api import (
URL,
ListToolDefsResponse,
ListToolGroupsResponse,
ToolDef,
ToolGroup,
ToolGroupNotFoundError,
ToolGroups,
)
from .common import CommonRoutingTableImpl
@ -43,7 +49,9 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
routing_key = self.tool_to_toolgroup[routing_key]
return await super().get_provider_impl(routing_key, provider_id)
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolDefsResponse:
async def list_tools(
self, toolgroup_id: str | None = None, authorization: str | None = None
) -> ListToolDefsResponse:
if toolgroup_id:
if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id):
toolgroup_id = group_id
@ -55,7 +63,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
for toolgroup in toolgroups:
if toolgroup.identifier not in self.toolgroups_to_tools:
try:
await self._index_tools(toolgroup)
await self._index_tools(toolgroup, authorization=authorization)
except AuthenticationRequiredError:
# Send authentication errors back to the client so it knows
# that it needs to supply credentials for remote MCP servers.
@ -70,9 +78,11 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
return ListToolDefsResponse(data=all_tools)
async def _index_tools(self, toolgroup: ToolGroup):
async def _index_tools(self, toolgroup: ToolGroup, authorization: str | None = None):
provider_impl = await super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint)
tooldefs_response = await provider_impl.list_runtime_tools(
toolgroup.identifier, toolgroup.mcp_endpoint, authorization=authorization
)
tooldefs = tooldefs_response.data
for t in tooldefs:

View file

@ -6,26 +6,27 @@
from typing import Any
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
from llama_stack.apis.models import ModelType
from llama_stack.apis.resource import ResourceType
from llama_stack.core.datatypes import (
VectorStoreWithOwner,
)
from llama_stack.log import get_logger
# Removed VectorStores import to avoid exposing public API
from llama_stack.apis.vector_io.vector_io import (
from llama_stack_api import (
ModelNotFoundError,
ModelType,
ModelTypeError,
ResourceType,
SearchRankingOptions,
VectorStoreChunkingStrategy,
VectorStoreDeleteResponse,
VectorStoreFileContentsResponse,
VectorStoreFileContentResponse,
VectorStoreFileDeleteResponse,
VectorStoreFileObject,
VectorStoreFileStatus,
VectorStoreObject,
VectorStoreSearchResponsePage,
)
from llama_stack.core.datatypes import (
VectorStoreWithOwner,
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model
@ -195,12 +196,17 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl):
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
include_embeddings: bool | None = False,
include_metadata: bool | None = False,
) -> VectorStoreFileContentResponse:
await self.assert_action_allowed("read", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_contents(
vector_store_id=vector_store_id,
file_id=file_id,
include_embeddings=include_embeddings,
include_metadata=include_metadata,
)
async def openai_update_vector_store_file(

View file

@ -13,7 +13,6 @@ import httpx
import jwt
from pydantic import BaseModel, Field
from llama_stack.apis.common.errors import TokenValidationError
from llama_stack.core.datatypes import (
AuthenticationConfig,
CustomAuthConfig,
@ -23,6 +22,7 @@ from llama_stack.core.datatypes import (
User,
)
from llama_stack.log import get_logger
from llama_stack_api import TokenValidationError
logger = get_logger(name=__name__, category="core::auth")

View file

@ -11,9 +11,9 @@ from datetime import UTC, datetime, timedelta
from starlette.types import ASGIApp, Receive, Scope, Send
from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendType
from llama_stack.core.storage.kvstore.kvstore import _KVSTORE_BACKENDS, kvstore_impl
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.kvstore.kvstore import _KVSTORE_BACKENDS, kvstore_impl
from llama_stack_api.internal.kvstore import KVStore
logger = get_logger(name=__name__, category="core::server")

View file

@ -12,9 +12,8 @@ from typing import Any
from aiohttp import hdrs
from starlette.routing import Route
from llama_stack.apis.datatypes import Api, ExternalApiSpec
from llama_stack.core.resolver import api_protocol_map
from llama_stack.schema_utils import WebMethod
from llama_stack_api import Api, ExternalApiSpec, WebMethod
EndpointFunc = Callable[..., Any]
PathParams = dict[str, str]

View file

@ -31,8 +31,6 @@ from fastapi.responses import JSONResponse, StreamingResponse
from openai import BadRequestError
from pydantic import BaseModel, ValidationError
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.core.access_control.access_control import AccessDeniedError
from llama_stack.core.datatypes import (
AuthenticationRequiredError,
@ -58,7 +56,7 @@ from llama_stack.core.utils.config import redact_sensitive_fields
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
from llama_stack.core.utils.context import preserve_contexts_async_generator
from llama_stack.log import LoggingConfig, get_logger, setup_logging
from llama_stack.providers.datatypes import Api
from llama_stack_api import Api, ConflictError, PaginatedResponse, ResourceNotFoundError
from .auth import AuthenticationMiddleware
from .quota import QuotaMiddleware
@ -526,8 +524,8 @@ def extract_path_params(route: str) -> list[str]:
def remove_disabled_providers(obj):
if isinstance(obj, dict):
keys = ["provider_id", "shield_id", "provider_model_id", "model_id"]
if any(k in obj and obj[k] in ("__disabled__", "", None) for k in keys):
# Filter out items where provider_id is explicitly disabled or empty
if "provider_id" in obj and obj["provider_id"] in ("__disabled__", "", None):
return None
return {k: v for k, v in ((k, remove_disabled_providers(v)) for k, v in obj.items()) if v is not None}
elif isinstance(obj, list):

View file

@ -13,26 +13,6 @@ from typing import Any
import yaml
from llama_stack.apis.agents import Agents
from llama_stack.apis.batches import Batches
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.prompts import Prompts
from llama_stack.apis.providers import Providers
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
from llama_stack.core.datatypes import Provider, SafetyConfig, StackRunConfig, VectorStoresConfig
from llama_stack.core.distribution import get_provider_registry
@ -54,7 +34,30 @@ from llama_stack.core.storage.datatypes import (
from llama_stack.core.store.registry import create_dist_registry
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
from llama_stack_api import (
Agents,
Api,
Batches,
Benchmarks,
Conversations,
DatasetIO,
Datasets,
Eval,
Files,
Inference,
Inspect,
Models,
PostTraining,
Prompts,
Providers,
Safety,
Scoring,
ScoringFunctions,
Shields,
ToolGroups,
ToolRuntime,
VectorIO,
)
logger = get_logger(name=__name__, category="core")
@ -382,8 +385,8 @@ def _initialize_storage(run_config: StackRunConfig):
else:
raise ValueError(f"Unknown storage backend type: {type}")
from llama_stack.providers.utils.kvstore.kvstore import register_kvstore_backends
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
from llama_stack.core.storage.kvstore.kvstore import register_kvstore_backends
from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends
register_kvstore_backends(kv_backends)
register_sqlstore_backends(sql_backends)

View file

@ -12,6 +12,8 @@ from typing import Annotated, Literal
from pydantic import BaseModel, Field, field_validator
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
class StorageBackendType(StrEnum):
KV_REDIS = "kv_redis"
@ -256,15 +258,24 @@ class ResponsesStoreReference(InferenceStoreReference):
class ServerStoresConfig(BaseModel):
metadata: KVStoreReference | None = Field(
default=None,
default=KVStoreReference(
backend="kv_default",
namespace="registry",
),
description="Metadata store configuration (uses KV backend)",
)
inference: InferenceStoreReference | None = Field(
default=None,
default=InferenceStoreReference(
backend="sql_default",
table_name="inference_store",
),
description="Inference store configuration (uses SQL backend)",
)
conversations: SqlStoreReference | None = Field(
default=None,
default=SqlStoreReference(
backend="sql_default",
table_name="openai_conversations",
),
description="Conversations store configuration (uses SQL backend)",
)
responses: ResponsesStoreReference | None = Field(
@ -272,13 +283,21 @@ class ServerStoresConfig(BaseModel):
description="Responses store configuration (uses SQL backend)",
)
prompts: KVStoreReference | None = Field(
default=None,
default=KVStoreReference(backend="kv_default", namespace="prompts"),
description="Prompts store configuration (uses KV backend)",
)
class StorageConfig(BaseModel):
backends: dict[str, StorageBackendConfig] = Field(
default={
"kv_default": SqliteKVStoreConfig(
db_path=f"${{env.SQLITE_STORE_DIR:={DISTRIBS_BASE_DIR}}}/kvstore.db",
),
"sql_default": SqliteSqlStoreConfig(
db_path=f"${{env.SQLITE_STORE_DIR:={DISTRIBS_BASE_DIR}}}/sql_store.db",
),
},
description="Named backend configurations (e.g., 'default', 'cache')",
)
stores: ServerStoresConfig = Field(

View file

@ -4,4 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack_api.internal.kvstore import KVStore as KVStore
from .kvstore import * # noqa: F401, F403

View file

@ -11,10 +11,21 @@
from __future__ import annotations
from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendConfig, StorageBackendType
import asyncio
from collections import defaultdict
from datetime import datetime
from typing import cast
from .api import KVStore
from .config import KVStoreConfig
from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendConfig
from llama_stack_api.internal.kvstore import KVStore
from .config import (
KVStoreConfig,
MongoDBKVStoreConfig,
PostgresKVStoreConfig,
RedisKVStoreConfig,
SqliteKVStoreConfig,
)
def kvstore_dependencies():
@ -30,7 +41,7 @@ def kvstore_dependencies():
class InmemoryKVStoreImpl(KVStore):
def __init__(self):
self._store = {}
self._store: dict[str, str] = {}
async def initialize(self) -> None:
pass
@ -38,7 +49,7 @@ class InmemoryKVStoreImpl(KVStore):
async def get(self, key: str) -> str | None:
return self._store.get(key)
async def set(self, key: str, value: str) -> None:
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
self._store[key] = value
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
@ -53,45 +64,65 @@ class InmemoryKVStoreImpl(KVStore):
_KVSTORE_BACKENDS: dict[str, KVStoreConfig] = {}
_KVSTORE_INSTANCES: dict[tuple[str, str], KVStore] = {}
_KVSTORE_LOCKS: defaultdict[tuple[str, str], asyncio.Lock] = defaultdict(asyncio.Lock)
def register_kvstore_backends(backends: dict[str, StorageBackendConfig]) -> None:
"""Register the set of available KV store backends for reference resolution."""
global _KVSTORE_BACKENDS
global _KVSTORE_INSTANCES
global _KVSTORE_LOCKS
_KVSTORE_BACKENDS.clear()
_KVSTORE_INSTANCES.clear()
_KVSTORE_LOCKS.clear()
for name, cfg in backends.items():
_KVSTORE_BACKENDS[name] = cfg
typed_cfg = cast(KVStoreConfig, cfg)
_KVSTORE_BACKENDS[name] = typed_cfg
async def kvstore_impl(reference: KVStoreReference) -> KVStore:
backend_name = reference.backend
cache_key = (backend_name, reference.namespace)
existing = _KVSTORE_INSTANCES.get(cache_key)
if existing:
return existing
backend_config = _KVSTORE_BACKENDS.get(backend_name)
if backend_config is None:
raise ValueError(f"Unknown KVStore backend '{backend_name}'. Registered backends: {sorted(_KVSTORE_BACKENDS)}")
config = backend_config.model_copy()
config.namespace = reference.namespace
lock = _KVSTORE_LOCKS[cache_key]
async with lock:
existing = _KVSTORE_INSTANCES.get(cache_key)
if existing:
return existing
if config.type == StorageBackendType.KV_REDIS.value:
from .redis import RedisKVStoreImpl
config = backend_config.model_copy()
config.namespace = reference.namespace
impl = RedisKVStoreImpl(config)
elif config.type == StorageBackendType.KV_SQLITE.value:
from .sqlite import SqliteKVStoreImpl
impl: KVStore
if isinstance(config, RedisKVStoreConfig):
from .redis import RedisKVStoreImpl
impl = SqliteKVStoreImpl(config)
elif config.type == StorageBackendType.KV_POSTGRES.value:
from .postgres import PostgresKVStoreImpl
impl = RedisKVStoreImpl(config)
elif isinstance(config, SqliteKVStoreConfig):
from .sqlite import SqliteKVStoreImpl
impl = PostgresKVStoreImpl(config)
elif config.type == StorageBackendType.KV_MONGODB.value:
from .mongodb import MongoDBKVStoreImpl
impl = SqliteKVStoreImpl(config)
elif isinstance(config, PostgresKVStoreConfig):
from .postgres import PostgresKVStoreImpl
impl = MongoDBKVStoreImpl(config)
else:
raise ValueError(f"Unknown kvstore type {config.type}")
impl = PostgresKVStoreImpl(config)
elif isinstance(config, MongoDBKVStoreConfig):
from .mongodb import MongoDBKVStoreImpl
await impl.initialize()
return impl
impl = MongoDBKVStoreImpl(config)
else:
raise ValueError(f"Unknown kvstore type {config.type}")
await impl.initialize()
_KVSTORE_INSTANCES[cache_key] = impl
return impl

View file

@ -9,8 +9,8 @@ from datetime import datetime
from pymongo import AsyncMongoClient
from pymongo.asynchronous.collection import AsyncCollection
from llama_stack.core.storage.kvstore import KVStore
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore
from ..config import MongoDBKVStoreConfig

View file

@ -6,12 +6,13 @@
from datetime import datetime
import psycopg2
from psycopg2.extras import DictCursor
import psycopg2 # type: ignore[import-not-found]
from psycopg2.extensions import connection as PGConnection # type: ignore[import-not-found]
from psycopg2.extras import DictCursor # type: ignore[import-not-found]
from llama_stack.log import get_logger
from llama_stack_api.internal.kvstore import KVStore
from ..api import KVStore
from ..config import PostgresKVStoreConfig
log = get_logger(name=__name__, category="providers::utils")
@ -20,12 +21,12 @@ log = get_logger(name=__name__, category="providers::utils")
class PostgresKVStoreImpl(KVStore):
def __init__(self, config: PostgresKVStoreConfig):
self.config = config
self.conn = None
self.cursor = None
self._conn: PGConnection | None = None
self._cursor: DictCursor | None = None
async def initialize(self) -> None:
try:
self.conn = psycopg2.connect(
self._conn = psycopg2.connect(
host=self.config.host,
port=self.config.port,
database=self.config.db,
@ -34,11 +35,11 @@ class PostgresKVStoreImpl(KVStore):
sslmode=self.config.ssl_mode,
sslrootcert=self.config.ca_cert_path,
)
self.conn.autocommit = True
self.cursor = self.conn.cursor(cursor_factory=DictCursor)
self._conn.autocommit = True
self._cursor = self._conn.cursor(cursor_factory=DictCursor)
# Create table if it doesn't exist
self.cursor.execute(
self._cursor.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.config.table_name} (
key TEXT PRIMARY KEY,
@ -51,6 +52,11 @@ class PostgresKVStoreImpl(KVStore):
log.exception("Could not connect to PostgreSQL database server")
raise RuntimeError("Could not connect to PostgreSQL database server") from e
def _cursor_or_raise(self) -> DictCursor:
if self._cursor is None:
raise RuntimeError("Postgres client not initialized")
return self._cursor
def _namespaced_key(self, key: str) -> str:
if not self.config.namespace:
return key
@ -58,7 +64,8 @@ class PostgresKVStoreImpl(KVStore):
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
key = self._namespaced_key(key)
self.cursor.execute(
cursor = self._cursor_or_raise()
cursor.execute(
f"""
INSERT INTO {self.config.table_name} (key, value, expiration)
VALUES (%s, %s, %s)
@ -70,7 +77,8 @@ class PostgresKVStoreImpl(KVStore):
async def get(self, key: str) -> str | None:
key = self._namespaced_key(key)
self.cursor.execute(
cursor = self._cursor_or_raise()
cursor.execute(
f"""
SELECT value FROM {self.config.table_name}
WHERE key = %s
@ -78,12 +86,13 @@ class PostgresKVStoreImpl(KVStore):
""",
(key,),
)
result = self.cursor.fetchone()
result = cursor.fetchone()
return result[0] if result else None
async def delete(self, key: str) -> None:
key = self._namespaced_key(key)
self.cursor.execute(
cursor = self._cursor_or_raise()
cursor.execute(
f"DELETE FROM {self.config.table_name} WHERE key = %s",
(key,),
)
@ -92,7 +101,8 @@ class PostgresKVStoreImpl(KVStore):
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
self.cursor.execute(
cursor = self._cursor_or_raise()
cursor.execute(
f"""
SELECT value FROM {self.config.table_name}
WHERE key >= %s AND key < %s
@ -101,14 +111,15 @@ class PostgresKVStoreImpl(KVStore):
""",
(start_key, end_key),
)
return [row[0] for row in self.cursor.fetchall()]
return [row[0] for row in cursor.fetchall()]
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
self.cursor.execute(
cursor = self._cursor_or_raise()
cursor.execute(
f"SELECT key FROM {self.config.table_name} WHERE key >= %s AND key < %s",
(start_key, end_key),
)
return [row[0] for row in self.cursor.fetchall()]
return [row[0] for row in cursor.fetchall()]

View file

@ -6,18 +6,25 @@
from datetime import datetime
from redis.asyncio import Redis
from redis.asyncio import Redis # type: ignore[import-not-found]
from llama_stack_api.internal.kvstore import KVStore
from ..api import KVStore
from ..config import RedisKVStoreConfig
class RedisKVStoreImpl(KVStore):
def __init__(self, config: RedisKVStoreConfig):
self.config = config
self._redis: Redis | None = None
async def initialize(self) -> None:
self.redis = Redis.from_url(self.config.url)
self._redis = Redis.from_url(self.config.url)
def _client(self) -> Redis:
if self._redis is None:
raise RuntimeError("Redis client not initialized")
return self._redis
def _namespaced_key(self, key: str) -> str:
if not self.config.namespace:
@ -26,30 +33,37 @@ class RedisKVStoreImpl(KVStore):
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
key = self._namespaced_key(key)
await self.redis.set(key, value)
client = self._client()
await client.set(key, value)
if expiration:
await self.redis.expireat(key, expiration)
await client.expireat(key, expiration)
async def get(self, key: str) -> str | None:
key = self._namespaced_key(key)
value = await self.redis.get(key)
client = self._client()
value = await client.get(key)
if value is None:
return None
await self.redis.ttl(key)
return value
await client.ttl(key)
if isinstance(value, bytes):
return value.decode("utf-8")
if isinstance(value, str):
return value
return str(value)
async def delete(self, key: str) -> None:
key = self._namespaced_key(key)
await self.redis.delete(key)
await self._client().delete(key)
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
client = self._client()
cursor = 0
pattern = start_key + "*" # Match all keys starting with start_key prefix
matching_keys = []
matching_keys: list[str | bytes] = []
while True:
cursor, keys = await self.redis.scan(cursor, match=pattern, count=1000)
cursor, keys = await client.scan(cursor, match=pattern, count=1000)
for key in keys:
key_str = key.decode("utf-8") if isinstance(key, bytes) else key
@ -61,7 +75,7 @@ class RedisKVStoreImpl(KVStore):
# Then fetch all values in a single MGET call
if matching_keys:
values = await self.redis.mget(matching_keys)
values = await client.mget(matching_keys)
return [
value.decode("utf-8") if isinstance(value, bytes) else value for value in values if value is not None
]
@ -70,7 +84,18 @@ class RedisKVStoreImpl(KVStore):
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
"""Get all keys in the given range."""
matching_keys = await self.redis.zrangebylex(self.namespace, f"[{start_key}", f"[{end_key}")
if not matching_keys:
return []
return [k.decode("utf-8") for k in matching_keys]
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
client = self._client()
cursor = 0
pattern = start_key + "*"
result: list[str] = []
while True:
cursor, keys = await client.scan(cursor, match=pattern, count=1000)
for key in keys:
key_str = key.decode("utf-8") if isinstance(key, bytes) else str(key)
if start_key <= key_str <= end_key:
result.append(key_str)
if cursor == 0:
break
return result

View file

@ -10,8 +10,8 @@ from datetime import datetime
import aiosqlite
from llama_stack.log import get_logger
from llama_stack_api.internal.kvstore import KVStore
from ..api import KVStore
from ..config import SqliteKVStoreConfig
logger = get_logger(name=__name__, category="providers::utils")

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack_api.internal.sqlstore import (
ColumnDefinition as ColumnDefinition,
)
from llama_stack_api.internal.sqlstore import (
ColumnType as ColumnType,
)
from llama_stack_api.internal.sqlstore import (
SqlStore as SqlStore,
)
from .sqlstore import * # noqa: F401,F403

View file

@ -14,8 +14,8 @@ from llama_stack.core.datatypes import User
from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.core.storage.datatypes import StorageBackendType
from llama_stack.log import get_logger
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
from llama_stack_api import PaginatedResponse
from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType, SqlStore
logger = get_logger(name=__name__, category="providers::utils")
@ -45,8 +45,13 @@ def _enhance_item_with_access_control(item: Mapping[str, Any], current_user: Use
enhanced["owner_principal"] = current_user.principal
enhanced["access_attributes"] = current_user.attributes
else:
enhanced["owner_principal"] = None
enhanced["access_attributes"] = None
# IMPORTANT: Use empty string and null value (not None) to match public access filter
# The public access filter in _get_public_access_conditions() expects:
# - owner_principal = '' (empty string)
# - access_attributes = null (JSON null, which serializes to the string 'null')
# Setting them to None (SQL NULL) will cause rows to be filtered out on read.
enhanced["owner_principal"] = ""
enhanced["access_attributes"] = None # Pydantic/JSON will serialize this as JSON null
return enhanced
@ -124,6 +129,23 @@ class AuthorizedSqlStore:
enhanced_data = [_enhance_item_with_access_control(item, current_user) for item in data]
await self.sql_store.insert(table, enhanced_data)
async def upsert(
self,
table: str,
data: Mapping[str, Any],
conflict_columns: list[str],
update_columns: list[str] | None = None,
) -> None:
"""Upsert a row with automatic access control attribute capture."""
current_user = get_authenticated_user()
enhanced_data = _enhance_item_with_access_control(data, current_user)
await self.sql_store.upsert(
table=table,
data=enhanced_data,
conflict_columns=conflict_columns,
update_columns=update_columns,
)
async def fetch_all(
self,
table: str,
@ -188,8 +210,9 @@ class AuthorizedSqlStore:
enhanced_data["owner_principal"] = current_user.principal
enhanced_data["access_attributes"] = current_user.attributes
else:
enhanced_data["owner_principal"] = None
enhanced_data["access_attributes"] = None
# IMPORTANT: Use empty string for owner_principal to match public access filter
enhanced_data["owner_principal"] = ""
enhanced_data["access_attributes"] = None # Will serialize as JSON null
await self.sql_store.update(table, enhanced_data, where)
@ -245,14 +268,24 @@ class AuthorizedSqlStore:
raise ValueError(f"Unsupported database type: {self.database_type}")
def _get_public_access_conditions(self) -> list[str]:
"""Get the SQL conditions for public access."""
# Public records are records that have no owner_principal or access_attributes
"""Get the SQL conditions for public access.
Public records are those with:
- owner_principal = '' (empty string)
- access_attributes is either SQL NULL or JSON null
Note: Different databases serialize None differently:
- SQLite: None JSON null (text = 'null')
- Postgres: None SQL NULL (IS NULL)
"""
conditions = ["owner_principal = ''"]
if self.database_type == StorageBackendType.SQL_POSTGRES.value:
# Postgres stores JSON null as 'null'
conditions.append("access_attributes::text = 'null'")
# Accept both SQL NULL and JSON null for Postgres compatibility
# This handles both old rows (SQL NULL) and new rows (JSON null)
conditions.append("(access_attributes IS NULL OR access_attributes::text = 'null')")
elif self.database_type == StorageBackendType.SQL_SQLITE.value:
conditions.append("access_attributes = 'null'")
# SQLite serializes None as JSON null
conditions.append("(access_attributes IS NULL OR access_attributes = 'null')")
else:
raise ValueError(f"Unsupported database type: {self.database_type}")
return conditions

View file

@ -26,11 +26,10 @@ from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.ext.asyncio.engine import AsyncEngine
from sqlalchemy.sql.elements import ColumnElement
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.core.storage.datatypes import SqlAlchemySqlStoreConfig
from llama_stack.log import get_logger
from .api import ColumnDefinition, ColumnType, SqlStore
from llama_stack_api import PaginatedResponse
from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType, SqlStore
logger = get_logger(name=__name__, category="providers::utils")
@ -72,13 +71,14 @@ def _build_where_expr(column: ColumnElement, value: Any) -> ColumnElement:
class SqlAlchemySqlStoreImpl(SqlStore):
def __init__(self, config: SqlAlchemySqlStoreConfig):
self.config = config
self._is_sqlite_backend = "sqlite" in self.config.engine_str
self.async_session = async_sessionmaker(self.create_engine())
self.metadata = MetaData()
def create_engine(self) -> AsyncEngine:
# Configure connection args for better concurrency support
connect_args = {}
if "sqlite" in self.config.engine_str:
if self._is_sqlite_backend:
# SQLite-specific optimizations for concurrent access
# With WAL mode, most locks resolve in milliseconds, but allow up to 5s for edge cases
connect_args["timeout"] = 5.0
@ -91,7 +91,7 @@ class SqlAlchemySqlStoreImpl(SqlStore):
)
# Enable WAL mode for SQLite to support concurrent readers and writers
if "sqlite" in self.config.engine_str:
if self._is_sqlite_backend:
@event.listens_for(engine.sync_engine, "connect")
def set_sqlite_pragma(dbapi_conn, connection_record):
@ -151,6 +151,29 @@ class SqlAlchemySqlStoreImpl(SqlStore):
await session.execute(self.metadata.tables[table].insert(), data)
await session.commit()
async def upsert(
self,
table: str,
data: Mapping[str, Any],
conflict_columns: list[str],
update_columns: list[str] | None = None,
) -> None:
table_obj = self.metadata.tables[table]
dialect_insert = self._get_dialect_insert(table_obj)
insert_stmt = dialect_insert.values(**data)
if update_columns is None:
update_columns = [col for col in data.keys() if col not in conflict_columns]
update_mapping = {col: getattr(insert_stmt.excluded, col) for col in update_columns}
conflict_cols = [table_obj.c[col] for col in conflict_columns]
stmt = insert_stmt.on_conflict_do_update(index_elements=conflict_cols, set_=update_mapping)
async with self.async_session() as session:
await session.execute(stmt)
await session.commit()
async def fetch_all(
self,
table: str,
@ -333,9 +356,18 @@ class SqlAlchemySqlStoreImpl(SqlStore):
add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}")
await conn.execute(add_column_sql)
except Exception as e:
# If any error occurs during migration, log it but don't fail
# The table creation will handle adding the column
logger.error(f"Error adding column {column_name} to table {table}: {e}")
pass
def _get_dialect_insert(self, table: Table):
if self._is_sqlite_backend:
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
return sqlite_insert(table)
else:
from sqlalchemy.dialects.postgresql import insert as pg_insert
return pg_insert(table)

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from threading import Lock
from typing import Annotated, cast
from pydantic import Field
@ -15,12 +16,13 @@ from llama_stack.core.storage.datatypes import (
StorageBackendConfig,
StorageBackendType,
)
from .api import SqlStore
from llama_stack_api.internal.sqlstore import SqlStore
sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"]
_SQLSTORE_BACKENDS: dict[str, StorageBackendConfig] = {}
_SQLSTORE_INSTANCES: dict[str, SqlStore] = {}
_SQLSTORE_LOCKS: dict[str, Lock] = {}
SqlStoreConfig = Annotated[
@ -52,19 +54,34 @@ def sqlstore_impl(reference: SqlStoreReference) -> SqlStore:
f"Unknown SQL store backend '{backend_name}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
)
if isinstance(backend_config, SqliteSqlStoreConfig | PostgresSqlStoreConfig):
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
existing = _SQLSTORE_INSTANCES.get(backend_name)
if existing:
return existing
config = cast(SqliteSqlStoreConfig | PostgresSqlStoreConfig, backend_config).model_copy()
return SqlAlchemySqlStoreImpl(config)
else:
raise ValueError(f"Unknown sqlstore type {backend_config.type}")
lock = _SQLSTORE_LOCKS.setdefault(backend_name, Lock())
with lock:
existing = _SQLSTORE_INSTANCES.get(backend_name)
if existing:
return existing
if isinstance(backend_config, SqliteSqlStoreConfig | PostgresSqlStoreConfig):
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
config = cast(SqliteSqlStoreConfig | PostgresSqlStoreConfig, backend_config).model_copy()
instance = SqlAlchemySqlStoreImpl(config)
_SQLSTORE_INSTANCES[backend_name] = instance
return instance
else:
raise ValueError(f"Unknown sqlstore type {backend_config.type}")
def register_sqlstore_backends(backends: dict[str, StorageBackendConfig]) -> None:
"""Register the set of available SQL store backends for reference resolution."""
global _SQLSTORE_BACKENDS
global _SQLSTORE_INSTANCES
_SQLSTORE_BACKENDS.clear()
_SQLSTORE_INSTANCES.clear()
_SQLSTORE_LOCKS.clear()
for name, cfg in backends.items():
_SQLSTORE_BACKENDS[name] = cfg

View file

@ -12,8 +12,8 @@ import pydantic
from llama_stack.core.datatypes import RoutableObjectWithProvider
from llama_stack.core.storage.datatypes import KVStoreReference
from llama_stack.core.storage.kvstore import KVStore, kvstore_impl
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
logger = get_logger(__name__, category="core::registry")

View file

@ -28,7 +28,7 @@ from pydantic import BaseModel, Field
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import Primitive
from llama_stack.schema_utils import json_schema_type, register_schema
from llama_stack_api import json_schema_type, register_schema
ROOT_SPAN_MARKERS = ["__root__", "__root_span__"]

View file

@ -0,0 +1,45 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Utility functions for type inspection and parameter handling.
"""
import inspect
import typing
from typing import Any, get_args, get_origin
from pydantic import BaseModel
from pydantic.fields import FieldInfo
def is_unwrapped_body_param(param_type: Any) -> bool:
"""
Check if a parameter type represents an unwrapped body parameter.
An unwrapped body parameter is an Annotated type with Body(embed=False)
This is used to determine whether request parameters should be flattened
in OpenAPI specs and client libraries (matching FastAPI's embed=False behavior).
Args:
param_type: The parameter type annotation to check
Returns:
True if the parameter should be treated as an unwrapped body parameter
"""
# Check if it's Annotated with Body(embed=False)
if get_origin(param_type) is typing.Annotated:
args = get_args(param_type)
base_type = args[0]
metadata = args[1:]
# Look for Body annotation with embed=False
# Body() returns a FieldInfo object, so we check for that type and the embed attribute
for item in metadata:
if isinstance(item, FieldInfo) and hasattr(item, "embed") and not item.embed:
return inspect.isclass(base_type) and issubclass(base_type, BaseModel)
return False

View file

@ -13,6 +13,5 @@ from ..starter.starter import get_distribution_template as get_starter_distribut
def get_distribution_template() -> DistributionTemplate:
template = get_starter_distribution_template(name="ci-tests")
template.description = "CI tests for Llama Stack"
template.run_configs.pop("run-with-postgres-store.yaml", None)
return template

View file

@ -0,0 +1,296 @@
version: 2
image_name: ci-tests
apis:
- agents
- batches
- datasetio
- eval
- file_processor
- files
- inference
- post_training
- safety
- scoring
- tool_runtime
- vector_io
providers:
inference:
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
provider_type: remote::cerebras
config:
base_url: https://api.cerebras.ai/v1
api_key: ${env.CEREBRAS_API_KEY:=}
- provider_id: ${env.OLLAMA_URL:+ollama}
provider_type: remote::ollama
config:
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
- provider_id: ${env.VLLM_URL:+vllm}
provider_type: remote::vllm
config:
base_url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi
config:
base_url: ${env.TGI_URL:=}
- provider_id: fireworks
provider_type: remote::fireworks
config:
base_url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY:=}
- provider_id: together
provider_type: remote::together
config:
base_url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock
provider_type: remote::bedrock
config:
api_key: ${env.AWS_BEARER_TOKEN_BEDROCK:=}
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia
config:
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
api_key: ${env.NVIDIA_API_KEY:=}
- provider_id: openai
provider_type: remote::openai
config:
api_key: ${env.OPENAI_API_KEY:=}
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
- provider_id: anthropic
provider_type: remote::anthropic
config:
api_key: ${env.ANTHROPIC_API_KEY:=}
- provider_id: gemini
provider_type: remote::gemini
config:
api_key: ${env.GEMINI_API_KEY:=}
- provider_id: ${env.VERTEX_AI_PROJECT:+vertexai}
provider_type: remote::vertexai
config:
project: ${env.VERTEX_AI_PROJECT:=}
location: ${env.VERTEX_AI_LOCATION:=us-central1}
- provider_id: groq
provider_type: remote::groq
config:
base_url: https://api.groq.com/openai/v1
api_key: ${env.GROQ_API_KEY:=}
- provider_id: sambanova
provider_type: remote::sambanova
config:
base_url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY:=}
- provider_id: ${env.AZURE_API_KEY:+azure}
provider_type: remote::azure
config:
api_key: ${env.AZURE_API_KEY:=}
base_url: ${env.AZURE_API_BASE:=}
api_version: ${env.AZURE_API_VERSION:=}
api_type: ${env.AZURE_API_TYPE:=}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
persistence:
namespace: vector_io::faiss
backend: kv_default
- provider_id: sqlite-vec
provider_type: inline::sqlite-vec
config:
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/sqlite_vec.db
persistence:
namespace: vector_io::sqlite_vec
backend: kv_default
- provider_id: ${env.MILVUS_URL:+milvus}
provider_type: inline::milvus
config:
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/ci-tests}/milvus.db
persistence:
namespace: vector_io::milvus
backend: kv_default
- provider_id: ${env.CHROMADB_URL:+chromadb}
provider_type: remote::chromadb
config:
url: ${env.CHROMADB_URL:=}
persistence:
namespace: vector_io::chroma_remote
backend: kv_default
- provider_id: ${env.PGVECTOR_DB:+pgvector}
provider_type: remote::pgvector
config:
host: ${env.PGVECTOR_HOST:=localhost}
port: ${env.PGVECTOR_PORT:=5432}
db: ${env.PGVECTOR_DB:=}
user: ${env.PGVECTOR_USER:=}
password: ${env.PGVECTOR_PASSWORD:=}
persistence:
namespace: vector_io::pgvector
backend: kv_default
- provider_id: ${env.QDRANT_URL:+qdrant}
provider_type: remote::qdrant
config:
api_key: ${env.QDRANT_API_KEY:=}
persistence:
namespace: vector_io::qdrant_remote
backend: kv_default
- provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate}
provider_type: remote::weaviate
config:
weaviate_api_key: null
weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080}
persistence:
namespace: vector_io::weaviate
backend: kv_default
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/ci-tests/files}
metadata_store:
table_name: files_metadata
backend: sql_default
file_processor:
- provider_id: reference
provider_type: inline::reference
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config:
excluded_categories: []
- provider_id: code-scanner
provider_type: inline::code-scanner
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence:
agent_state:
namespace: agents
backend: kv_default
responses:
table_name: responses
backend: sql_default
max_write_queue_size: 10000
num_writers: 4
post_training:
- provider_id: torchtune-cpu
provider_type: inline::torchtune-cpu
config:
checkpoint_format: meta
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
namespace: eval
backend: kv_default
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config:
kvstore:
namespace: datasetio::huggingface
backend: kv_default
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
namespace: datasetio::localfs
backend: kv_default
scoring:
- provider_id: basic
provider_type: inline::basic
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:=}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:=}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
batches:
- provider_id: reference
provider_type: inline::reference
config:
kvstore:
namespace: batches
backend: kv_default
storage:
backends:
kv_default:
type: kv_postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
table_name: ${env.POSTGRES_TABLE_NAME:=llamastack_kvstore}
sql_default:
type: sql_postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
stores:
metadata:
namespace: registry
backend: kv_default
inference:
table_name: inference_store
backend: sql_default
max_write_queue_size: 10000
num_writers: 4
conversations:
table_name: openai_conversations
backend: sql_default
prompts:
namespace: prompts
backend: kv_default
registered_resources:
models: []
shields:
- shield_id: llama-guard
provider_id: ${env.SAFETY_MODEL:+llama-guard}
provider_shield_id: ${env.SAFETY_MODEL:=}
- shield_id: code-scanner
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
telemetry:
enabled: true
vector_stores:
default_provider_id: faiss
default_embedding_model:
provider_id: sentence-transformers
model_id: nomic-ai/nomic-embed-text-v1.5
safety:
default_shield_id: llama-guard

View file

@ -18,44 +18,43 @@ providers:
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
provider_type: remote::cerebras
config:
base_url: https://api.cerebras.ai
base_url: https://api.cerebras.ai/v1
api_key: ${env.CEREBRAS_API_KEY:=}
- provider_id: ${env.OLLAMA_URL:+ollama}
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:=http://localhost:11434}
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
- provider_id: ${env.VLLM_URL:+vllm}
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:=}
base_url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi
config:
url: ${env.TGI_URL:=}
base_url: ${env.TGI_URL:=}
- provider_id: fireworks
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference/v1
base_url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY:=}
- provider_id: together
provider_type: remote::together
config:
url: https://api.together.xyz/v1
base_url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock
provider_type: remote::bedrock
config:
api_key: ${env.AWS_BEDROCK_API_KEY:=}
api_key: ${env.AWS_BEARER_TOKEN_BEDROCK:=}
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: openai
provider_type: remote::openai
config:
@ -77,18 +76,18 @@ providers:
- provider_id: groq
provider_type: remote::groq
config:
url: https://api.groq.com
base_url: https://api.groq.com/openai/v1
api_key: ${env.GROQ_API_KEY:=}
- provider_id: sambanova
provider_type: remote::sambanova
config:
url: https://api.sambanova.ai/v1
base_url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY:=}
- provider_id: ${env.AZURE_API_KEY:+azure}
provider_type: remote::azure
config:
api_key: ${env.AZURE_API_KEY:=}
api_base: ${env.AZURE_API_BASE:=}
base_url: ${env.AZURE_API_BASE:=}
api_version: ${env.AZURE_API_VERSION:=}
api_type: ${env.AZURE_API_TYPE:=}
- provider_id: sentence-transformers

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models import ModelType
from llama_stack.core.datatypes import (
BuildProvider,
ModelInput,
@ -17,6 +16,7 @@ from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.remote.vector_io.chroma import ChromaVectorIOConfig
from llama_stack_api import ModelType
def get_distribution_template() -> DistributionTemplate:

View file

@ -6,7 +6,6 @@
from pathlib import Path
from llama_stack.apis.models import ModelType
from llama_stack.core.datatypes import (
BuildProvider,
ModelInput,
@ -22,6 +21,7 @@ from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack_api import ModelType
def get_distribution_template() -> DistributionTemplate:

View file

@ -16,9 +16,8 @@ providers:
- provider_id: nvidia
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: nvidia
provider_type: remote::nvidia
config:

View file

@ -16,9 +16,8 @@ providers:
- provider_id: nvidia
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
vector_io:
- provider_id: faiss
provider_type: inline::faiss

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .agents import *
from .oci import get_distribution_template # noqa: F401

View file

@ -0,0 +1,35 @@
version: 2
distribution_spec:
description: Use Oracle Cloud Infrastructure (OCI) Generative AI for running LLM
inference with scalable cloud services
providers:
inference:
- provider_type: remote::oci
vector_io:
- provider_type: inline::faiss
- provider_type: remote::chromadb
- provider_type: remote::pgvector
safety:
- provider_type: inline::llama-guard
agents:
- provider_type: inline::meta-reference
eval:
- provider_type: inline::meta-reference
datasetio:
- provider_type: remote::huggingface
- provider_type: inline::localfs
scoring:
- provider_type: inline::basic
- provider_type: inline::llm-as-judge
- provider_type: inline::braintrust
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol
files:
- provider_type: inline::localfs
image_type: venv
additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio]

View file

@ -0,0 +1,140 @@
---
orphan: true
---
# OCI Distribution
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations.
{{ providers_table }}
{% if run_config_env_vars %}
### Environment Variables
The following environment variables can be configured:
{% for var, (default_value, description) in run_config_env_vars.items() %}
- `{{ var }}`: {{ description }} (default: `{{ default_value }}`)
{% endfor %}
{% endif %}
{% if default_models %}
### Models
The following models are available by default:
{% for model in default_models %}
- `{{ model.model_id }} {{ model.doc_string }}`
{% endfor %}
{% endif %}
## Prerequisites
### Oracle Cloud Infrastructure Setup
Before using the OCI Generative AI distribution, ensure you have:
1. **Oracle Cloud Infrastructure Account**: Sign up at [Oracle Cloud Infrastructure](https://cloud.oracle.com/)
2. **Generative AI Service Access**: Enable the Generative AI service in your OCI tenancy
3. **Compartment**: Create or identify a compartment where you'll deploy Generative AI models
4. **Authentication**: Configure authentication using either:
- **Instance Principal** (recommended for cloud-hosted deployments)
- **API Key** (for on-premises or development environments)
### Authentication Methods
#### Instance Principal Authentication (Recommended)
Instance Principal authentication allows OCI resources to authenticate using the identity of the compute instance they're running on. This is the most secure method for production deployments.
Requirements:
- Instance must be running in an Oracle Cloud Infrastructure compartment
- Instance must have appropriate IAM policies to access Generative AI services
#### API Key Authentication
For development or on-premises deployments, follow [this doc](https://docs.oracle.com/en-us/iaas/Content/API/Concepts/apisigningkey.htm) to learn how to create your API signing key for your config file.
### Required IAM Policies
Ensure your OCI user or instance has the following policy statements:
```
Allow group <group_name> to use generative-ai-inference-endpoints in compartment <compartment_name>
Allow group <group_name> to manage generative-ai-inference-endpoints in compartment <compartment_name>
```
## Supported Services
### Inference: OCI Generative AI
Oracle Cloud Infrastructure Generative AI provides access to high-performance AI models through OCI's Platform-as-a-Service offering. The service supports:
- **Chat Completions**: Conversational AI with context awareness
- **Text Generation**: Complete prompts and generate text content
#### Available Models
Common OCI Generative AI models include access to Meta, Cohere, OpenAI, Grok, and more models.
### Safety: Llama Guard
For content safety and moderation, this distribution uses Meta's LlamaGuard model through the OCI Generative AI service to provide:
- Content filtering and moderation
- Policy compliance checking
- Harmful content detection
### Vector Storage: Multiple Options
The distribution supports several vector storage providers:
- **FAISS**: Local in-memory vector search
- **ChromaDB**: Distributed vector database
- **PGVector**: PostgreSQL with vector extensions
### Additional Services
- **Dataset I/O**: Local filesystem and Hugging Face integration
- **Tool Runtime**: Web search (Brave, Tavily) and RAG capabilities
- **Evaluation**: Meta reference evaluation framework
## Running Llama Stack with OCI
You can run the OCI distribution via Docker or local virtual environment.
### Via venv
If you've set up your local development environment, you can also build the image using your local virtual environment.
```bash
OCI_AUTH=$OCI_AUTH_TYPE OCI_REGION=$OCI_REGION OCI_COMPARTMENT_OCID=$OCI_COMPARTMENT_OCID llama stack run --port 8321 oci
```
### Configuration Examples
#### Using Instance Principal (Recommended for Production)
```bash
export OCI_AUTH_TYPE=instance_principal
export OCI_REGION=us-chicago-1
export OCI_COMPARTMENT_OCID=ocid1.compartment.oc1..<your-compartment-id>
```
#### Using API Key Authentication (Development)
```bash
export OCI_AUTH_TYPE=config_file
export OCI_CONFIG_FILE_PATH=~/.oci/config
export OCI_CLI_PROFILE=DEFAULT
export OCI_REGION=us-chicago-1
export OCI_COMPARTMENT_OCID=ocid1.compartment.oc1..your-compartment-id
```
## Regional Endpoints
OCI Generative AI is available in multiple regions. The service automatically routes to the appropriate regional endpoint based on your configuration. For a full list of regional model availability, visit:
https://docs.oracle.com/en-us/iaas/Content/generative-ai/overview.htm#regions
## Troubleshooting
### Common Issues
1. **Authentication Errors**: Verify your OCI credentials and IAM policies
2. **Model Not Found**: Ensure the model OCID is correct and the model is available in your region
3. **Permission Denied**: Check compartment permissions and Generative AI service access
4. **Region Unavailable**: Verify the specified region supports Generative AI services
### Getting Help
For additional support:
- [OCI Generative AI Documentation](https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm)
- [Llama Stack Issues](https://github.com/meta-llama/llama-stack/issues)

View file

@ -0,0 +1,108 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from llama_stack.core.datatypes import BuildProvider, Provider, ToolGroupInput
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.remote.inference.oci.config import OCIConfig
def get_distribution_template(name: str = "oci") -> DistributionTemplate:
providers = {
"inference": [BuildProvider(provider_type="remote::oci")],
"vector_io": [
BuildProvider(provider_type="inline::faiss"),
BuildProvider(provider_type="remote::chromadb"),
BuildProvider(provider_type="remote::pgvector"),
],
"safety": [BuildProvider(provider_type="inline::llama-guard")],
"agents": [BuildProvider(provider_type="inline::meta-reference")],
"eval": [BuildProvider(provider_type="inline::meta-reference")],
"datasetio": [
BuildProvider(provider_type="remote::huggingface"),
BuildProvider(provider_type="inline::localfs"),
],
"scoring": [
BuildProvider(provider_type="inline::basic"),
BuildProvider(provider_type="inline::llm-as-judge"),
BuildProvider(provider_type="inline::braintrust"),
],
"tool_runtime": [
BuildProvider(provider_type="remote::brave-search"),
BuildProvider(provider_type="remote::tavily-search"),
BuildProvider(provider_type="inline::rag-runtime"),
BuildProvider(provider_type="remote::model-context-protocol"),
],
"files": [BuildProvider(provider_type="inline::localfs")],
}
inference_provider = Provider(
provider_id="oci",
provider_type="remote::oci",
config=OCIConfig.sample_run_config(),
)
vector_io_provider = Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
files_provider = Provider(
provider_id="meta-reference-files",
provider_type="inline::localfs",
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
]
return DistributionTemplate(
name=name,
distro_type="remote_hosted",
description="Use Oracle Cloud Infrastructure (OCI) Generative AI for running LLM inference with scalable cloud services",
container_image=None,
template_path=Path(__file__).parent / "doc_template.md",
providers=providers,
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider],
"vector_io": [vector_io_provider],
"files": [files_provider],
},
default_tool_groups=default_tool_groups,
),
},
run_config_env_vars={
"OCI_AUTH_TYPE": (
"instance_principal",
"OCI authentication type (instance_principal or config_file)",
),
"OCI_REGION": (
"",
"OCI region (e.g., us-ashburn-1, us-chicago-1, us-phoenix-1, eu-frankfurt-1)",
),
"OCI_COMPARTMENT_OCID": (
"",
"OCI compartment ID for the Generative AI service",
),
"OCI_CONFIG_FILE_PATH": (
"~/.oci/config",
"OCI config file path (required if OCI_AUTH_TYPE is config_file)",
),
"OCI_CLI_PROFILE": (
"DEFAULT",
"OCI CLI profile name to use from config file",
),
},
)

View file

@ -0,0 +1,136 @@
version: 2
image_name: oci
apis:
- agents
- datasetio
- eval
- files
- inference
- safety
- scoring
- tool_runtime
- vector_io
providers:
inference:
- provider_id: oci
provider_type: remote::oci
config:
oci_auth_type: ${env.OCI_AUTH_TYPE:=instance_principal}
oci_config_file_path: ${env.OCI_CONFIG_FILE_PATH:=~/.oci/config}
oci_config_profile: ${env.OCI_CLI_PROFILE:=DEFAULT}
oci_region: ${env.OCI_REGION:=us-ashburn-1}
oci_compartment_id: ${env.OCI_COMPARTMENT_OCID:=}
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
persistence:
namespace: vector_io::faiss
backend: kv_default
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config:
excluded_categories: []
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence:
agent_state:
namespace: agents
backend: kv_default
responses:
table_name: responses
backend: sql_default
max_write_queue_size: 10000
num_writers: 4
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
namespace: eval
backend: kv_default
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config:
kvstore:
namespace: datasetio::huggingface
backend: kv_default
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
namespace: datasetio::localfs
backend: kv_default
scoring:
- provider_id: basic
provider_type: inline::basic
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:=}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:=}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/oci/files}
metadata_store:
table_name: files_metadata
backend: sql_default
storage:
backends:
kv_default:
type: kv_sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/oci}/kvstore.db
sql_default:
type: sql_sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/oci}/sql_store.db
stores:
metadata:
namespace: registry
backend: kv_default
inference:
table_name: inference_store
backend: sql_default
max_write_queue_size: 10000
num_writers: 4
conversations:
table_name: openai_conversations
backend: sql_default
prompts:
namespace: prompts
backend: kv_default
registered_resources:
models: []
shields: []
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
server:
port: 8321
telemetry:
enabled: true

View file

@ -5,8 +5,6 @@
# the root directory of this source tree.
from llama_stack.apis.datasets import DatasetPurpose, URIDataSource
from llama_stack.apis.models import ModelType
from llama_stack.core.datatypes import (
BenchmarkInput,
BuildProvider,
@ -34,6 +32,7 @@ from llama_stack.providers.remote.vector_io.pgvector.config import (
PGVectorVectorIOConfig,
)
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
from llama_stack_api import DatasetPurpose, ModelType, URIDataSource
def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderModelEntry]]]:

View file

@ -27,12 +27,12 @@ providers:
- provider_id: groq
provider_type: remote::groq
config:
url: https://api.groq.com
base_url: https://api.groq.com/openai/v1
api_key: ${env.GROQ_API_KEY:=}
- provider_id: together
provider_type: remote::together
config:
url: https://api.together.xyz/v1
base_url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=}
vector_io:
- provider_id: sqlite-vec

View file

@ -11,7 +11,7 @@ providers:
- provider_id: vllm-inference
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:=http://localhost:8000/v1}
base_url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}

View file

@ -18,44 +18,43 @@ providers:
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
provider_type: remote::cerebras
config:
base_url: https://api.cerebras.ai
base_url: https://api.cerebras.ai/v1
api_key: ${env.CEREBRAS_API_KEY:=}
- provider_id: ${env.OLLAMA_URL:+ollama}
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:=http://localhost:11434}
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
- provider_id: ${env.VLLM_URL:+vllm}
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:=}
base_url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi
config:
url: ${env.TGI_URL:=}
base_url: ${env.TGI_URL:=}
- provider_id: fireworks
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference/v1
base_url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY:=}
- provider_id: together
provider_type: remote::together
config:
url: https://api.together.xyz/v1
base_url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock
provider_type: remote::bedrock
config:
api_key: ${env.AWS_BEDROCK_API_KEY:=}
api_key: ${env.AWS_BEARER_TOKEN_BEDROCK:=}
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: openai
provider_type: remote::openai
config:
@ -77,18 +76,18 @@ providers:
- provider_id: groq
provider_type: remote::groq
config:
url: https://api.groq.com
base_url: https://api.groq.com/openai/v1
api_key: ${env.GROQ_API_KEY:=}
- provider_id: sambanova
provider_type: remote::sambanova
config:
url: https://api.sambanova.ai/v1
base_url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY:=}
- provider_id: ${env.AZURE_API_KEY:+azure}
provider_type: remote::azure
config:
api_key: ${env.AZURE_API_KEY:=}
api_base: ${env.AZURE_API_BASE:=}
base_url: ${env.AZURE_API_BASE:=}
api_version: ${env.AZURE_API_VERSION:=}
api_type: ${env.AZURE_API_TYPE:=}
- provider_id: sentence-transformers
@ -169,20 +168,15 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sql_postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
responses_store:
type: sql_postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
persistence:
agent_state:
namespace: agents
backend: kv_default
responses:
table_name: responses
backend: sql_default
max_write_queue_size: 10000
num_writers: 4
post_training:
- provider_id: huggingface-gpu
provider_type: inline::huggingface-gpu
@ -241,10 +235,10 @@ providers:
config:
kvstore:
namespace: batches
backend: kv_postgres
backend: kv_default
storage:
backends:
kv_postgres:
kv_default:
type: kv_postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
@ -252,7 +246,7 @@ storage:
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
table_name: ${env.POSTGRES_TABLE_NAME:=llamastack_kvstore}
sql_postgres:
sql_default:
type: sql_postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
@ -262,27 +256,44 @@ storage:
stores:
metadata:
namespace: registry
backend: kv_postgres
backend: kv_default
inference:
table_name: inference_store
backend: sql_postgres
backend: sql_default
max_write_queue_size: 10000
num_writers: 4
conversations:
table_name: openai_conversations
backend: sql_postgres
backend: sql_default
prompts:
namespace: prompts
backend: kv_postgres
backend: kv_default
registered_resources:
models: []
shields: []
shields:
- shield_id: llama-guard
provider_id: ${env.SAFETY_MODEL:+llama-guard}
provider_shield_id: ${env.SAFETY_MODEL:=}
- shield_id: code-scanner
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
telemetry:
enabled: true
vector_stores:
default_provider_id: faiss
default_embedding_model:
provider_id: sentence-transformers
model_id: nomic-ai/nomic-embed-text-v1.5
safety:
default_shield_id: llama-guard

View file

@ -18,44 +18,43 @@ providers:
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
provider_type: remote::cerebras
config:
base_url: https://api.cerebras.ai
base_url: https://api.cerebras.ai/v1
api_key: ${env.CEREBRAS_API_KEY:=}
- provider_id: ${env.OLLAMA_URL:+ollama}
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:=http://localhost:11434}
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
- provider_id: ${env.VLLM_URL:+vllm}
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:=}
base_url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi
config:
url: ${env.TGI_URL:=}
base_url: ${env.TGI_URL:=}
- provider_id: fireworks
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference/v1
base_url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY:=}
- provider_id: together
provider_type: remote::together
config:
url: https://api.together.xyz/v1
base_url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock
provider_type: remote::bedrock
config:
api_key: ${env.AWS_BEDROCK_API_KEY:=}
api_key: ${env.AWS_BEARER_TOKEN_BEDROCK:=}
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: openai
provider_type: remote::openai
config:
@ -77,18 +76,18 @@ providers:
- provider_id: groq
provider_type: remote::groq
config:
url: https://api.groq.com
base_url: https://api.groq.com/openai/v1
api_key: ${env.GROQ_API_KEY:=}
- provider_id: sambanova
provider_type: remote::sambanova
config:
url: https://api.sambanova.ai/v1
base_url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY:=}
- provider_id: ${env.AZURE_API_KEY:+azure}
provider_type: remote::azure
config:
api_key: ${env.AZURE_API_KEY:=}
api_base: ${env.AZURE_API_BASE:=}
base_url: ${env.AZURE_API_BASE:=}
api_version: ${env.AZURE_API_VERSION:=}
api_type: ${env.AZURE_API_TYPE:=}
- provider_id: sentence-transformers

View file

@ -18,44 +18,43 @@ providers:
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
provider_type: remote::cerebras
config:
base_url: https://api.cerebras.ai
base_url: https://api.cerebras.ai/v1
api_key: ${env.CEREBRAS_API_KEY:=}
- provider_id: ${env.OLLAMA_URL:+ollama}
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:=http://localhost:11434}
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
- provider_id: ${env.VLLM_URL:+vllm}
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:=}
base_url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi
config:
url: ${env.TGI_URL:=}
base_url: ${env.TGI_URL:=}
- provider_id: fireworks
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference/v1
base_url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY:=}
- provider_id: together
provider_type: remote::together
config:
url: https://api.together.xyz/v1
base_url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock
provider_type: remote::bedrock
config:
api_key: ${env.AWS_BEDROCK_API_KEY:=}
api_key: ${env.AWS_BEARER_TOKEN_BEDROCK:=}
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: openai
provider_type: remote::openai
config:
@ -77,18 +76,18 @@ providers:
- provider_id: groq
provider_type: remote::groq
config:
url: https://api.groq.com
base_url: https://api.groq.com/openai/v1
api_key: ${env.GROQ_API_KEY:=}
- provider_id: sambanova
provider_type: remote::sambanova
config:
url: https://api.sambanova.ai/v1
base_url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY:=}
- provider_id: ${env.AZURE_API_KEY:+azure}
provider_type: remote::azure
config:
api_key: ${env.AZURE_API_KEY:=}
api_base: ${env.AZURE_API_BASE:=}
base_url: ${env.AZURE_API_BASE:=}
api_version: ${env.AZURE_API_VERSION:=}
api_type: ${env.AZURE_API_TYPE:=}
- provider_id: sentence-transformers
@ -169,20 +168,15 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sql_postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
responses_store:
type: sql_postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
persistence:
agent_state:
namespace: agents
backend: kv_default
responses:
table_name: responses
backend: sql_default
max_write_queue_size: 10000
num_writers: 4
post_training:
- provider_id: torchtune-cpu
provider_type: inline::torchtune-cpu
@ -238,10 +232,10 @@ providers:
config:
kvstore:
namespace: batches
backend: kv_postgres
backend: kv_default
storage:
backends:
kv_postgres:
kv_default:
type: kv_postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
@ -249,7 +243,7 @@ storage:
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
table_name: ${env.POSTGRES_TABLE_NAME:=llamastack_kvstore}
sql_postgres:
sql_default:
type: sql_postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
@ -259,27 +253,44 @@ storage:
stores:
metadata:
namespace: registry
backend: kv_postgres
backend: kv_default
inference:
table_name: inference_store
backend: sql_postgres
backend: sql_default
max_write_queue_size: 10000
num_writers: 4
conversations:
table_name: openai_conversations
backend: sql_postgres
backend: sql_default
prompts:
namespace: prompts
backend: kv_postgres
backend: kv_default
registered_resources:
models: []
shields: []
shields:
- shield_id: llama-guard
provider_id: ${env.SAFETY_MODEL:+llama-guard}
provider_shield_id: ${env.SAFETY_MODEL:=}
- shield_id: code-scanner
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
telemetry:
enabled: true
vector_stores:
default_provider_id: faiss
default_embedding_model:
provider_id: sentence-transformers
model_id: nomic-ai/nomic-embed-text-v1.5
safety:
default_shield_id: llama-guard

View file

@ -18,44 +18,43 @@ providers:
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
provider_type: remote::cerebras
config:
base_url: https://api.cerebras.ai
base_url: https://api.cerebras.ai/v1
api_key: ${env.CEREBRAS_API_KEY:=}
- provider_id: ${env.OLLAMA_URL:+ollama}
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:=http://localhost:11434}
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
- provider_id: ${env.VLLM_URL:+vllm}
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:=}
base_url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi
config:
url: ${env.TGI_URL:=}
base_url: ${env.TGI_URL:=}
- provider_id: fireworks
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference/v1
base_url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY:=}
- provider_id: together
provider_type: remote::together
config:
url: https://api.together.xyz/v1
base_url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock
provider_type: remote::bedrock
config:
api_key: ${env.AWS_BEDROCK_API_KEY:=}
api_key: ${env.AWS_BEARER_TOKEN_BEDROCK:=}
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: openai
provider_type: remote::openai
config:
@ -77,18 +76,18 @@ providers:
- provider_id: groq
provider_type: remote::groq
config:
url: https://api.groq.com
base_url: https://api.groq.com/openai/v1
api_key: ${env.GROQ_API_KEY:=}
- provider_id: sambanova
provider_type: remote::sambanova
config:
url: https://api.sambanova.ai/v1
base_url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY:=}
- provider_id: ${env.AZURE_API_KEY:+azure}
provider_type: remote::azure
config:
api_key: ${env.AZURE_API_KEY:=}
api_base: ${env.AZURE_API_BASE:=}
base_url: ${env.AZURE_API_BASE:=}
api_version: ${env.AZURE_API_VERSION:=}
api_type: ${env.AZURE_API_TYPE:=}
- provider_id: sentence-transformers

View file

@ -17,14 +17,10 @@ from llama_stack.core.datatypes import (
ToolGroupInput,
VectorStoresConfig,
)
from llama_stack.core.storage.datatypes import (
InferenceStoreReference,
KVStoreReference,
SqlStoreReference,
)
from llama_stack.core.storage.kvstore.config import PostgresKVStoreConfig
from llama_stack.core.storage.sqlstore.sqlstore import PostgresSqlStoreConfig
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
from llama_stack.providers.datatypes import RemoteProviderSpec
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
@ -41,8 +37,7 @@ from llama_stack.providers.remote.vector_io.pgvector.config import (
)
from llama_stack.providers.remote.vector_io.qdrant.config import QdrantVectorIOConfig
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
from llama_stack.providers.utils.kvstore.config import PostgresKVStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
from llama_stack_api import RemoteProviderSpec
def _get_config_for_provider(provider_spec: ProviderSpec) -> dict[str, Any]:
@ -155,10 +150,11 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
BuildProvider(provider_type="inline::reference"),
],
}
files_config = LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}")
files_provider = Provider(
provider_id="meta-reference-files",
provider_type="inline::localfs",
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
config=files_config,
)
embedding_provider = Provider(
provider_id="sentence-transformers",
@ -188,7 +184,8 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
provider_shield_id="${env.CODE_SCANNER_MODEL:=}",
),
]
postgres_config = PostgresSqlStoreConfig.sample_run_config()
postgres_sql_config = PostgresSqlStoreConfig.sample_run_config()
postgres_kv_config = PostgresKVStoreConfig.sample_run_config()
default_overrides = {
"inference": remote_inference_providers + [embedding_provider],
"vector_io": [
@ -245,6 +242,33 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
"files": [files_provider],
}
base_run_settings = RunConfigSettings(
provider_overrides=default_overrides,
default_models=[],
default_tool_groups=default_tool_groups,
default_shields=default_shields,
vector_stores_config=VectorStoresConfig(
default_provider_id="faiss",
default_embedding_model=QualifiedModel(
provider_id="sentence-transformers",
model_id="nomic-ai/nomic-embed-text-v1.5",
),
),
safety_config=SafetyConfig(
default_shield_id="llama-guard",
),
)
postgres_run_settings = base_run_settings.model_copy(
update={
"storage_backends": {
"kv_default": postgres_kv_config,
"sql_default": postgres_sql_config,
}
},
deep=True,
)
return DistributionTemplate(
name=name,
distro_type="self_hosted",
@ -254,71 +278,8 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
providers=providers,
additional_pip_packages=list(set(PostgresSqlStoreConfig.pip_packages() + PostgresKVStoreConfig.pip_packages())),
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides=default_overrides,
default_models=[],
default_tool_groups=default_tool_groups,
default_shields=default_shields,
vector_stores_config=VectorStoresConfig(
default_provider_id="faiss",
default_embedding_model=QualifiedModel(
provider_id="sentence-transformers",
model_id="nomic-ai/nomic-embed-text-v1.5",
),
),
safety_config=SafetyConfig(
default_shield_id="llama-guard",
),
),
"run-with-postgres-store.yaml": RunConfigSettings(
provider_overrides={
**default_overrides,
"agents": [
Provider(
provider_id="meta-reference",
provider_type="inline::meta-reference",
config=dict(
persistence_store=postgres_config,
responses_store=postgres_config,
),
)
],
"batches": [
Provider(
provider_id="reference",
provider_type="inline::reference",
config=dict(
kvstore=KVStoreReference(
backend="kv_postgres",
namespace="batches",
).model_dump(exclude_none=True),
),
)
],
},
storage_backends={
"kv_postgres": PostgresKVStoreConfig.sample_run_config(),
"sql_postgres": postgres_config,
},
storage_stores={
"metadata": KVStoreReference(
backend="kv_postgres",
namespace="registry",
).model_dump(exclude_none=True),
"inference": InferenceStoreReference(
backend="sql_postgres",
table_name="inference_store",
).model_dump(exclude_none=True),
"conversations": SqlStoreReference(
backend="sql_postgres",
table_name="openai_conversations",
).model_dump(exclude_none=True),
"prompts": KVStoreReference(
backend="kv_postgres",
namespace="prompts",
).model_dump(exclude_none=True),
},
),
"run.yaml": base_run_settings,
"run-with-postgres-store.yaml": postgres_run_settings,
},
run_config_env_vars={
"LLAMA_STACK_PORT": (

View file

@ -12,8 +12,6 @@ import rich
import yaml
from pydantic import BaseModel, Field
from llama_stack.apis.datasets import DatasetPurpose
from llama_stack.apis.models import ModelType
from llama_stack.core.datatypes import (
LLAMA_STACK_RUN_CONFIG_VERSION,
Api,
@ -37,13 +35,14 @@ from llama_stack.core.storage.datatypes import (
SqlStoreReference,
StorageBackendType,
)
from llama_stack.core.storage.kvstore.config import SqliteKVStoreConfig
from llama_stack.core.storage.kvstore.config import get_pip_packages as get_kv_pip_packages
from llama_stack.core.storage.sqlstore.sqlstore import SqliteSqlStoreConfig
from llama_stack.core.storage.sqlstore.sqlstore import get_pip_packages as get_sql_pip_packages
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.core.utils.image_types import LlamaStackImageType
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.config import get_pip_packages as get_kv_pip_packages
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import get_pip_packages as get_sql_pip_packages
from llama_stack_api import DatasetPurpose, ModelType
def filter_empty_values(obj: Any) -> Any:

View file

@ -15,7 +15,7 @@ providers:
- provider_id: watsonx
provider_type: remote::watsonx
config:
url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}
base_url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}
api_key: ${env.WATSONX_API_KEY:=}
project_id: ${env.WATSONX_PROJECT_ID:=}
vector_io:

View file

@ -26,8 +26,10 @@ from fairscale.nn.model_parallel.initialize import (
)
from termcolor import cprint
from llama_stack.models.llama.datatypes import ToolPromptFormat
from ..checkpoint import maybe_reshard_state_dict
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage, ToolPromptFormat
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage
from .args import ModelArgs
from .chat_format import ChatFormat, LLMInput
from .model import Transformer

View file

@ -15,13 +15,10 @@ from pathlib import Path
from termcolor import colored
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall, ToolDefinition, ToolPromptFormat
from ..datatypes import (
BuiltinTool,
RawMessage,
StopReason,
ToolCall,
ToolDefinition,
ToolPromptFormat,
)
from . import template_data
from .chat_format import ChatFormat

Some files were not shown because too many files have changed in this diff Show more