Merge branch 'main' into HuggingfacePostTrainingConfig-branch

This commit is contained in:
Sarthak Deshpande 2025-08-25 11:59:15 +05:30 committed by GitHub
commit d0d737680f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
193 changed files with 7108 additions and 881 deletions

View file

@ -29,12 +29,16 @@ class ListBatchesResponse(BaseModel):
@runtime_checkable
class Batches(Protocol):
"""Protocol for batch processing API operations.
"""
The Batches API enables efficient processing of multiple requests in a single operation,
particularly useful for processing large datasets, batch evaluation workflows, and
cost-effective inference at scale.
The API is designed to allow use of openai client libraries for seamless integration.
This API provides the following extensions:
- idempotent batch creation
Note: This API is currently under active development and may undergo changes.
"""
@ -45,6 +49,7 @@ class Batches(Protocol):
endpoint: str,
completion_window: Literal["24h"],
metadata: dict[str, str] | None = None,
idempotency_key: str | None = None,
) -> BatchObject:
"""Create a new batch for processing multiple API requests.
@ -52,6 +57,7 @@ class Batches(Protocol):
:param endpoint: The endpoint to be used for all requests in the batch.
:param completion_window: The time window within which the batch should be processed.
:param metadata: Optional metadata for the batch.
:param idempotency_key: Optional idempotency key. When provided, enables idempotent behavior.
:returns: The created batch object.
"""
...

View file

@ -473,6 +473,28 @@ class EmbeddingsResponse(BaseModel):
embeddings: list[list[float]]
@json_schema_type
class RerankData(BaseModel):
"""A single rerank result from a reranking response.
:param index: The original index of the document in the input list
:param relevance_score: The relevance score from the model output. Values are inverted when applicable so that higher scores indicate greater relevance.
"""
index: int
relevance_score: float
@json_schema_type
class RerankResponse(BaseModel):
"""Response from a reranking request.
:param data: List of rerank result objects, sorted by relevance score (descending)
"""
data: list[RerankData]
@json_schema_type
class OpenAIChatCompletionContentPartTextParam(BaseModel):
"""Text content part for OpenAI-compatible chat completion messages.
@ -1046,6 +1068,7 @@ class InferenceProvider(Protocol):
:returns: A BatchCompletionResponse with the full completions.
"""
raise NotImplementedError("Batch completion is not implemented")
return # this is so mypy's safe-super rule will consider the method concrete
@webmethod(route="/inference/chat-completion", method="POST")
async def chat_completion(
@ -1110,6 +1133,7 @@ class InferenceProvider(Protocol):
:returns: A BatchChatCompletionResponse with the full completions.
"""
raise NotImplementedError("Batch chat completion is not implemented")
return # this is so mypy's safe-super rule will consider the method concrete
@webmethod(route="/inference/embeddings", method="POST")
async def embeddings(
@ -1131,6 +1155,25 @@ class InferenceProvider(Protocol):
"""
...
@webmethod(route="/inference/rerank", method="POST", experimental=True)
async def rerank(
self,
model: str,
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
max_num_results: int | None = None,
) -> RerankResponse:
"""Rerank a list of documents based on their relevance to a query.
:param model: The identifier of the reranking model to use.
:param query: The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length.
:param items: List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length.
:param max_num_results: (Optional) Maximum number of results to return. Default: returns all.
:returns: RerankResponse with indices sorted by relevance score (descending).
"""
raise NotImplementedError("Reranking is not implemented")
return # this is so mypy's safe-super rule will consider the method concrete
@webmethod(route="/openai/v1/completions", method="POST")
async def openai_completion(
self,

View file

@ -386,6 +386,7 @@ class MetricDataPoint(BaseModel):
timestamp: int
value: float
unit: str
@json_schema_type
@ -518,7 +519,7 @@ class Telemetry(Protocol):
metric_name: str,
start_time: int,
end_time: int | None = None,
granularity: str | None = "1d",
granularity: str | None = None,
query_type: MetricQueryType = MetricQueryType.RANGE,
label_matchers: list[MetricLabelMatcher] | None = None,
) -> QueryMetricsResponse:

View file

@ -15,7 +15,7 @@ from llama_stack.log import get_logger
REPO_ROOT = Path(__file__).parent.parent.parent.parent
logger = get_logger(name=__name__, category="server")
logger = get_logger(name=__name__, category="cli")
class StackRun(Subcommand):

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import importlib.resources
import logging
import sys
from pydantic import BaseModel
@ -17,9 +16,10 @@ from llama_stack.core.external import load_external_apis
from llama_stack.core.utils.exec import run_command
from llama_stack.core.utils.image_types import LlamaStackImageType
from llama_stack.distributions.template import DistributionTemplate
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="core")
# These are the dependencies needed by the distribution server.
# `llama-stack` is automatically installed by the installation script.

View file

@ -3,7 +3,6 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import textwrap
from typing import Any
@ -21,9 +20,10 @@ from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.core.utils.prompt_for_config import prompt_for_config
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, ProviderSpec
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="core")
def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provider) -> Provider:

View file

@ -318,6 +318,41 @@ class QuotaConfig(BaseModel):
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
class CORSConfig(BaseModel):
allow_origins: list[str] = Field(default_factory=list)
allow_origin_regex: str | None = Field(default=None)
allow_methods: list[str] = Field(default=["OPTIONS"])
allow_headers: list[str] = Field(default_factory=list)
allow_credentials: bool = Field(default=False)
expose_headers: list[str] = Field(default_factory=list)
max_age: int = Field(default=600, ge=0)
@model_validator(mode="after")
def validate_credentials_config(self) -> Self:
if self.allow_credentials and (self.allow_origins == ["*"] or "*" in self.allow_origins):
raise ValueError("Cannot use wildcard origins with credentials enabled")
return self
def process_cors_config(cors_config: bool | CORSConfig | None) -> CORSConfig | None:
if cors_config is False or cors_config is None:
return None
if cors_config is True:
# dev mode: allow localhost on any port
return CORSConfig(
allow_origins=[],
allow_origin_regex=r"https?://localhost:\d+",
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_headers=["Content-Type", "Authorization", "X-Requested-With"],
)
if isinstance(cors_config, CORSConfig):
return cors_config
raise ValueError(f"Expected bool or CORSConfig, got {type(cors_config).__name__}")
class ServerConfig(BaseModel):
port: int = Field(
default=8321,
@ -349,6 +384,12 @@ class ServerConfig(BaseModel):
default=None,
description="Per client quota request configuration",
)
cors: bool | CORSConfig | None = Field(
default=None,
description="CORS configuration for cross-origin requests. Can be:\n"
"- true: Enable localhost CORS for development\n"
"- {allow_origins: [...], allow_methods: [...], ...}: Full configuration",
)
class StackRunConfig(BaseModel):

View file

@ -7,7 +7,7 @@
import asyncio
import inspect
import json
import logging
import logging # allow-direct-logging
import os
import sys
from concurrent.futures import ThreadPoolExecutor
@ -48,6 +48,7 @@ from llama_stack.core.stack import (
from llama_stack.core.utils.config import redact_sensitive_fields
from llama_stack.core.utils.context import preserve_contexts_async_generator
from llama_stack.core.utils.exec import in_notebook
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry.tracing import (
CURRENT_TRACE_CONTEXT,
end_trace,
@ -55,7 +56,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
start_trace,
)
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="core")
T = TypeVar("T")
@ -145,39 +146,26 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
):
super().__init__()
self.async_client = AsyncLlamaStackAsLibraryClient(
config_path_or_distro_name, custom_provider_registry, provider_data
config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal
)
self.pool_executor = ThreadPoolExecutor(max_workers=4)
self.skip_logger_removal = skip_logger_removal
self.provider_data = provider_data
self.loop = asyncio.new_event_loop()
def initialize(self):
if in_notebook():
import nest_asyncio
nest_asyncio.apply()
if not self.skip_logger_removal:
self._remove_root_logger_handlers()
# use a new event loop to avoid interfering with the main event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(self.async_client.initialize())
loop.run_until_complete(self.async_client.initialize())
finally:
asyncio.set_event_loop(None)
def _remove_root_logger_handlers(self):
def initialize(self):
"""
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
Deprecated method for backward compatibility.
"""
root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
pass
def request(self, *args, **kwargs):
loop = self.loop
@ -215,6 +203,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
config_path_or_distro_name: str,
custom_provider_registry: ProviderRegistry | None = None,
provider_data: dict[str, Any] | None = None,
skip_logger_removal: bool = False,
):
super().__init__()
# when using the library client, we should not log to console since many
@ -222,6 +211,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
if in_notebook():
import nest_asyncio
nest_asyncio.apply()
if not skip_logger_removal:
self._remove_root_logger_handlers()
if config_path_or_distro_name.endswith(".yaml"):
config_path = Path(config_path_or_distro_name)
if not config_path.exists():
@ -238,7 +234,24 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
self.provider_data = provider_data
self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError
def _remove_root_logger_handlers(self):
"""
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
"""
root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
async def initialize(self) -> bool:
"""
Initialize the async client.
Returns:
bool: True if initialization was successful
"""
try:
self.route_impls = None
self.impls = await construct_stack(self.config, self.custom_provider_registry)

View file

@ -6,15 +6,15 @@
import contextvars
import json
import logging
from contextlib import AbstractContextManager
from typing import Any
from llama_stack.core.datatypes import User
from llama_stack.log import get_logger
from .utils.dynamic import instantiate_class_type
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="core")
# Context variable for request provider data and auth attributes
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)

View file

@ -12,7 +12,7 @@ from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routers")
class DatasetIORouter(DatasetIO):

View file

@ -16,7 +16,7 @@ from llama_stack.apis.scoring import (
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routers")
class ScoringRouter(Scoring):

View file

@ -65,7 +65,7 @@ from llama_stack.providers.datatypes import HealthResponse, HealthStatus, Routin
from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack.providers.utils.telemetry.tracing import get_current_span
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="core::routers")
class InferenceRouter(Inference):

View file

@ -13,7 +13,7 @@ from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routers")
class SafetyRouter(Safety):

View file

@ -22,7 +22,7 @@ from llama_stack.log import get_logger
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routers")
class ToolRuntimeRouter(ToolRuntime):

View file

@ -30,7 +30,7 @@ from llama_stack.apis.vector_io import (
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routers")
class VectorIORouter(VectorIO):

View file

@ -14,7 +14,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):

View file

@ -23,7 +23,7 @@ from llama_stack.core.store import DistributionRegistry
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, RoutingTable
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
def get_impl_api(p: Any) -> Api:

View file

@ -26,7 +26,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):

View file

@ -17,7 +17,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
class ModelsRoutingTable(CommonRoutingTableImpl, Models):

View file

@ -19,7 +19,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):

View file

@ -15,7 +15,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):

View file

@ -14,7 +14,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:

View file

@ -30,7 +30,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):

View file

@ -15,7 +15,7 @@ from llama_stack.core.server.auth_providers import create_auth_provider
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
logger = get_logger(name=__name__, category="core::auth")
class AuthenticationMiddleware:

View file

@ -23,7 +23,7 @@ from llama_stack.core.datatypes import (
)
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
logger = get_logger(name=__name__, category="core::auth")
class AuthResponse(BaseModel):

View file

@ -15,7 +15,7 @@ from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
logger = get_logger(name=__name__, category="quota")
logger = get_logger(name=__name__, category="core::server")
class QuotaMiddleware:

View file

@ -9,7 +9,7 @@ import asyncio
import functools
import inspect
import json
import logging
import logging # allow-direct-logging
import os
import ssl
import sys
@ -28,6 +28,7 @@ from aiohttp import hdrs
from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from openai import BadRequestError
from pydantic import BaseModel, ValidationError
@ -40,6 +41,7 @@ from llama_stack.core.datatypes import (
AuthenticationRequiredError,
LoggingConfig,
StackRunConfig,
process_cors_config,
)
from llama_stack.core.distribution import builtin_automatically_routed_apis
from llama_stack.core.external import ExternalApiSpec, load_external_apis
@ -82,7 +84,7 @@ from .quota import QuotaMiddleware
REPO_ROOT = Path(__file__).parent.parent.parent.parent
logger = get_logger(name=__name__, category="server")
logger = get_logger(name=__name__, category="core::server")
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
@ -413,7 +415,7 @@ def main(args: argparse.Namespace | None = None):
config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg)
logger = get_logger(name=__name__, category="server", config=logger_config)
logger = get_logger(name=__name__, category="core::server", config=logger_config)
if args.env:
for env_pair in args.env:
try:
@ -483,6 +485,12 @@ def main(args: argparse.Namespace | None = None):
window_seconds=window_seconds,
)
if config.server.cors:
logger.info("Enabling CORS")
cors_config = process_cors_config(config.server.cors)
if cors_config:
app.add_middleware(CORSMiddleware, **cors_config.model_dump())
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
else:

View file

@ -16,7 +16,7 @@ from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
logger = get_logger(__name__, category="core")
logger = get_logger(__name__, category="core::registry")
class DistributionRegistry(Protocol):

View file

@ -10,7 +10,7 @@ from pathlib import Path
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="config_resolution")
logger = get_logger(name=__name__, category="core")
DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions"

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import importlib
import os
import signal
import subprocess
@ -12,9 +12,9 @@ import sys
from termcolor import cprint
log = logging.getLogger(__name__)
from llama_stack.log import get_logger
import importlib
log = get_logger(name=__name__, category="core")
def formulate_run_args(image_type: str, image_name: str) -> list:

View file

@ -6,7 +6,6 @@
import inspect
import json
import logging
from enum import Enum
from typing import Annotated, Any, Literal, Union, get_args, get_origin
@ -14,7 +13,9 @@ from pydantic import BaseModel
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefinedType
log = logging.getLogger(__name__)
from llama_stack.log import get_logger
log = get_logger(name=__name__, category="core")
def is_list_of_primitives(field_type):

View file

@ -34,7 +34,7 @@ distribution_spec:
telemetry:
- provider_type: inline::meta-reference
post_training:
- provider_type: inline::huggingface
- provider_type: inline::huggingface-cpu
eval:
- provider_type: inline::meta-reference
datasetio:

View file

@ -156,8 +156,8 @@ providers:
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
post_training:
- provider_id: huggingface
provider_type: inline::huggingface
- provider_id: huggingface-cpu
provider_type: inline::huggingface-cpu
config:
checkpoint_format: huggingface
distributed_backend: null

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .starter_gpu import get_distribution_template # noqa: F401

View file

@ -0,0 +1,59 @@
version: 2
distribution_spec:
description: Quick start template for running Llama Stack with several popular providers.
This distribution is intended for GPU-enabled environments.
providers:
inference:
- provider_type: remote::cerebras
- provider_type: remote::ollama
- provider_type: remote::vllm
- provider_type: remote::tgi
- provider_type: remote::fireworks
- provider_type: remote::together
- provider_type: remote::bedrock
- provider_type: remote::nvidia
- provider_type: remote::openai
- provider_type: remote::anthropic
- provider_type: remote::gemini
- provider_type: remote::vertexai
- provider_type: remote::groq
- provider_type: remote::sambanova
- provider_type: inline::sentence-transformers
vector_io:
- provider_type: inline::faiss
- provider_type: inline::sqlite-vec
- provider_type: inline::milvus
- provider_type: remote::chromadb
- provider_type: remote::pgvector
files:
- provider_type: inline::localfs
safety:
- provider_type: inline::llama-guard
- provider_type: inline::code-scanner
agents:
- provider_type: inline::meta-reference
telemetry:
- provider_type: inline::meta-reference
post_training:
- provider_type: inline::torchtune-gpu
eval:
- provider_type: inline::meta-reference
datasetio:
- provider_type: remote::huggingface
- provider_type: inline::localfs
scoring:
- provider_type: inline::basic
- provider_type: inline::llm-as-judge
- provider_type: inline::braintrust
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol
batches:
- provider_type: inline::reference
image_type: venv
additional_pip_packages:
- aiosqlite
- asyncpg
- sqlalchemy[asyncio]

View file

@ -0,0 +1,238 @@
version: 2
image_name: starter-gpu
apis:
- agents
- batches
- datasetio
- eval
- files
- inference
- post_training
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
provider_type: remote::cerebras
config:
base_url: https://api.cerebras.ai
api_key: ${env.CEREBRAS_API_KEY:=}
- provider_id: ${env.OLLAMA_URL:+ollama}
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:=http://localhost:11434}
- provider_id: ${env.VLLM_URL:+vllm}
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi
config:
url: ${env.TGI_URL:=}
- provider_id: fireworks
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY:=}
- provider_id: together
provider_type: remote::together
config:
url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock
provider_type: remote::bedrock
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: openai
provider_type: remote::openai
config:
api_key: ${env.OPENAI_API_KEY:=}
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
- provider_id: anthropic
provider_type: remote::anthropic
config:
api_key: ${env.ANTHROPIC_API_KEY:=}
- provider_id: gemini
provider_type: remote::gemini
config:
api_key: ${env.GEMINI_API_KEY:=}
- provider_id: ${env.VERTEX_AI_PROJECT:+vertexai}
provider_type: remote::vertexai
config:
project: ${env.VERTEX_AI_PROJECT:=}
location: ${env.VERTEX_AI_LOCATION:=us-central1}
- provider_id: groq
provider_type: remote::groq
config:
url: https://api.groq.com
api_key: ${env.GROQ_API_KEY:=}
- provider_id: sambanova
provider_type: remote::sambanova
config:
url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY:=}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/faiss_store.db
- provider_id: sqlite-vec
provider_type: inline::sqlite-vec
config:
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db
- provider_id: ${env.MILVUS_URL:+milvus}
provider_type: inline::milvus
config:
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter}/milvus.db
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/milvus_registry.db
- provider_id: ${env.CHROMADB_URL:+chromadb}
provider_type: remote::chromadb
config:
url: ${env.CHROMADB_URL:=}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter/}/chroma_remote_registry.db
- provider_id: ${env.PGVECTOR_DB:+pgvector}
provider_type: remote::pgvector
config:
host: ${env.PGVECTOR_HOST:=localhost}
port: ${env.PGVECTOR_PORT:=5432}
db: ${env.PGVECTOR_DB:=}
user: ${env.PGVECTOR_USER:=}
password: ${env.PGVECTOR_PASSWORD:=}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config:
excluded_categories: []
- provider_id: code-scanner
provider_type: inline::code-scanner
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/agents_store.db
responses_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/responses_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
post_training:
- provider_id: torchtune-gpu
provider_type: inline::torchtune-gpu
config:
checkpoint_format: meta
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/meta_reference_eval.db
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/huggingface_datasetio.db
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/localfs_datasetio.db
scoring:
- provider_id: basic
provider_type: inline::basic
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:=}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:=}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
batches:
- provider_id: reference
provider_type: inline::reference
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/batches.db
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/registry.db
inference_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/inference_store.db
models: []
shields:
- shield_id: llama-guard
provider_id: ${env.SAFETY_MODEL:+llama-guard}
provider_shield_id: ${env.SAFETY_MODEL:=}
- shield_id: code-scanner
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321

View file

@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.distributions.template import BuildProvider, DistributionTemplate
from ..starter.starter import get_distribution_template as get_starter_distribution_template
def get_distribution_template() -> DistributionTemplate:
template = get_starter_distribution_template()
name = "starter-gpu"
template.name = name
template.description = "Quick start template for running Llama Stack with several popular providers. This distribution is intended for GPU-enabled environments."
template.providers["post_training"] = [
BuildProvider(provider_type="inline::torchtune-gpu"),
]
return template

View file

@ -1,6 +1,7 @@
version: 2
distribution_spec:
description: Quick start template for running Llama Stack with several popular providers
description: Quick start template for running Llama Stack with several popular providers.
This distribution is intended for CPU-only environments.
providers:
inference:
- provider_type: remote::cerebras
@ -34,7 +35,7 @@ distribution_spec:
telemetry:
- provider_type: inline::meta-reference
post_training:
- provider_type: inline::huggingface
- provider_type: inline::huggingface-cpu
eval:
- provider_type: inline::meta-reference
datasetio:

View file

@ -156,8 +156,8 @@ providers:
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
post_training:
- provider_id: huggingface
provider_type: inline::huggingface
- provider_id: huggingface-cpu
provider_type: inline::huggingface-cpu
config:
checkpoint_format: huggingface
distributed_backend: null

View file

@ -120,7 +120,7 @@ def get_distribution_template() -> DistributionTemplate:
],
"agents": [BuildProvider(provider_type="inline::meta-reference")],
"telemetry": [BuildProvider(provider_type="inline::meta-reference")],
"post_training": [BuildProvider(provider_type="inline::huggingface")],
"post_training": [BuildProvider(provider_type="inline::huggingface-cpu")],
"eval": [BuildProvider(provider_type="inline::meta-reference")],
"datasetio": [
BuildProvider(provider_type="remote::huggingface"),
@ -178,7 +178,7 @@ def get_distribution_template() -> DistributionTemplate:
return DistributionTemplate(
name=name,
distro_type="self_hosted",
description="Quick start template for running Llama Stack with several popular providers",
description="Quick start template for running Llama Stack with several popular providers. This distribution is intended for CPU-only environments.",
container_image=None,
template_path=None,
providers=providers,

View file

@ -4,10 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import logging # allow-direct-logging
import os
import re
from logging.config import dictConfig
from logging.config import dictConfig # allow-direct-logging
from rich.console import Console
from rich.errors import MarkupError

View file

@ -13,14 +13,15 @@
# Copyright (c) Meta Platforms, Inc. and its affiliates.
import math
from logging import getLogger
import torch
import torch.nn.functional as F
from llama_stack.log import get_logger
from .utils import get_negative_inf_value, to_2tuple
logger = getLogger()
logger = get_logger(name=__name__, category="models::llama")
def resize_local_position_embedding(orig_pos_embed, grid_size):

View file

@ -13,7 +13,6 @@
import math
from collections import defaultdict
from logging import getLogger
from typing import Any
import torch
@ -21,9 +20,11 @@ import torchvision.transforms as tv
from PIL import Image
from torchvision.transforms import functional as F
from llama_stack.log import get_logger
IMAGE_RES = 224
logger = getLogger()
logger = get_logger(name=__name__, category="models::llama")
class VariableSizeImageTransform:

View file

@ -3,8 +3,6 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import math
from collections.abc import Callable
from functools import partial
@ -22,6 +20,8 @@ from PIL import Image as PIL_Image
from torch import Tensor, nn
from torch.distributed import _functional_collectives as funcol
from llama_stack.log import get_logger
from ..model import ModelArgs, RMSNorm, apply_rotary_emb, precompute_freqs_cis
from .encoder_utils import (
build_encoder_attention_mask,
@ -34,9 +34,10 @@ from .encoder_utils import (
from .image_transform import VariableSizeImageTransform
from .utils import get_negative_inf_value, to_2tuple
logger = logging.getLogger(__name__)
MP_SCALE = 8
logger = get_logger(name=__name__, category="models::llama")
def reduce_from_tensor_model_parallel_region(input_):
"""All-reduce the input tensor across model parallel group."""
@ -771,7 +772,7 @@ class TilePositionEmbedding(nn.Module):
if embed is not None:
# reshape the weights to the correct shape
nt_old, nt_old, _, w = embed.shape
logging.info(f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}")
logger.info(f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}")
embed_new = TilePositionEmbedding._dynamic_resize(embed, self.num_tiles)
# assign the weights to the module
state_dict[prefix + "embedding"] = embed_new

View file

@ -4,8 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Collection, Iterator, Sequence, Set
from logging import getLogger
from pathlib import Path
from typing import (
Literal,
@ -14,11 +14,9 @@ from typing import (
import tiktoken
from llama_stack.log import get_logger
from llama_stack.models.llama.tokenizer_utils import load_bpe_file
logger = getLogger(__name__)
# The tiktoken tokenizer can handle <=400k chars without
# pyo3_runtime.PanicException.
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
@ -31,6 +29,8 @@ MAX_NO_WHITESPACES_CHARS = 25_000
_INSTANCE = None
logger = get_logger(name=__name__, category="models::llama")
class Tokenizer:
"""

View file

@ -11,7 +11,7 @@ from llama_stack.log import get_logger
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="models::llama")
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import os
from collections.abc import Callable
@ -13,11 +12,13 @@ from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
from torch import Tensor, nn
from torch.nn import functional as F
from llama_stack.log import get_logger
from ...datatypes import QuantizationMode
from ..model import Transformer, TransformerBlock
from ..moe import MoE
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="models::llama")
def swiglu_wrapper_no_reduce(

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
from collections.abc import Collection, Iterator, Sequence, Set
from logging import getLogger
from pathlib import Path
from typing import (
Literal,
@ -14,11 +13,9 @@ from typing import (
import tiktoken
from llama_stack.log import get_logger
from llama_stack.models.llama.tokenizer_utils import load_bpe_file
logger = getLogger(__name__)
# The tiktoken tokenizer can handle <=400k chars without
# pyo3_runtime.PanicException.
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
@ -101,6 +98,8 @@ BASIC_SPECIAL_TOKENS = [
"<|fim_suffix|>",
]
logger = get_logger(name=__name__, category="models::llama")
class Tokenizer:
"""

View file

@ -6,9 +6,10 @@
# type: ignore
import collections
import logging
log = logging.getLogger(__name__)
from llama_stack.log import get_logger
log = get_logger(name=__name__, category="models::llama")
try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401

View file

@ -84,7 +84,7 @@ MEMORY_QUERY_TOOL = "knowledge_search"
WEB_SEARCH_TOOL = "web_search"
RAG_TOOL_GROUP = "builtin::rag"
logger = get_logger(name=__name__, category="agents")
logger = get_logger(name=__name__, category="agents::meta_reference")
class ChatAgent(ShieldRunnerMixin):

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import uuid
from collections.abc import AsyncGenerator
from datetime import UTC, datetime
@ -42,6 +41,7 @@ from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.datatypes import AccessRule
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from llama_stack.providers.utils.pagination import paginate_records
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
@ -51,7 +51,7 @@ from .config import MetaReferenceAgentsImplConfig
from .persistence import AgentInfo
from .responses.openai_responses import OpenAIResponsesImpl
logger = logging.getLogger()
logger = get_logger(name=__name__, category="agents::meta_reference")
class MetaReferenceAgentsImpl(Agents):

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import json
import logging
import uuid
from datetime import UTC, datetime
@ -15,9 +14,10 @@ from llama_stack.core.access_control.access_control import AccessDeniedError, is
from llama_stack.core.access_control.datatypes import AccessRule
from llama_stack.core.datatypes import User
from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="agents::meta_reference")
class AgentSessionInfo(Session):

View file

@ -41,7 +41,7 @@ from .utils import (
convert_response_text_to_chat_response_format,
)
logger = get_logger(name=__name__, category="responses")
logger = get_logger(name=__name__, category="openai::responses")
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):

View file

@ -47,7 +47,7 @@ from llama_stack.log import get_logger
from .types import ChatCompletionContext, ChatCompletionResult
from .utils import convert_chat_choice_to_response_message, is_function_tool_call
logger = get_logger(name=__name__, category="responses")
logger = get_logger(name=__name__, category="agents::meta_reference")
class StreamingResponseOrchestrator:

View file

@ -38,7 +38,7 @@ from llama_stack.log import get_logger
from .types import ChatCompletionContext, ToolExecutionResult
logger = get_logger(name=__name__, category="responses")
logger = get_logger(name=__name__, category="agents::meta_reference")
class ToolExecutor:

View file

@ -17,6 +17,8 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseOutputMessageContent,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPCall,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseText,
)
from llama_stack.apis.inference import (
@ -99,14 +101,22 @@ async def convert_response_input_to_chat_messages(
"""
messages: list[OpenAIMessageParam] = []
if isinstance(input, list):
# extract all OpenAIResponseInputFunctionToolCallOutput items
# so their corresponding OpenAIToolMessageParam instances can
# be added immediately following the corresponding
# OpenAIAssistantMessageParam
tool_call_results = {}
for input_item in input:
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
messages.append(
OpenAIToolMessageParam(
content=input_item.output,
tool_call_id=input_item.call_id,
)
tool_call_results[input_item.call_id] = OpenAIToolMessageParam(
content=input_item.output,
tool_call_id=input_item.call_id,
)
for input_item in input:
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
# skip as these have been extracted and inserted in order
pass
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
tool_call = OpenAIChatCompletionToolCall(
index=0,
@ -117,6 +127,28 @@ async def convert_response_input_to_chat_messages(
),
)
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
if input_item.call_id in tool_call_results:
messages.append(tool_call_results[input_item.call_id])
del tool_call_results[input_item.call_id]
elif isinstance(input_item, OpenAIResponseOutputMessageMCPCall):
tool_call = OpenAIChatCompletionToolCall(
index=0,
id=input_item.id,
function=OpenAIChatCompletionToolCallFunction(
name=input_item.name,
arguments=input_item.arguments,
),
)
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
messages.append(
OpenAIToolMessageParam(
content=input_item.output,
tool_call_id=input_item.id,
)
)
elif isinstance(input_item, OpenAIResponseOutputMessageMCPListTools):
# the tool list will be handled separately
pass
else:
content = await convert_response_content_to_chat_content(input_item.content)
message_type = await get_message_type_by_role(input_item.role)
@ -125,6 +157,10 @@ async def convert_response_input_to_chat_messages(
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
)
messages.append(message_type(content=content))
if len(tool_call_results):
raise ValueError(
f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call"
)
else:
messages.append(OpenAIUserMessageParam(content=input))
return messages

View file

@ -5,13 +5,13 @@
# the root directory of this source tree.
import asyncio
import logging
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry import tracing
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="agents::meta_reference")
class SafetyException(Exception): # noqa: N818

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import asyncio
import hashlib
import itertools
import json
import time
@ -136,28 +137,45 @@ class ReferenceBatchesImpl(Batches):
endpoint: str,
completion_window: Literal["24h"],
metadata: dict[str, str] | None = None,
idempotency_key: str | None = None,
) -> BatchObject:
"""
Create a new batch for processing multiple API requests.
Error handling by levels -
0. Input param handling, results in 40x errors before processing, e.g.
- Wrong completion_window
- Invalid metadata types
- Unknown endpoint
-> no batch created
1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g.
- input_file_id missing
- invalid json in file
- missing custom_id, method, url, body
- invalid model
- streaming
-> batch created, validation sends to failed status
2. Processing errors, result in error_file_id entries, e.g.
- Any error returned from inference endpoint
-> batch created, goes to completed status
This implementation provides optional idempotency: when an idempotency key
(idempotency_key) is provided, a deterministic ID is generated based on the input
parameters. If a batch with the same parameters already exists, it will be
returned instead of creating a duplicate. Without an idempotency key,
each request creates a new batch with a unique ID.
Args:
input_file_id: The ID of an uploaded file containing requests for the batch.
endpoint: The endpoint to be used for all requests in the batch.
completion_window: The time window within which the batch should be processed.
metadata: Optional metadata for the batch.
idempotency_key: Optional idempotency key for enabling idempotent behavior.
Returns:
The created or existing batch object.
"""
# Error handling by levels -
# 0. Input param handling, results in 40x errors before processing, e.g.
# - Wrong completion_window
# - Invalid metadata types
# - Unknown endpoint
# -> no batch created
# 1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g.
# - input_file_id missing
# - invalid json in file
# - missing custom_id, method, url, body
# - invalid model
# - streaming
# -> batch created, validation sends to failed status
# 2. Processing errors, result in error_file_id entries, e.g.
# - Any error returned from inference endpoint
# -> batch created, goes to completed status
# TODO: set expiration time for garbage collection
if endpoint not in ["/v1/chat/completions"]:
@ -171,6 +189,35 @@ class ReferenceBatchesImpl(Batches):
)
batch_id = f"batch_{uuid.uuid4().hex[:16]}"
# For idempotent requests, use the idempotency key for the batch ID
# This ensures the same key always maps to the same batch ID,
# allowing us to detect parameter conflicts
if idempotency_key is not None:
hash_input = idempotency_key.encode("utf-8")
hash_digest = hashlib.sha256(hash_input).hexdigest()[:24]
batch_id = f"batch_{hash_digest}"
try:
existing_batch = await self.retrieve_batch(batch_id)
if (
existing_batch.input_file_id != input_file_id
or existing_batch.endpoint != endpoint
or existing_batch.completion_window != completion_window
or existing_batch.metadata != metadata
):
raise ConflictError(
f"Idempotency key '{idempotency_key}' was previously used with different parameters. "
"Either use a new idempotency key or ensure all parameters match the original request."
)
logger.info(f"Returning existing batch with ID: {batch_id}")
return existing_batch
except ResourceNotFoundError:
# Batch doesn't exist, continue with creation
pass
current_time = int(time.time())
batch = BatchObject(
@ -185,6 +232,7 @@ class ReferenceBatchesImpl(Batches):
)
await self.kvstore.set(f"batch:{batch_id}", batch.to_json())
logger.info(f"Created new batch with ID: {batch_id}")
if self.process_batches:
task = asyncio.create_task(self._process_batch(batch_id))

View file

@ -11,6 +11,7 @@ from typing import Annotated
from fastapi import File, Form, Response, UploadFile
from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.common.responses import Order
from llama_stack.apis.files import (
Files,
@ -20,12 +21,15 @@ from llama_stack.apis.files import (
OpenAIFilePurpose,
)
from llama_stack.core.datatypes import AccessRule
from llama_stack.log import get_logger
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
from .config import LocalfsFilesImplConfig
logger = get_logger(name=__name__, category="files")
class LocalfsFilesImpl(Files):
def __init__(self, config: LocalfsFilesImplConfig, policy: list[AccessRule]) -> None:
@ -65,6 +69,18 @@ class LocalfsFilesImpl(Files):
"""Get the filesystem path for a file ID."""
return Path(self.config.storage_dir) / file_id
async def _lookup_file_id(self, file_id: str) -> tuple[OpenAIFileObject, Path]:
"""Look up a OpenAIFileObject and filesystem path from its ID."""
if not self.sql_store:
raise RuntimeError("Files provider not initialized")
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
if not row:
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
file_path = Path(row.pop("file_path"))
return OpenAIFileObject(**row), file_path
# OpenAI Files API Implementation
async def openai_upload_file(
self,
@ -157,37 +173,19 @@ class LocalfsFilesImpl(Files):
async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject:
"""Returns information about a specific file."""
if not self.sql_store:
raise RuntimeError("Files provider not initialized")
file_obj, _ = await self._lookup_file_id(file_id)
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
if not row:
raise ValueError(f"File with id {file_id} not found")
return OpenAIFileObject(
id=row["id"],
filename=row["filename"],
purpose=OpenAIFilePurpose(row["purpose"]),
bytes=row["bytes"],
created_at=row["created_at"],
expires_at=row["expires_at"],
)
return file_obj
async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse:
"""Delete a file."""
if not self.sql_store:
raise RuntimeError("Files provider not initialized")
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
if not row:
raise ValueError(f"File with id {file_id} not found")
# Delete physical file
file_path = Path(row["file_path"])
_, file_path = await self._lookup_file_id(file_id)
if file_path.exists():
file_path.unlink()
# Delete metadata from database
assert self.sql_store is not None, "Files provider not initialized"
await self.sql_store.delete("openai_files", where={"id": file_id})
return OpenAIFileDeleteResponse(
@ -197,25 +195,17 @@ class LocalfsFilesImpl(Files):
async def openai_retrieve_file_content(self, file_id: str) -> Response:
"""Returns the contents of the specified file."""
if not self.sql_store:
raise RuntimeError("Files provider not initialized")
# Get file metadata
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
if not row:
raise ValueError(f"File with id {file_id} not found")
# Read file content
file_path = Path(row["file_path"])
if not file_path.exists():
raise ValueError(f"File content not found on disk: {file_path}")
file_obj, file_path = await self._lookup_file_id(file_id)
with open(file_path, "rb") as f:
content = f.read()
if not file_path.exists():
logger.warning(f"File '{file_id}'s underlying '{file_path}' is missing, deleting metadata.")
await self.openai_delete_file(file_id)
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
# Return as binary response with appropriate content type
return Response(
content=content,
content=file_path.read_bytes(),
media_type="application/octet-stream",
headers={"Content-Disposition": f'attachment; filename="{row["filename"]}"'},
headers={"Content-Disposition": f'attachment; filename="{file_obj.filename}"'},
)

View file

@ -12,7 +12,6 @@
import copy
import json
import logging
import multiprocessing
import os
import tempfile
@ -32,13 +31,14 @@ from fairscale.nn.model_parallel.initialize import (
from pydantic import BaseModel, Field
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import GenerationResult
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
)
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="inference")
class ProcessingMessageName(str, Enum):

View file

@ -4,13 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from collections.abc import AsyncGenerator
from llama_stack.apis.inference import (
CompletionResponse,
InferenceProvider,
InterleavedContent,
LogProbConfig,
Message,
ResponseFormat,
@ -21,6 +19,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
@ -32,7 +31,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
from .config import SentenceTransformersInferenceConfig
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="inference")
class SentenceTransformersInferenceImpl(
@ -100,25 +99,3 @@ class SentenceTransformersInferenceImpl(
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
raise ValueError("Sentence transformers don't support chat completion")
async def batch_completion(
self,
model_id: str,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch completion is not supported for Sentence Transformers")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_config: ToolConfig | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch chat completion is not supported for Sentence Transformers")

View file

@ -6,7 +6,6 @@
import gc
import json
import logging
import multiprocessing
from pathlib import Path
from typing import Any
@ -28,6 +27,7 @@ from llama_stack.apis.post_training import (
LoraFinetuningConfig,
TrainingConfig,
)
from llama_stack.log import get_logger
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from ..config import HuggingFacePostTrainingConfig
@ -44,7 +44,7 @@ from ..utils import (
split_dataset,
)
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="post_training")
class HFFinetuningSingleDevice:

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import gc
import logging
import multiprocessing
from pathlib import Path
from typing import Any
@ -24,6 +23,7 @@ from llama_stack.apis.post_training import (
DPOAlignmentConfig,
TrainingConfig,
)
from llama_stack.log import get_logger
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from ..config import HuggingFacePostTrainingConfig
@ -40,7 +40,7 @@ from ..utils import (
split_dataset,
)
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="post_training")
class HFDPOAlignmentSingleDevice:

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import os
import signal
import sys
@ -19,10 +18,11 @@ from transformers import AutoConfig, AutoModelForCausalLM
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.post_training import Checkpoint, TrainingConfig
from llama_stack.log import get_logger
from .config import HuggingFacePostTrainingConfig
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="post_training")
def setup_environment():

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import os
import time
from datetime import UTC, datetime
@ -19,6 +18,7 @@ from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training
from torchtune import utils as torchtune_utils
from torchtune.data import padded_collate_sft
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
from torchtune.modules.loss import CEWithChunkedOutputLoss
from torchtune.modules.peft import (
get_adapter_params,
@ -45,6 +45,7 @@ from llama_stack.apis.post_training import (
)
from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.core.utils.model_utils import model_local_dir
from llama_stack.log import get_logger
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from llama_stack.providers.inline.post_training.torchtune.common import utils
@ -56,9 +57,7 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
)
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
log = logging.getLogger(__name__)
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
log = get_logger(name=__name__, category="post_training")
class LoraFinetuningSingleDevice:

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import uuid
from typing import TYPE_CHECKING, Any
@ -20,13 +19,14 @@ from llama_stack.apis.safety import (
)
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from .config import CodeScannerConfig
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="safety")
ALLOWED_CODE_SCANNER_MODEL_IDS = [
"code-scanner",

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import re
import uuid
from string import Template
@ -21,6 +20,7 @@ from llama_stack.apis.safety import (
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
from llama_stack.apis.shields import Shield
from llama_stack.core.datatypes import Api
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import Role
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
@ -132,6 +132,8 @@ Provide your safety assessment for ONLY THE LAST $agent_type message in the abov
PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}")
logger = get_logger(name=__name__, category="safety")
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
def __init__(self, config: LlamaGuardConfig, deps) -> None:
@ -407,7 +409,7 @@ class LlamaGuardShield:
unsafe_code_list = [code.strip() for code in unsafe_code.split(",")]
invalid_codes = [code for code in unsafe_code_list if code not in SAFETY_CODE_TO_CATEGORIES_MAP]
if invalid_codes:
logging.warning(f"Invalid safety codes returned: {invalid_codes}")
logger.warning(f"Invalid safety codes returned: {invalid_codes}")
# just returning safe object, as we don't know what the invalid codes can map to
return ModerationObject(
id=f"modr-{uuid.uuid4()}",

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import Any
import torch
@ -21,6 +20,7 @@ from llama_stack.apis.safety import (
from llama_stack.apis.safety.safety import ModerationObject
from llama_stack.apis.shields import Shield
from llama_stack.core.utils.model_utils import model_local_dir
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
@ -28,7 +28,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import PromptGuardConfig, PromptGuardType
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="safety")
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"

View file

@ -7,7 +7,6 @@
import collections
import functools
import json
import logging
import random
import re
import string
@ -20,7 +19,9 @@ import nltk
from pythainlp.tokenize import sent_tokenize as sent_tokenize_thai
from pythainlp.tokenize import word_tokenize as word_tokenize_thai
logger = logging.getLogger()
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="scoring")
WORD_LIST = [
"western",

View file

@ -4,13 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import datetime
import threading
from typing import Any
from opentelemetry import metrics, trace
logger = logging.getLogger(__name__)
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.metrics import MeterProvider
@ -40,6 +38,7 @@ from llama_stack.apis.telemetry import (
UnstructuredLogEvent,
)
from llama_stack.core.datatypes import Api
from llama_stack.log import get_logger
from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import (
ConsoleSpanProcessor,
)
@ -61,6 +60,8 @@ _GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = {
_global_lock = threading.Lock()
_TRACER_PROVIDER = None
logger = get_logger(name=__name__, category="telemetry")
def is_tracing_enabled(tracer):
with tracer.start_as_current_span("check_tracing") as span:
@ -145,11 +146,41 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
metric_name: str,
start_time: int,
end_time: int | None = None,
granularity: str | None = "1d",
granularity: str | None = None,
query_type: MetricQueryType = MetricQueryType.RANGE,
label_matchers: list[MetricLabelMatcher] | None = None,
) -> QueryMetricsResponse:
raise NotImplementedError("Querying metrics is not implemented")
"""Query metrics from the telemetry store.
Args:
metric_name: The name of the metric to query (e.g., "prompt_tokens")
start_time: Start time as Unix timestamp
end_time: End time as Unix timestamp (defaults to now if None)
granularity: Time granularity for aggregation
query_type: Type of query (RANGE or INSTANT)
label_matchers: Label filters to apply
Returns:
QueryMetricsResponse with metric time series data
"""
# Convert timestamps to datetime objects
start_dt = datetime.datetime.fromtimestamp(start_time, datetime.UTC)
end_dt = datetime.datetime.fromtimestamp(end_time, datetime.UTC) if end_time else None
# Use SQLite trace store if available
if hasattr(self, "trace_store") and self.trace_store:
return await self.trace_store.query_metrics(
metric_name=metric_name,
start_time=start_dt,
end_time=end_dt,
granularity=granularity,
query_type=query_type,
label_matchers=label_matchers,
)
else:
raise ValueError(
f"In order to query_metrics, you must have {TelemetrySink.SQLITE} set in your telemetry sinks"
)
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
with self._lock:

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import asyncio
import logging
import secrets
import string
from typing import Any
@ -32,6 +31,7 @@ from llama_stack.apis.tools import (
ToolRuntime,
)
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.memory.vector_store import (
@ -42,7 +42,7 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import RagToolRuntimeConfig
from .context_retriever import generate_rag_query
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="tool_runtime")
def make_random_string(length: int = 8):

View file

@ -8,7 +8,6 @@ import asyncio
import base64
import io
import json
import logging
from typing import Any
import faiss
@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import (
QueryChunksResponse,
VectorIO,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
HealthResponse,
HealthStatus,
@ -40,7 +40,7 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import FaissVectorIOConfig
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="vector_io")
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::"

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import asyncio
import logging
import re
import sqlite3
import struct
@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import (
QueryChunksResponse,
VectorIO,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
@ -36,7 +36,7 @@ from llama_stack.providers.utils.memory.vector_store import (
VectorDBWithIndex,
)
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="vector_io")
# Specifying search mode is dependent on the VectorIO provider.
VECTOR_SEARCH = "vector"

View file

@ -5,9 +5,11 @@
# the root directory of this source tree.
from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
)
from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages
@ -23,4 +25,14 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig",
description="Local filesystem-based file storage provider for managing files and documents locally.",
),
remote_provider_spec(
api=Api.files,
adapter=AdapterSpec(
adapter_type="s3",
pip_packages=["boto3"] + sql_store_pip_packages,
module="llama_stack.providers.remote.files.s3",
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
),
),
]

View file

@ -5,34 +5,74 @@
# the root directory of this source tree.
from typing import cast
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
# We provide two versions of these providers so that distributions can package the appropriate version of torch.
# The CPU version is used for distributions that don't have GPU support -- they result in smaller container images.
torchtune_def = dict(
api=Api.post_training,
pip_packages=["torchtune==0.5.0", "torchao==0.8.0", "numpy"],
module="llama_stack.providers.inline.post_training.torchtune",
config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig",
api_dependencies=[
Api.datasetio,
Api.datasets,
],
description="TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework.",
)
huggingface_def = dict(
api=Api.post_training,
pip_packages=["trl", "transformers", "peft", "datasets"],
module="llama_stack.providers.inline.post_training.huggingface",
config_class="llama_stack.providers.inline.post_training.huggingface.HuggingFacePostTrainingConfig",
api_dependencies=[
Api.datasetio,
Api.datasets,
],
description="HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.",
)
def available_providers() -> list[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.post_training,
provider_type="inline::torchtune",
pip_packages=["torch", "torchtune==0.5.0", "torchao==0.8.0", "numpy"],
module="llama_stack.providers.inline.post_training.torchtune",
config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig",
api_dependencies=[
Api.datasetio,
Api.datasets,
],
description="TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework.",
**{
**torchtune_def,
"provider_type": "inline::torchtune-cpu",
"pip_packages": (
cast(list[str], torchtune_def["pip_packages"])
+ ["torch torchtune==0.5.0 torchao==0.8.0 --index-url https://download.pytorch.org/whl/cpu"]
),
},
),
InlineProviderSpec(
api=Api.post_training,
provider_type="inline::huggingface",
pip_packages=["torch", "trl", "transformers", "peft", "datasets"],
module="llama_stack.providers.inline.post_training.huggingface",
config_class="llama_stack.providers.inline.post_training.huggingface.HuggingFacePostTrainingConfig",
api_dependencies=[
Api.datasetio,
Api.datasets,
],
description="HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.",
**{
**huggingface_def,
"provider_type": "inline::huggingface-cpu",
"pip_packages": (
cast(list[str], huggingface_def["pip_packages"])
+ ["torch --index-url https://download.pytorch.org/whl/cpu"]
),
},
),
InlineProviderSpec(
**{
**torchtune_def,
"provider_type": "inline::torchtune-gpu",
"pip_packages": (
cast(list[str], torchtune_def["pip_packages"]) + ["torch torchtune==0.5.0 torchao==0.8.0"]
),
},
),
InlineProviderSpec(
**{
**huggingface_def,
"provider_type": "inline::huggingface-gpu",
"pip_packages": (cast(list[str], huggingface_def["pip_packages"]) + ["torch"]),
},
),
remote_provider_spec(
api=Api.post_training,

View file

@ -0,0 +1,237 @@
# S3 Files Provider
A remote S3-based implementation of the Llama Stack Files API that provides scalable cloud file storage with metadata persistence.
## Features
- **AWS S3 Storage**: Store files in AWS S3 buckets for scalable, durable storage
- **Metadata Management**: Uses SQL database for efficient file metadata queries
- **OpenAI API Compatibility**: Full compatibility with OpenAI Files API endpoints
- **Flexible Authentication**: Support for IAM roles and access keys
- **Custom S3 Endpoints**: Support for MinIO and other S3-compatible services
## Configuration
### Basic Configuration
```yaml
api: files
provider_type: remote::s3
config:
bucket_name: my-llama-stack-files
region: us-east-1
metadata_store:
type: sqlite
db_path: ./s3_files_metadata.db
```
### Advanced Configuration
```yaml
api: files
provider_type: remote::s3
config:
bucket_name: my-llama-stack-files
region: us-east-1
aws_access_key_id: YOUR_ACCESS_KEY
aws_secret_access_key: YOUR_SECRET_KEY
endpoint_url: https://s3.amazonaws.com # Optional for custom endpoints
metadata_store:
type: sqlite
db_path: ./s3_files_metadata.db
```
### Environment Variables
The configuration supports environment variable substitution:
```yaml
config:
bucket_name: "${env.S3_BUCKET_NAME}"
region: "${env.AWS_REGION:=us-east-1}"
aws_access_key_id: "${env.AWS_ACCESS_KEY_ID:=}"
aws_secret_access_key: "${env.AWS_SECRET_ACCESS_KEY:=}"
endpoint_url: "${env.S3_ENDPOINT_URL:=}"
```
Note: `S3_BUCKET_NAME` has no default value since S3 bucket names must be globally unique.
## Authentication
### IAM Roles (Recommended)
For production deployments, use IAM roles:
```yaml
config:
bucket_name: my-bucket
region: us-east-1
# No credentials needed - will use IAM role
```
### Access Keys
For development or specific use cases:
```yaml
config:
bucket_name: my-bucket
region: us-east-1
aws_access_key_id: AKIAIOSFODNN7EXAMPLE
aws_secret_access_key: wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY
```
## S3 Bucket Setup
### Required Permissions
The S3 provider requires the following permissions:
```json
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": [
"s3:GetObject",
"s3:PutObject",
"s3:DeleteObject",
"s3:ListBucket"
],
"Resource": [
"arn:aws:s3:::your-bucket-name",
"arn:aws:s3:::your-bucket-name/*"
]
}
]
}
```
### Automatic Bucket Creation
By default, the S3 provider expects the bucket to already exist. If you want the provider to automatically create the bucket when it doesn't exist, set `auto_create_bucket: true` in your configuration:
```yaml
config:
bucket_name: my-bucket
auto_create_bucket: true # Will create bucket if it doesn't exist
region: us-east-1
```
**Note**: When `auto_create_bucket` is enabled, the provider will need additional permissions:
```json
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": [
"s3:GetObject",
"s3:PutObject",
"s3:DeleteObject",
"s3:ListBucket",
"s3:CreateBucket"
],
"Resource": [
"arn:aws:s3:::your-bucket-name",
"arn:aws:s3:::your-bucket-name/*"
]
}
]
}
```
### Bucket Policy (Optional)
For additional security, you can add a bucket policy:
```json
{
"Version": "2012-10-17",
"Statement": [
{
"Sid": "LlamaStackAccess",
"Effect": "Allow",
"Principal": {
"AWS": "arn:aws:iam::YOUR-ACCOUNT:role/LlamaStackRole"
},
"Action": [
"s3:GetObject",
"s3:PutObject",
"s3:DeleteObject"
],
"Resource": "arn:aws:s3:::your-bucket-name/*"
},
{
"Sid": "LlamaStackBucketAccess",
"Effect": "Allow",
"Principal": {
"AWS": "arn:aws:iam::YOUR-ACCOUNT:role/LlamaStackRole"
},
"Action": [
"s3:ListBucket"
],
"Resource": "arn:aws:s3:::your-bucket-name"
}
]
}
```
## Features
### Metadata Persistence
File metadata is stored in a SQL database for fast queries and OpenAI API compatibility. The metadata includes:
- File ID
- Original filename
- Purpose (assistants, batch, etc.)
- File size in bytes
- Created and expiration timestamps
### TTL and Cleanup
Files currently have a fixed long expiration time (100 years).
## Development and Testing
### Using MinIO
For self-hosted S3-compatible storage:
```yaml
config:
bucket_name: test-bucket
region: us-east-1
endpoint_url: http://localhost:9000
aws_access_key_id: minioadmin
aws_secret_access_key: minioadmin
```
## Monitoring and Logging
The provider logs important operations and errors. For production deployments, consider:
- CloudWatch monitoring for S3 operations
- Custom metrics for file upload/download rates
- Error rate monitoring
- Performance metrics tracking
## Error Handling
The provider handles various error scenarios:
- S3 connectivity issues
- Bucket access permissions
- File not found errors
- Metadata consistency checks
## Known Limitations
- Fixed long TTL (100 years) instead of configurable expiration
- No server-side encryption enabled by default
- No support for AWS session tokens
- No S3 key prefix organization support
- No multipart upload support (all files uploaded as single objects)

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.core.datatypes import Api
from .config import S3FilesImplConfig
async def get_adapter_impl(config: S3FilesImplConfig, deps: dict[Api, Any]):
from .files import S3FilesImpl
# TODO: authorization policies and user separation
impl = S3FilesImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
class S3FilesImplConfig(BaseModel):
"""Configuration for S3-based files provider."""
bucket_name: str = Field(description="S3 bucket name to store files")
region: str = Field(default="us-east-1", description="AWS region where the bucket is located")
aws_access_key_id: str | None = Field(default=None, description="AWS access key ID (optional if using IAM roles)")
aws_secret_access_key: str | None = Field(
default=None, description="AWS secret access key (optional if using IAM roles)"
)
endpoint_url: str | None = Field(default=None, description="Custom S3 endpoint URL (for MinIO, LocalStack, etc.)")
auto_create_bucket: bool = Field(
default=False, description="Automatically create the S3 bucket if it doesn't exist"
)
metadata_store: SqlStoreConfig = Field(description="SQL store configuration for file metadata")
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
return {
"bucket_name": "${env.S3_BUCKET_NAME}", # no default, buckets must be globally unique
"region": "${env.AWS_REGION:=us-east-1}",
"aws_access_key_id": "${env.AWS_ACCESS_KEY_ID:=}",
"aws_secret_access_key": "${env.AWS_SECRET_ACCESS_KEY:=}",
"endpoint_url": "${env.S3_ENDPOINT_URL:=}",
"auto_create_bucket": "${env.S3_AUTO_CREATE_BUCKET:=false}",
"metadata_store": SqliteSqlStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="s3_files_metadata.db",
),
}

View file

@ -0,0 +1,272 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import time
import uuid
from typing import Annotated
import boto3
from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError
from fastapi import File, Form, Response, UploadFile
from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.common.responses import Order
from llama_stack.apis.files import (
Files,
ListOpenAIFileResponse,
OpenAIFileDeleteResponse,
OpenAIFileObject,
OpenAIFilePurpose,
)
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStore, sqlstore_impl
from .config import S3FilesImplConfig
# TODO: provider data for S3 credentials
def _create_s3_client(config: S3FilesImplConfig) -> boto3.client:
try:
s3_config = {
"region_name": config.region,
}
# endpoint URL if specified (for MinIO, LocalStack, etc.)
if config.endpoint_url:
s3_config["endpoint_url"] = config.endpoint_url
if config.aws_access_key_id and config.aws_secret_access_key:
s3_config.update(
{
"aws_access_key_id": config.aws_access_key_id,
"aws_secret_access_key": config.aws_secret_access_key,
}
)
return boto3.client("s3", **s3_config)
except (BotoCoreError, NoCredentialsError) as e:
raise RuntimeError(f"Failed to initialize S3 client: {e}") from e
async def _create_bucket_if_not_exists(client: boto3.client, config: S3FilesImplConfig) -> None:
try:
client.head_bucket(Bucket=config.bucket_name)
except ClientError as e:
error_code = e.response["Error"]["Code"]
if error_code == "404":
if not config.auto_create_bucket:
raise RuntimeError(
f"S3 bucket '{config.bucket_name}' does not exist. "
f"Either create the bucket manually or set 'auto_create_bucket: true' in your configuration."
) from e
try:
# For us-east-1, we can't specify LocationConstraint
if config.region == "us-east-1":
client.create_bucket(Bucket=config.bucket_name)
else:
client.create_bucket(
Bucket=config.bucket_name,
CreateBucketConfiguration={"LocationConstraint": config.region},
)
except ClientError as create_error:
raise RuntimeError(
f"Failed to create S3 bucket '{config.bucket_name}': {create_error}"
) from create_error
elif error_code == "403":
raise RuntimeError(f"Access denied to S3 bucket '{config.bucket_name}'") from e
else:
raise RuntimeError(f"Failed to access S3 bucket '{config.bucket_name}': {e}") from e
class S3FilesImpl(Files):
"""S3-based implementation of the Files API."""
# TODO: implement expiration, for now a silly offset
_SILLY_EXPIRATION_OFFSET = 100 * 365 * 24 * 60 * 60
def __init__(self, config: S3FilesImplConfig) -> None:
self._config = config
self._client: boto3.client | None = None
self._sql_store: SqlStore | None = None
async def initialize(self) -> None:
self._client = _create_s3_client(self._config)
await _create_bucket_if_not_exists(self._client, self._config)
self._sql_store = sqlstore_impl(self._config.metadata_store)
await self._sql_store.create_table(
"openai_files",
{
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
"filename": ColumnType.STRING,
"purpose": ColumnType.STRING,
"bytes": ColumnType.INTEGER,
"created_at": ColumnType.INTEGER,
"expires_at": ColumnType.INTEGER,
# TODO: add s3_etag field for integrity checking
},
)
async def shutdown(self) -> None:
pass
@property
def client(self) -> boto3.client:
assert self._client is not None, "Provider not initialized"
return self._client
@property
def sql_store(self) -> SqlStore:
assert self._sql_store is not None, "Provider not initialized"
return self._sql_store
async def openai_upload_file(
self,
file: Annotated[UploadFile, File()],
purpose: Annotated[OpenAIFilePurpose, Form()],
) -> OpenAIFileObject:
file_id = f"file-{uuid.uuid4().hex}"
filename = getattr(file, "filename", None) or "uploaded_file"
created_at = int(time.time())
expires_at = created_at + self._SILLY_EXPIRATION_OFFSET
content = await file.read()
file_size = len(content)
await self.sql_store.insert(
"openai_files",
{
"id": file_id,
"filename": filename,
"purpose": purpose.value,
"bytes": file_size,
"created_at": created_at,
"expires_at": expires_at,
},
)
try:
self.client.put_object(
Bucket=self._config.bucket_name,
Key=file_id,
Body=content,
# TODO: enable server-side encryption
)
except ClientError as e:
await self.sql_store.delete("openai_files", where={"id": file_id})
raise RuntimeError(f"Failed to upload file to S3: {e}") from e
return OpenAIFileObject(
id=file_id,
filename=filename,
purpose=purpose,
bytes=file_size,
created_at=created_at,
expires_at=expires_at,
)
async def openai_list_files(
self,
after: str | None = None,
limit: int | None = 10000,
order: Order | None = Order.desc,
purpose: OpenAIFilePurpose | None = None,
) -> ListOpenAIFileResponse:
# this purely defensive. it should not happen because the router also default to Order.desc.
if not order:
order = Order.desc
where_conditions = {}
if purpose:
where_conditions["purpose"] = purpose.value
paginated_result = await self.sql_store.fetch_all(
table="openai_files",
where=where_conditions if where_conditions else None,
order_by=[("created_at", order.value)],
cursor=("id", after) if after else None,
limit=limit,
)
files = [
OpenAIFileObject(
id=row["id"],
filename=row["filename"],
purpose=OpenAIFilePurpose(row["purpose"]),
bytes=row["bytes"],
created_at=row["created_at"],
expires_at=row["expires_at"],
)
for row in paginated_result.data
]
return ListOpenAIFileResponse(
data=files,
has_more=paginated_result.has_more,
# empty string or None? spec says str, ref impl returns str | None, we go with spec
first_id=files[0].id if files else "",
last_id=files[-1].id if files else "",
)
async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject:
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
if not row:
raise ResourceNotFoundError(file_id, "File", "files.list()")
return OpenAIFileObject(
id=row["id"],
filename=row["filename"],
purpose=OpenAIFilePurpose(row["purpose"]),
bytes=row["bytes"],
created_at=row["created_at"],
expires_at=row["expires_at"],
)
async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse:
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
if not row:
raise ResourceNotFoundError(file_id, "File", "files.list()")
try:
self.client.delete_object(
Bucket=self._config.bucket_name,
Key=row["id"],
)
except ClientError as e:
if e.response["Error"]["Code"] != "NoSuchKey":
raise RuntimeError(f"Failed to delete file from S3: {e}") from e
await self.sql_store.delete("openai_files", where={"id": file_id})
return OpenAIFileDeleteResponse(id=file_id, deleted=True)
async def openai_retrieve_file_content(self, file_id: str) -> Response:
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
if not row:
raise ResourceNotFoundError(file_id, "File", "files.list()")
try:
response = self.client.get_object(
Bucket=self._config.bucket_name,
Key=row["id"],
)
# TODO: can we stream this instead of loading it into memory
content = response["Body"].read()
except ClientError as e:
if e.response["Error"]["Code"] == "NoSuchKey":
await self.sql_store.delete("openai_files", where={"id": file_id})
raise ResourceNotFoundError(file_id, "File", "files.list()") from e
raise RuntimeError(f"Failed to download file from S3: {e}") from e
return Response(
content=content,
media_type="application/octet-stream",
headers={"Content-Disposition": f'attachment; filename="{row["filename"]}"'},
)

View file

@ -65,7 +65,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import FireworksImplConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="inference::fireworks")
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):

View file

@ -3,15 +3,14 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from llama_stack.log import get_logger
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .models import MODEL_ENTRIES
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="inference::llama_openai_compat")
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):

View file

@ -41,6 +41,11 @@ client.initialize()
### Create Completion
> Note on Completion API
>
> The hosted NVIDIA Llama NIMs (e.g., `meta-llama/Llama-3.1-8B-Instruct`) with ```NVIDIA_BASE_URL="https://integrate.api.nvidia.com"``` does not support the ```completion``` method, while the locally deployed NIM does.
```python
response = client.inference.completion(
model_id="meta-llama/Llama-3.1-8B-Instruct",
@ -76,7 +81,78 @@ response = client.inference.chat_completion(
print(f"Response: {response.completion_message.content}")
```
### Tool Calling Example ###
```python
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
tool_definition = ToolDefinition(
tool_name="get_weather",
description="Get current weather information for a location",
parameters={
"location": ToolParamDefinition(
param_type="string",
description="The city and state, e.g. San Francisco, CA",
required=True,
),
"unit": ToolParamDefinition(
param_type="string",
description="Temperature unit (celsius or fahrenheit)",
required=False,
default="celsius",
),
},
)
tool_response = client.inference.chat_completion(
model_id="meta-llama/Llama-3.1-8B-Instruct",
messages=[{"role": "user", "content": "What's the weather like in San Francisco?"}],
tools=[tool_definition],
)
print(f"Tool Response: {tool_response.completion_message.content}")
if tool_response.completion_message.tool_calls:
for tool_call in tool_response.completion_message.tool_calls:
print(f"Tool Called: {tool_call.tool_name}")
print(f"Arguments: {tool_call.arguments}")
```
### Structured Output Example
```python
from llama_stack.apis.inference import JsonSchemaResponseFormat, ResponseFormatType
person_schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
"occupation": {"type": "string"},
},
"required": ["name", "age", "occupation"],
}
response_format = JsonSchemaResponseFormat(
type=ResponseFormatType.json_schema, json_schema=person_schema
)
structured_response = client.inference.chat_completion(
model_id="meta-llama/Llama-3.1-8B-Instruct",
messages=[
{
"role": "user",
"content": "Create a profile for a fictional person named Alice who is 30 years old and is a software engineer. ",
}
],
response_format=response_format,
)
print(f"Structured Response: {structured_response.completion_message.content}")
```
### Create Embeddings
> Note on OpenAI embeddings compatibility
>
> NVIDIA asymmetric embedding models (e.g., `nvidia/llama-3.2-nv-embedqa-1b-v2`) require an `input_type` parameter not present in the standard OpenAI embeddings API. The NVIDIA Inference Adapter automatically sets `input_type="query"` when using the OpenAI-compatible embeddings endpoint for NVIDIA. For passage embeddings, use the `embeddings` API with `task_type="document"`.
```python
response = client.inference.embeddings(
model_id="nvidia/llama-3.2-nv-embedqa-1b-v2",

View file

@ -4,11 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import warnings
from collections.abc import AsyncIterator
from openai import APIConnectionError, BadRequestError
from openai import NOT_GIVEN, APIConnectionError
from llama_stack.apis.common.content_types import (
InterleavedContent,
@ -27,12 +26,16 @@ from llama_stack.apis.inference import (
Inference,
LogProbConfig,
Message,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
ResponseFormat,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
)
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
@ -54,7 +57,7 @@ from .openai_utils import (
)
from .utils import _is_nvidia_hosted
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="inference::nvidia")
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
@ -194,15 +197,11 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
}
extra_body["input_type"] = task_type_options[task_type]
try:
response = await self.client.embeddings.create(
model=provider_model_id,
input=input,
extra_body=extra_body,
)
except BadRequestError as e:
raise ValueError(f"Failed to get embeddings: {e}") from e
response = await self.client.embeddings.create(
model=provider_model_id,
input=input,
extra_body=extra_body,
)
#
# OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=list[float], ...)], ...)
# ->
@ -210,6 +209,57 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
#
return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data])
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
"""
OpenAI-compatible embeddings for NVIDIA NIM.
Note: NVIDIA NIM asymmetric embedding models require an "input_type" field not present in the standard OpenAI embeddings API.
We default this to "query" to ensure requests succeed when using the
OpenAI-compatible endpoint. For passage embeddings, use the embeddings API with
`task_type='document'`.
"""
extra_body: dict[str, object] = {"input_type": "query"}
logger.warning(
"NVIDIA OpenAI-compatible embeddings: defaulting to input_type='query'. "
"For passage embeddings, use the embeddings API with task_type='document'."
)
response = await self.client.embeddings.create(
model=await self._get_provider_model_id(model),
input=input,
encoding_format=encoding_format if encoding_format is not None else NOT_GIVEN,
dimensions=dimensions if dimensions is not None else NOT_GIVEN,
user=user if user is not None else NOT_GIVEN,
extra_body=extra_body,
)
data = []
for i, embedding_data in enumerate(response.data):
data.append(
OpenAIEmbeddingData(
embedding=embedding_data.embedding,
index=i,
)
)
usage = OpenAIEmbeddingUsage(
prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens,
)
return OpenAIEmbeddingsResponse(
data=data,
model=response.model,
usage=usage,
)
async def chat_completion(
self,
model_id: str,

View file

@ -4,13 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import httpx
from llama_stack.log import get_logger
from . import NVIDIAConfig
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="inference::nvidia")
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:

View file

@ -85,7 +85,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="inference::ollama")
class OllamaInferenceAdapter(
@ -619,28 +619,6 @@ class OllamaInferenceAdapter(
response.id = id
return response
async def batch_completion(
self,
model_id: str,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch completion is not supported for Ollama")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_config: ToolConfig | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch chat completion is not supported for Ollama")
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
async def _convert_content(content) -> dict:

View file

@ -4,15 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import OpenAIConfig
from .models import MODEL_ENTRIES
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="inference::openai")
#

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import logging
from collections.abc import AsyncGenerator
from huggingface_hub import AsyncInferenceClient, HfApi
@ -34,6 +33,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.apis.models import Model
from llama_stack.log import get_logger
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import (
@ -58,7 +58,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="inference::tgi")
def build_hf_repo_model_entries():

View file

@ -61,7 +61,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import TogetherImplConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="inference::together")
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):

View file

@ -85,7 +85,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import VLLMInferenceAdapterConfig
log = get_logger(name=__name__, category="inference")
log = get_logger(name=__name__, category="inference::vllm")
def build_hf_repo_model_entries():
@ -711,25 +711,3 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
user=user,
)
return await self.client.chat.completions.create(**params) # type: ignore
async def batch_completion(
self,
model_id: str,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch completion is not supported for Ollama")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_config: ToolConfig | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch chat completion is not supported for Ollama")

View file

@ -4,18 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import warnings
from typing import Any
from pydantic import BaseModel
from llama_stack.apis.post_training import TrainingConfig
from llama_stack.log import get_logger
from llama_stack.providers.remote.post_training.nvidia.config import SFTLoRADefaultConfig
from .config import NvidiaPostTrainingConfig
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="post_training::nvidia")
def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None:

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import json
import logging
from typing import Any
from llama_stack.apis.inference import Message
@ -16,12 +15,13 @@ from llama_stack.apis.safety import (
ViolationLevel,
)
from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from .config import BedrockSafetyConfig
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="safety::bedrock")
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):

View file

@ -4,20 +4,20 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import Any
import requests
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import RunShieldResponse, Safety, SafetyViolation, ViolationLevel
from llama_stack.apis.safety import ModerationObject, RunShieldResponse, Safety, SafetyViolation, ViolationLevel
from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
from .config import NVIDIASafetyConfig
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="safety::nvidia")
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
@ -67,6 +67,9 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
self.shield = NeMoGuardrails(self.config, shield.shield_id)
return await self.shield.run(messages)
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
raise NotImplementedError("NVIDIA safety provider currently does not implement run_moderation")
class NeMoGuardrails:
"""

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import json
import logging
from typing import Any
import litellm
@ -20,12 +19,13 @@ from llama_stack.apis.safety import (
)
from llama_stack.apis.shields import Shield
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
from .config import SambaNovaSafetyConfig
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="safety::sambanova")
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import asyncio
import json
import logging
from typing import Any
from urllib.parse import urlparse
@ -20,6 +19,7 @@ from llama_stack.apis.vector_io import (
QueryChunksResponse,
VectorIO,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
from llama_stack.providers.utils.kvstore import kvstore_impl
@ -33,7 +33,7 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="vector_io::chroma")
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import asyncio
import logging
import os
from typing import Any
@ -21,6 +20,7 @@ from llama_stack.apis.vector_io import (
QueryChunksResponse,
VectorIO,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
from llama_stack.providers.utils.kvstore import kvstore_impl
@ -36,7 +36,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="vector_io::milvus")
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::"

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import Any
import psycopg2
@ -22,6 +21,7 @@ from llama_stack.apis.vector_io import (
QueryChunksResponse,
VectorIO,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
@ -34,7 +34,7 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import PGVectorVectorIOConfig
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="vector_io::pgvector")
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::"

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import asyncio
import logging
import uuid
from typing import Any
@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import (
VectorStoreChunkingStrategy,
VectorStoreFileObject,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
@ -36,7 +36,7 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="vector_io::qdrant")
CHUNK_ID_KEY = "_chunk_id"
# KV store prefixes for vector databases

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import logging
from typing import Any
import weaviate
@ -19,6 +18,7 @@ from llama_stack.apis.files.files import Files
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
@ -34,7 +34,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
from .config import WeaviateVectorIOConfig
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="vector_io::weaviate")
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"

View file

@ -5,10 +5,11 @@
# the root directory of this source tree.
import base64
import logging
import struct
from typing import TYPE_CHECKING
from llama_stack.log import get_logger
if TYPE_CHECKING:
from sentence_transformers import SentenceTransformer
@ -27,7 +28,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import interleaved_con
EMBEDDING_MODELS = {}
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="providers::utils")
class SentenceTransformerEmbeddingMixin:

View file

@ -54,7 +54,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="providers::utils")
class LiteLLMOpenAIMixin(
@ -429,28 +429,6 @@ class LiteLLMOpenAIMixin(
)
return await litellm.acompletion(**params)
async def batch_completion(
self,
model_id: str,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch completion is not supported for OpenAI Compat")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_config: ToolConfig | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch chat completion is not supported for OpenAI Compat")
async def check_model_availability(self, model: str) -> bool:
"""
Check if a specific model is available via LiteLLM for the current

Some files were not shown because too many files have changed in this diff Show more