mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
Merge branch 'main' into HuggingfacePostTrainingConfig-branch
This commit is contained in:
commit
d0d737680f
193 changed files with 7108 additions and 881 deletions
|
@ -29,12 +29,16 @@ class ListBatchesResponse(BaseModel):
|
|||
|
||||
@runtime_checkable
|
||||
class Batches(Protocol):
|
||||
"""Protocol for batch processing API operations.
|
||||
|
||||
"""
|
||||
The Batches API enables efficient processing of multiple requests in a single operation,
|
||||
particularly useful for processing large datasets, batch evaluation workflows, and
|
||||
cost-effective inference at scale.
|
||||
|
||||
The API is designed to allow use of openai client libraries for seamless integration.
|
||||
|
||||
This API provides the following extensions:
|
||||
- idempotent batch creation
|
||||
|
||||
Note: This API is currently under active development and may undergo changes.
|
||||
"""
|
||||
|
||||
|
@ -45,6 +49,7 @@ class Batches(Protocol):
|
|||
endpoint: str,
|
||||
completion_window: Literal["24h"],
|
||||
metadata: dict[str, str] | None = None,
|
||||
idempotency_key: str | None = None,
|
||||
) -> BatchObject:
|
||||
"""Create a new batch for processing multiple API requests.
|
||||
|
||||
|
@ -52,6 +57,7 @@ class Batches(Protocol):
|
|||
:param endpoint: The endpoint to be used for all requests in the batch.
|
||||
:param completion_window: The time window within which the batch should be processed.
|
||||
:param metadata: Optional metadata for the batch.
|
||||
:param idempotency_key: Optional idempotency key. When provided, enables idempotent behavior.
|
||||
:returns: The created batch object.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -473,6 +473,28 @@ class EmbeddingsResponse(BaseModel):
|
|||
embeddings: list[list[float]]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RerankData(BaseModel):
|
||||
"""A single rerank result from a reranking response.
|
||||
|
||||
:param index: The original index of the document in the input list
|
||||
:param relevance_score: The relevance score from the model output. Values are inverted when applicable so that higher scores indicate greater relevance.
|
||||
"""
|
||||
|
||||
index: int
|
||||
relevance_score: float
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RerankResponse(BaseModel):
|
||||
"""Response from a reranking request.
|
||||
|
||||
:param data: List of rerank result objects, sorted by relevance score (descending)
|
||||
"""
|
||||
|
||||
data: list[RerankData]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionContentPartTextParam(BaseModel):
|
||||
"""Text content part for OpenAI-compatible chat completion messages.
|
||||
|
@ -1046,6 +1068,7 @@ class InferenceProvider(Protocol):
|
|||
:returns: A BatchCompletionResponse with the full completions.
|
||||
"""
|
||||
raise NotImplementedError("Batch completion is not implemented")
|
||||
return # this is so mypy's safe-super rule will consider the method concrete
|
||||
|
||||
@webmethod(route="/inference/chat-completion", method="POST")
|
||||
async def chat_completion(
|
||||
|
@ -1110,6 +1133,7 @@ class InferenceProvider(Protocol):
|
|||
:returns: A BatchChatCompletionResponse with the full completions.
|
||||
"""
|
||||
raise NotImplementedError("Batch chat completion is not implemented")
|
||||
return # this is so mypy's safe-super rule will consider the method concrete
|
||||
|
||||
@webmethod(route="/inference/embeddings", method="POST")
|
||||
async def embeddings(
|
||||
|
@ -1131,6 +1155,25 @@ class InferenceProvider(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/inference/rerank", method="POST", experimental=True)
|
||||
async def rerank(
|
||||
self,
|
||||
model: str,
|
||||
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||
max_num_results: int | None = None,
|
||||
) -> RerankResponse:
|
||||
"""Rerank a list of documents based on their relevance to a query.
|
||||
|
||||
:param model: The identifier of the reranking model to use.
|
||||
:param query: The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length.
|
||||
:param items: List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length.
|
||||
:param max_num_results: (Optional) Maximum number of results to return. Default: returns all.
|
||||
:returns: RerankResponse with indices sorted by relevance score (descending).
|
||||
"""
|
||||
raise NotImplementedError("Reranking is not implemented")
|
||||
return # this is so mypy's safe-super rule will consider the method concrete
|
||||
|
||||
@webmethod(route="/openai/v1/completions", method="POST")
|
||||
async def openai_completion(
|
||||
self,
|
||||
|
|
|
@ -386,6 +386,7 @@ class MetricDataPoint(BaseModel):
|
|||
|
||||
timestamp: int
|
||||
value: float
|
||||
unit: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -518,7 +519,7 @@ class Telemetry(Protocol):
|
|||
metric_name: str,
|
||||
start_time: int,
|
||||
end_time: int | None = None,
|
||||
granularity: str | None = "1d",
|
||||
granularity: str | None = None,
|
||||
query_type: MetricQueryType = MetricQueryType.RANGE,
|
||||
label_matchers: list[MetricLabelMatcher] | None = None,
|
||||
) -> QueryMetricsResponse:
|
||||
|
|
|
@ -15,7 +15,7 @@ from llama_stack.log import get_logger
|
|||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
logger = get_logger(name=__name__, category="server")
|
||||
logger = get_logger(name=__name__, category="cli")
|
||||
|
||||
|
||||
class StackRun(Subcommand):
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import importlib.resources
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
@ -17,9 +16,10 @@ from llama_stack.core.external import load_external_apis
|
|||
from llama_stack.core.utils.exec import run_command
|
||||
from llama_stack.core.utils.image_types import LlamaStackImageType
|
||||
from llama_stack.distributions.template import DistributionTemplate
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
# These are the dependencies needed by the distribution server.
|
||||
# `llama-stack` is automatically installed by the installation script.
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import logging
|
||||
import textwrap
|
||||
from typing import Any
|
||||
|
||||
|
@ -21,9 +20,10 @@ from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
|
|||
from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.core.utils.prompt_for_config import prompt_for_config
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provider) -> Provider:
|
||||
|
|
|
@ -318,6 +318,41 @@ class QuotaConfig(BaseModel):
|
|||
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
|
||||
|
||||
|
||||
class CORSConfig(BaseModel):
|
||||
allow_origins: list[str] = Field(default_factory=list)
|
||||
allow_origin_regex: str | None = Field(default=None)
|
||||
allow_methods: list[str] = Field(default=["OPTIONS"])
|
||||
allow_headers: list[str] = Field(default_factory=list)
|
||||
allow_credentials: bool = Field(default=False)
|
||||
expose_headers: list[str] = Field(default_factory=list)
|
||||
max_age: int = Field(default=600, ge=0)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_credentials_config(self) -> Self:
|
||||
if self.allow_credentials and (self.allow_origins == ["*"] or "*" in self.allow_origins):
|
||||
raise ValueError("Cannot use wildcard origins with credentials enabled")
|
||||
return self
|
||||
|
||||
|
||||
def process_cors_config(cors_config: bool | CORSConfig | None) -> CORSConfig | None:
|
||||
if cors_config is False or cors_config is None:
|
||||
return None
|
||||
|
||||
if cors_config is True:
|
||||
# dev mode: allow localhost on any port
|
||||
return CORSConfig(
|
||||
allow_origins=[],
|
||||
allow_origin_regex=r"https?://localhost:\d+",
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||
allow_headers=["Content-Type", "Authorization", "X-Requested-With"],
|
||||
)
|
||||
|
||||
if isinstance(cors_config, CORSConfig):
|
||||
return cors_config
|
||||
|
||||
raise ValueError(f"Expected bool or CORSConfig, got {type(cors_config).__name__}")
|
||||
|
||||
|
||||
class ServerConfig(BaseModel):
|
||||
port: int = Field(
|
||||
default=8321,
|
||||
|
@ -349,6 +384,12 @@ class ServerConfig(BaseModel):
|
|||
default=None,
|
||||
description="Per client quota request configuration",
|
||||
)
|
||||
cors: bool | CORSConfig | None = Field(
|
||||
default=None,
|
||||
description="CORS configuration for cross-origin requests. Can be:\n"
|
||||
"- true: Enable localhost CORS for development\n"
|
||||
"- {allow_origins: [...], allow_methods: [...], ...}: Full configuration",
|
||||
)
|
||||
|
||||
|
||||
class StackRunConfig(BaseModel):
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import logging # allow-direct-logging
|
||||
import os
|
||||
import sys
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
@ -48,6 +48,7 @@ from llama_stack.core.stack import (
|
|||
from llama_stack.core.utils.config import redact_sensitive_fields
|
||||
from llama_stack.core.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.core.utils.exec import in_notebook
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
CURRENT_TRACE_CONTEXT,
|
||||
end_trace,
|
||||
|
@ -55,7 +56,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
|||
start_trace,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
@ -145,39 +146,26 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
):
|
||||
super().__init__()
|
||||
self.async_client = AsyncLlamaStackAsLibraryClient(
|
||||
config_path_or_distro_name, custom_provider_registry, provider_data
|
||||
config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal
|
||||
)
|
||||
self.pool_executor = ThreadPoolExecutor(max_workers=4)
|
||||
self.skip_logger_removal = skip_logger_removal
|
||||
self.provider_data = provider_data
|
||||
|
||||
self.loop = asyncio.new_event_loop()
|
||||
|
||||
def initialize(self):
|
||||
if in_notebook():
|
||||
import nest_asyncio
|
||||
|
||||
nest_asyncio.apply()
|
||||
if not self.skip_logger_removal:
|
||||
self._remove_root_logger_handlers()
|
||||
|
||||
# use a new event loop to avoid interfering with the main event loop
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(self.async_client.initialize())
|
||||
loop.run_until_complete(self.async_client.initialize())
|
||||
finally:
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
def _remove_root_logger_handlers(self):
|
||||
def initialize(self):
|
||||
"""
|
||||
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
|
||||
Deprecated method for backward compatibility.
|
||||
"""
|
||||
root_logger = logging.getLogger()
|
||||
|
||||
for handler in root_logger.handlers[:]:
|
||||
root_logger.removeHandler(handler)
|
||||
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
|
||||
pass
|
||||
|
||||
def request(self, *args, **kwargs):
|
||||
loop = self.loop
|
||||
|
@ -215,6 +203,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
config_path_or_distro_name: str,
|
||||
custom_provider_registry: ProviderRegistry | None = None,
|
||||
provider_data: dict[str, Any] | None = None,
|
||||
skip_logger_removal: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
# when using the library client, we should not log to console since many
|
||||
|
@ -222,6 +211,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
|
||||
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
|
||||
|
||||
if in_notebook():
|
||||
import nest_asyncio
|
||||
|
||||
nest_asyncio.apply()
|
||||
if not skip_logger_removal:
|
||||
self._remove_root_logger_handlers()
|
||||
|
||||
if config_path_or_distro_name.endswith(".yaml"):
|
||||
config_path = Path(config_path_or_distro_name)
|
||||
if not config_path.exists():
|
||||
|
@ -238,7 +234,24 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
self.provider_data = provider_data
|
||||
self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError
|
||||
|
||||
def _remove_root_logger_handlers(self):
|
||||
"""
|
||||
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
|
||||
"""
|
||||
root_logger = logging.getLogger()
|
||||
|
||||
for handler in root_logger.handlers[:]:
|
||||
root_logger.removeHandler(handler)
|
||||
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
"""
|
||||
Initialize the async client.
|
||||
|
||||
Returns:
|
||||
bool: True if initialization was successful
|
||||
"""
|
||||
|
||||
try:
|
||||
self.route_impls = None
|
||||
self.impls = await construct_stack(self.config, self.custom_provider_registry)
|
||||
|
|
|
@ -6,15 +6,15 @@
|
|||
|
||||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import User
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .utils.dynamic import instantiate_class_type
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
# Context variable for request provider data and auth attributes
|
||||
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
||||
|
|
|
@ -12,7 +12,7 @@ from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
logger = get_logger(name=__name__, category="core::routers")
|
||||
|
||||
|
||||
class DatasetIORouter(DatasetIO):
|
||||
|
|
|
@ -16,7 +16,7 @@ from llama_stack.apis.scoring import (
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
logger = get_logger(name=__name__, category="core::routers")
|
||||
|
||||
|
||||
class ScoringRouter(Scoring):
|
||||
|
|
|
@ -65,7 +65,7 @@ from llama_stack.providers.datatypes import HealthResponse, HealthStatus, Routin
|
|||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
logger = get_logger(name=__name__, category="core::routers")
|
||||
|
||||
|
||||
class InferenceRouter(Inference):
|
||||
|
|
|
@ -13,7 +13,7 @@ from llama_stack.apis.shields import Shield
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
logger = get_logger(name=__name__, category="core::routers")
|
||||
|
||||
|
||||
class SafetyRouter(Safety):
|
||||
|
|
|
@ -22,7 +22,7 @@ from llama_stack.log import get_logger
|
|||
|
||||
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
logger = get_logger(name=__name__, category="core::routers")
|
||||
|
||||
|
||||
class ToolRuntimeRouter(ToolRuntime):
|
||||
|
|
|
@ -30,7 +30,7 @@ from llama_stack.apis.vector_io import (
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
logger = get_logger(name=__name__, category="core::routers")
|
||||
|
||||
|
||||
class VectorIORouter(VectorIO):
|
||||
|
|
|
@ -14,7 +14,7 @@ from llama_stack.log import get_logger
|
|||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||
|
||||
|
||||
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
||||
|
|
|
@ -23,7 +23,7 @@ from llama_stack.core.store import DistributionRegistry
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||
|
||||
|
||||
def get_impl_api(p: Any) -> Api:
|
||||
|
|
|
@ -26,7 +26,7 @@ from llama_stack.log import get_logger
|
|||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||
|
||||
|
||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||
|
|
|
@ -17,7 +17,7 @@ from llama_stack.log import get_logger
|
|||
|
||||
from .common import CommonRoutingTableImpl, lookup_model
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
|
|
|
@ -19,7 +19,7 @@ from llama_stack.log import get_logger
|
|||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||
|
||||
|
||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||
|
|
|
@ -15,7 +15,7 @@ from llama_stack.log import get_logger
|
|||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||
|
||||
|
||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||
|
|
|
@ -14,7 +14,7 @@ from llama_stack.log import get_logger
|
|||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||
|
||||
|
||||
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:
|
||||
|
|
|
@ -30,7 +30,7 @@ from llama_stack.log import get_logger
|
|||
|
||||
from .common import CommonRoutingTableImpl, lookup_model
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||
|
||||
|
||||
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||
|
|
|
@ -15,7 +15,7 @@ from llama_stack.core.server.auth_providers import create_auth_provider
|
|||
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="auth")
|
||||
logger = get_logger(name=__name__, category="core::auth")
|
||||
|
||||
|
||||
class AuthenticationMiddleware:
|
||||
|
|
|
@ -23,7 +23,7 @@ from llama_stack.core.datatypes import (
|
|||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="auth")
|
||||
logger = get_logger(name=__name__, category="core::auth")
|
||||
|
||||
|
||||
class AuthResponse(BaseModel):
|
||||
|
|
|
@ -15,7 +15,7 @@ from llama_stack.providers.utils.kvstore.api import KVStore
|
|||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
|
||||
|
||||
logger = get_logger(name=__name__, category="quota")
|
||||
logger = get_logger(name=__name__, category="core::server")
|
||||
|
||||
|
||||
class QuotaMiddleware:
|
||||
|
|
|
@ -9,7 +9,7 @@ import asyncio
|
|||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import logging # allow-direct-logging
|
||||
import os
|
||||
import ssl
|
||||
import sys
|
||||
|
@ -28,6 +28,7 @@ from aiohttp import hdrs
|
|||
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
||||
from fastapi import Path as FastapiPath
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from openai import BadRequestError
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
@ -40,6 +41,7 @@ from llama_stack.core.datatypes import (
|
|||
AuthenticationRequiredError,
|
||||
LoggingConfig,
|
||||
StackRunConfig,
|
||||
process_cors_config,
|
||||
)
|
||||
from llama_stack.core.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.core.external import ExternalApiSpec, load_external_apis
|
||||
|
@ -82,7 +84,7 @@ from .quota import QuotaMiddleware
|
|||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
logger = get_logger(name=__name__, category="server")
|
||||
logger = get_logger(name=__name__, category="core::server")
|
||||
|
||||
|
||||
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
||||
|
@ -413,7 +415,7 @@ def main(args: argparse.Namespace | None = None):
|
|||
config_contents = yaml.safe_load(fp)
|
||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||
logger_config = LoggingConfig(**cfg)
|
||||
logger = get_logger(name=__name__, category="server", config=logger_config)
|
||||
logger = get_logger(name=__name__, category="core::server", config=logger_config)
|
||||
if args.env:
|
||||
for env_pair in args.env:
|
||||
try:
|
||||
|
@ -483,6 +485,12 @@ def main(args: argparse.Namespace | None = None):
|
|||
window_seconds=window_seconds,
|
||||
)
|
||||
|
||||
if config.server.cors:
|
||||
logger.info("Enabling CORS")
|
||||
cors_config = process_cors_config(config.server.cors)
|
||||
if cors_config:
|
||||
app.add_middleware(CORSMiddleware, **cors_config.model_dump())
|
||||
|
||||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
else:
|
||||
|
|
|
@ -16,7 +16,7 @@ from llama_stack.log import get_logger
|
|||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
|
||||
logger = get_logger(__name__, category="core")
|
||||
logger = get_logger(__name__, category="core::registry")
|
||||
|
||||
|
||||
class DistributionRegistry(Protocol):
|
||||
|
|
|
@ -10,7 +10,7 @@ from pathlib import Path
|
|||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="config_resolution")
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions"
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import importlib
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
|
@ -12,9 +12,9 @@ import sys
|
|||
|
||||
from termcolor import cprint
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
import importlib
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
def formulate_run_args(image_type: str, image_name: str) -> list:
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Literal, Union, get_args, get_origin
|
||||
|
||||
|
@ -14,7 +13,9 @@ from pydantic import BaseModel
|
|||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefinedType
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
def is_list_of_primitives(field_type):
|
||||
|
|
|
@ -34,7 +34,7 @@ distribution_spec:
|
|||
telemetry:
|
||||
- provider_type: inline::meta-reference
|
||||
post_training:
|
||||
- provider_type: inline::huggingface
|
||||
- provider_type: inline::huggingface-cpu
|
||||
eval:
|
||||
- provider_type: inline::meta-reference
|
||||
datasetio:
|
||||
|
|
|
@ -156,8 +156,8 @@ providers:
|
|||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
post_training:
|
||||
- provider_id: huggingface
|
||||
provider_type: inline::huggingface
|
||||
- provider_id: huggingface-cpu
|
||||
provider_type: inline::huggingface-cpu
|
||||
config:
|
||||
checkpoint_format: huggingface
|
||||
distributed_backend: null
|
||||
|
|
7
llama_stack/distributions/starter-gpu/__init__.py
Normal file
7
llama_stack/distributions/starter-gpu/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .starter_gpu import get_distribution_template # noqa: F401
|
59
llama_stack/distributions/starter-gpu/build.yaml
Normal file
59
llama_stack/distributions/starter-gpu/build.yaml
Normal file
|
@ -0,0 +1,59 @@
|
|||
version: 2
|
||||
distribution_spec:
|
||||
description: Quick start template for running Llama Stack with several popular providers.
|
||||
This distribution is intended for GPU-enabled environments.
|
||||
providers:
|
||||
inference:
|
||||
- provider_type: remote::cerebras
|
||||
- provider_type: remote::ollama
|
||||
- provider_type: remote::vllm
|
||||
- provider_type: remote::tgi
|
||||
- provider_type: remote::fireworks
|
||||
- provider_type: remote::together
|
||||
- provider_type: remote::bedrock
|
||||
- provider_type: remote::nvidia
|
||||
- provider_type: remote::openai
|
||||
- provider_type: remote::anthropic
|
||||
- provider_type: remote::gemini
|
||||
- provider_type: remote::vertexai
|
||||
- provider_type: remote::groq
|
||||
- provider_type: remote::sambanova
|
||||
- provider_type: inline::sentence-transformers
|
||||
vector_io:
|
||||
- provider_type: inline::faiss
|
||||
- provider_type: inline::sqlite-vec
|
||||
- provider_type: inline::milvus
|
||||
- provider_type: remote::chromadb
|
||||
- provider_type: remote::pgvector
|
||||
files:
|
||||
- provider_type: inline::localfs
|
||||
safety:
|
||||
- provider_type: inline::llama-guard
|
||||
- provider_type: inline::code-scanner
|
||||
agents:
|
||||
- provider_type: inline::meta-reference
|
||||
telemetry:
|
||||
- provider_type: inline::meta-reference
|
||||
post_training:
|
||||
- provider_type: inline::torchtune-gpu
|
||||
eval:
|
||||
- provider_type: inline::meta-reference
|
||||
datasetio:
|
||||
- provider_type: remote::huggingface
|
||||
- provider_type: inline::localfs
|
||||
scoring:
|
||||
- provider_type: inline::basic
|
||||
- provider_type: inline::llm-as-judge
|
||||
- provider_type: inline::braintrust
|
||||
tool_runtime:
|
||||
- provider_type: remote::brave-search
|
||||
- provider_type: remote::tavily-search
|
||||
- provider_type: inline::rag-runtime
|
||||
- provider_type: remote::model-context-protocol
|
||||
batches:
|
||||
- provider_type: inline::reference
|
||||
image_type: venv
|
||||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
- asyncpg
|
||||
- sqlalchemy[asyncio]
|
238
llama_stack/distributions/starter-gpu/run.yaml
Normal file
238
llama_stack/distributions/starter-gpu/run.yaml
Normal file
|
@ -0,0 +1,238 @@
|
|||
version: 2
|
||||
image_name: starter-gpu
|
||||
apis:
|
||||
- agents
|
||||
- batches
|
||||
- datasetio
|
||||
- eval
|
||||
- files
|
||||
- inference
|
||||
- post_training
|
||||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
||||
provider_type: remote::cerebras
|
||||
config:
|
||||
base_url: https://api.cerebras.ai
|
||||
api_key: ${env.CEREBRAS_API_KEY:=}
|
||||
- provider_id: ${env.OLLAMA_URL:+ollama}
|
||||
provider_type: remote::ollama
|
||||
config:
|
||||
url: ${env.OLLAMA_URL:=http://localhost:11434}
|
||||
- provider_id: ${env.VLLM_URL:+vllm}
|
||||
provider_type: remote::vllm
|
||||
config:
|
||||
url: ${env.VLLM_URL:=}
|
||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
- provider_id: ${env.TGI_URL:+tgi}
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
url: ${env.TGI_URL:=}
|
||||
- provider_id: fireworks
|
||||
provider_type: remote::fireworks
|
||||
config:
|
||||
url: https://api.fireworks.ai/inference/v1
|
||||
api_key: ${env.FIREWORKS_API_KEY:=}
|
||||
- provider_id: together
|
||||
provider_type: remote::together
|
||||
config:
|
||||
url: https://api.together.xyz/v1
|
||||
api_key: ${env.TOGETHER_API_KEY:=}
|
||||
- provider_id: bedrock
|
||||
provider_type: remote::bedrock
|
||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
||||
api_key: ${env.NVIDIA_API_KEY:=}
|
||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
||||
- provider_id: openai
|
||||
provider_type: remote::openai
|
||||
config:
|
||||
api_key: ${env.OPENAI_API_KEY:=}
|
||||
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
|
||||
- provider_id: anthropic
|
||||
provider_type: remote::anthropic
|
||||
config:
|
||||
api_key: ${env.ANTHROPIC_API_KEY:=}
|
||||
- provider_id: gemini
|
||||
provider_type: remote::gemini
|
||||
config:
|
||||
api_key: ${env.GEMINI_API_KEY:=}
|
||||
- provider_id: ${env.VERTEX_AI_PROJECT:+vertexai}
|
||||
provider_type: remote::vertexai
|
||||
config:
|
||||
project: ${env.VERTEX_AI_PROJECT:=}
|
||||
location: ${env.VERTEX_AI_LOCATION:=us-central1}
|
||||
- provider_id: groq
|
||||
provider_type: remote::groq
|
||||
config:
|
||||
url: https://api.groq.com
|
||||
api_key: ${env.GROQ_API_KEY:=}
|
||||
- provider_id: sambanova
|
||||
provider_type: remote::sambanova
|
||||
config:
|
||||
url: https://api.sambanova.ai/v1
|
||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/faiss_store.db
|
||||
- provider_id: sqlite-vec
|
||||
provider_type: inline::sqlite-vec
|
||||
config:
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db
|
||||
- provider_id: ${env.MILVUS_URL:+milvus}
|
||||
provider_type: inline::milvus
|
||||
config:
|
||||
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter}/milvus.db
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/milvus_registry.db
|
||||
- provider_id: ${env.CHROMADB_URL:+chromadb}
|
||||
provider_type: remote::chromadb
|
||||
config:
|
||||
url: ${env.CHROMADB_URL:=}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter/}/chroma_remote_registry.db
|
||||
- provider_id: ${env.PGVECTOR_DB:+pgvector}
|
||||
provider_type: remote::pgvector
|
||||
config:
|
||||
host: ${env.PGVECTOR_HOST:=localhost}
|
||||
port: ${env.PGVECTOR_PORT:=5432}
|
||||
db: ${env.PGVECTOR_DB:=}
|
||||
user: ${env.PGVECTOR_USER:=}
|
||||
password: ${env.PGVECTOR_PASSWORD:=}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
excluded_categories: []
|
||||
- provider_id: code-scanner
|
||||
provider_type: inline::code-scanner
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
post_training:
|
||||
- provider_id: torchtune-gpu
|
||||
provider_type: inline::torchtune-gpu
|
||||
config:
|
||||
checkpoint_format: meta
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
- provider_id: llm-as-judge
|
||||
provider_type: inline::llm-as-judge
|
||||
- provider_id: braintrust
|
||||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:=}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
batches:
|
||||
- provider_id: reference
|
||||
provider_type: inline::reference
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/batches.db
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/inference_store.db
|
||||
models: []
|
||||
shields:
|
||||
- shield_id: llama-guard
|
||||
provider_id: ${env.SAFETY_MODEL:+llama-guard}
|
||||
provider_shield_id: ${env.SAFETY_MODEL:=}
|
||||
- shield_id: code-scanner
|
||||
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
|
||||
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
benchmarks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
22
llama_stack/distributions/starter-gpu/starter_gpu.py
Normal file
22
llama_stack/distributions/starter-gpu/starter_gpu.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.distributions.template import BuildProvider, DistributionTemplate
|
||||
|
||||
from ..starter.starter import get_distribution_template as get_starter_distribution_template
|
||||
|
||||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
template = get_starter_distribution_template()
|
||||
name = "starter-gpu"
|
||||
template.name = name
|
||||
template.description = "Quick start template for running Llama Stack with several popular providers. This distribution is intended for GPU-enabled environments."
|
||||
|
||||
template.providers["post_training"] = [
|
||||
BuildProvider(provider_type="inline::torchtune-gpu"),
|
||||
]
|
||||
return template
|
|
@ -1,6 +1,7 @@
|
|||
version: 2
|
||||
distribution_spec:
|
||||
description: Quick start template for running Llama Stack with several popular providers
|
||||
description: Quick start template for running Llama Stack with several popular providers.
|
||||
This distribution is intended for CPU-only environments.
|
||||
providers:
|
||||
inference:
|
||||
- provider_type: remote::cerebras
|
||||
|
@ -34,7 +35,7 @@ distribution_spec:
|
|||
telemetry:
|
||||
- provider_type: inline::meta-reference
|
||||
post_training:
|
||||
- provider_type: inline::huggingface
|
||||
- provider_type: inline::huggingface-cpu
|
||||
eval:
|
||||
- provider_type: inline::meta-reference
|
||||
datasetio:
|
||||
|
|
|
@ -156,8 +156,8 @@ providers:
|
|||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
post_training:
|
||||
- provider_id: huggingface
|
||||
provider_type: inline::huggingface
|
||||
- provider_id: huggingface-cpu
|
||||
provider_type: inline::huggingface-cpu
|
||||
config:
|
||||
checkpoint_format: huggingface
|
||||
distributed_backend: null
|
||||
|
|
|
@ -120,7 +120,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
],
|
||||
"agents": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"telemetry": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"post_training": [BuildProvider(provider_type="inline::huggingface")],
|
||||
"post_training": [BuildProvider(provider_type="inline::huggingface-cpu")],
|
||||
"eval": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"datasetio": [
|
||||
BuildProvider(provider_type="remote::huggingface"),
|
||||
|
@ -178,7 +178,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
return DistributionTemplate(
|
||||
name=name,
|
||||
distro_type="self_hosted",
|
||||
description="Quick start template for running Llama Stack with several popular providers",
|
||||
description="Quick start template for running Llama Stack with several popular providers. This distribution is intended for CPU-only environments.",
|
||||
container_image=None,
|
||||
template_path=None,
|
||||
providers=providers,
|
||||
|
|
|
@ -4,10 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import logging # allow-direct-logging
|
||||
import os
|
||||
import re
|
||||
from logging.config import dictConfig
|
||||
from logging.config import dictConfig # allow-direct-logging
|
||||
|
||||
from rich.console import Console
|
||||
from rich.errors import MarkupError
|
||||
|
|
|
@ -13,14 +13,15 @@
|
|||
|
||||
# Copyright (c) Meta Platforms, Inc. and its affiliates.
|
||||
import math
|
||||
from logging import getLogger
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .utils import get_negative_inf_value, to_2tuple
|
||||
|
||||
logger = getLogger()
|
||||
logger = get_logger(name=__name__, category="models::llama")
|
||||
|
||||
|
||||
def resize_local_position_embedding(orig_pos_embed, grid_size):
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from logging import getLogger
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
@ -21,9 +20,11 @@ import torchvision.transforms as tv
|
|||
from PIL import Image
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
IMAGE_RES = 224
|
||||
|
||||
logger = getLogger()
|
||||
logger = get_logger(name=__name__, category="models::llama")
|
||||
|
||||
|
||||
class VariableSizeImageTransform:
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import math
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
|
@ -22,6 +20,8 @@ from PIL import Image as PIL_Image
|
|||
from torch import Tensor, nn
|
||||
from torch.distributed import _functional_collectives as funcol
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from ..model import ModelArgs, RMSNorm, apply_rotary_emb, precompute_freqs_cis
|
||||
from .encoder_utils import (
|
||||
build_encoder_attention_mask,
|
||||
|
@ -34,9 +34,10 @@ from .encoder_utils import (
|
|||
from .image_transform import VariableSizeImageTransform
|
||||
from .utils import get_negative_inf_value, to_2tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
MP_SCALE = 8
|
||||
|
||||
logger = get_logger(name=__name__, category="models::llama")
|
||||
|
||||
|
||||
def reduce_from_tensor_model_parallel_region(input_):
|
||||
"""All-reduce the input tensor across model parallel group."""
|
||||
|
@ -771,7 +772,7 @@ class TilePositionEmbedding(nn.Module):
|
|||
if embed is not None:
|
||||
# reshape the weights to the correct shape
|
||||
nt_old, nt_old, _, w = embed.shape
|
||||
logging.info(f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}")
|
||||
logger.info(f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}")
|
||||
embed_new = TilePositionEmbedding._dynamic_resize(embed, self.num_tiles)
|
||||
# assign the weights to the module
|
||||
state_dict[prefix + "embedding"] = embed_new
|
||||
|
|
|
@ -4,8 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from collections.abc import Collection, Iterator, Sequence, Set
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Literal,
|
||||
|
@ -14,11 +14,9 @@ from typing import (
|
|||
|
||||
import tiktoken
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.tokenizer_utils import load_bpe_file
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
# The tiktoken tokenizer can handle <=400k chars without
|
||||
# pyo3_runtime.PanicException.
|
||||
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
||||
|
@ -31,6 +29,8 @@ MAX_NO_WHITESPACES_CHARS = 25_000
|
|||
|
||||
_INSTANCE = None
|
||||
|
||||
logger = get_logger(name=__name__, category="models::llama")
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""
|
||||
|
|
|
@ -11,7 +11,7 @@ from llama_stack.log import get_logger
|
|||
|
||||
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
logger = get_logger(name=__name__, category="models::llama")
|
||||
|
||||
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
|
||||
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
|
||||
|
@ -13,11 +12,13 @@ from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
|||
from torch import Tensor, nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from ...datatypes import QuantizationMode
|
||||
from ..model import Transformer, TransformerBlock
|
||||
from ..moe import MoE
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="models::llama")
|
||||
|
||||
|
||||
def swiglu_wrapper_no_reduce(
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Collection, Iterator, Sequence, Set
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Literal,
|
||||
|
@ -14,11 +13,9 @@ from typing import (
|
|||
|
||||
import tiktoken
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.tokenizer_utils import load_bpe_file
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
# The tiktoken tokenizer can handle <=400k chars without
|
||||
# pyo3_runtime.PanicException.
|
||||
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
||||
|
@ -101,6 +98,8 @@ BASIC_SPECIAL_TOKENS = [
|
|||
"<|fim_suffix|>",
|
||||
]
|
||||
|
||||
logger = get_logger(name=__name__, category="models::llama")
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""
|
||||
|
|
|
@ -6,9 +6,10 @@
|
|||
|
||||
# type: ignore
|
||||
import collections
|
||||
import logging
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
log = get_logger(name=__name__, category="models::llama")
|
||||
|
||||
try:
|
||||
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
||||
|
|
|
@ -84,7 +84,7 @@ MEMORY_QUERY_TOOL = "knowledge_search"
|
|||
WEB_SEARCH_TOOL = "web_search"
|
||||
RAG_TOOL_GROUP = "builtin::rag"
|
||||
|
||||
logger = get_logger(name=__name__, category="agents")
|
||||
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
||||
|
||||
class ChatAgent(ShieldRunnerMixin):
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import UTC, datetime
|
||||
|
@ -42,6 +41,7 @@ from llama_stack.apis.safety import Safety
|
|||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||
from llama_stack.providers.utils.pagination import paginate_records
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
|
@ -51,7 +51,7 @@ from .config import MetaReferenceAgentsImplConfig
|
|||
from .persistence import AgentInfo
|
||||
from .responses.openai_responses import OpenAIResponsesImpl
|
||||
|
||||
logger = logging.getLogger()
|
||||
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
||||
|
||||
class MetaReferenceAgentsImpl(Agents):
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
|
@ -15,9 +14,10 @@ from llama_stack.core.access_control.access_control import AccessDeniedError, is
|
|||
from llama_stack.core.access_control.datatypes import AccessRule
|
||||
from llama_stack.core.datatypes import User
|
||||
from llama_stack.core.request_headers import get_authenticated_user
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
||||
|
||||
class AgentSessionInfo(Session):
|
||||
|
|
|
@ -41,7 +41,7 @@ from .utils import (
|
|||
convert_response_text_to_chat_response_format,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="responses")
|
||||
logger = get_logger(name=__name__, category="openai::responses")
|
||||
|
||||
|
||||
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
|
||||
|
|
|
@ -47,7 +47,7 @@ from llama_stack.log import get_logger
|
|||
from .types import ChatCompletionContext, ChatCompletionResult
|
||||
from .utils import convert_chat_choice_to_response_message, is_function_tool_call
|
||||
|
||||
logger = get_logger(name=__name__, category="responses")
|
||||
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
||||
|
||||
class StreamingResponseOrchestrator:
|
||||
|
|
|
@ -38,7 +38,7 @@ from llama_stack.log import get_logger
|
|||
|
||||
from .types import ChatCompletionContext, ToolExecutionResult
|
||||
|
||||
logger = get_logger(name=__name__, category="responses")
|
||||
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
|
|
|
@ -17,6 +17,8 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseOutputMessageContent,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseOutputMessageMCPCall,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseText,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
|
@ -99,14 +101,22 @@ async def convert_response_input_to_chat_messages(
|
|||
"""
|
||||
messages: list[OpenAIMessageParam] = []
|
||||
if isinstance(input, list):
|
||||
# extract all OpenAIResponseInputFunctionToolCallOutput items
|
||||
# so their corresponding OpenAIToolMessageParam instances can
|
||||
# be added immediately following the corresponding
|
||||
# OpenAIAssistantMessageParam
|
||||
tool_call_results = {}
|
||||
for input_item in input:
|
||||
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
|
||||
messages.append(
|
||||
OpenAIToolMessageParam(
|
||||
content=input_item.output,
|
||||
tool_call_id=input_item.call_id,
|
||||
)
|
||||
tool_call_results[input_item.call_id] = OpenAIToolMessageParam(
|
||||
content=input_item.output,
|
||||
tool_call_id=input_item.call_id,
|
||||
)
|
||||
|
||||
for input_item in input:
|
||||
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
|
||||
# skip as these have been extracted and inserted in order
|
||||
pass
|
||||
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
|
@ -117,6 +127,28 @@ async def convert_response_input_to_chat_messages(
|
|||
),
|
||||
)
|
||||
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
|
||||
if input_item.call_id in tool_call_results:
|
||||
messages.append(tool_call_results[input_item.call_id])
|
||||
del tool_call_results[input_item.call_id]
|
||||
elif isinstance(input_item, OpenAIResponseOutputMessageMCPCall):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
id=input_item.id,
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name=input_item.name,
|
||||
arguments=input_item.arguments,
|
||||
),
|
||||
)
|
||||
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
|
||||
messages.append(
|
||||
OpenAIToolMessageParam(
|
||||
content=input_item.output,
|
||||
tool_call_id=input_item.id,
|
||||
)
|
||||
)
|
||||
elif isinstance(input_item, OpenAIResponseOutputMessageMCPListTools):
|
||||
# the tool list will be handled separately
|
||||
pass
|
||||
else:
|
||||
content = await convert_response_content_to_chat_content(input_item.content)
|
||||
message_type = await get_message_type_by_role(input_item.role)
|
||||
|
@ -125,6 +157,10 @@ async def convert_response_input_to_chat_messages(
|
|||
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
|
||||
)
|
||||
messages.append(message_type(content=content))
|
||||
if len(tool_call_results):
|
||||
raise ValueError(
|
||||
f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call"
|
||||
)
|
||||
else:
|
||||
messages.append(OpenAIUserMessageParam(content=input))
|
||||
return messages
|
||||
|
|
|
@ -5,13 +5,13 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
||||
|
||||
class SafetyException(Exception): # noqa: N818
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import itertools
|
||||
import json
|
||||
import time
|
||||
|
@ -136,28 +137,45 @@ class ReferenceBatchesImpl(Batches):
|
|||
endpoint: str,
|
||||
completion_window: Literal["24h"],
|
||||
metadata: dict[str, str] | None = None,
|
||||
idempotency_key: str | None = None,
|
||||
) -> BatchObject:
|
||||
"""
|
||||
Create a new batch for processing multiple API requests.
|
||||
|
||||
Error handling by levels -
|
||||
0. Input param handling, results in 40x errors before processing, e.g.
|
||||
- Wrong completion_window
|
||||
- Invalid metadata types
|
||||
- Unknown endpoint
|
||||
-> no batch created
|
||||
1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g.
|
||||
- input_file_id missing
|
||||
- invalid json in file
|
||||
- missing custom_id, method, url, body
|
||||
- invalid model
|
||||
- streaming
|
||||
-> batch created, validation sends to failed status
|
||||
2. Processing errors, result in error_file_id entries, e.g.
|
||||
- Any error returned from inference endpoint
|
||||
-> batch created, goes to completed status
|
||||
This implementation provides optional idempotency: when an idempotency key
|
||||
(idempotency_key) is provided, a deterministic ID is generated based on the input
|
||||
parameters. If a batch with the same parameters already exists, it will be
|
||||
returned instead of creating a duplicate. Without an idempotency key,
|
||||
each request creates a new batch with a unique ID.
|
||||
|
||||
Args:
|
||||
input_file_id: The ID of an uploaded file containing requests for the batch.
|
||||
endpoint: The endpoint to be used for all requests in the batch.
|
||||
completion_window: The time window within which the batch should be processed.
|
||||
metadata: Optional metadata for the batch.
|
||||
idempotency_key: Optional idempotency key for enabling idempotent behavior.
|
||||
|
||||
Returns:
|
||||
The created or existing batch object.
|
||||
"""
|
||||
|
||||
# Error handling by levels -
|
||||
# 0. Input param handling, results in 40x errors before processing, e.g.
|
||||
# - Wrong completion_window
|
||||
# - Invalid metadata types
|
||||
# - Unknown endpoint
|
||||
# -> no batch created
|
||||
# 1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g.
|
||||
# - input_file_id missing
|
||||
# - invalid json in file
|
||||
# - missing custom_id, method, url, body
|
||||
# - invalid model
|
||||
# - streaming
|
||||
# -> batch created, validation sends to failed status
|
||||
# 2. Processing errors, result in error_file_id entries, e.g.
|
||||
# - Any error returned from inference endpoint
|
||||
# -> batch created, goes to completed status
|
||||
|
||||
# TODO: set expiration time for garbage collection
|
||||
|
||||
if endpoint not in ["/v1/chat/completions"]:
|
||||
|
@ -171,6 +189,35 @@ class ReferenceBatchesImpl(Batches):
|
|||
)
|
||||
|
||||
batch_id = f"batch_{uuid.uuid4().hex[:16]}"
|
||||
|
||||
# For idempotent requests, use the idempotency key for the batch ID
|
||||
# This ensures the same key always maps to the same batch ID,
|
||||
# allowing us to detect parameter conflicts
|
||||
if idempotency_key is not None:
|
||||
hash_input = idempotency_key.encode("utf-8")
|
||||
hash_digest = hashlib.sha256(hash_input).hexdigest()[:24]
|
||||
batch_id = f"batch_{hash_digest}"
|
||||
|
||||
try:
|
||||
existing_batch = await self.retrieve_batch(batch_id)
|
||||
|
||||
if (
|
||||
existing_batch.input_file_id != input_file_id
|
||||
or existing_batch.endpoint != endpoint
|
||||
or existing_batch.completion_window != completion_window
|
||||
or existing_batch.metadata != metadata
|
||||
):
|
||||
raise ConflictError(
|
||||
f"Idempotency key '{idempotency_key}' was previously used with different parameters. "
|
||||
"Either use a new idempotency key or ensure all parameters match the original request."
|
||||
)
|
||||
|
||||
logger.info(f"Returning existing batch with ID: {batch_id}")
|
||||
return existing_batch
|
||||
except ResourceNotFoundError:
|
||||
# Batch doesn't exist, continue with creation
|
||||
pass
|
||||
|
||||
current_time = int(time.time())
|
||||
|
||||
batch = BatchObject(
|
||||
|
@ -185,6 +232,7 @@ class ReferenceBatchesImpl(Batches):
|
|||
)
|
||||
|
||||
await self.kvstore.set(f"batch:{batch_id}", batch.to_json())
|
||||
logger.info(f"Created new batch with ID: {batch_id}")
|
||||
|
||||
if self.process_batches:
|
||||
task = asyncio.create_task(self._process_batch(batch_id))
|
||||
|
|
|
@ -11,6 +11,7 @@ from typing import Annotated
|
|||
|
||||
from fastapi import File, Form, Response, UploadFile
|
||||
|
||||
from llama_stack.apis.common.errors import ResourceNotFoundError
|
||||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.apis.files import (
|
||||
Files,
|
||||
|
@ -20,12 +21,15 @@ from llama_stack.apis.files import (
|
|||
OpenAIFilePurpose,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
|
||||
|
||||
from .config import LocalfsFilesImplConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="files")
|
||||
|
||||
|
||||
class LocalfsFilesImpl(Files):
|
||||
def __init__(self, config: LocalfsFilesImplConfig, policy: list[AccessRule]) -> None:
|
||||
|
@ -65,6 +69,18 @@ class LocalfsFilesImpl(Files):
|
|||
"""Get the filesystem path for a file ID."""
|
||||
return Path(self.config.storage_dir) / file_id
|
||||
|
||||
async def _lookup_file_id(self, file_id: str) -> tuple[OpenAIFileObject, Path]:
|
||||
"""Look up a OpenAIFileObject and filesystem path from its ID."""
|
||||
if not self.sql_store:
|
||||
raise RuntimeError("Files provider not initialized")
|
||||
|
||||
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
|
||||
if not row:
|
||||
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
|
||||
|
||||
file_path = Path(row.pop("file_path"))
|
||||
return OpenAIFileObject(**row), file_path
|
||||
|
||||
# OpenAI Files API Implementation
|
||||
async def openai_upload_file(
|
||||
self,
|
||||
|
@ -157,37 +173,19 @@ class LocalfsFilesImpl(Files):
|
|||
|
||||
async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject:
|
||||
"""Returns information about a specific file."""
|
||||
if not self.sql_store:
|
||||
raise RuntimeError("Files provider not initialized")
|
||||
file_obj, _ = await self._lookup_file_id(file_id)
|
||||
|
||||
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
|
||||
if not row:
|
||||
raise ValueError(f"File with id {file_id} not found")
|
||||
|
||||
return OpenAIFileObject(
|
||||
id=row["id"],
|
||||
filename=row["filename"],
|
||||
purpose=OpenAIFilePurpose(row["purpose"]),
|
||||
bytes=row["bytes"],
|
||||
created_at=row["created_at"],
|
||||
expires_at=row["expires_at"],
|
||||
)
|
||||
return file_obj
|
||||
|
||||
async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse:
|
||||
"""Delete a file."""
|
||||
if not self.sql_store:
|
||||
raise RuntimeError("Files provider not initialized")
|
||||
|
||||
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
|
||||
if not row:
|
||||
raise ValueError(f"File with id {file_id} not found")
|
||||
|
||||
# Delete physical file
|
||||
file_path = Path(row["file_path"])
|
||||
_, file_path = await self._lookup_file_id(file_id)
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
|
||||
# Delete metadata from database
|
||||
assert self.sql_store is not None, "Files provider not initialized"
|
||||
await self.sql_store.delete("openai_files", where={"id": file_id})
|
||||
|
||||
return OpenAIFileDeleteResponse(
|
||||
|
@ -197,25 +195,17 @@ class LocalfsFilesImpl(Files):
|
|||
|
||||
async def openai_retrieve_file_content(self, file_id: str) -> Response:
|
||||
"""Returns the contents of the specified file."""
|
||||
if not self.sql_store:
|
||||
raise RuntimeError("Files provider not initialized")
|
||||
|
||||
# Get file metadata
|
||||
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
|
||||
if not row:
|
||||
raise ValueError(f"File with id {file_id} not found")
|
||||
|
||||
# Read file content
|
||||
file_path = Path(row["file_path"])
|
||||
if not file_path.exists():
|
||||
raise ValueError(f"File content not found on disk: {file_path}")
|
||||
file_obj, file_path = await self._lookup_file_id(file_id)
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
content = f.read()
|
||||
if not file_path.exists():
|
||||
logger.warning(f"File '{file_id}'s underlying '{file_path}' is missing, deleting metadata.")
|
||||
await self.openai_delete_file(file_id)
|
||||
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
|
||||
|
||||
# Return as binary response with appropriate content type
|
||||
return Response(
|
||||
content=content,
|
||||
content=file_path.read_bytes(),
|
||||
media_type="application/octet-stream",
|
||||
headers={"Content-Disposition": f'attachment; filename="{row["filename"]}"'},
|
||||
headers={"Content-Disposition": f'attachment; filename="{file_obj.filename}"'},
|
||||
)
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import tempfile
|
||||
|
@ -32,13 +31,14 @@ from fairscale.nn.model_parallel.initialize import (
|
|||
from pydantic import BaseModel, Field
|
||||
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import GenerationResult
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class ProcessingMessageName(str, Enum):
|
||||
|
|
|
@ -4,13 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionResponse,
|
||||
InferenceProvider,
|
||||
InterleavedContent,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
|
@ -21,6 +19,7 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
|
@ -32,7 +31,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
|
||||
from .config import SentenceTransformersInferenceConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class SentenceTransformersInferenceImpl(
|
||||
|
@ -100,25 +99,3 @@ class SentenceTransformersInferenceImpl(
|
|||
tool_config: ToolConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
raise ValueError("Sentence transformers don't support chat completion")
|
||||
|
||||
async def batch_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content_batch: list[InterleavedContent],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
):
|
||||
raise NotImplementedError("Batch completion is not supported for Sentence Transformers")
|
||||
|
||||
async def batch_chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages_batch: list[list[Message]],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
):
|
||||
raise NotImplementedError("Batch chat completion is not supported for Sentence Transformers")
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
import gc
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
@ -28,6 +27,7 @@ from llama_stack.apis.post_training import (
|
|||
LoraFinetuningConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||
|
||||
from ..config import HuggingFacePostTrainingConfig
|
||||
|
@ -44,7 +44,7 @@ from ..utils import (
|
|||
split_dataset,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="post_training")
|
||||
|
||||
|
||||
class HFFinetuningSingleDevice:
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
@ -24,6 +23,7 @@ from llama_stack.apis.post_training import (
|
|||
DPOAlignmentConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||
|
||||
from ..config import HuggingFacePostTrainingConfig
|
||||
|
@ -40,7 +40,7 @@ from ..utils import (
|
|||
split_dataset,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="post_training")
|
||||
|
||||
|
||||
class HFDPOAlignmentSingleDevice:
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
|
@ -19,10 +18,11 @@ from transformers import AutoConfig, AutoModelForCausalLM
|
|||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.post_training import Checkpoint, TrainingConfig
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .config import HuggingFacePostTrainingConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="post_training")
|
||||
|
||||
|
||||
def setup_environment():
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
|
@ -19,6 +18,7 @@ from torch.utils.data import DataLoader, DistributedSampler
|
|||
from torchtune import modules, training
|
||||
from torchtune import utils as torchtune_utils
|
||||
from torchtune.data import padded_collate_sft
|
||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
||||
from torchtune.modules.peft import (
|
||||
get_adapter_params,
|
||||
|
@ -45,6 +45,7 @@ from llama_stack.apis.post_training import (
|
|||
)
|
||||
from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
from llama_stack.core.utils.model_utils import model_local_dir
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||
from llama_stack.providers.inline.post_training.torchtune.common import utils
|
||||
|
@ -56,9 +57,7 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
|
|||
)
|
||||
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||
log = get_logger(name=__name__, category="post_training")
|
||||
|
||||
|
||||
class LoraFinetuningSingleDevice:
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
|
@ -20,13 +19,14 @@ from llama_stack.apis.safety import (
|
|||
)
|
||||
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
from .config import CodeScannerConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="safety")
|
||||
|
||||
ALLOWED_CODE_SCANNER_MODEL_IDS = [
|
||||
"code-scanner",
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from string import Template
|
||||
|
@ -21,6 +20,7 @@ from llama_stack.apis.safety import (
|
|||
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.core.datatypes import Api
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import Role
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
|
@ -132,6 +132,8 @@ Provide your safety assessment for ONLY THE LAST $agent_type message in the abov
|
|||
|
||||
PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}")
|
||||
|
||||
logger = get_logger(name=__name__, category="safety")
|
||||
|
||||
|
||||
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||
def __init__(self, config: LlamaGuardConfig, deps) -> None:
|
||||
|
@ -407,7 +409,7 @@ class LlamaGuardShield:
|
|||
unsafe_code_list = [code.strip() for code in unsafe_code.split(",")]
|
||||
invalid_codes = [code for code in unsafe_code_list if code not in SAFETY_CODE_TO_CATEGORIES_MAP]
|
||||
if invalid_codes:
|
||||
logging.warning(f"Invalid safety codes returned: {invalid_codes}")
|
||||
logger.warning(f"Invalid safety codes returned: {invalid_codes}")
|
||||
# just returning safe object, as we don't know what the invalid codes can map to
|
||||
return ModerationObject(
|
||||
id=f"modr-{uuid.uuid4()}",
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
@ -21,6 +20,7 @@ from llama_stack.apis.safety import (
|
|||
from llama_stack.apis.safety.safety import ModerationObject
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.core.utils.model_utils import model_local_dir
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
|
@ -28,7 +28,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .config import PromptGuardConfig, PromptGuardType
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="safety")
|
||||
|
||||
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
||||
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
import collections
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
|
@ -20,7 +19,9 @@ import nltk
|
|||
from pythainlp.tokenize import sent_tokenize as sent_tokenize_thai
|
||||
from pythainlp.tokenize import word_tokenize as word_tokenize_thai
|
||||
|
||||
logger = logging.getLogger()
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="scoring")
|
||||
|
||||
WORD_LIST = [
|
||||
"western",
|
||||
|
|
|
@ -4,13 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import datetime
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from opentelemetry import metrics, trace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.metrics import MeterProvider
|
||||
|
@ -40,6 +38,7 @@ from llama_stack.apis.telemetry import (
|
|||
UnstructuredLogEvent,
|
||||
)
|
||||
from llama_stack.core.datatypes import Api
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import (
|
||||
ConsoleSpanProcessor,
|
||||
)
|
||||
|
@ -61,6 +60,8 @@ _GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = {
|
|||
_global_lock = threading.Lock()
|
||||
_TRACER_PROVIDER = None
|
||||
|
||||
logger = get_logger(name=__name__, category="telemetry")
|
||||
|
||||
|
||||
def is_tracing_enabled(tracer):
|
||||
with tracer.start_as_current_span("check_tracing") as span:
|
||||
|
@ -145,11 +146,41 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
metric_name: str,
|
||||
start_time: int,
|
||||
end_time: int | None = None,
|
||||
granularity: str | None = "1d",
|
||||
granularity: str | None = None,
|
||||
query_type: MetricQueryType = MetricQueryType.RANGE,
|
||||
label_matchers: list[MetricLabelMatcher] | None = None,
|
||||
) -> QueryMetricsResponse:
|
||||
raise NotImplementedError("Querying metrics is not implemented")
|
||||
"""Query metrics from the telemetry store.
|
||||
|
||||
Args:
|
||||
metric_name: The name of the metric to query (e.g., "prompt_tokens")
|
||||
start_time: Start time as Unix timestamp
|
||||
end_time: End time as Unix timestamp (defaults to now if None)
|
||||
granularity: Time granularity for aggregation
|
||||
query_type: Type of query (RANGE or INSTANT)
|
||||
label_matchers: Label filters to apply
|
||||
|
||||
Returns:
|
||||
QueryMetricsResponse with metric time series data
|
||||
"""
|
||||
# Convert timestamps to datetime objects
|
||||
start_dt = datetime.datetime.fromtimestamp(start_time, datetime.UTC)
|
||||
end_dt = datetime.datetime.fromtimestamp(end_time, datetime.UTC) if end_time else None
|
||||
|
||||
# Use SQLite trace store if available
|
||||
if hasattr(self, "trace_store") and self.trace_store:
|
||||
return await self.trace_store.query_metrics(
|
||||
metric_name=metric_name,
|
||||
start_time=start_dt,
|
||||
end_time=end_dt,
|
||||
granularity=granularity,
|
||||
query_type=query_type,
|
||||
label_matchers=label_matchers,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"In order to query_metrics, you must have {TelemetrySink.SQLITE} set in your telemetry sinks"
|
||||
)
|
||||
|
||||
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
|
||||
with self._lock:
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import secrets
|
||||
import string
|
||||
from typing import Any
|
||||
|
@ -32,6 +31,7 @@ from llama_stack.apis.tools import (
|
|||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
|
@ -42,7 +42,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
from .config import RagToolRuntimeConfig
|
||||
from .context_retriever import generate_rag_query
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="tool_runtime")
|
||||
|
||||
|
||||
def make_random_string(length: int = 8):
|
||||
|
|
|
@ -8,7 +8,6 @@ import asyncio
|
|||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import faiss
|
||||
|
@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import (
|
|||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
|
@ -40,7 +40,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
|
||||
from .config import FaissVectorIOConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="vector_io")
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::"
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import sqlite3
|
||||
import struct
|
||||
|
@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import (
|
|||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
|
@ -36,7 +36,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
VectorDBWithIndex,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="vector_io")
|
||||
|
||||
# Specifying search mode is dependent on the VectorIO provider.
|
||||
VECTOR_SEARCH = "vector"
|
||||
|
|
|
@ -5,9 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages
|
||||
|
||||
|
@ -23,4 +25,14 @@ def available_providers() -> list[ProviderSpec]:
|
|||
config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig",
|
||||
description="Local filesystem-based file storage provider for managing files and documents locally.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.files,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="s3",
|
||||
pip_packages=["boto3"] + sql_store_pip_packages,
|
||||
module="llama_stack.providers.remote.files.s3",
|
||||
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
|
||||
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
@ -5,34 +5,74 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import cast
|
||||
|
||||
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
|
||||
|
||||
# We provide two versions of these providers so that distributions can package the appropriate version of torch.
|
||||
# The CPU version is used for distributions that don't have GPU support -- they result in smaller container images.
|
||||
torchtune_def = dict(
|
||||
api=Api.post_training,
|
||||
pip_packages=["torchtune==0.5.0", "torchao==0.8.0", "numpy"],
|
||||
module="llama_stack.providers.inline.post_training.torchtune",
|
||||
config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig",
|
||||
api_dependencies=[
|
||||
Api.datasetio,
|
||||
Api.datasets,
|
||||
],
|
||||
description="TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework.",
|
||||
)
|
||||
|
||||
huggingface_def = dict(
|
||||
api=Api.post_training,
|
||||
pip_packages=["trl", "transformers", "peft", "datasets"],
|
||||
module="llama_stack.providers.inline.post_training.huggingface",
|
||||
config_class="llama_stack.providers.inline.post_training.huggingface.HuggingFacePostTrainingConfig",
|
||||
api_dependencies=[
|
||||
Api.datasetio,
|
||||
Api.datasets,
|
||||
],
|
||||
description="HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.",
|
||||
)
|
||||
|
||||
|
||||
def available_providers() -> list[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.post_training,
|
||||
provider_type="inline::torchtune",
|
||||
pip_packages=["torch", "torchtune==0.5.0", "torchao==0.8.0", "numpy"],
|
||||
module="llama_stack.providers.inline.post_training.torchtune",
|
||||
config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig",
|
||||
api_dependencies=[
|
||||
Api.datasetio,
|
||||
Api.datasets,
|
||||
],
|
||||
description="TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework.",
|
||||
**{
|
||||
**torchtune_def,
|
||||
"provider_type": "inline::torchtune-cpu",
|
||||
"pip_packages": (
|
||||
cast(list[str], torchtune_def["pip_packages"])
|
||||
+ ["torch torchtune==0.5.0 torchao==0.8.0 --index-url https://download.pytorch.org/whl/cpu"]
|
||||
),
|
||||
},
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.post_training,
|
||||
provider_type="inline::huggingface",
|
||||
pip_packages=["torch", "trl", "transformers", "peft", "datasets"],
|
||||
module="llama_stack.providers.inline.post_training.huggingface",
|
||||
config_class="llama_stack.providers.inline.post_training.huggingface.HuggingFacePostTrainingConfig",
|
||||
api_dependencies=[
|
||||
Api.datasetio,
|
||||
Api.datasets,
|
||||
],
|
||||
description="HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.",
|
||||
**{
|
||||
**huggingface_def,
|
||||
"provider_type": "inline::huggingface-cpu",
|
||||
"pip_packages": (
|
||||
cast(list[str], huggingface_def["pip_packages"])
|
||||
+ ["torch --index-url https://download.pytorch.org/whl/cpu"]
|
||||
),
|
||||
},
|
||||
),
|
||||
InlineProviderSpec(
|
||||
**{
|
||||
**torchtune_def,
|
||||
"provider_type": "inline::torchtune-gpu",
|
||||
"pip_packages": (
|
||||
cast(list[str], torchtune_def["pip_packages"]) + ["torch torchtune==0.5.0 torchao==0.8.0"]
|
||||
),
|
||||
},
|
||||
),
|
||||
InlineProviderSpec(
|
||||
**{
|
||||
**huggingface_def,
|
||||
"provider_type": "inline::huggingface-gpu",
|
||||
"pip_packages": (cast(list[str], huggingface_def["pip_packages"]) + ["torch"]),
|
||||
},
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.post_training,
|
||||
|
|
237
llama_stack/providers/remote/files/s3/README.md
Normal file
237
llama_stack/providers/remote/files/s3/README.md
Normal file
|
@ -0,0 +1,237 @@
|
|||
# S3 Files Provider
|
||||
|
||||
A remote S3-based implementation of the Llama Stack Files API that provides scalable cloud file storage with metadata persistence.
|
||||
|
||||
## Features
|
||||
|
||||
- **AWS S3 Storage**: Store files in AWS S3 buckets for scalable, durable storage
|
||||
- **Metadata Management**: Uses SQL database for efficient file metadata queries
|
||||
- **OpenAI API Compatibility**: Full compatibility with OpenAI Files API endpoints
|
||||
- **Flexible Authentication**: Support for IAM roles and access keys
|
||||
- **Custom S3 Endpoints**: Support for MinIO and other S3-compatible services
|
||||
|
||||
## Configuration
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
```yaml
|
||||
api: files
|
||||
provider_type: remote::s3
|
||||
config:
|
||||
bucket_name: my-llama-stack-files
|
||||
region: us-east-1
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ./s3_files_metadata.db
|
||||
```
|
||||
|
||||
### Advanced Configuration
|
||||
|
||||
```yaml
|
||||
api: files
|
||||
provider_type: remote::s3
|
||||
config:
|
||||
bucket_name: my-llama-stack-files
|
||||
region: us-east-1
|
||||
aws_access_key_id: YOUR_ACCESS_KEY
|
||||
aws_secret_access_key: YOUR_SECRET_KEY
|
||||
endpoint_url: https://s3.amazonaws.com # Optional for custom endpoints
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ./s3_files_metadata.db
|
||||
```
|
||||
|
||||
### Environment Variables
|
||||
|
||||
The configuration supports environment variable substitution:
|
||||
|
||||
```yaml
|
||||
config:
|
||||
bucket_name: "${env.S3_BUCKET_NAME}"
|
||||
region: "${env.AWS_REGION:=us-east-1}"
|
||||
aws_access_key_id: "${env.AWS_ACCESS_KEY_ID:=}"
|
||||
aws_secret_access_key: "${env.AWS_SECRET_ACCESS_KEY:=}"
|
||||
endpoint_url: "${env.S3_ENDPOINT_URL:=}"
|
||||
```
|
||||
|
||||
Note: `S3_BUCKET_NAME` has no default value since S3 bucket names must be globally unique.
|
||||
|
||||
## Authentication
|
||||
|
||||
### IAM Roles (Recommended)
|
||||
|
||||
For production deployments, use IAM roles:
|
||||
|
||||
```yaml
|
||||
config:
|
||||
bucket_name: my-bucket
|
||||
region: us-east-1
|
||||
# No credentials needed - will use IAM role
|
||||
```
|
||||
|
||||
### Access Keys
|
||||
|
||||
For development or specific use cases:
|
||||
|
||||
```yaml
|
||||
config:
|
||||
bucket_name: my-bucket
|
||||
region: us-east-1
|
||||
aws_access_key_id: AKIAIOSFODNN7EXAMPLE
|
||||
aws_secret_access_key: wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY
|
||||
```
|
||||
|
||||
## S3 Bucket Setup
|
||||
|
||||
### Required Permissions
|
||||
|
||||
The S3 provider requires the following permissions:
|
||||
|
||||
```json
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": [
|
||||
"s3:GetObject",
|
||||
"s3:PutObject",
|
||||
"s3:DeleteObject",
|
||||
"s3:ListBucket"
|
||||
],
|
||||
"Resource": [
|
||||
"arn:aws:s3:::your-bucket-name",
|
||||
"arn:aws:s3:::your-bucket-name/*"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Automatic Bucket Creation
|
||||
|
||||
By default, the S3 provider expects the bucket to already exist. If you want the provider to automatically create the bucket when it doesn't exist, set `auto_create_bucket: true` in your configuration:
|
||||
|
||||
```yaml
|
||||
config:
|
||||
bucket_name: my-bucket
|
||||
auto_create_bucket: true # Will create bucket if it doesn't exist
|
||||
region: us-east-1
|
||||
```
|
||||
|
||||
**Note**: When `auto_create_bucket` is enabled, the provider will need additional permissions:
|
||||
|
||||
```json
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": [
|
||||
"s3:GetObject",
|
||||
"s3:PutObject",
|
||||
"s3:DeleteObject",
|
||||
"s3:ListBucket",
|
||||
"s3:CreateBucket"
|
||||
],
|
||||
"Resource": [
|
||||
"arn:aws:s3:::your-bucket-name",
|
||||
"arn:aws:s3:::your-bucket-name/*"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Bucket Policy (Optional)
|
||||
|
||||
For additional security, you can add a bucket policy:
|
||||
|
||||
```json
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Sid": "LlamaStackAccess",
|
||||
"Effect": "Allow",
|
||||
"Principal": {
|
||||
"AWS": "arn:aws:iam::YOUR-ACCOUNT:role/LlamaStackRole"
|
||||
},
|
||||
"Action": [
|
||||
"s3:GetObject",
|
||||
"s3:PutObject",
|
||||
"s3:DeleteObject"
|
||||
],
|
||||
"Resource": "arn:aws:s3:::your-bucket-name/*"
|
||||
},
|
||||
{
|
||||
"Sid": "LlamaStackBucketAccess",
|
||||
"Effect": "Allow",
|
||||
"Principal": {
|
||||
"AWS": "arn:aws:iam::YOUR-ACCOUNT:role/LlamaStackRole"
|
||||
},
|
||||
"Action": [
|
||||
"s3:ListBucket"
|
||||
],
|
||||
"Resource": "arn:aws:s3:::your-bucket-name"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
### Metadata Persistence
|
||||
|
||||
File metadata is stored in a SQL database for fast queries and OpenAI API compatibility. The metadata includes:
|
||||
|
||||
- File ID
|
||||
- Original filename
|
||||
- Purpose (assistants, batch, etc.)
|
||||
- File size in bytes
|
||||
- Created and expiration timestamps
|
||||
|
||||
### TTL and Cleanup
|
||||
|
||||
Files currently have a fixed long expiration time (100 years).
|
||||
|
||||
## Development and Testing
|
||||
|
||||
### Using MinIO
|
||||
|
||||
For self-hosted S3-compatible storage:
|
||||
|
||||
```yaml
|
||||
config:
|
||||
bucket_name: test-bucket
|
||||
region: us-east-1
|
||||
endpoint_url: http://localhost:9000
|
||||
aws_access_key_id: minioadmin
|
||||
aws_secret_access_key: minioadmin
|
||||
```
|
||||
|
||||
## Monitoring and Logging
|
||||
|
||||
The provider logs important operations and errors. For production deployments, consider:
|
||||
|
||||
- CloudWatch monitoring for S3 operations
|
||||
- Custom metrics for file upload/download rates
|
||||
- Error rate monitoring
|
||||
- Performance metrics tracking
|
||||
|
||||
## Error Handling
|
||||
|
||||
The provider handles various error scenarios:
|
||||
|
||||
- S3 connectivity issues
|
||||
- Bucket access permissions
|
||||
- File not found errors
|
||||
- Metadata consistency checks
|
||||
|
||||
## Known Limitations
|
||||
|
||||
- Fixed long TTL (100 years) instead of configurable expiration
|
||||
- No server-side encryption enabled by default
|
||||
- No support for AWS session tokens
|
||||
- No S3 key prefix organization support
|
||||
- No multipart upload support (all files uploaded as single objects)
|
20
llama_stack/providers/remote/files/s3/__init__.py
Normal file
20
llama_stack/providers/remote/files/s3/__init__.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import Api
|
||||
|
||||
from .config import S3FilesImplConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: S3FilesImplConfig, deps: dict[Api, Any]):
|
||||
from .files import S3FilesImpl
|
||||
|
||||
# TODO: authorization policies and user separation
|
||||
impl = S3FilesImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
42
llama_stack/providers/remote/files/s3/config.py
Normal file
42
llama_stack/providers/remote/files/s3/config.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
|
||||
|
||||
|
||||
class S3FilesImplConfig(BaseModel):
|
||||
"""Configuration for S3-based files provider."""
|
||||
|
||||
bucket_name: str = Field(description="S3 bucket name to store files")
|
||||
region: str = Field(default="us-east-1", description="AWS region where the bucket is located")
|
||||
aws_access_key_id: str | None = Field(default=None, description="AWS access key ID (optional if using IAM roles)")
|
||||
aws_secret_access_key: str | None = Field(
|
||||
default=None, description="AWS secret access key (optional if using IAM roles)"
|
||||
)
|
||||
endpoint_url: str | None = Field(default=None, description="Custom S3 endpoint URL (for MinIO, LocalStack, etc.)")
|
||||
auto_create_bucket: bool = Field(
|
||||
default=False, description="Automatically create the S3 bucket if it doesn't exist"
|
||||
)
|
||||
metadata_store: SqlStoreConfig = Field(description="SQL store configuration for file metadata")
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
return {
|
||||
"bucket_name": "${env.S3_BUCKET_NAME}", # no default, buckets must be globally unique
|
||||
"region": "${env.AWS_REGION:=us-east-1}",
|
||||
"aws_access_key_id": "${env.AWS_ACCESS_KEY_ID:=}",
|
||||
"aws_secret_access_key": "${env.AWS_SECRET_ACCESS_KEY:=}",
|
||||
"endpoint_url": "${env.S3_ENDPOINT_URL:=}",
|
||||
"auto_create_bucket": "${env.S3_AUTO_CREATE_BUCKET:=false}",
|
||||
"metadata_store": SqliteSqlStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="s3_files_metadata.db",
|
||||
),
|
||||
}
|
272
llama_stack/providers/remote/files/s3/files.py
Normal file
272
llama_stack/providers/remote/files/s3/files.py
Normal file
|
@ -0,0 +1,272 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError
|
||||
from fastapi import File, Form, Response, UploadFile
|
||||
|
||||
from llama_stack.apis.common.errors import ResourceNotFoundError
|
||||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.apis.files import (
|
||||
Files,
|
||||
ListOpenAIFileResponse,
|
||||
OpenAIFileDeleteResponse,
|
||||
OpenAIFileObject,
|
||||
OpenAIFilePurpose,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStore, sqlstore_impl
|
||||
|
||||
from .config import S3FilesImplConfig
|
||||
|
||||
# TODO: provider data for S3 credentials
|
||||
|
||||
|
||||
def _create_s3_client(config: S3FilesImplConfig) -> boto3.client:
|
||||
try:
|
||||
s3_config = {
|
||||
"region_name": config.region,
|
||||
}
|
||||
|
||||
# endpoint URL if specified (for MinIO, LocalStack, etc.)
|
||||
if config.endpoint_url:
|
||||
s3_config["endpoint_url"] = config.endpoint_url
|
||||
|
||||
if config.aws_access_key_id and config.aws_secret_access_key:
|
||||
s3_config.update(
|
||||
{
|
||||
"aws_access_key_id": config.aws_access_key_id,
|
||||
"aws_secret_access_key": config.aws_secret_access_key,
|
||||
}
|
||||
)
|
||||
|
||||
return boto3.client("s3", **s3_config)
|
||||
|
||||
except (BotoCoreError, NoCredentialsError) as e:
|
||||
raise RuntimeError(f"Failed to initialize S3 client: {e}") from e
|
||||
|
||||
|
||||
async def _create_bucket_if_not_exists(client: boto3.client, config: S3FilesImplConfig) -> None:
|
||||
try:
|
||||
client.head_bucket(Bucket=config.bucket_name)
|
||||
except ClientError as e:
|
||||
error_code = e.response["Error"]["Code"]
|
||||
if error_code == "404":
|
||||
if not config.auto_create_bucket:
|
||||
raise RuntimeError(
|
||||
f"S3 bucket '{config.bucket_name}' does not exist. "
|
||||
f"Either create the bucket manually or set 'auto_create_bucket: true' in your configuration."
|
||||
) from e
|
||||
try:
|
||||
# For us-east-1, we can't specify LocationConstraint
|
||||
if config.region == "us-east-1":
|
||||
client.create_bucket(Bucket=config.bucket_name)
|
||||
else:
|
||||
client.create_bucket(
|
||||
Bucket=config.bucket_name,
|
||||
CreateBucketConfiguration={"LocationConstraint": config.region},
|
||||
)
|
||||
except ClientError as create_error:
|
||||
raise RuntimeError(
|
||||
f"Failed to create S3 bucket '{config.bucket_name}': {create_error}"
|
||||
) from create_error
|
||||
elif error_code == "403":
|
||||
raise RuntimeError(f"Access denied to S3 bucket '{config.bucket_name}'") from e
|
||||
else:
|
||||
raise RuntimeError(f"Failed to access S3 bucket '{config.bucket_name}': {e}") from e
|
||||
|
||||
|
||||
class S3FilesImpl(Files):
|
||||
"""S3-based implementation of the Files API."""
|
||||
|
||||
# TODO: implement expiration, for now a silly offset
|
||||
_SILLY_EXPIRATION_OFFSET = 100 * 365 * 24 * 60 * 60
|
||||
|
||||
def __init__(self, config: S3FilesImplConfig) -> None:
|
||||
self._config = config
|
||||
self._client: boto3.client | None = None
|
||||
self._sql_store: SqlStore | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self._client = _create_s3_client(self._config)
|
||||
await _create_bucket_if_not_exists(self._client, self._config)
|
||||
|
||||
self._sql_store = sqlstore_impl(self._config.metadata_store)
|
||||
await self._sql_store.create_table(
|
||||
"openai_files",
|
||||
{
|
||||
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||
"filename": ColumnType.STRING,
|
||||
"purpose": ColumnType.STRING,
|
||||
"bytes": ColumnType.INTEGER,
|
||||
"created_at": ColumnType.INTEGER,
|
||||
"expires_at": ColumnType.INTEGER,
|
||||
# TODO: add s3_etag field for integrity checking
|
||||
},
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def client(self) -> boto3.client:
|
||||
assert self._client is not None, "Provider not initialized"
|
||||
return self._client
|
||||
|
||||
@property
|
||||
def sql_store(self) -> SqlStore:
|
||||
assert self._sql_store is not None, "Provider not initialized"
|
||||
return self._sql_store
|
||||
|
||||
async def openai_upload_file(
|
||||
self,
|
||||
file: Annotated[UploadFile, File()],
|
||||
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||
) -> OpenAIFileObject:
|
||||
file_id = f"file-{uuid.uuid4().hex}"
|
||||
|
||||
filename = getattr(file, "filename", None) or "uploaded_file"
|
||||
|
||||
created_at = int(time.time())
|
||||
expires_at = created_at + self._SILLY_EXPIRATION_OFFSET
|
||||
content = await file.read()
|
||||
file_size = len(content)
|
||||
|
||||
await self.sql_store.insert(
|
||||
"openai_files",
|
||||
{
|
||||
"id": file_id,
|
||||
"filename": filename,
|
||||
"purpose": purpose.value,
|
||||
"bytes": file_size,
|
||||
"created_at": created_at,
|
||||
"expires_at": expires_at,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
self.client.put_object(
|
||||
Bucket=self._config.bucket_name,
|
||||
Key=file_id,
|
||||
Body=content,
|
||||
# TODO: enable server-side encryption
|
||||
)
|
||||
except ClientError as e:
|
||||
await self.sql_store.delete("openai_files", where={"id": file_id})
|
||||
|
||||
raise RuntimeError(f"Failed to upload file to S3: {e}") from e
|
||||
|
||||
return OpenAIFileObject(
|
||||
id=file_id,
|
||||
filename=filename,
|
||||
purpose=purpose,
|
||||
bytes=file_size,
|
||||
created_at=created_at,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
async def openai_list_files(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int | None = 10000,
|
||||
order: Order | None = Order.desc,
|
||||
purpose: OpenAIFilePurpose | None = None,
|
||||
) -> ListOpenAIFileResponse:
|
||||
# this purely defensive. it should not happen because the router also default to Order.desc.
|
||||
if not order:
|
||||
order = Order.desc
|
||||
|
||||
where_conditions = {}
|
||||
if purpose:
|
||||
where_conditions["purpose"] = purpose.value
|
||||
|
||||
paginated_result = await self.sql_store.fetch_all(
|
||||
table="openai_files",
|
||||
where=where_conditions if where_conditions else None,
|
||||
order_by=[("created_at", order.value)],
|
||||
cursor=("id", after) if after else None,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
files = [
|
||||
OpenAIFileObject(
|
||||
id=row["id"],
|
||||
filename=row["filename"],
|
||||
purpose=OpenAIFilePurpose(row["purpose"]),
|
||||
bytes=row["bytes"],
|
||||
created_at=row["created_at"],
|
||||
expires_at=row["expires_at"],
|
||||
)
|
||||
for row in paginated_result.data
|
||||
]
|
||||
|
||||
return ListOpenAIFileResponse(
|
||||
data=files,
|
||||
has_more=paginated_result.has_more,
|
||||
# empty string or None? spec says str, ref impl returns str | None, we go with spec
|
||||
first_id=files[0].id if files else "",
|
||||
last_id=files[-1].id if files else "",
|
||||
)
|
||||
|
||||
async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject:
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
if not row:
|
||||
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||
|
||||
return OpenAIFileObject(
|
||||
id=row["id"],
|
||||
filename=row["filename"],
|
||||
purpose=OpenAIFilePurpose(row["purpose"]),
|
||||
bytes=row["bytes"],
|
||||
created_at=row["created_at"],
|
||||
expires_at=row["expires_at"],
|
||||
)
|
||||
|
||||
async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse:
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
if not row:
|
||||
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||
|
||||
try:
|
||||
self.client.delete_object(
|
||||
Bucket=self._config.bucket_name,
|
||||
Key=row["id"],
|
||||
)
|
||||
except ClientError as e:
|
||||
if e.response["Error"]["Code"] != "NoSuchKey":
|
||||
raise RuntimeError(f"Failed to delete file from S3: {e}") from e
|
||||
|
||||
await self.sql_store.delete("openai_files", where={"id": file_id})
|
||||
|
||||
return OpenAIFileDeleteResponse(id=file_id, deleted=True)
|
||||
|
||||
async def openai_retrieve_file_content(self, file_id: str) -> Response:
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
if not row:
|
||||
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||
|
||||
try:
|
||||
response = self.client.get_object(
|
||||
Bucket=self._config.bucket_name,
|
||||
Key=row["id"],
|
||||
)
|
||||
# TODO: can we stream this instead of loading it into memory
|
||||
content = response["Body"].read()
|
||||
except ClientError as e:
|
||||
if e.response["Error"]["Code"] == "NoSuchKey":
|
||||
await self.sql_store.delete("openai_files", where={"id": file_id})
|
||||
raise ResourceNotFoundError(file_id, "File", "files.list()") from e
|
||||
raise RuntimeError(f"Failed to download file from S3: {e}") from e
|
||||
|
||||
return Response(
|
||||
content=content,
|
||||
media_type="application/octet-stream",
|
||||
headers={"Content-Disposition": f'attachment; filename="{row["filename"]}"'},
|
||||
)
|
|
@ -65,7 +65,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
from .config import FireworksImplConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
logger = get_logger(name=__name__, category="inference::fireworks")
|
||||
|
||||
|
||||
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
|
|
|
@ -3,15 +3,14 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import logging
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="inference::llama_openai_compat")
|
||||
|
||||
|
||||
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
|
|
|
@ -41,6 +41,11 @@ client.initialize()
|
|||
|
||||
### Create Completion
|
||||
|
||||
> Note on Completion API
|
||||
>
|
||||
> The hosted NVIDIA Llama NIMs (e.g., `meta-llama/Llama-3.1-8B-Instruct`) with ```NVIDIA_BASE_URL="https://integrate.api.nvidia.com"``` does not support the ```completion``` method, while the locally deployed NIM does.
|
||||
|
||||
|
||||
```python
|
||||
response = client.inference.completion(
|
||||
model_id="meta-llama/Llama-3.1-8B-Instruct",
|
||||
|
@ -76,7 +81,78 @@ response = client.inference.chat_completion(
|
|||
print(f"Response: {response.completion_message.content}")
|
||||
```
|
||||
|
||||
### Tool Calling Example ###
|
||||
```python
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||
|
||||
tool_definition = ToolDefinition(
|
||||
tool_name="get_weather",
|
||||
description="Get current weather information for a location",
|
||||
parameters={
|
||||
"location": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="The city and state, e.g. San Francisco, CA",
|
||||
required=True,
|
||||
),
|
||||
"unit": ToolParamDefinition(
|
||||
param_type="string",
|
||||
description="Temperature unit (celsius or fahrenheit)",
|
||||
required=False,
|
||||
default="celsius",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
tool_response = client.inference.chat_completion(
|
||||
model_id="meta-llama/Llama-3.1-8B-Instruct",
|
||||
messages=[{"role": "user", "content": "What's the weather like in San Francisco?"}],
|
||||
tools=[tool_definition],
|
||||
)
|
||||
|
||||
print(f"Tool Response: {tool_response.completion_message.content}")
|
||||
if tool_response.completion_message.tool_calls:
|
||||
for tool_call in tool_response.completion_message.tool_calls:
|
||||
print(f"Tool Called: {tool_call.tool_name}")
|
||||
print(f"Arguments: {tool_call.arguments}")
|
||||
```
|
||||
|
||||
### Structured Output Example
|
||||
```python
|
||||
from llama_stack.apis.inference import JsonSchemaResponseFormat, ResponseFormatType
|
||||
|
||||
person_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"},
|
||||
"occupation": {"type": "string"},
|
||||
},
|
||||
"required": ["name", "age", "occupation"],
|
||||
}
|
||||
|
||||
response_format = JsonSchemaResponseFormat(
|
||||
type=ResponseFormatType.json_schema, json_schema=person_schema
|
||||
)
|
||||
|
||||
structured_response = client.inference.chat_completion(
|
||||
model_id="meta-llama/Llama-3.1-8B-Instruct",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Create a profile for a fictional person named Alice who is 30 years old and is a software engineer. ",
|
||||
}
|
||||
],
|
||||
response_format=response_format,
|
||||
)
|
||||
|
||||
print(f"Structured Response: {structured_response.completion_message.content}")
|
||||
```
|
||||
|
||||
### Create Embeddings
|
||||
> Note on OpenAI embeddings compatibility
|
||||
>
|
||||
> NVIDIA asymmetric embedding models (e.g., `nvidia/llama-3.2-nv-embedqa-1b-v2`) require an `input_type` parameter not present in the standard OpenAI embeddings API. The NVIDIA Inference Adapter automatically sets `input_type="query"` when using the OpenAI-compatible embeddings endpoint for NVIDIA. For passage embeddings, use the `embeddings` API with `task_type="document"`.
|
||||
|
||||
```python
|
||||
response = client.inference.embeddings(
|
||||
model_id="nvidia/llama-3.2-nv-embedqa-1b-v2",
|
||||
|
|
|
@ -4,11 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from openai import APIConnectionError, BadRequestError
|
||||
from openai import NOT_GIVEN, APIConnectionError
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
|
@ -27,12 +26,16 @@ from llama_stack.apis.inference import (
|
|||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
|
@ -54,7 +57,7 @@ from .openai_utils import (
|
|||
)
|
||||
from .utils import _is_nvidia_hosted
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="inference::nvidia")
|
||||
|
||||
|
||||
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
||||
|
@ -194,15 +197,11 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
|||
}
|
||||
extra_body["input_type"] = task_type_options[task_type]
|
||||
|
||||
try:
|
||||
response = await self.client.embeddings.create(
|
||||
model=provider_model_id,
|
||||
input=input,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
except BadRequestError as e:
|
||||
raise ValueError(f"Failed to get embeddings: {e}") from e
|
||||
|
||||
response = await self.client.embeddings.create(
|
||||
model=provider_model_id,
|
||||
input=input,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
#
|
||||
# OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=list[float], ...)], ...)
|
||||
# ->
|
||||
|
@ -210,6 +209,57 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
|||
#
|
||||
return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data])
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
"""
|
||||
OpenAI-compatible embeddings for NVIDIA NIM.
|
||||
|
||||
Note: NVIDIA NIM asymmetric embedding models require an "input_type" field not present in the standard OpenAI embeddings API.
|
||||
We default this to "query" to ensure requests succeed when using the
|
||||
OpenAI-compatible endpoint. For passage embeddings, use the embeddings API with
|
||||
`task_type='document'`.
|
||||
"""
|
||||
extra_body: dict[str, object] = {"input_type": "query"}
|
||||
logger.warning(
|
||||
"NVIDIA OpenAI-compatible embeddings: defaulting to input_type='query'. "
|
||||
"For passage embeddings, use the embeddings API with task_type='document'."
|
||||
)
|
||||
|
||||
response = await self.client.embeddings.create(
|
||||
model=await self._get_provider_model_id(model),
|
||||
input=input,
|
||||
encoding_format=encoding_format if encoding_format is not None else NOT_GIVEN,
|
||||
dimensions=dimensions if dimensions is not None else NOT_GIVEN,
|
||||
user=user if user is not None else NOT_GIVEN,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
data = []
|
||||
for i, embedding_data in enumerate(response.data):
|
||||
data.append(
|
||||
OpenAIEmbeddingData(
|
||||
embedding=embedding_data.embedding,
|
||||
index=i,
|
||||
)
|
||||
)
|
||||
|
||||
usage = OpenAIEmbeddingUsage(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingsResponse(
|
||||
data=data,
|
||||
model=response.model,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
|
|
@ -4,13 +4,13 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from . import NVIDIAConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="inference::nvidia")
|
||||
|
||||
|
||||
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
|
||||
|
|
|
@ -85,7 +85,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
logger = get_logger(name=__name__, category="inference::ollama")
|
||||
|
||||
|
||||
class OllamaInferenceAdapter(
|
||||
|
@ -619,28 +619,6 @@ class OllamaInferenceAdapter(
|
|||
response.id = id
|
||||
return response
|
||||
|
||||
async def batch_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content_batch: list[InterleavedContent],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
):
|
||||
raise NotImplementedError("Batch completion is not supported for Ollama")
|
||||
|
||||
async def batch_chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages_batch: list[list[Message]],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
):
|
||||
raise NotImplementedError("Batch chat completion is not supported for Ollama")
|
||||
|
||||
|
||||
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
|
||||
async def _convert_content(content) -> dict:
|
||||
|
|
|
@ -4,15 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import OpenAIConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="inference::openai")
|
||||
|
||||
|
||||
#
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||
|
@ -34,6 +33,7 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
|
@ -58,7 +58,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="inference::tgi")
|
||||
|
||||
|
||||
def build_hf_repo_model_entries():
|
||||
|
|
|
@ -61,7 +61,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
from .config import TogetherImplConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
logger = get_logger(name=__name__, category="inference::together")
|
||||
|
||||
|
||||
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
|
|
|
@ -85,7 +85,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .config import VLLMInferenceAdapterConfig
|
||||
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
log = get_logger(name=__name__, category="inference::vllm")
|
||||
|
||||
|
||||
def build_hf_repo_model_entries():
|
||||
|
@ -711,25 +711,3 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
user=user,
|
||||
)
|
||||
return await self.client.chat.completions.create(**params) # type: ignore
|
||||
|
||||
async def batch_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content_batch: list[InterleavedContent],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
):
|
||||
raise NotImplementedError("Batch completion is not supported for Ollama")
|
||||
|
||||
async def batch_chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages_batch: list[list[Message]],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
):
|
||||
raise NotImplementedError("Batch chat completion is not supported for Ollama")
|
||||
|
|
|
@ -4,18 +4,18 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.post_training import TrainingConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.remote.post_training.nvidia.config import SFTLoRADefaultConfig
|
||||
|
||||
from .config import NvidiaPostTrainingConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="post_training::nvidia")
|
||||
|
||||
|
||||
def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None:
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
|
@ -16,12 +15,13 @@ from llama_stack.apis.safety import (
|
|||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||
|
||||
from .config import BedrockSafetyConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="safety::bedrock")
|
||||
|
||||
|
||||
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||
|
|
|
@ -4,20 +4,20 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import RunShieldResponse, Safety, SafetyViolation, ViolationLevel
|
||||
from llama_stack.apis.safety import ModerationObject, RunShieldResponse, Safety, SafetyViolation, ViolationLevel
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
|
||||
|
||||
from .config import NVIDIASafetyConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="safety::nvidia")
|
||||
|
||||
|
||||
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||
|
@ -67,6 +67,9 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|||
self.shield = NeMoGuardrails(self.config, shield.shield_id)
|
||||
return await self.shield.run(messages)
|
||||
|
||||
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
||||
raise NotImplementedError("NVIDIA safety provider currently does not implement run_moderation")
|
||||
|
||||
|
||||
class NeMoGuardrails:
|
||||
"""
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
|
@ -20,12 +19,13 @@ from llama_stack.apis.safety import (
|
|||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
|
||||
|
||||
from .config import SambaNovaSafetyConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="safety::sambanova")
|
||||
|
||||
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
@ -20,6 +19,7 @@ from llama_stack.apis.vector_io import (
|
|||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
@ -33,7 +33,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
|
||||
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="vector_io::chroma")
|
||||
|
||||
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
|
@ -21,6 +20,7 @@ from llama_stack.apis.vector_io import (
|
|||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
@ -36,7 +36,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
|
|||
|
||||
from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="vector_io::milvus")
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::"
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import psycopg2
|
||||
|
@ -22,6 +21,7 @@ from llama_stack.apis.vector_io import (
|
|||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
|
@ -34,7 +34,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
|
||||
from .config import PGVectorVectorIOConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="vector_io::pgvector")
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::"
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
|
@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreChunkingStrategy,
|
||||
VectorStoreFileObject,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||
|
@ -36,7 +36,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
|
||||
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="vector_io::qdrant")
|
||||
CHUNK_ID_KEY = "_chunk_id"
|
||||
|
||||
# KV store prefixes for vector databases
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import weaviate
|
||||
|
@ -19,6 +18,7 @@ from llama_stack.apis.files.files import Files
|
|||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
|
@ -34,7 +34,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
|
|||
|
||||
from .config import WeaviateVectorIOConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="vector_io::weaviate")
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"
|
||||
|
|
|
@ -5,10 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import struct
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
|
@ -27,7 +28,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import interleaved_con
|
|||
EMBEDDING_MODELS = {}
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
class SentenceTransformerEmbeddingMixin:
|
||||
|
|
|
@ -54,7 +54,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
class LiteLLMOpenAIMixin(
|
||||
|
@ -429,28 +429,6 @@ class LiteLLMOpenAIMixin(
|
|||
)
|
||||
return await litellm.acompletion(**params)
|
||||
|
||||
async def batch_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content_batch: list[InterleavedContent],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
):
|
||||
raise NotImplementedError("Batch completion is not supported for OpenAI Compat")
|
||||
|
||||
async def batch_chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages_batch: list[list[Message]],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
):
|
||||
raise NotImplementedError("Batch chat completion is not supported for OpenAI Compat")
|
||||
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
"""
|
||||
Check if a specific model is available via LiteLLM for the current
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue