feat: add auto-generated CI documentation pre-commit hook (#2890)

Our CI is entirely undocumented, this commit adds a README.md file with
a table of the current CI and what is does

---------

Signed-off-by: Nathan Weinberg <nweinber@redhat.com>
This commit is contained in:
Nathan Weinberg 2025-07-25 11:57:01 -04:00 committed by Mustafa Elbehery
parent 7f834339ba
commit b381ed6d64
93 changed files with 495 additions and 477 deletions

View file

@ -145,6 +145,24 @@ repos:
pass_filenames: false
require_serial: true
files: ^.github/workflows/.*$
- id: check-log-usage
name: Check for proper log usage (use llama_stack.log instead)
entry: bash
language: system
types: [python]
pass_filenames: true
args:
- -c
- |
matches=$(grep -EnH '^[^#]*\b(import logging|from logging\b)' "$@" | grep -v '# allow-direct-logging' || true)
if [ -n "$matches" ]; then
# GitHub Actions annotation format
while IFS=: read -r file line_num rest; do
echo "::error file=$file,line=$line_num::Do not use 'import logging' or 'from logging import' in $file. Use the custom log instead: from llama_stack.log import get_logger; logger = get_logger(). If direct logging is truly needed, add: # allow-direct-logging"
done <<< "$matches"
exit 1
fi
exit 0
ci:
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks

View file

@ -15,7 +15,7 @@ from llama_stack.log import get_logger
REPO_ROOT = Path(__file__).parent.parent.parent.parent
logger = get_logger(name=__name__, category="server")
log = get_logger(name=__name__, category="server")
class StackRun(Subcommand):
@ -126,7 +126,7 @@ class StackRun(Subcommand):
self.parser.error("Config file is required for venv environment")
if config_file:
logger.info(f"Using run configuration: {config_file}")
log.info(f"Using run configuration: {config_file}")
try:
config_dict = yaml.safe_load(config_file.read_text())
@ -145,7 +145,7 @@ class StackRun(Subcommand):
# If neither image type nor image name is provided, assume the server should be run directly
# using the current environment packages.
if not image_type and not image_name:
logger.info("No image type or image name provided. Assuming environment packages.")
log.info("No image type or image name provided. Assuming environment packages.")
from llama_stack.core.server.server import main as server_main
# Build the server args from the current args passed to the CLI
@ -185,11 +185,11 @@ class StackRun(Subcommand):
run_command(run_args)
def _start_ui_development_server(self, stack_server_port: int):
logger.info("Attempting to start UI development server...")
log.info("Attempting to start UI development server...")
# Check if npm is available
npm_check = subprocess.run(["npm", "--version"], capture_output=True, text=True, check=False)
if npm_check.returncode != 0:
logger.warning(
log.warning(
f"'npm' command not found or not executable. UI development server will not be started. Error: {npm_check.stderr}"
)
return
@ -214,13 +214,13 @@ class StackRun(Subcommand):
stderr=stderr_log_file,
env={**os.environ, "NEXT_PUBLIC_LLAMA_STACK_BASE_URL": f"http://localhost:{stack_server_port}"},
)
logger.info(f"UI development server process started in {ui_dir} with PID {process.pid}.")
logger.info(f"Logs: stdout -> {ui_stdout_log_path}, stderr -> {ui_stderr_log_path}")
logger.info(f"UI will be available at http://localhost:{os.getenv('LLAMA_STACK_UI_PORT', 8322)}")
log.info(f"UI development server process started in {ui_dir} with PID {process.pid}.")
log.info(f"Logs: stdout -> {ui_stdout_log_path}, stderr -> {ui_stderr_log_path}")
log.info(f"UI will be available at http://localhost:{os.getenv('LLAMA_STACK_UI_PORT', 8322)}")
except FileNotFoundError:
logger.error(
log.error(
"Failed to start UI development server: 'npm' command not found. Make sure npm is installed and in your PATH."
)
except Exception as e:
logger.error(f"Failed to start UI development server in {ui_dir}: {e}")
log.error(f"Failed to start UI development server in {ui_dir}: {e}")

View file

@ -8,7 +8,7 @@ import argparse
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="cli")
log = get_logger(name=__name__, category="cli")
# TODO: this can probably just be inlined now?

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import importlib.resources
import logging
import sys
from pydantic import BaseModel
@ -17,10 +16,9 @@ from llama_stack.core.external import load_external_apis
from llama_stack.core.utils.exec import run_command
from llama_stack.core.utils.image_types import LlamaStackImageType
from llama_stack.distributions.template import DistributionTemplate
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__)
# These are the dependencies needed by the distribution server.
# `llama-stack` is automatically installed by the installation script.
SERVER_DEPENDENCIES = [
@ -33,6 +31,8 @@ SERVER_DEPENDENCIES = [
"opentelemetry-exporter-otlp-proto-http",
]
log = get_logger(name=__name__, category="core")
class ApiInput(BaseModel):
api: Api

View file

@ -3,7 +3,6 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import textwrap
from typing import Any
@ -21,9 +20,10 @@ from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.core.utils.prompt_for_config import prompt_for_config
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, ProviderSpec
logger = logging.getLogger(__name__)
log = get_logger(name=__name__, category="core")
def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provider) -> Provider:
@ -49,7 +49,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
is_nux = len(config.providers) == 0
if is_nux:
logger.info(
log.info(
textwrap.dedent(
"""
Llama Stack is composed of several APIs working together. For each API served by the Stack,
@ -75,12 +75,12 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
existing_providers = config.providers.get(api_str, [])
if existing_providers:
logger.info(f"Re-configuring existing providers for API `{api_str}`...")
log.info(f"Re-configuring existing providers for API `{api_str}`...")
updated_providers = []
for p in existing_providers:
logger.info(f"> Configuring provider `({p.provider_type})`")
log.info(f"> Configuring provider `({p.provider_type})`")
updated_providers.append(configure_single_provider(provider_registry[api], p))
logger.info("")
log.info("")
else:
# we are newly configuring this API
plist = build_spec.providers.get(api_str, [])
@ -89,17 +89,17 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
if not plist:
raise ValueError(f"No provider configured for API {api_str}?")
logger.info(f"Configuring API `{api_str}`...")
log.info(f"Configuring API `{api_str}`...")
updated_providers = []
for i, provider in enumerate(plist):
if i >= 1:
others = ", ".join(p.provider_type for p in plist[i:])
logger.info(
log.info(
f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n"
)
break
logger.info(f"> Configuring provider `({provider.provider_type})`")
log.info(f"> Configuring provider `({provider.provider_type})`")
pid = provider.provider_type.split("::")[-1]
updated_providers.append(
configure_single_provider(
@ -111,7 +111,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
),
)
)
logger.info("")
log.info("")
config.providers[api_str] = updated_providers
@ -169,7 +169,7 @@ def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfi
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
if "routing_table" in config_dict:
logger.info("Upgrading config...")
log.info("Upgrading config...")
config_dict = upgrade_from_routing_table(config_dict)
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION

View file

@ -23,7 +23,7 @@ from llama_stack.providers.datatypes import (
remote_provider_spec,
)
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
def stack_apis() -> list[Api]:
@ -141,18 +141,18 @@ def get_provider_registry(config=None) -> dict[Api, dict[str, ProviderSpec]]:
registry: dict[Api, dict[str, ProviderSpec]] = {}
for api in providable_apis():
name = api.name.lower()
logger.debug(f"Importing module {name}")
log.debug(f"Importing module {name}")
try:
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
registry[api] = {a.provider_type: a for a in module.available_providers()}
except ImportError as e:
logger.warning(f"Failed to import module {name}: {e}")
log.warning(f"Failed to import module {name}: {e}")
# Refresh providable APIs with external APIs if any
external_apis = load_external_apis(config)
for api, api_spec in external_apis.items():
name = api_spec.name.lower()
logger.info(f"Importing external API {name} module {api_spec.module}")
log.info(f"Importing external API {name} module {api_spec.module}")
try:
module = importlib.import_module(api_spec.module)
registry[api] = {a.provider_type: a for a in module.available_providers()}
@ -161,7 +161,7 @@ def get_provider_registry(config=None) -> dict[Api, dict[str, ProviderSpec]]:
# This assume that the in-tree provider(s) are not available for this API which means
# that users will need to use external providers for this API.
registry[api] = {}
logger.error(
log.error(
f"Failed to import external API {name}: {e}. Could not populate the in-tree provider(s) registry for {api.name}. \n"
"Install the API package to load any in-tree providers for this API."
)
@ -183,13 +183,13 @@ def get_provider_registry(config=None) -> dict[Api, dict[str, ProviderSpec]]:
def get_external_providers_from_dir(
registry: dict[Api, dict[str, ProviderSpec]], config
) -> dict[Api, dict[str, ProviderSpec]]:
logger.warning(
log.warning(
"Specifying external providers via `external_providers_dir` is being deprecated. Please specify `module:` in the provider instead."
)
external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir))
if not os.path.exists(external_providers_dir):
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
logger.info(f"Loading external providers from {external_providers_dir}")
log.info(f"Loading external providers from {external_providers_dir}")
for api in providable_apis():
api_name = api.name.lower()
@ -198,13 +198,13 @@ def get_external_providers_from_dir(
for provider_type in ["remote", "inline"]:
api_dir = os.path.join(external_providers_dir, provider_type, api_name)
if not os.path.exists(api_dir):
logger.debug(f"No {provider_type} provider directory found for {api_name}")
log.debug(f"No {provider_type} provider directory found for {api_name}")
continue
# Look for provider spec files in the API directory
for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")):
provider_name = os.path.splitext(os.path.basename(spec_path))[0]
logger.info(f"Loading {provider_type} provider spec from {spec_path}")
log.info(f"Loading {provider_type} provider spec from {spec_path}")
try:
with open(spec_path) as f:
@ -217,16 +217,16 @@ def get_external_providers_from_dir(
spec = _load_inline_provider_spec(spec_data, api, provider_name)
provider_type_key = f"inline::{provider_name}"
logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
log.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
if provider_type_key in registry[api]:
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
log.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
registry[api][provider_type_key] = spec
logger.info(f"Successfully loaded external provider {provider_type_key}")
log.info(f"Successfully loaded external provider {provider_type_key}")
except yaml.YAMLError as yaml_err:
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
log.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
raise yaml_err
except Exception as e:
logger.error(f"Failed to load provider spec from {spec_path}: {e}")
log.error(f"Failed to load provider spec from {spec_path}: {e}")
raise e
return registry
@ -241,7 +241,7 @@ def get_external_providers_from_module(
else:
provider_list = config.providers.items()
if provider_list is None:
logger.warning("Could not get list of providers from config")
log.warning("Could not get list of providers from config")
return registry
for provider_api, providers in provider_list:
for provider in providers:
@ -272,6 +272,6 @@ def get_external_providers_from_module(
"get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available"
) from exc
except Exception as e:
logger.error(f"Failed to load provider spec from module {provider.module}: {e}")
log.error(f"Failed to load provider spec from module {provider.module}: {e}")
raise e
return registry

View file

@ -11,7 +11,7 @@ from llama_stack.apis.datatypes import Api, ExternalApiSpec
from llama_stack.core.datatypes import BuildConfig, StackRunConfig
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
def load_external_apis(config: StackRunConfig | BuildConfig | None) -> dict[Api, ExternalApiSpec]:
@ -28,10 +28,10 @@ def load_external_apis(config: StackRunConfig | BuildConfig | None) -> dict[Api,
external_apis_dir = config.external_apis_dir.expanduser().resolve()
if not external_apis_dir.is_dir():
logger.error(f"External APIs directory is not a directory: {external_apis_dir}")
log.error(f"External APIs directory is not a directory: {external_apis_dir}")
return {}
logger.info(f"Loading external APIs from {external_apis_dir}")
log.info(f"Loading external APIs from {external_apis_dir}")
external_apis: dict[Api, ExternalApiSpec] = {}
# Look for YAML files in the external APIs directory
@ -42,13 +42,13 @@ def load_external_apis(config: StackRunConfig | BuildConfig | None) -> dict[Api,
spec = ExternalApiSpec(**spec_data)
api = Api.add(spec.name)
logger.info(f"Loaded external API spec for {spec.name} from {yaml_path}")
log.info(f"Loaded external API spec for {spec.name} from {yaml_path}")
external_apis[api] = spec
except yaml.YAMLError as yaml_err:
logger.error(f"Failed to parse YAML file {yaml_path}: {yaml_err}")
log.error(f"Failed to parse YAML file {yaml_path}: {yaml_err}")
raise
except Exception:
logger.exception(f"Failed to load external API spec from {yaml_path}")
log.exception(f"Failed to load external API spec from {yaml_path}")
raise
return external_apis

View file

@ -7,7 +7,6 @@
import asyncio
import inspect
import json
import logging
import os
import sys
from concurrent.futures import ThreadPoolExecutor
@ -48,6 +47,7 @@ from llama_stack.core.stack import (
from llama_stack.core.utils.config import redact_sensitive_fields
from llama_stack.core.utils.context import preserve_contexts_async_generator
from llama_stack.core.utils.exec import in_notebook
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry.tracing import (
CURRENT_TRACE_CONTEXT,
end_trace,
@ -55,7 +55,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
start_trace,
)
logger = logging.getLogger(__name__)
log = get_logger(name=__name__, category="core")
T = TypeVar("T")
@ -84,7 +84,7 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any:
try:
return [convert_to_pydantic(item_type, item) for item in value]
except Exception:
logger.error(f"Error converting list {value} into {item_type}")
log.error(f"Error converting list {value} into {item_type}")
return value
elif origin is dict:
@ -92,7 +92,7 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any:
try:
return {k: convert_to_pydantic(val_type, v) for k, v in value.items()}
except Exception:
logger.error(f"Error converting dict {value} into {val_type}")
log.error(f"Error converting dict {value} into {val_type}")
return value
try:
@ -108,7 +108,7 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any:
return convert_to_pydantic(union_type, value)
except Exception:
continue
logger.warning(
log.warning(
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
)
raise ValueError(f"Failed to convert parameter {value} into {annotation}: {e}") from e
@ -171,13 +171,15 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
def _remove_root_logger_handlers(self):
"""
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
Remove all handlers from the root log. Needed to avoid polluting the console with logs.
"""
import logging # allow-direct-logging
root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
log.info(f"Removed handler {handler.__class__.__name__} from root log")
def request(self, *args, **kwargs):
loop = self.loop

View file

@ -16,7 +16,7 @@ from llama_stack.providers.datatypes import HealthResponse, HealthStatus
from .datatypes import StackRunConfig
from .utils.config import redact_sensitive_fields
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
class ProviderImplConfig(BaseModel):
@ -38,7 +38,7 @@ class ProviderImpl(Providers):
pass
async def shutdown(self) -> None:
logger.debug("ProviderImpl.shutdown")
log.debug("ProviderImpl.shutdown")
pass
async def list_providers(self) -> ListProvidersResponse:

View file

@ -6,19 +6,19 @@
import contextvars
import json
import logging
from contextlib import AbstractContextManager
from typing import Any
from llama_stack.core.datatypes import User
from llama_stack.log import get_logger
from .utils.dynamic import instantiate_class_type
log = logging.getLogger(__name__)
# Context variable for request provider data and auth attributes
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
log = get_logger(name=__name__, category="core")
class RequestProviderDataContext(AbstractContextManager):
"""Context manager for request provider data"""

View file

@ -54,7 +54,7 @@ from llama_stack.providers.datatypes import (
VectorDBsProtocolPrivate,
)
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
class InvalidProviderError(Exception):
@ -101,7 +101,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
protocols[api] = api_class
except (ImportError, AttributeError):
logger.exception(f"Failed to load external API {api_spec.name}")
log.exception(f"Failed to load external API {api_spec.name}")
return protocols
@ -223,7 +223,7 @@ def validate_and_prepare_providers(
specs = {}
for provider in providers:
if not provider.provider_id or provider.provider_id == "__disabled__":
logger.debug(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
log.debug(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
continue
validate_provider(provider, api, provider_registry)
@ -245,10 +245,10 @@ def validate_provider(provider: Provider, api: Api, provider_registry: ProviderR
p = provider_registry[api][provider.provider_type]
if p.deprecation_error:
logger.error(p.deprecation_error)
log.error(p.deprecation_error)
raise InvalidProviderError(p.deprecation_error)
elif p.deprecation_warning:
logger.warning(
log.warning(
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
)
@ -261,9 +261,9 @@ def sort_providers_by_deps(
{k: list(v.values()) for k, v in providers_with_specs.items()}
)
logger.debug(f"Resolved {len(sorted_providers)} providers")
log.debug(f"Resolved {len(sorted_providers)} providers")
for api_str, provider in sorted_providers:
logger.debug(f" {api_str} => {provider.provider_id}")
log.debug(f" {api_str} => {provider.provider_id}")
return sorted_providers
@ -348,7 +348,7 @@ async def instantiate_provider(
if not hasattr(provider_spec, "module") or provider_spec.module is None:
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
logger.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}")
log.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}")
module = importlib.import_module(provider_spec.module)
args = []
if isinstance(provider_spec, RemoteProviderSpec):
@ -418,7 +418,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
obj_params = set(obj_sig.parameters)
obj_params.discard("self")
if not (proto_params <= obj_params):
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
log.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
missing_methods.append((name, "signature_mismatch"))
else:
# Check if the method has a concrete implementation (not just a protocol stub)

View file

@ -12,7 +12,7 @@ from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
class DatasetIORouter(DatasetIO):
@ -20,15 +20,15 @@ class DatasetIORouter(DatasetIO):
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing DatasetIORouter")
log.debug("Initializing DatasetIORouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("DatasetIORouter.initialize")
log.debug("DatasetIORouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("DatasetIORouter.shutdown")
log.debug("DatasetIORouter.shutdown")
pass
async def register_dataset(
@ -38,7 +38,7 @@ class DatasetIORouter(DatasetIO):
metadata: dict[str, Any] | None = None,
dataset_id: str | None = None,
) -> None:
logger.debug(
log.debug(
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
)
await self.routing_table.register_dataset(
@ -54,7 +54,7 @@ class DatasetIORouter(DatasetIO):
start_index: int | None = None,
limit: int | None = None,
) -> PaginatedResponse:
logger.debug(
log.debug(
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
)
provider = await self.routing_table.get_provider_impl(dataset_id)
@ -65,7 +65,7 @@ class DatasetIORouter(DatasetIO):
)
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
log.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
provider = await self.routing_table.get_provider_impl(dataset_id)
return await provider.append_rows(
dataset_id=dataset_id,

View file

@ -16,7 +16,7 @@ from llama_stack.apis.scoring import (
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
class ScoringRouter(Scoring):
@ -24,15 +24,15 @@ class ScoringRouter(Scoring):
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing ScoringRouter")
log.debug("Initializing ScoringRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("ScoringRouter.initialize")
log.debug("ScoringRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("ScoringRouter.shutdown")
log.debug("ScoringRouter.shutdown")
pass
async def score_batch(
@ -41,7 +41,7 @@ class ScoringRouter(Scoring):
scoring_functions: dict[str, ScoringFnParams | None] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
log.debug(f"ScoringRouter.score_batch: {dataset_id}")
res = {}
for fn_identifier in scoring_functions.keys():
provider = await self.routing_table.get_provider_impl(fn_identifier)
@ -63,7 +63,7 @@ class ScoringRouter(Scoring):
input_rows: list[dict[str, Any]],
scoring_functions: dict[str, ScoringFnParams | None] = None,
) -> ScoreResponse:
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
log.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
res = {}
# look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys():
@ -82,15 +82,15 @@ class EvalRouter(Eval):
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing EvalRouter")
log.debug("Initializing EvalRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("EvalRouter.initialize")
log.debug("EvalRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("EvalRouter.shutdown")
log.debug("EvalRouter.shutdown")
pass
async def run_eval(
@ -98,7 +98,7 @@ class EvalRouter(Eval):
benchmark_id: str,
benchmark_config: BenchmarkConfig,
) -> Job:
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
log.debug(f"EvalRouter.run_eval: {benchmark_id}")
provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.run_eval(
benchmark_id=benchmark_id,
@ -112,7 +112,7 @@ class EvalRouter(Eval):
scoring_functions: list[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
log.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.evaluate_rows(
benchmark_id=benchmark_id,
@ -126,7 +126,7 @@ class EvalRouter(Eval):
benchmark_id: str,
job_id: str,
) -> Job:
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
log.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.job_status(benchmark_id, job_id)
@ -135,7 +135,7 @@ class EvalRouter(Eval):
benchmark_id: str,
job_id: str,
) -> None:
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
log.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
provider = await self.routing_table.get_provider_impl(benchmark_id)
await provider.job_cancel(
benchmark_id,
@ -147,7 +147,7 @@ class EvalRouter(Eval):
benchmark_id: str,
job_id: str,
) -> EvaluateResponse:
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
log.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.job_result(
benchmark_id,

View file

@ -58,7 +58,7 @@ from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack.providers.utils.inference.stream_utils import stream_and_store_openai_completion
from llama_stack.providers.utils.telemetry.tracing import get_current_span
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
class InferenceRouter(Inference):
@ -70,7 +70,7 @@ class InferenceRouter(Inference):
telemetry: Telemetry | None = None,
store: InferenceStore | None = None,
) -> None:
logger.debug("Initializing InferenceRouter")
log.debug("Initializing InferenceRouter")
self.routing_table = routing_table
self.telemetry = telemetry
self.store = store
@ -79,10 +79,10 @@ class InferenceRouter(Inference):
self.formatter = ChatFormat(self.tokenizer)
async def initialize(self) -> None:
logger.debug("InferenceRouter.initialize")
log.debug("InferenceRouter.initialize")
async def shutdown(self) -> None:
logger.debug("InferenceRouter.shutdown")
log.debug("InferenceRouter.shutdown")
async def register_model(
self,
@ -92,7 +92,7 @@ class InferenceRouter(Inference):
metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None,
) -> None:
logger.debug(
log.debug(
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
)
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
@ -117,7 +117,7 @@ class InferenceRouter(Inference):
"""
span = get_current_span()
if span is None:
logger.warning("No span found for token usage metrics")
log.warning("No span found for token usage metrics")
return []
metrics = [
("prompt_tokens", prompt_tokens),
@ -182,7 +182,7 @@ class InferenceRouter(Inference):
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
logger.debug(
log.debug(
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
)
if sampling_params is None:
@ -288,7 +288,7 @@ class InferenceRouter(Inference):
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> BatchChatCompletionResponse:
logger.debug(
log.debug(
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
)
provider = await self.routing_table.get_provider_impl(model_id)
@ -313,7 +313,7 @@ class InferenceRouter(Inference):
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
logger.debug(
log.debug(
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
)
model = await self.routing_table.get_model(model_id)
@ -374,7 +374,7 @@ class InferenceRouter(Inference):
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> BatchCompletionResponse:
logger.debug(
log.debug(
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
)
provider = await self.routing_table.get_provider_impl(model_id)
@ -388,7 +388,7 @@ class InferenceRouter(Inference):
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
logger.debug(f"InferenceRouter.embeddings: {model_id}")
log.debug(f"InferenceRouter.embeddings: {model_id}")
model = await self.routing_table.get_model(model_id)
if model is None:
raise ModelNotFoundError(model_id)
@ -426,7 +426,7 @@ class InferenceRouter(Inference):
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
logger.debug(
log.debug(
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
)
model_obj = await self.routing_table.get_model(model)
@ -487,7 +487,7 @@ class InferenceRouter(Inference):
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
logger.debug(
log.debug(
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
)
model_obj = await self.routing_table.get_model(model)
@ -558,7 +558,7 @@ class InferenceRouter(Inference):
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
logger.debug(
log.debug(
f"InferenceRouter.openai_embeddings: {model=}, input_type={type(input)}, {encoding_format=}, {dimensions=}",
)
model_obj = await self.routing_table.get_model(model)

View file

@ -14,7 +14,7 @@ from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
class SafetyRouter(Safety):
@ -22,15 +22,15 @@ class SafetyRouter(Safety):
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing SafetyRouter")
log.debug("Initializing SafetyRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("SafetyRouter.initialize")
log.debug("SafetyRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("SafetyRouter.shutdown")
log.debug("SafetyRouter.shutdown")
pass
async def register_shield(
@ -40,7 +40,7 @@ class SafetyRouter(Safety):
provider_id: str | None = None,
params: dict[str, Any] | None = None,
) -> Shield:
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
log.debug(f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
async def unregister_shield(self, identifier: str) -> None:
@ -53,7 +53,7 @@ class SafetyRouter(Safety):
messages: list[Message],
params: dict[str, Any] = None,
) -> RunShieldResponse:
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
log.debug(f"SafetyRouter.run_shield: {shield_id}")
provider = await self.routing_table.get_provider_impl(shield_id)
return await provider.run_shield(
shield_id=shield_id,

View file

@ -22,7 +22,7 @@ from llama_stack.log import get_logger
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
class ToolRuntimeRouter(ToolRuntime):
@ -31,7 +31,7 @@ class ToolRuntimeRouter(ToolRuntime):
self,
routing_table: ToolGroupsRoutingTable,
) -> None:
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
log.debug("Initializing ToolRuntimeRouter.RagToolImpl")
self.routing_table = routing_table
async def query(
@ -40,7 +40,7 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_ids: list[str],
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
log.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
provider = await self.routing_table.get_provider_impl("knowledge_search")
return await provider.query(content, vector_db_ids, query_config)
@ -50,7 +50,7 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
logger.debug(
log.debug(
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
)
provider = await self.routing_table.get_provider_impl("insert_into_memory")
@ -60,7 +60,7 @@ class ToolRuntimeRouter(ToolRuntime):
self,
routing_table: ToolGroupsRoutingTable,
) -> None:
logger.debug("Initializing ToolRuntimeRouter")
log.debug("Initializing ToolRuntimeRouter")
self.routing_table = routing_table
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
@ -69,15 +69,15 @@ class ToolRuntimeRouter(ToolRuntime):
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
async def initialize(self) -> None:
logger.debug("ToolRuntimeRouter.initialize")
log.debug("ToolRuntimeRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("ToolRuntimeRouter.shutdown")
log.debug("ToolRuntimeRouter.shutdown")
pass
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
log.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
provider = await self.routing_table.get_provider_impl(tool_name)
return await provider.invoke_tool(
tool_name=tool_name,
@ -87,5 +87,5 @@ class ToolRuntimeRouter(ToolRuntime):
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolsResponse:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
log.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.list_tools(tool_group_id)

View file

@ -30,7 +30,7 @@ from llama_stack.apis.vector_io import (
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
class VectorIORouter(VectorIO):
@ -40,15 +40,15 @@ class VectorIORouter(VectorIO):
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing VectorIORouter")
log.debug("Initializing VectorIORouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("VectorIORouter.initialize")
log.debug("VectorIORouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("VectorIORouter.shutdown")
log.debug("VectorIORouter.shutdown")
pass
async def _get_first_embedding_model(self) -> tuple[str, int] | None:
@ -70,10 +70,10 @@ class VectorIORouter(VectorIO):
raise ValueError(f"Embedding model {embedding_models[0].identifier} has no embedding dimension")
return embedding_models[0].identifier, dimension
else:
logger.warning("No embedding models found in the routing table")
log.warning("No embedding models found in the routing table")
return None
except Exception as e:
logger.error(f"Error getting embedding models: {e}")
log.error(f"Error getting embedding models: {e}")
return None
async def register_vector_db(
@ -85,7 +85,7 @@ class VectorIORouter(VectorIO):
vector_db_name: str | None = None,
provider_vector_db_id: str | None = None,
) -> None:
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
log.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
await self.routing_table.register_vector_db(
vector_db_id,
embedding_model,
@ -101,7 +101,7 @@ class VectorIORouter(VectorIO):
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
logger.debug(
log.debug(
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
)
provider = await self.routing_table.get_provider_impl(vector_db_id)
@ -113,7 +113,7 @@ class VectorIORouter(VectorIO):
query: InterleavedContent,
params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
log.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
provider = await self.routing_table.get_provider_impl(vector_db_id)
return await provider.query_chunks(vector_db_id, query, params)
@ -129,7 +129,7 @@ class VectorIORouter(VectorIO):
embedding_dimension: int | None = None,
provider_id: str | None = None,
) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}")
log.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}")
# If no embedding model is provided, use the first available one
if embedding_model is None:
@ -137,7 +137,7 @@ class VectorIORouter(VectorIO):
if embedding_model_info is None:
raise ValueError("No embedding model provided and no embedding models available in the system")
embedding_model, embedding_dimension = embedding_model_info
logger.info(f"No embedding model specified, using first available: {embedding_model}")
log.info(f"No embedding model specified, using first available: {embedding_model}")
vector_db_id = f"vs_{uuid.uuid4()}"
registered_vector_db = await self.routing_table.register_vector_db(
@ -168,7 +168,7 @@ class VectorIORouter(VectorIO):
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
logger.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}")
log.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}")
# Route to default provider for now - could aggregate from all providers in the future
# call retrieve on each vector dbs to get list of vector stores
vector_dbs = await self.routing_table.get_all_with_type("vector_db")
@ -179,7 +179,7 @@ class VectorIORouter(VectorIO):
vector_store = await provider.openai_retrieve_vector_store(vector_db.identifier)
all_stores.append(vector_store)
except Exception as e:
logger.error(f"Error retrieving vector store {vector_db.identifier}: {e}")
log.error(f"Error retrieving vector store {vector_db.identifier}: {e}")
continue
# Sort by created_at
@ -215,7 +215,7 @@ class VectorIORouter(VectorIO):
self,
vector_store_id: str,
) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}")
log.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}")
return await self.routing_table.openai_retrieve_vector_store(vector_store_id)
async def openai_update_vector_store(
@ -225,7 +225,7 @@ class VectorIORouter(VectorIO):
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}")
log.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}")
return await self.routing_table.openai_update_vector_store(
vector_store_id=vector_store_id,
name=name,
@ -237,7 +237,7 @@ class VectorIORouter(VectorIO):
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
log.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
return await self.routing_table.openai_delete_vector_store(vector_store_id)
async def openai_search_vector_store(
@ -250,7 +250,7 @@ class VectorIORouter(VectorIO):
rewrite_query: bool | None = False,
search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage:
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
log.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
return await self.routing_table.openai_search_vector_store(
vector_store_id=vector_store_id,
query=query,
@ -268,7 +268,7 @@ class VectorIORouter(VectorIO):
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}")
log.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}")
return await self.routing_table.openai_attach_file_to_vector_store(
vector_store_id=vector_store_id,
file_id=file_id,
@ -285,7 +285,7 @@ class VectorIORouter(VectorIO):
before: str | None = None,
filter: VectorStoreFileStatus | None = None,
) -> list[VectorStoreFileObject]:
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}")
log.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}")
return await self.routing_table.openai_list_files_in_vector_store(
vector_store_id=vector_store_id,
limit=limit,
@ -300,7 +300,7 @@ class VectorIORouter(VectorIO):
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}")
log.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}")
return await self.routing_table.openai_retrieve_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
@ -311,7 +311,7 @@ class VectorIORouter(VectorIO):
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}")
log.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}")
return await self.routing_table.openai_retrieve_vector_store_file_contents(
vector_store_id=vector_store_id,
file_id=file_id,
@ -323,7 +323,7 @@ class VectorIORouter(VectorIO):
file_id: str,
attributes: dict[str, Any],
) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}")
log.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}")
return await self.routing_table.openai_update_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
@ -335,7 +335,7 @@ class VectorIORouter(VectorIO):
vector_store_id: str,
file_id: str,
) -> VectorStoreFileDeleteResponse:
logger.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}")
log.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}")
return await self.routing_table.openai_delete_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,

View file

@ -14,7 +14,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):

View file

@ -23,7 +23,7 @@ from llama_stack.core.store import DistributionRegistry
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, RoutingTable
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
def get_impl_api(p: Any) -> Api:
@ -177,7 +177,7 @@ class CommonRoutingTableImpl(RoutingTable):
# Check if user has permission to access this object
if not is_action_allowed(self.policy, "read", obj, get_authenticated_user()):
logger.debug(f"Access denied to {type} '{identifier}'")
log.debug(f"Access denied to {type} '{identifier}'")
return None
return obj
@ -205,7 +205,7 @@ class CommonRoutingTableImpl(RoutingTable):
raise AccessDeniedError("create", obj, creator)
if creator:
obj.owner = creator
logger.info(f"Setting owner for {obj.type} '{obj.identifier}' to {obj.owner.principal}")
log.info(f"Setting owner for {obj.type} '{obj.identifier}' to {obj.owner.principal}")
registered_obj = await register_object_with_provider(obj, p)
# TODO: This needs to be fixed for all APIs once they return the registered object
@ -250,7 +250,7 @@ async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) ->
if model is not None:
return model
logger.warning(
log.warning(
f"WARNING: model identifier '{model_id}' not found in routing table. Falling back to "
"searching in all providers. This is only for backwards compatibility and will stop working "
"soon. Migrate your calls to use fully scoped `provider_id/model_id` names."

View file

@ -26,7 +26,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):

View file

@ -17,7 +17,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
@ -33,7 +33,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
try:
models = await provider.list_models()
except Exception as e:
logger.exception(f"Model refresh failed for provider {provider_id}: {e}")
log.exception(f"Model refresh failed for provider {provider_id}: {e}")
continue
self.listed_providers.add(provider_id)
@ -132,7 +132,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
model_ids[model.provider_resource_id] = model.identifier
continue
logger.debug(f"unregistering model {model.identifier}")
log.debug(f"unregistering model {model.identifier}")
await self.unregister_object(model)
for model in models:
@ -143,7 +143,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
if model.identifier == model.provider_resource_id:
model.identifier = f"{provider_id}/{model.provider_resource_id}"
logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
log.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
await self.register_object(
ModelWithOwner(
identifier=model.identifier,

View file

@ -19,7 +19,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):

View file

@ -15,7 +15,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):

View file

@ -14,7 +14,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:

View file

@ -30,7 +30,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
@ -57,7 +57,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
if len(self.impls_by_provider_id) > 0:
provider_id = list(self.impls_by_provider_id.keys())[0]
if len(self.impls_by_provider_id) > 1:
logger.warning(
log.warning(
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
)
else:

View file

@ -15,7 +15,7 @@ from llama_stack.core.server.auth_providers import create_auth_provider
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
log = get_logger(name=__name__, category="auth")
class AuthenticationMiddleware:
@ -105,13 +105,13 @@ class AuthenticationMiddleware:
try:
validation_result = await self.auth_provider.validate_token(token, scope)
except httpx.TimeoutException:
logger.exception("Authentication request timed out")
log.exception("Authentication request timed out")
return await self._send_auth_error(send, "Authentication service timeout")
except ValueError as e:
logger.exception("Error during authentication")
log.exception("Error during authentication")
return await self._send_auth_error(send, str(e))
except Exception:
logger.exception("Error during authentication")
log.exception("Error during authentication")
return await self._send_auth_error(send, "Authentication service error")
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
@ -122,7 +122,7 @@ class AuthenticationMiddleware:
scope["principal"] = validation_result.principal
if validation_result.attributes:
scope["user_attributes"] = validation_result.attributes
logger.debug(
log.debug(
f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes"
)

View file

@ -23,7 +23,7 @@ from llama_stack.core.datatypes import (
)
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
log = get_logger(name=__name__, category="auth")
class AuthResponse(BaseModel):
@ -163,7 +163,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
timeout=10.0, # Add a reasonable timeout
)
if response.status_code != 200:
logger.warning(f"Token introspection failed with status code: {response.status_code}")
log.warning(f"Token introspection failed with status code: {response.status_code}")
raise ValueError(f"Token introspection failed: {response.status_code}")
fields = response.json()
@ -176,13 +176,13 @@ class OAuth2TokenAuthProvider(AuthProvider):
attributes=access_attributes,
)
except httpx.TimeoutException:
logger.exception("Token introspection request timed out")
log.exception("Token introspection request timed out")
raise
except ValueError:
# Re-raise ValueError exceptions to preserve their message
raise
except Exception as e:
logger.exception("Error during token introspection")
log.exception("Error during token introspection")
raise ValueError("Token introspection error") from e
async def close(self):
@ -273,7 +273,7 @@ class CustomAuthProvider(AuthProvider):
timeout=10.0, # Add a reasonable timeout
)
if response.status_code != 200:
logger.warning(f"Authentication failed with status code: {response.status_code}")
log.warning(f"Authentication failed with status code: {response.status_code}")
raise ValueError(f"Authentication failed: {response.status_code}")
# Parse and validate the auth response
@ -282,17 +282,17 @@ class CustomAuthProvider(AuthProvider):
auth_response = AuthResponse(**response_data)
return User(principal=auth_response.principal, attributes=auth_response.attributes)
except Exception as e:
logger.exception("Error parsing authentication response")
log.exception("Error parsing authentication response")
raise ValueError("Invalid authentication response format") from e
except httpx.TimeoutException:
logger.exception("Authentication request timed out")
log.exception("Authentication request timed out")
raise
except ValueError:
# Re-raise ValueError exceptions to preserve their message
raise
except Exception as e:
logger.exception("Error during authentication")
log.exception("Error during authentication")
raise ValueError("Authentication service error") from e
async def close(self):
@ -329,7 +329,7 @@ class GitHubTokenAuthProvider(AuthProvider):
try:
user_info = await _get_github_user_info(token, self.config.github_api_base_url)
except httpx.HTTPStatusError as e:
logger.warning(f"GitHub token validation failed: {e}")
log.warning(f"GitHub token validation failed: {e}")
raise ValueError("GitHub token validation failed. Please check your token and try again.") from e
principal = user_info["user"]["login"]

View file

@ -15,7 +15,7 @@ from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
logger = get_logger(name=__name__, category="quota")
log = get_logger(name=__name__, category="quota")
class QuotaMiddleware:
@ -46,7 +46,7 @@ class QuotaMiddleware:
self.window_seconds = window_seconds
if isinstance(self.kv_config, SqliteKVStoreConfig):
logger.warning(
log.warning(
"QuotaMiddleware: Using SQLite backend. Expiry/TTL is not enforced; cleanup is manual. "
f"window_seconds={self.window_seconds}"
)
@ -84,11 +84,11 @@ class QuotaMiddleware:
else:
await kv.set(key, str(count))
except Exception:
logger.exception("Failed to access KV store for quota")
log.exception("Failed to access KV store for quota")
return await self._send_error(send, 500, "Quota service error")
if count > limit:
logger.warning(
log.warning(
"Quota exceeded for client %s: %d/%d",
key_id,
count,

View file

@ -9,7 +9,7 @@ import asyncio
import functools
import inspect
import json
import logging
import logging # allow-direct-logging
import os
import ssl
import sys
@ -80,7 +80,7 @@ from .quota import QuotaMiddleware
REPO_ROOT = Path(__file__).parent.parent.parent.parent
logger = get_logger(name=__name__, category="server")
log = get_logger(name=__name__, category="server")
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
@ -157,9 +157,9 @@ async def shutdown(app):
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Starting up")
log.info("Starting up")
yield
logger.info("Shutting down")
log.info("Shutting down")
await shutdown(app)
@ -182,11 +182,11 @@ async def sse_generator(event_gen_coroutine):
yield create_sse_event(item)
await asyncio.sleep(0.01)
except asyncio.CancelledError:
logger.info("Generator cancelled")
log.info("Generator cancelled")
if event_gen:
await event_gen.aclose()
except Exception as e:
logger.exception("Error in sse_generator")
log.exception("Error in sse_generator")
yield create_sse_event(
{
"error": {
@ -206,11 +206,11 @@ async def log_request_pre_validation(request: Request):
log_output = rich.pretty.pretty_repr(parsed_body)
except (json.JSONDecodeError, UnicodeDecodeError):
log_output = repr(body_bytes)
logger.debug(f"Incoming raw request body for {request.method} {request.url.path}:\n{log_output}")
log.debug(f"Incoming raw request body for {request.method} {request.url.path}:\n{log_output}")
else:
logger.debug(f"Incoming {request.method} {request.url.path} request with empty body.")
log.debug(f"Incoming {request.method} {request.url.path} request with empty body.")
except Exception as e:
logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
log.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
@ -238,10 +238,10 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
result.url = route
return result
except Exception as e:
if logger.isEnabledFor(logging.DEBUG):
logger.exception(f"Error executing endpoint {route=} {method=}")
if log.isEnabledFor(logging.DEBUG):
log.exception(f"Error executing endpoint {route=} {method=}")
else:
logger.error(f"Error executing endpoint {route=} {method=}: {str(e)}")
log.error(f"Error executing endpoint {route=} {method=}: {str(e)}")
raise translate_exception(e) from e
sig = inspect.signature(func)
@ -291,7 +291,7 @@ class TracingMiddleware:
# Check if the path is a FastAPI built-in path
if path.startswith(self.fastapi_paths):
# Pass through to FastAPI's built-in handlers
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
log.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
return await self.app(scope, receive, send)
if not hasattr(self, "route_impls"):
@ -303,7 +303,7 @@ class TracingMiddleware:
)
except ValueError:
# If no matching endpoint is found, pass through to FastAPI
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
log.debug(f"No matching route found for path: {path}, falling back to FastAPI")
return await self.app(scope, receive, send)
trace_attributes = {"__location__": "server", "raw_path": path}
@ -404,15 +404,15 @@ def main(args: argparse.Namespace | None = None):
config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg)
logger = get_logger(name=__name__, category="server", config=logger_config)
log = get_logger(name=__name__, category="server", config=logger_config)
if args.env:
for env_pair in args.env:
try:
key, value = validate_env_pair(env_pair)
logger.info(f"Setting CLI environment variable {key} => {value}")
log.info(f"Setting CLI environment variable {key} => {value}")
os.environ[key] = value
except ValueError as e:
logger.error(f"Error: {str(e)}")
log.error(f"Error: {str(e)}")
sys.exit(1)
config = replace_env_vars(config_contents)
config = StackRunConfig(**cast_image_name_to_string(config))
@ -438,16 +438,16 @@ def main(args: argparse.Namespace | None = None):
impls = loop.run_until_complete(construct_stack(config))
except InvalidProviderError as e:
logger.error(f"Error: {str(e)}")
log.error(f"Error: {str(e)}")
sys.exit(1)
if config.server.auth:
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
log.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth, impls=impls)
else:
if config.server.quota:
quota = config.server.quota
logger.warning(
log.warning(
"Configured authenticated_max_requests (%d) but no auth is enabled; "
"falling back to anonymous_max_requests (%d) for all the requests",
quota.authenticated_max_requests,
@ -455,7 +455,7 @@ def main(args: argparse.Namespace | None = None):
)
if config.server.quota:
logger.info("Enabling quota middleware for authenticated and anonymous clients")
log.info("Enabling quota middleware for authenticated and anonymous clients")
quota = config.server.quota
anonymous_max_requests = quota.anonymous_max_requests
@ -516,7 +516,7 @@ def main(args: argparse.Namespace | None = None):
if not available_methods:
raise ValueError(f"No methods found for {route.name} on {impl}")
method = available_methods[0]
logger.debug(f"{method} {route.path}")
log.debug(f"{method} {route.path}")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
@ -528,7 +528,7 @@ def main(args: argparse.Namespace | None = None):
)
)
logger.debug(f"serving APIs: {apis_to_serve}")
log.debug(f"serving APIs: {apis_to_serve}")
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler)
@ -553,21 +553,21 @@ def main(args: argparse.Namespace | None = None):
if config.server.tls_cafile:
ssl_config["ssl_ca_certs"] = config.server.tls_cafile
ssl_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED
logger.info(
log.info(
f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}\n CA: {config.server.tls_cafile}"
)
else:
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
log.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
listen_host = config.server.host or ["::", "0.0.0.0"]
logger.info(f"Listening on {listen_host}:{port}")
log.info(f"Listening on {listen_host}:{port}")
uvicorn_config = {
"app": app,
"host": listen_host,
"port": port,
"lifespan": "on",
"log_level": logger.getEffectiveLevel(),
"log_level": log.getEffectiveLevel(),
"log_config": logger_config,
}
if ssl_config:
@ -586,19 +586,19 @@ def main(args: argparse.Namespace | None = None):
try:
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
except (KeyboardInterrupt, SystemExit):
logger.info("Received interrupt signal, shutting down gracefully...")
log.info("Received interrupt signal, shutting down gracefully...")
finally:
if not loop.is_closed():
logger.debug("Closing event loop")
log.debug("Closing event loop")
loop.close()
def _log_run_config(run_config: StackRunConfig):
"""Logs the run config with redacted fields and disabled providers removed."""
logger.info("Run configuration:")
log.info("Run configuration:")
safe_config = redact_sensitive_fields(run_config.model_dump(mode="json"))
clean_config = remove_disabled_providers(safe_config)
logger.info(yaml.dump(clean_config, indent=2))
log.info(yaml.dump(clean_config, indent=2))
def extract_path_params(route: str) -> list[str]:

View file

@ -45,7 +45,7 @@ from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
class LlamaStack(
@ -105,11 +105,11 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
method = getattr(impls[api], register_method)
for obj in objects:
logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
log.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
# Do not register models on disabled providers
if hasattr(obj, "provider_id") and (not obj.provider_id or obj.provider_id == "__disabled__"):
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
log.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
continue
# we want to maintain the type information in arguments to method.
@ -123,7 +123,7 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
objects_to_process = response.data if hasattr(response, "data") else response
for obj in objects_to_process:
logger.debug(
log.debug(
f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}",
)
@ -160,7 +160,7 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
try:
resolved_provider_id = replace_env_vars(v["provider_id"], f"{path}[{i}].provider_id")
if resolved_provider_id == "__disabled__":
logger.debug(
log.debug(
f"Skipping config env variable expansion for disabled provider: {v.get('provider_id', '')}"
)
# Create a copy with resolved provider_id but original config
@ -315,7 +315,7 @@ async def construct_stack(
TEST_RECORDING_CONTEXT = setup_inference_recording()
if TEST_RECORDING_CONTEXT:
TEST_RECORDING_CONTEXT.__enter__()
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
log.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
policy = run_config.server.auth.access_policy if run_config.server.auth else []
@ -337,12 +337,12 @@ async def construct_stack(
import traceback
if task.cancelled():
logger.error("Model refresh task cancelled")
log.error("Model refresh task cancelled")
elif task.exception():
logger.error(f"Model refresh task failed: {task.exception()}")
log.error(f"Model refresh task failed: {task.exception()}")
traceback.print_exception(task.exception())
else:
logger.debug("Model refresh task completed")
log.debug("Model refresh task completed")
REGISTRY_REFRESH_TASK.add_done_callback(cb)
return impls
@ -351,23 +351,23 @@ async def construct_stack(
async def shutdown_stack(impls: dict[Api, Any]):
for impl in impls.values():
impl_name = impl.__class__.__name__
logger.info(f"Shutting down {impl_name}")
log.info(f"Shutting down {impl_name}")
try:
if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logger.warning(f"No shutdown method for {impl_name}")
log.warning(f"No shutdown method for {impl_name}")
except TimeoutError:
logger.exception(f"Shutdown timeout for {impl_name}")
log.exception(f"Shutdown timeout for {impl_name}")
except (Exception, asyncio.CancelledError) as e:
logger.exception(f"Failed to shutdown {impl_name}: {e}")
log.exception(f"Failed to shutdown {impl_name}: {e}")
global TEST_RECORDING_CONTEXT
if TEST_RECORDING_CONTEXT:
try:
TEST_RECORDING_CONTEXT.__exit__(None, None, None)
except Exception as e:
logger.error(f"Error during inference recording cleanup: {e}")
log.error(f"Error during inference recording cleanup: {e}")
global REGISTRY_REFRESH_TASK
if REGISTRY_REFRESH_TASK:
@ -375,14 +375,14 @@ async def shutdown_stack(impls: dict[Api, Any]):
async def refresh_registry_once(impls: dict[Api, Any]):
logger.debug("refreshing registry")
log.debug("refreshing registry")
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
for routing_table in routing_tables:
await routing_table.refresh()
async def refresh_registry_task(impls: dict[Api, Any]):
logger.info("starting registry refresh task")
log.info("starting registry refresh task")
while True:
await refresh_registry_once(impls)

View file

@ -10,7 +10,7 @@ from pathlib import Path
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="config_resolution")
log = get_logger(name=__name__, category="config_resolution")
DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions"
@ -42,25 +42,25 @@ def resolve_config_or_distro(
# Strategy 1: Try as file path first
config_path = Path(config_or_distro)
if config_path.exists() and config_path.is_file():
logger.info(f"Using file path: {config_path}")
log.info(f"Using file path: {config_path}")
return config_path.resolve()
# Strategy 2: Try as distribution name (if no .yaml extension)
if not config_or_distro.endswith(".yaml"):
distro_config = _get_distro_config_path(config_or_distro, mode)
if distro_config.exists():
logger.info(f"Using distribution: {distro_config}")
log.info(f"Using distribution: {distro_config}")
return distro_config
# Strategy 3: Try as built distribution name
distrib_config = DISTRIBS_BASE_DIR / f"llamastack-{config_or_distro}" / f"{config_or_distro}-{mode}.yaml"
if distrib_config.exists():
logger.info(f"Using built distribution: {distrib_config}")
log.info(f"Using built distribution: {distrib_config}")
return distrib_config
distrib_config = DISTRIBS_BASE_DIR / f"{config_or_distro}" / f"{config_or_distro}-{mode}.yaml"
if distrib_config.exists():
logger.info(f"Using built distribution: {distrib_config}")
log.info(f"Using built distribution: {distrib_config}")
return distrib_config
# Strategy 4: Failed - provide helpful error

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import importlib
import os
import signal
import subprocess
@ -12,9 +12,9 @@ import sys
from termcolor import cprint
log = logging.getLogger(__name__)
from llama_stack.log import get_logger
import importlib
log = get_logger(name=__name__, category="core")
def formulate_run_args(image_type: str, image_name: str) -> list:

View file

@ -6,7 +6,6 @@
import inspect
import json
import logging
from enum import Enum
from typing import Annotated, Any, Literal, Union, get_args, get_origin
@ -14,7 +13,9 @@ from pydantic import BaseModel
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefinedType
log = logging.getLogger(__name__)
from llama_stack.log import get_logger
log = get_logger(name=__name__, category="core")
def is_list_of_primitives(field_type):

View file

@ -4,11 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import logging # allow-direct-logging
import os
import re
import sys
from logging.config import dictConfig
from logging.config import dictConfig # allow-direct-logging
from rich.console import Console
from rich.errors import MarkupError

View file

@ -13,7 +13,7 @@
# Copyright (c) Meta Platforms, Inc. and its affiliates.
import math
from logging import getLogger
from logging import getLogger # allow-direct-logging
import torch
import torch.nn.functional as F

View file

@ -13,7 +13,7 @@
import math
from collections import defaultdict
from logging import getLogger
from logging import getLogger # allow-direct-logging
from typing import Any
import torch

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import math
from collections.abc import Callable
from functools import partial
@ -22,6 +21,8 @@ from PIL import Image as PIL_Image
from torch import Tensor, nn
from torch.distributed import _functional_collectives as funcol
from llama_stack.log import get_logger
from ..model import ModelArgs, RMSNorm, apply_rotary_emb, precompute_freqs_cis
from .encoder_utils import (
build_encoder_attention_mask,
@ -34,9 +35,10 @@ from .encoder_utils import (
from .image_transform import VariableSizeImageTransform
from .utils import get_negative_inf_value, to_2tuple
logger = logging.getLogger(__name__)
MP_SCALE = 8
log = get_logger(name=__name__, category="core")
def reduce_from_tensor_model_parallel_region(input_):
"""All-reduce the input tensor across model parallel group."""
@ -415,7 +417,7 @@ class VisionEncoder(nn.Module):
)
state_dict[prefix + "gated_positional_embedding"] = global_pos_embed
state_dict[prefix + "gated_positional_embedding_gate"] = torch.zeros(1, dtype=global_pos_embed.dtype)
logger.info(f"Initialized global positional embedding with size {global_pos_embed.size()}")
log.info(f"Initialized global positional embedding with size {global_pos_embed.size()}")
else:
global_pos_embed = resize_global_position_embedding(
state_dict[prefix + "gated_positional_embedding"],
@ -423,7 +425,7 @@ class VisionEncoder(nn.Module):
self.max_num_tiles,
self.max_num_tiles,
)
logger.info(
log.info(
f"Resized global positional embedding from {state_dict[prefix + 'gated_positional_embedding'].size()} to {global_pos_embed.size()}"
)
state_dict[prefix + "gated_positional_embedding"] = global_pos_embed
@ -771,7 +773,7 @@ class TilePositionEmbedding(nn.Module):
if embed is not None:
# reshape the weights to the correct shape
nt_old, nt_old, _, w = embed.shape
logging.info(f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}")
log.info(f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}")
embed_new = TilePositionEmbedding._dynamic_resize(embed, self.num_tiles)
# assign the weights to the module
state_dict[prefix + "embedding"] = embed_new

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from collections.abc import Collection, Iterator, Sequence, Set
from logging import getLogger
from logging import getLogger # allow-direct-logging
from pathlib import Path
from typing import (
Literal,

View file

@ -11,7 +11,7 @@ from llama_stack.log import get_logger
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
logger = get_logger(name=__name__, category="inference")
log = get_logger(name=__name__, category="inference")
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
@ -215,7 +215,7 @@ class ToolUtils:
# FIXME: Enable multiple tool calls
return function_calls[0]
else:
logger.debug(f"Did not parse tool call from message body: {message_body}")
log.debug(f"Did not parse tool call from message body: {message_body}")
return None
@staticmethod

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import os
from collections.abc import Callable
@ -13,11 +12,13 @@ from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
from torch import Tensor, nn
from torch.nn import functional as F
from llama_stack.log import get_logger
from ...datatypes import QuantizationMode
from ..model import Transformer, TransformerBlock
from ..moe import MoE
log = logging.getLogger(__name__)
logger = get_logger(__name__, category="core")
def swiglu_wrapper_no_reduce(
@ -186,7 +187,7 @@ def logging_callbacks(
if use_rich_progress:
console.print(message)
elif rank == 0: # Only log from rank 0 for non-rich logging
log.info(message)
logger.info(message)
total_blocks = sum(1 for _, block in model.named_modules() if should_quantize_block(block))
progress = None
@ -220,6 +221,6 @@ def logging_callbacks(
if completed is not None:
progress.update(task_id, completed=completed)
elif rank == 0 and completed and completed % 10 == 0:
log.info(f"Rank {rank}: {completed}/{total_blocks} blocks completed")
logger.info(f"Rank {rank}: {completed}/{total_blocks} blocks completed")
return progress, log_status, update_status

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from collections.abc import Collection, Iterator, Sequence, Set
from logging import getLogger
from logging import getLogger # allow-direct-logging
from pathlib import Path
from typing import (
Literal,

View file

@ -6,16 +6,17 @@
# type: ignore
import collections
import logging
log = logging.getLogger(__name__)
from llama_stack.log import get_logger
logger = get_logger(__name__, category="core")
try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401
log.info("Using efficient FP8 or INT4 operators in FBGEMM.")
logger.info("Using efficient FP8 or INT4 operators in FBGEMM.")
except ImportError:
log.error("No efficient FP8 or INT4 operators. Please install FBGEMM.")
logger.error("No efficient FP8 or INT4 operators. Please install FBGEMM.")
raise
import torch

View file

@ -84,7 +84,7 @@ MEMORY_QUERY_TOOL = "knowledge_search"
WEB_SEARCH_TOOL = "web_search"
RAG_TOOL_GROUP = "builtin::rag"
logger = get_logger(name=__name__, category="agents")
log = get_logger(name=__name__, category="agents")
class ChatAgent(ShieldRunnerMixin):
@ -612,7 +612,7 @@ class ChatAgent(ShieldRunnerMixin):
)
if n_iter >= self.agent_config.max_infer_iters:
logger.info(f"done with MAX iterations ({n_iter}), exiting.")
log.info(f"done with MAX iterations ({n_iter}), exiting.")
# NOTE: mark end_of_turn to indicate to client that we are done with the turn
# Do not continue the tool call loop after this point
message.stop_reason = StopReason.end_of_turn
@ -620,7 +620,7 @@ class ChatAgent(ShieldRunnerMixin):
break
if stop_reason == StopReason.out_of_tokens:
logger.info("out of token budget, exiting.")
log.info("out of token budget, exiting.")
yield message
break
@ -634,7 +634,7 @@ class ChatAgent(ShieldRunnerMixin):
message.content = [message.content] + output_attachments
yield message
else:
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
log.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
input_messages = input_messages + [message]
else:
input_messages = input_messages + [message]
@ -889,7 +889,7 @@ class ChatAgent(ShieldRunnerMixin):
else:
tool_name_str = tool_name
logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}")
log.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}")
result = await self.tool_runtime_api.invoke_tool(
tool_name=tool_name_str,
kwargs={
@ -899,7 +899,7 @@ class ChatAgent(ShieldRunnerMixin):
**self.tool_name_to_args.get(tool_name_str, {}),
},
)
logger.debug(f"tool call {tool_name_str} completed with result: {result}")
log.debug(f"tool call {tool_name_str} completed with result: {result}")
return result

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import uuid
from collections.abc import AsyncGenerator
from datetime import UTC, datetime
@ -42,6 +41,7 @@ from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.datatypes import AccessRule
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from llama_stack.providers.utils.pagination import paginate_records
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
@ -51,7 +51,7 @@ from .config import MetaReferenceAgentsImplConfig
from .openai_responses import OpenAIResponsesImpl
from .persistence import AgentInfo
logger = logging.getLogger()
log = get_logger(name=__name__, category="agents")
class MetaReferenceAgentsImpl(Agents):
@ -268,7 +268,7 @@ class MetaReferenceAgentsImpl(Agents):
# Get the agent info using the key
agent_info_json = await self.persistence_store.get(agent_key)
if not agent_info_json:
logger.error(f"Could not find agent info for key {agent_key}")
log.error(f"Could not find agent info for key {agent_key}")
continue
try:
@ -281,7 +281,7 @@ class MetaReferenceAgentsImpl(Agents):
)
)
except Exception as e:
logger.error(f"Error parsing agent info for {agent_id}: {e}")
log.error(f"Error parsing agent info for {agent_id}: {e}")
continue
# Convert Agent objects to dictionaries

View file

@ -75,7 +75,7 @@ from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefiniti
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
logger = get_logger(name=__name__, category="openai_responses")
log = get_logger(name=__name__, category="openai_responses")
OPENAI_RESPONSES_PREFIX = "openai_responses:"
@ -544,12 +544,12 @@ class OpenAIResponsesImpl:
break
if function_tool_calls:
logger.info("Exiting inference loop since there is a function (client-side) tool call")
log.info("Exiting inference loop since there is a function (client-side) tool call")
break
n_iter += 1
if n_iter >= max_infer_iters:
logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {max_infer_iters=}")
log.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {max_infer_iters=}")
break
messages = next_turn_messages
@ -698,7 +698,7 @@ class OpenAIResponsesImpl:
)
return search_response.data
except Exception as e:
logger.warning(f"Failed to search vector store {vector_store_id}: {e}")
log.warning(f"Failed to search vector store {vector_store_id}: {e}")
return []
# Run all searches in parallel using gather

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import json
import logging
import uuid
from datetime import UTC, datetime
@ -15,9 +14,10 @@ from llama_stack.core.access_control.access_control import AccessDeniedError, is
from llama_stack.core.access_control.datatypes import AccessRule
from llama_stack.core.datatypes import User
from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="agents")
class AgentSessionInfo(Session):

View file

@ -5,13 +5,13 @@
# the root directory of this source tree.
import asyncio
import logging
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry import tracing
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="agents")
class SafetyException(Exception): # noqa: N818

View file

@ -73,11 +73,12 @@ from .config import MetaReferenceInferenceConfig
from .generators import LlamaGenerator
from .model_parallel import LlamaModelParallelGenerator
log = get_logger(__name__, category="inference")
# there's a single model parallel process running serving the model. for now,
# we don't support multiple concurrent requests to this process.
SEMAPHORE = asyncio.Semaphore(1)
logger = get_logger(__name__, category="inference")
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
return LlamaGenerator(config, model_id, llama_model)
@ -144,7 +145,7 @@ class MetaReferenceInferenceImpl(
return model
async def load_model(self, model_id, llama_model) -> None:
log.info(f"Loading model `{model_id}`")
logger.info(f"Loading model `{model_id}`")
builder_params = [self.config, model_id, llama_model]
@ -166,7 +167,7 @@ class MetaReferenceInferenceImpl(
self.model_id = model_id
self.llama_model = llama_model
log.info("Warming up...")
logger.info("Warming up...")
await self.completion(
model_id=model_id,
content="Hello, world!",
@ -177,7 +178,7 @@ class MetaReferenceInferenceImpl(
messages=[UserMessage(content="Hi how are you?")],
sampling_params=SamplingParams(max_tokens=20),
)
log.info("Warmed up!")
logger.info("Warmed up!")
def check_model(self, request) -> None:
if self.model_id is None or self.llama_model is None:

View file

@ -12,7 +12,6 @@
import copy
import json
import logging
import multiprocessing
import os
import tempfile
@ -32,13 +31,14 @@ from fairscale.nn.model_parallel.initialize import (
from pydantic import BaseModel, Field
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import GenerationResult
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
)
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="inference")
class ProcessingMessageName(str, Enum):
@ -236,7 +236,7 @@ def worker_process_entrypoint(
except StopIteration:
break
log.info("[debug] worker process done")
log.info("[debug] worker process done")
def launch_dist_group(

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from collections.abc import AsyncGenerator
from llama_stack.apis.inference import (
@ -32,8 +31,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
from .config import SentenceTransformersInferenceConfig
log = logging.getLogger(__name__)
class SentenceTransformersInferenceImpl(
OpenAIChatCompletionToLlamaStackMixin,

View file

@ -6,7 +6,6 @@
import gc
import json
import logging
import multiprocessing
from pathlib import Path
from typing import Any
@ -28,6 +27,7 @@ from llama_stack.apis.post_training import (
LoraFinetuningConfig,
TrainingConfig,
)
from llama_stack.log import get_logger
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from ..config import HuggingFacePostTrainingConfig
@ -44,7 +44,7 @@ from ..utils import (
split_dataset,
)
logger = logging.getLogger(__name__)
log = get_logger(name=__name__, category="core")
class HFFinetuningSingleDevice:
@ -69,14 +69,14 @@ class HFFinetuningSingleDevice:
try:
messages = json.loads(row["chat_completion_input"])
if not isinstance(messages, list) or len(messages) != 1:
logger.warning(f"Invalid chat_completion_input format: {row['chat_completion_input']}")
log.warning(f"Invalid chat_completion_input format: {row['chat_completion_input']}")
return None, None
if "content" not in messages[0]:
logger.warning(f"Message missing content: {messages[0]}")
log.warning(f"Message missing content: {messages[0]}")
return None, None
return messages[0]["content"], row["expected_answer"]
except json.JSONDecodeError:
logger.warning(f"Failed to parse chat_completion_input: {row['chat_completion_input']}")
log.warning(f"Failed to parse chat_completion_input: {row['chat_completion_input']}")
return None, None
return None, None
@ -86,13 +86,13 @@ class HFFinetuningSingleDevice:
try:
dialog = json.loads(row["dialog"])
if not isinstance(dialog, list) or len(dialog) < 2:
logger.warning(f"Dialog must have at least 2 messages: {row['dialog']}")
log.warning(f"Dialog must have at least 2 messages: {row['dialog']}")
return None, None
if dialog[0].get("role") != "user":
logger.warning(f"First message must be from user: {dialog[0]}")
log.warning(f"First message must be from user: {dialog[0]}")
return None, None
if not any(msg.get("role") == "assistant" for msg in dialog):
logger.warning("Dialog must have at least one assistant message")
log.warning("Dialog must have at least one assistant message")
return None, None
# Convert to human/gpt format
@ -100,14 +100,14 @@ class HFFinetuningSingleDevice:
conversations = []
for msg in dialog:
if "role" not in msg or "content" not in msg:
logger.warning(f"Message missing role or content: {msg}")
log.warning(f"Message missing role or content: {msg}")
continue
conversations.append({"from": role_map[msg["role"]], "value": msg["content"]})
# Format as a single conversation
return conversations[0]["value"], conversations[1]["value"]
except json.JSONDecodeError:
logger.warning(f"Failed to parse dialog: {row['dialog']}")
log.warning(f"Failed to parse dialog: {row['dialog']}")
return None, None
return None, None
@ -198,7 +198,7 @@ class HFFinetuningSingleDevice:
"""
import asyncio
logger.info("Starting training process with async wrapper")
log.info("Starting training process with async wrapper")
asyncio.run(
self._run_training(
model=model,
@ -228,14 +228,14 @@ class HFFinetuningSingleDevice:
raise ValueError("DataConfig is required for training")
# Load dataset
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
log.info(f"Loading dataset: {config.data_config.dataset_id}")
rows = await load_rows_from_dataset(self.datasetio_api, config.data_config.dataset_id)
if not self.validate_dataset_format(rows):
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
logger.info(f"Loaded {len(rows)} rows from dataset")
log.info(f"Loaded {len(rows)} rows from dataset")
# Initialize tokenizer
logger.info(f"Initializing tokenizer for model: {model}")
log.info(f"Initializing tokenizer for model: {model}")
try:
tokenizer = AutoTokenizer.from_pretrained(model, **provider_config.model_specific_config)
@ -257,16 +257,16 @@ class HFFinetuningSingleDevice:
# This ensures consistent sequence lengths across the training process
tokenizer.model_max_length = provider_config.max_seq_length
logger.info("Tokenizer initialized successfully")
log.info("Tokenizer initialized successfully")
except Exception as e:
raise RuntimeError(f"Failed to initialize tokenizer: {str(e)}") from e
# Create and preprocess dataset
logger.info("Creating and preprocessing dataset")
log.info("Creating and preprocessing dataset")
try:
ds = self._create_dataset(rows, config, provider_config)
ds = self._preprocess_dataset(ds, tokenizer, provider_config)
logger.info(f"Dataset created with {len(ds)} examples")
log.info(f"Dataset created with {len(ds)} examples")
except Exception as e:
raise ValueError(f"Failed to create dataset: {str(e)}") from e
@ -293,11 +293,11 @@ class HFFinetuningSingleDevice:
Returns:
Configured SFTConfig object
"""
logger.info("Configuring training arguments")
log.info("Configuring training arguments")
lr = 2e-5
if config.optimizer_config:
lr = config.optimizer_config.lr
logger.info(f"Using custom learning rate: {lr}")
log.info(f"Using custom learning rate: {lr}")
# Validate data config
if not config.data_config:
@ -350,17 +350,17 @@ class HFFinetuningSingleDevice:
peft_config: Optional LoRA configuration
output_dir_path: Path to save the model
"""
logger.info("Saving final model")
log.info("Saving final model")
model_obj.config.use_cache = True
if peft_config:
logger.info("Merging LoRA weights with base model")
log.info("Merging LoRA weights with base model")
model_obj = trainer.model.merge_and_unload()
else:
model_obj = trainer.model
save_path = output_dir_path / "merged_model"
logger.info(f"Saving model to {save_path}")
log.info(f"Saving model to {save_path}")
model_obj.save_pretrained(save_path)
async def _run_training(
@ -380,13 +380,13 @@ class HFFinetuningSingleDevice:
setup_signal_handlers()
# Convert config dicts back to objects
logger.info("Initializing configuration objects")
log.info("Initializing configuration objects")
provider_config_obj = HuggingFacePostTrainingConfig(**provider_config)
config_obj = TrainingConfig(**config)
# Initialize and validate device
device = setup_torch_device(provider_config_obj.device)
logger.info(f"Using device '{device}'")
log.info(f"Using device '{device}'")
# Load dataset and tokenizer
train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config_obj, provider_config_obj)
@ -409,7 +409,7 @@ class HFFinetuningSingleDevice:
model_obj = load_model(model, device, provider_config_obj)
# Initialize trainer
logger.info("Initializing SFTTrainer")
log.info("Initializing SFTTrainer")
trainer = SFTTrainer(
model=model_obj,
train_dataset=train_dataset,
@ -420,9 +420,9 @@ class HFFinetuningSingleDevice:
try:
# Train
logger.info("Starting training")
log.info("Starting training")
trainer.train()
logger.info("Training completed successfully")
log.info("Training completed successfully")
# Save final model if output directory is provided
if output_dir_path:
@ -430,12 +430,12 @@ class HFFinetuningSingleDevice:
finally:
# Clean up resources
logger.info("Cleaning up resources")
log.info("Cleaning up resources")
if hasattr(trainer, "model"):
evacuate_model_from_device(trainer.model, device.type)
del trainer
gc.collect()
logger.info("Cleanup completed")
log.info("Cleanup completed")
async def train(
self,
@ -449,7 +449,7 @@ class HFFinetuningSingleDevice:
"""Train a model using HuggingFace's SFTTrainer"""
# Initialize and validate device
device = setup_torch_device(provider_config.device)
logger.info(f"Using device '{device}'")
log.info(f"Using device '{device}'")
output_dir_path = None
if output_dir:
@ -479,7 +479,7 @@ class HFFinetuningSingleDevice:
raise ValueError("DataConfig is required for training")
# Train in a separate process
logger.info("Starting training in separate process")
log.info("Starting training in separate process")
try:
# Setup multiprocessing for device
if device.type in ["cuda", "mps"]:

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import gc
import logging
import multiprocessing
from pathlib import Path
from typing import Any
@ -24,6 +23,7 @@ from llama_stack.apis.post_training import (
DPOAlignmentConfig,
TrainingConfig,
)
from llama_stack.log import get_logger
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from ..config import HuggingFacePostTrainingConfig
@ -40,7 +40,7 @@ from ..utils import (
split_dataset,
)
logger = logging.getLogger(__name__)
logger = get_logger(__name__, category="core")
class HFDPOAlignmentSingleDevice:

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import os
import signal
import sys
@ -19,10 +18,11 @@ from transformers import AutoConfig, AutoModelForCausalLM
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.post_training import Checkpoint, TrainingConfig
from llama_stack.log import get_logger
from .config import HuggingFacePostTrainingConfig
logger = logging.getLogger(__name__)
logger = get_logger(__name__, category="core")
def setup_environment():

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import os
import time
from datetime import UTC, datetime
@ -19,6 +18,7 @@ from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training
from torchtune import utils as torchtune_utils
from torchtune.data import padded_collate_sft
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
from torchtune.modules.loss import CEWithChunkedOutputLoss
from torchtune.modules.peft import (
get_adapter_params,
@ -45,6 +45,7 @@ from llama_stack.apis.post_training import (
)
from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.core.utils.model_utils import model_local_dir
from llama_stack.log import get_logger
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from llama_stack.providers.inline.post_training.torchtune.common import utils
@ -56,9 +57,7 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
)
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
log = logging.getLogger(__name__)
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
log = get_logger(name=__name__, category="core")
class LoraFinetuningSingleDevice:

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import Any
from llama_stack.apis.inference import Message
@ -15,13 +14,14 @@ from llama_stack.apis.safety import (
ViolationLevel,
)
from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from .config import CodeScannerConfig
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="safety")
ALLOWED_CODE_SCANNER_MODEL_IDS = [
"CodeScanner",

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import Any
import torch
@ -19,6 +18,7 @@ from llama_stack.apis.safety import (
)
from llama_stack.apis.shields import Shield
from llama_stack.core.utils.model_utils import model_local_dir
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
@ -26,10 +26,10 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import PromptGuardConfig, PromptGuardType
log = logging.getLogger(__name__)
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
log = get_logger(name=__name__, category="safety")
class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
def __init__(self, config: PromptGuardConfig, _deps) -> None:

View file

@ -7,7 +7,6 @@
import collections
import functools
import json
import logging
import random
import re
import string
@ -20,7 +19,9 @@ import nltk
from pythainlp.tokenize import sent_tokenize as sent_tokenize_thai
from pythainlp.tokenize import word_tokenize as word_tokenize_thai
logger = logging.getLogger()
from llama_stack.log import get_logger
log = get_logger(name=__name__, category="core")
WORD_LIST = [
"western",
@ -1726,7 +1727,7 @@ def get_langid(text: str, lid_path: str | None = None) -> str:
try:
line_langs.append(langdetect.detect(line))
except langdetect.LangDetectException as e:
logger.info("Unable to detect language for text %s due to %s", line, e) # refex: disable=pytotw.037
log.info("Unable to detect language for text %s due to %s", line, e) # refex: disable=pytotw.037
if len(line_langs) == 0:
return "en"
@ -1885,7 +1886,7 @@ class ResponseLanguageChecker(Instruction):
return langdetect.detect(value) == self._language
except langdetect.LangDetectException as e:
# Count as instruction is followed.
logger.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037
log.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037
return True
@ -3110,7 +3111,7 @@ class CapitalLettersEnglishChecker(Instruction):
return value.isupper() and langdetect.detect(value) == "en"
except langdetect.LangDetectException as e:
# Count as instruction is followed.
logger.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037
log.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037
return True
@ -3139,7 +3140,7 @@ class LowercaseLettersEnglishChecker(Instruction):
return value.islower() and langdetect.detect(value) == "en"
except langdetect.LangDetectException as e:
# Count as instruction is followed.
logger.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037
log.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037
return True

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import asyncio
import logging
import secrets
import string
from typing import Any
@ -32,6 +31,7 @@ from llama_stack.apis.tools import (
ToolRuntime,
)
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.memory.vector_store import (
@ -42,7 +42,7 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import RagToolRuntimeConfig
from .context_retriever import generate_rag_query
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="tools")
def make_random_string(length: int = 8):

View file

@ -8,7 +8,6 @@ import asyncio
import base64
import io
import json
import logging
from typing import Any
import faiss
@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import (
QueryChunksResponse,
VectorIO,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
HealthResponse,
HealthStatus,
@ -39,7 +39,7 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import FaissVectorIOConfig
logger = logging.getLogger(__name__)
log = get_logger(name=__name__, category="core")
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::"
@ -83,7 +83,7 @@ class FaissIndex(EmbeddingIndex):
self.index = faiss.deserialize_index(np.load(buffer, allow_pickle=False))
self.chunk_ids = [chunk.chunk_id for chunk in self.chunk_by_index.values()]
except Exception as e:
logger.debug(e, exc_info=True)
log.debug(e, exc_info=True)
raise ValueError(
"Error deserializing Faiss index from storage. If you recently upgraded your Llama Stack, Faiss, "
"or NumPy versions, you may need to delete the index and re-create it again or downgrade versions.\n"
@ -262,7 +262,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
assert self.kvstore is not None
if vector_db_id not in self.cache:
logger.warning(f"Vector DB {vector_db_id} not found")
log.warning(f"Vector DB {vector_db_id} not found")
return
await self.cache[vector_db_id].index.delete()

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import asyncio
import logging
import re
import sqlite3
import struct
@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import (
QueryChunksResponse,
VectorIO,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
@ -35,7 +35,7 @@ from llama_stack.providers.utils.memory.vector_store import (
VectorDBWithIndex,
)
logger = logging.getLogger(__name__)
log = get_logger(name=__name__, category="core")
# Specifying search mode is dependent on the VectorIO provider.
VECTOR_SEARCH = "vector"
@ -257,7 +257,7 @@ class SQLiteVecIndex(EmbeddingIndex):
except sqlite3.Error as e:
connection.rollback()
logger.error(f"Error inserting into {self.vector_table}: {e}")
log.error(f"Error inserting into {self.vector_table}: {e}")
raise
finally:
@ -306,7 +306,7 @@ class SQLiteVecIndex(EmbeddingIndex):
try:
chunk = Chunk.model_validate_json(chunk_json)
except Exception as e:
logger.error(f"Error parsing chunk JSON for id {_id}: {e}")
log.error(f"Error parsing chunk JSON for id {_id}: {e}")
continue
chunks.append(chunk)
scores.append(score)
@ -352,7 +352,7 @@ class SQLiteVecIndex(EmbeddingIndex):
try:
chunk = Chunk.model_validate_json(chunk_json)
except Exception as e:
logger.error(f"Error parsing chunk JSON for id {_id}: {e}")
log.error(f"Error parsing chunk JSON for id {_id}: {e}")
continue
chunks.append(chunk)
scores.append(score)
@ -447,7 +447,7 @@ class SQLiteVecIndex(EmbeddingIndex):
connection.commit()
except Exception as e:
connection.rollback()
logger.error(f"Error deleting chunk {chunk_id}: {e}")
log.error(f"Error deleting chunk {chunk_id}: {e}")
raise
finally:
cur.close()
@ -530,7 +530,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
async def unregister_vector_db(self, vector_db_id: str) -> None:
if vector_db_id not in self.cache:
logger.warning(f"Vector DB {vector_db_id} not found")
log.warning(f"Vector DB {vector_db_id} not found")
return
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]

View file

@ -65,7 +65,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import FireworksImplConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
log = get_logger(name=__name__, category="inference")
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
@ -256,7 +256,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
"stream": bool(request.stream),
**self._build_options(request.sampling_params, request.response_format, request.logprobs),
}
logger.debug(f"params to fireworks: {params}")
log.debug(f"params to fireworks: {params}")
return params

View file

@ -3,7 +3,6 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
@ -11,8 +10,6 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .models import MODEL_ENTRIES
logger = logging.getLogger(__name__)
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
"""

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import warnings
from collections.abc import AsyncIterator
@ -33,6 +32,7 @@ from llama_stack.apis.inference import (
ToolChoice,
ToolConfig,
)
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
@ -54,7 +54,7 @@ from .openai_utils import (
)
from .utils import _is_nvidia_hosted
logger = logging.getLogger(__name__)
log = get_logger(name=__name__, category="inference")
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
@ -75,7 +75,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
# TODO(mf): filter by available models
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
logger.info(f"Initializing NVIDIAInferenceAdapter({config.url})...")
log.info(f"Initializing NVIDIAInferenceAdapter({config.url})...")
if _is_nvidia_hosted(config):
if not config.api_key:

View file

@ -4,13 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import httpx
from llama_stack.log import get_logger
from . import NVIDIAConfig
logger = logging.getLogger(__name__)
log = get_logger(name=__name__, category="inference")
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
@ -44,7 +45,7 @@ async def check_health(config: NVIDIAConfig) -> None:
RuntimeError: If the server is not running or ready
"""
if not _is_nvidia_hosted(config):
logger.info("Checking NVIDIA NIM health...")
log.info("Checking NVIDIA NIM health...")
try:
is_live, is_ready = await _get_health(config.url)
if not is_live:

View file

@ -85,7 +85,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
log = get_logger(name=__name__, category="inference")
class OllamaInferenceAdapter(
@ -117,10 +117,10 @@ class OllamaInferenceAdapter(
return self._openai_client
async def initialize(self) -> None:
logger.info(f"checking connectivity to Ollama at `{self.config.url}`...")
log.info(f"checking connectivity to Ollama at `{self.config.url}`...")
health_response = await self.health()
if health_response["status"] == HealthStatus.ERROR:
logger.warning(
log.warning(
"Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal"
)
@ -339,7 +339,7 @@ class OllamaInferenceAdapter(
"options": sampling_options,
"stream": request.stream,
}
logger.debug(f"params to ollama: {params}")
log.debug(f"params to ollama: {params}")
return params
@ -437,7 +437,7 @@ class OllamaInferenceAdapter(
if provider_resource_id not in available_models:
available_models_latest = [m.model.split(":latest")[0] for m in response.models]
if provider_resource_id in available_models_latest:
logger.warning(
log.warning(
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
)
return model

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -12,8 +11,6 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import OpenAIConfig
from .models import MODEL_ENTRIES
logger = logging.getLogger(__name__)
#
# This OpenAI adapter implements Inference methods using two mixins -

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import logging
from collections.abc import AsyncGenerator
from huggingface_hub import AsyncInferenceClient, HfApi
@ -34,6 +33,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.apis.models import Model
from llama_stack.log import get_logger
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import (
@ -58,7 +58,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
log = logging.getLogger(__name__)
logger = get_logger(__name__, category="core")
def build_hf_repo_model_entries():
@ -307,7 +307,7 @@ class TGIAdapter(_HfAdapter):
async def initialize(self, config: TGIImplConfig) -> None:
if not config.url:
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
log.info(f"Initializing TGI client with url={config.url}")
logger.info(f"Initializing TGI client with url={config.url}")
self.client = AsyncInferenceClient(
model=config.url,
)

View file

@ -61,7 +61,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import TogetherImplConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
log = get_logger(name=__name__, category="inference")
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
@ -232,7 +232,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
"stream": request.stream,
**self._build_options(request.sampling_params, request.logprobs, request.response_format),
}
logger.debug(f"params to together: {params}")
log.debug(f"params to together: {params}")
return params
async def embeddings(

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import warnings
from typing import Any
@ -15,8 +14,6 @@ from llama_stack.providers.remote.post_training.nvidia.config import SFTLoRADefa
from .config import NvidiaPostTrainingConfig
logger = logging.getLogger(__name__)
def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None:
keys = set(config_dict.__annotations__.keys()) if isinstance(config_dict, BaseModel) else config_dict.keys()

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import json
import logging
from typing import Any
from llama_stack.apis.inference import Message
@ -16,12 +15,13 @@ from llama_stack.apis.safety import (
ViolationLevel,
)
from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from .config import BedrockSafetyConfig
logger = logging.getLogger(__name__)
log = get_logger(name=__name__, category="safety")
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
@ -76,13 +76,13 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
"""
shield_params = shield.params
logger.debug(f"run_shield::{shield_params}::messages={messages}")
log.debug(f"run_shield::{shield_params}::messages={messages}")
# - convert the messages into format Bedrock expects
content_messages = []
for message in messages:
content_messages.append({"text": {"text": message.content}})
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
log.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
response = self.bedrock_runtime_client.apply_guardrail(
guardrailIdentifier=shield.provider_resource_id,

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import Any
import requests
@ -17,8 +16,6 @@ from llama_stack.providers.utils.inference.openai_compat import convert_message_
from .config import NVIDIASafetyConfig
logger = logging.getLogger(__name__)
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
def __init__(self, config: NVIDIASafetyConfig) -> None:

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import json
import logging
from typing import Any
import litellm
@ -20,12 +19,13 @@ from llama_stack.apis.safety import (
)
from llama_stack.apis.shields import Shield
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
from .config import SambaNovaSafetyConfig
logger = logging.getLogger(__name__)
log = get_logger(name=__name__, category="safety")
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
@ -66,7 +66,7 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
"guard" not in shield.provider_resource_id.lower()
or shield.provider_resource_id.split("sambanova/")[-1] not in self.environment_available_models
):
logger.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}")
log.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}")
async def unregister_shield(self, identifier: str) -> None:
pass
@ -79,9 +79,9 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
raise ValueError(f"Shield {shield_id} not found")
shield_params = shield.params
logger.debug(f"run_shield::{shield_params}::messages={messages}")
log.debug(f"run_shield::{shield_params}::messages={messages}")
content_messages = [await convert_message_to_openai_dict_new(m) for m in messages]
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
log.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
response = litellm.completion(
model=shield.provider_resource_id, messages=content_messages, api_key=self._get_api_key()

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import asyncio
import json
import logging
from typing import Any
from urllib.parse import urlparse
@ -20,6 +19,7 @@ from llama_stack.apis.vector_io import (
QueryChunksResponse,
VectorIO,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
from llama_stack.providers.utils.kvstore import kvstore_impl
@ -32,8 +32,6 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
log = logging.getLogger(__name__)
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI
VERSION = "v3"
@ -43,6 +41,8 @@ OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:chroma:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:chroma:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:chroma:{VERSION}::"
logger = get_logger(__name__, category="core")
# this is a helper to allow us to use async and non-async chroma clients interchangeably
async def maybe_await(result):
@ -92,7 +92,7 @@ class ChromaIndex(EmbeddingIndex):
doc = json.loads(doc)
chunk = Chunk(**doc)
except Exception:
log.exception(f"Failed to parse document: {doc}")
logger.exception(f"Failed to parse document: {doc}")
continue
score = 1.0 / float(dist) if dist != 0 else float("inf")
@ -137,7 +137,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
inference_api: Api.inference,
files_api: Files | None,
) -> None:
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
logger.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
self.config = config
self.inference_api = inference_api
self.client = None
@ -150,7 +150,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.vector_db_store = self.kvstore
if isinstance(self.config, RemoteChromaVectorIOConfig):
log.info(f"Connecting to Chroma server at: {self.config.url}")
logger.info(f"Connecting to Chroma server at: {self.config.url}")
url = self.config.url.rstrip("/")
parsed = urlparse(url)
@ -159,7 +159,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.client = await chromadb.AsyncHttpClient(host=parsed.hostname, port=parsed.port)
else:
log.info(f"Connecting to Chroma local db at: {self.config.db_path}")
logger.info(f"Connecting to Chroma local db at: {self.config.db_path}")
self.client = chromadb.PersistentClient(path=self.config.db_path)
self.openai_vector_stores = await self._load_openai_vector_stores()
@ -182,7 +182,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
async def unregister_vector_db(self, vector_db_id: str) -> None:
if vector_db_id not in self.cache:
log.warning(f"Vector DB {vector_db_id} not found")
logger.warning(f"Vector DB {vector_db_id} not found")
return
await self.cache[vector_db_id].index.delete()

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import asyncio
import logging
import os
from typing import Any
@ -21,6 +20,7 @@ from llama_stack.apis.vector_io import (
QueryChunksResponse,
VectorIO,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
from llama_stack.providers.utils.kvstore import kvstore_impl
@ -34,7 +34,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
logger = logging.getLogger(__name__)
log = get_logger(name=__name__, category="core")
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::"
@ -68,7 +68,7 @@ class MilvusIndex(EmbeddingIndex):
)
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
log.info(f"Creating new collection {self.collection_name} with nullable sparse field")
# Create schema for vector search
schema = self.client.create_schema()
schema.add_field(
@ -147,7 +147,7 @@ class MilvusIndex(EmbeddingIndex):
data=data,
)
except Exception as e:
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
log.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
raise e
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
@ -203,7 +203,7 @@ class MilvusIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores)
except Exception as e:
logger.error(f"Error performing BM25 search: {e}")
log.error(f"Error performing BM25 search: {e}")
# Fallback to simple text search
return await self._fallback_keyword_search(query_string, k, score_threshold)
@ -247,7 +247,7 @@ class MilvusIndex(EmbeddingIndex):
self.client.delete, collection_name=self.collection_name, filter=f'chunk_id == "{chunk_id}"'
)
except Exception as e:
logger.error(f"Error deleting chunk {chunk_id} from Milvus collection {self.collection_name}: {e}")
log.error(f"Error deleting chunk {chunk_id} from Milvus collection {self.collection_name}: {e}")
raise
@ -288,10 +288,10 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
)
self.cache[vector_db.identifier] = index
if isinstance(self.config, RemoteMilvusVectorIOConfig):
logger.info(f"Connecting to Milvus server at {self.config.uri}")
log.info(f"Connecting to Milvus server at {self.config.uri}")
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
else:
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
log.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
uri = os.path.expanduser(self.config.db_path)
self.client = MilvusClient(uri=uri)

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import Any
import psycopg2
@ -22,6 +21,7 @@ from llama_stack.apis.vector_io import (
QueryChunksResponse,
VectorIO,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
@ -33,8 +33,6 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import PGVectorVectorIOConfig
log = logging.getLogger(__name__)
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::"
VECTOR_INDEX_PREFIX = f"vector_index:pgvector:{VERSION}::"
@ -42,6 +40,8 @@ OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:pgvector:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:pgvector:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:pgvector:{VERSION}::"
logger = get_logger(__name__, category="core")
def check_extension_version(cur):
cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'")
@ -187,7 +187,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
self.metadatadata_collection_name = "openai_vector_stores_metadata"
async def initialize(self) -> None:
log.info(f"Initializing PGVector memory adapter with config: {self.config}")
logger.info(f"Initializing PGVector memory adapter with config: {self.config}")
self.kvstore = await kvstore_impl(self.config.kvstore)
await self.initialize_openai_vector_stores()
@ -203,7 +203,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
version = check_extension_version(cur)
if version:
log.info(f"Vector extension version: {version}")
logger.info(f"Vector extension version: {version}")
else:
raise RuntimeError("Vector extension is not installed.")
@ -216,13 +216,13 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
"""
)
except Exception as e:
log.exception("Could not connect to PGVector database server")
logger.exception("Could not connect to PGVector database server")
raise RuntimeError("Could not connect to PGVector database server") from e
async def shutdown(self) -> None:
if self.conn is not None:
self.conn.close()
log.info("Connection to PGVector database server closed")
logger.info("Connection to PGVector database server closed")
async def register_vector_db(self, vector_db: VectorDB) -> None:
# Persist vector DB metadata in the KV store

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import asyncio
import logging
import uuid
from typing import Any
@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import (
VectorStoreChunkingStrategy,
VectorStoreFileObject,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
@ -35,13 +35,14 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
log = logging.getLogger(__name__)
CHUNK_ID_KEY = "_chunk_id"
# KV store prefixes for vector databases
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:qdrant:{VERSION}::"
logger = get_logger(__name__, category="core")
def convert_id(_id: str) -> str:
"""
@ -96,7 +97,7 @@ class QdrantIndex(EmbeddingIndex):
points_selector=models.PointIdsList(points=[convert_id(chunk_id)]),
)
except Exception as e:
log.error(f"Error deleting chunk {chunk_id} from Qdrant collection {self.collection_name}: {e}")
logger.error(f"Error deleting chunk {chunk_id} from Qdrant collection {self.collection_name}: {e}")
raise
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
@ -118,7 +119,7 @@ class QdrantIndex(EmbeddingIndex):
try:
chunk = Chunk(**point.payload["chunk_content"])
except Exception:
log.exception("Failed to parse chunk")
logger.exception("Failed to parse chunk")
continue
chunks.append(chunk)

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import logging
from typing import Any
import weaviate
@ -19,6 +18,7 @@ from llama_stack.apis.files.files import Files
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
@ -33,8 +33,6 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
from .config import WeaviateVectorIOConfig
log = logging.getLogger(__name__)
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"
VECTOR_INDEX_PREFIX = f"vector_index:weaviate:{VERSION}::"
@ -42,6 +40,8 @@ OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:weaviate:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:weaviate:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:weaviate:{VERSION}::"
logger = get_logger(__name__, category="core")
class WeaviateIndex(EmbeddingIndex):
def __init__(
@ -102,7 +102,7 @@ class WeaviateIndex(EmbeddingIndex):
chunk_dict = json.loads(chunk_json)
chunk = Chunk(**chunk_dict)
except Exception:
log.exception(f"Failed to parse document: {chunk_json}")
logger.exception(f"Failed to parse document: {chunk_json}")
continue
score = 1.0 / doc.metadata.distance if doc.metadata.distance != 0 else float("inf")
@ -171,7 +171,7 @@ class WeaviateVectorIOAdapter(
def _get_client(self) -> weaviate.Client:
if "localhost" in self.config.weaviate_cluster_url:
log.info("using Weaviate locally in container")
logger.info("using Weaviate locally in container")
host, port = self.config.weaviate_cluster_url.split(":")
key = "local_test"
client = weaviate.connect_to_local(
@ -179,7 +179,7 @@ class WeaviateVectorIOAdapter(
port=port,
)
else:
log.info("Using Weaviate remote cluster with URL")
logger.info("Using Weaviate remote cluster with URL")
key = f"{self.config.weaviate_cluster_url}::{self.config.weaviate_api_key}"
if key in self.client_cache:
return self.client_cache[key]
@ -197,7 +197,7 @@ class WeaviateVectorIOAdapter(
self.kvstore = await kvstore_impl(self.config.kvstore)
else:
self.kvstore = None
log.info("No kvstore configured, registry will not persist across restarts")
logger.info("No kvstore configured, registry will not persist across restarts")
# Load existing vector DB definitions
if self.kvstore is not None:
@ -254,7 +254,7 @@ class WeaviateVectorIOAdapter(
client = self._get_client()
sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True)
if sanitized_collection_name not in self.cache or client.collections.exists(sanitized_collection_name) is False:
log.warning(f"Vector DB {sanitized_collection_name} not found")
logger.warning(f"Vector DB {sanitized_collection_name} not found")
return
client.collections.delete(sanitized_collection_name)
await self.cache[sanitized_collection_name].index.delete()

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import base64
import logging
import struct
from typing import TYPE_CHECKING
@ -27,7 +26,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import interleaved_con
EMBEDDING_MODELS = {}
log = logging.getLogger(__name__)
from llama_stack.log import get_logger
log = get_logger(name=__name__, category="inference")
class SentenceTransformerEmbeddingMixin:

View file

@ -54,7 +54,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
logger = get_logger(name=__name__, category="inference")
log = get_logger(name=__name__, category="inference")
class LiteLLMOpenAIMixin(
@ -157,7 +157,7 @@ class LiteLLMOpenAIMixin(
params = await self._get_params(request)
params["model"] = self.get_litellm_model_name(params["model"])
logger.debug(f"params to litellm (openai compat): {params}")
log.debug(f"params to litellm (openai compat): {params}")
# see https://docs.litellm.ai/docs/completion/stream#async-completion
response = await litellm.acompletion(**params)
if stream:
@ -460,7 +460,7 @@ class LiteLLMOpenAIMixin(
:return: True if the model is available dynamically, False otherwise.
"""
if self.litellm_provider_name not in litellm.models_by_provider:
logger.error(f"Provider {self.litellm_provider_name} is not registered in litellm.")
log.error(f"Provider {self.litellm_provider_name} is not registered in litellm.")
return False
return model in litellm.models_by_provider[self.litellm_provider_name]

View file

@ -17,7 +17,7 @@ from llama_stack.providers.utils.inference import (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
)
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
class RemoteInferenceProviderConfig(BaseModel):
@ -135,7 +135,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
"""
logger.info(
log.info(
f"check_model_availability is not implemented for {self.__class__.__name__}. Returning False by default."
)
return False

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import base64
import json
import logging
import struct
import time
import uuid
@ -116,6 +115,7 @@ from llama_stack.apis.inference import (
from llama_stack.apis.inference import (
OpenAIChoice as OpenAIChatCompletionChoice,
)
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import (
BuiltinTool,
StopReason,
@ -128,7 +128,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
decode_assistant_message,
)
logger = logging.getLogger(__name__)
log = get_logger(name=__name__, category="inference")
class OpenAICompatCompletionChoiceDelta(BaseModel):
@ -316,7 +316,7 @@ def process_chat_completion_response(
if t.tool_name in request_tools:
new_tool_calls.append(t)
else:
logger.warning(f"Tool {t.tool_name} not found in request tools")
log.warning(f"Tool {t.tool_name} not found in request tools")
if len(new_tool_calls) < len(raw_message.tool_calls):
raw_message.tool_calls = new_tool_calls
@ -477,7 +477,7 @@ async def process_chat_completion_stream_response(
)
)
else:
logger.warning(f"Tool {tool_call.tool_name} not found in request tools")
log.warning(f"Tool {tool_call.tool_name} not found in request tools")
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
@ -1198,7 +1198,7 @@ async def convert_openai_chat_completion_stream(
)
for idx, buffer in tool_call_idx_to_buffer.items():
logger.debug(f"toolcall_buffer[{idx}]: {buffer}")
log.debug(f"toolcall_buffer[{idx}]: {buffer}")
if buffer["name"]:
delta = ")"
buffer["content"] += delta

View file

@ -25,7 +25,7 @@ from llama_stack.apis.inference import (
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
class OpenAIMixin(ABC):
@ -125,9 +125,9 @@ class OpenAIMixin(ABC):
Direct OpenAI completion API call.
"""
if guided_choice is not None:
logger.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
log.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
if prompt_logprobs is not None:
logger.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")
log.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")
# TODO: fix openai_completion to return type compatible with OpenAI's API response
return await self.client.completions.create( # type: ignore[no-any-return]
@ -267,6 +267,6 @@ class OpenAIMixin(ABC):
pass
except Exception as e:
# All other errors (auth, rate limit, network, etc.)
logger.warning(f"Failed to check model availability for {model}: {e}")
log.warning(f"Failed to check model availability for {model}: {e}")
return False

View file

@ -4,16 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from datetime import datetime
from pymongo import AsyncMongoClient
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore
from ..config import MongoDBKVStoreConfig
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="core")
class MongoDBKVStoreImpl(KVStore):

View file

@ -4,16 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from datetime import datetime
import psycopg2
from psycopg2.extras import DictCursor
from llama_stack.log import get_logger
from ..api import KVStore
from ..config import PostgresKVStoreConfig
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="core")
class PostgresKVStoreImpl(KVStore):

View file

@ -6,7 +6,6 @@
import asyncio
import json
import logging
import mimetypes
import time
import uuid
@ -37,10 +36,11 @@ from llama_stack.apis.vector_io import (
VectorStoreSearchResponse,
VectorStoreSearchResponsePage,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, make_overlapped_chunks
logger = logging.getLogger(__name__)
log = get_logger(name=__name__, category="core")
# Constants for OpenAI vector stores
CHUNK_MULTIPLIER = 5
@ -378,7 +378,7 @@ class OpenAIVectorStoreMixin(ABC):
try:
await self.unregister_vector_db(vector_store_id)
except Exception as e:
logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}")
log.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}")
return VectorStoreDeleteResponse(
id=vector_store_id,
@ -460,7 +460,7 @@ class OpenAIVectorStoreMixin(ABC):
)
except Exception as e:
logger.error(f"Error searching vector store {vector_store_id}: {e}")
log.error(f"Error searching vector store {vector_store_id}: {e}")
# Return empty results on error
return VectorStoreSearchResponsePage(
search_query=search_query,
@ -614,7 +614,7 @@ class OpenAIVectorStoreMixin(ABC):
)
vector_store_file_object.status = "completed"
except Exception as e:
logger.error(f"Error attaching file to vector store: {e}")
log.error(f"Error attaching file to vector store: {e}")
vector_store_file_object.status = "failed"
vector_store_file_object.last_error = VectorStoreFileLastError(
code="server_error",

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import base64
import io
import logging
import re
import time
from abc import ABC, abstractmethod
@ -25,6 +24,7 @@ from llama_stack.apis.common.content_types import (
from llama_stack.apis.tools import RAGDocument
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import Api
from llama_stack.providers.utils.inference.prompt_adapter import (
@ -32,12 +32,12 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
)
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
log = logging.getLogger(__name__)
# Constants for reranker types
RERANKER_TYPE_RRF = "rrf"
RERANKER_TYPE_WEIGHTED = "weighted"
log = get_logger(name=__name__, category="memory")
def parse_pdf(data: bytes) -> str:
# For PDF and DOC/DOCX files, we can't reliably convert to string

View file

@ -17,7 +17,7 @@ from pydantic import BaseModel
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="scheduler")
log = get_logger(name=__name__, category="scheduler")
# TODO: revisit the list of possible statuses when defining a more coherent
@ -186,7 +186,7 @@ class _NaiveSchedulerBackend(_SchedulerBackend):
except Exception as e:
on_log_message_cb(str(e))
job.status = JobStatus.failed
logger.exception(f"Job {job.id} failed.")
log.exception(f"Job {job.id} failed.")
asyncio.run_coroutine_threadsafe(do(), self._loop)
@ -222,7 +222,7 @@ class Scheduler:
msg = (datetime.now(UTC), message)
# At least for the time being, until there's a better way to expose
# logs to users, log messages on console
logger.info(f"Job {job.id}: {message}")
log.info(f"Job {job.id}: {message}")
job.append_log(msg)
self._backend.on_log_message_cb(job, msg)

View file

@ -17,7 +17,7 @@ from llama_stack.log import get_logger
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
from .sqlstore import SqlStoreType
logger = get_logger(name=__name__, category="authorized_sqlstore")
log = get_logger(name=__name__, category="authorized_sqlstore")
# Hardcoded copy of the default policy that our SQL filtering implements
# WARNING: If default_policy() changes, this constant must be updated accordingly
@ -81,7 +81,7 @@ class AuthorizedSqlStore:
actual_default = default_policy()
if SQL_OPTIMIZED_POLICY != actual_default:
logger.warning(
log.warning(
f"SQL_OPTIMIZED_POLICY does not match default_policy(). "
f"SQL filtering will use conservative mode. "
f"Expected: {SQL_OPTIMIZED_POLICY}, Got: {actual_default}",

View file

@ -29,7 +29,7 @@ from llama_stack.log import get_logger
from .api import ColumnDefinition, ColumnType, SqlStore
from .sqlstore import SqlAlchemySqlStoreConfig
logger = get_logger(name=__name__, category="sqlstore")
log = get_logger(name=__name__, category="sqlstore")
TYPE_MAPPING: dict[ColumnType, Any] = {
ColumnType.INTEGER: Integer,
@ -280,5 +280,5 @@ class SqlAlchemySqlStoreImpl(SqlStore):
except Exception as e:
# If any error occurs during migration, log it but don't fail
# The table creation will handle adding the column
logger.error(f"Error adding column {column_name} to table {table}: {e}")
log.error(f"Error adding column {column_name} to table {table}: {e}")
pass

View file

@ -6,7 +6,7 @@
import asyncio
import contextvars
import logging
import logging # allow-direct-logging
import queue
import random
import threading

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import sys
import time
import uuid
@ -19,10 +18,9 @@ from llama_stack.apis.post_training import (
LoraFinetuningConfig,
TrainingConfig,
)
from llama_stack.log import get_logger
# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", force=True)
logger = logging.getLogger(__name__)
log = get_logger(name=__name__, category="core")
skip_because_resource_intensive = pytest.mark.skip(
@ -71,14 +69,14 @@ class TestPostTraining:
)
@pytest.mark.timeout(360) # 6 minutes timeout
def test_supervised_fine_tune(self, llama_stack_client, purpose, source):
logger.info("Starting supervised fine-tuning test")
log.info("Starting supervised fine-tuning test")
# register dataset to train
dataset = llama_stack_client.datasets.register(
purpose=purpose,
source=source,
)
logger.info(f"Registered dataset with ID: {dataset.identifier}")
log.info(f"Registered dataset with ID: {dataset.identifier}")
algorithm_config = LoraFinetuningConfig(
type="LoRA",
@ -105,7 +103,7 @@ class TestPostTraining:
)
job_uuid = f"test-job{uuid.uuid4()}"
logger.info(f"Starting training job with UUID: {job_uuid}")
log.info(f"Starting training job with UUID: {job_uuid}")
# train with HF trl SFTTrainer as the default
_ = llama_stack_client.post_training.supervised_fine_tune(
@ -121,21 +119,21 @@ class TestPostTraining:
while True:
status = llama_stack_client.post_training.job.status(job_uuid=job_uuid)
if not status:
logger.error("Job not found")
log.error("Job not found")
break
logger.info(f"Current status: {status}")
log.info(f"Current status: {status}")
assert status.status in ["scheduled", "in_progress", "completed"]
if status.status == "completed":
break
logger.info("Waiting for job to complete...")
log.info("Waiting for job to complete...")
time.sleep(10) # Increased sleep time to reduce polling frequency
artifacts = llama_stack_client.post_training.job.artifacts(job_uuid=job_uuid)
logger.info(f"Job artifacts: {artifacts}")
log.info(f"Job artifacts: {artifacts}")
logger.info(f"Registered dataset with ID: {dataset.identifier}")
log.info(f"Registered dataset with ID: {dataset.identifier}")
# TODO: Fix these tests to properly represent the Jobs API in training
#
@ -181,17 +179,21 @@ class TestPostTraining:
)
@pytest.mark.timeout(360)
def test_preference_optimize(self, llama_stack_client, purpose, source):
logger.info("Starting DPO preference optimization test")
log.info("Starting DPO preference optimization test")
# register preference dataset to train
dataset = llama_stack_client.datasets.register(
purpose=purpose,
source=source,
)
logger.info(f"Registered preference dataset with ID: {dataset.identifier}")
log.info(f"Registered preference dataset with ID: {dataset.identifier}")
# DPO algorithm configuration
algorithm_config = DPOAlignmentConfig(
reward_scale=1.0,
reward_clip=10.0,
epsilon=1e-8,
gamma=0.99,
beta=0.1,
loss_type=DPOLossType.sigmoid, # Default loss type
)
@ -211,7 +213,7 @@ class TestPostTraining:
)
job_uuid = f"test-dpo-job-{uuid.uuid4()}"
logger.info(f"Starting DPO training job with UUID: {job_uuid}")
log.info(f"Starting DPO training job with UUID: {job_uuid}")
# train with HuggingFace DPO implementation
_ = llama_stack_client.post_training.preference_optimize(
@ -226,15 +228,15 @@ class TestPostTraining:
while True:
status = llama_stack_client.post_training.job.status(job_uuid=job_uuid)
if not status:
logger.error("DPO job not found")
log.error("DPO job not found")
break
logger.info(f"Current DPO status: {status}")
log.info(f"Current DPO status: {status}")
if status.status == "completed":
break
logger.info("Waiting for DPO job to complete...")
log.info("Waiting for DPO job to complete...")
time.sleep(10) # Increased sleep time to reduce polling frequency
artifacts = llama_stack_client.post_training.job.artifacts(job_uuid=job_uuid)
logger.info(f"DPO job artifacts: {artifacts}")
log.info(f"DPO job artifacts: {artifacts}")

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import time
from io import BytesIO
@ -13,8 +12,10 @@ from llama_stack_client import BadRequestError, LlamaStackClient
from openai import BadRequestError as OpenAIBadRequestError
from llama_stack.apis.vector_io import Chunk
from llama_stack.core.library_client import LlamaStackAsLibraryClient
from llama_stack.log import get_logger
logger = logging.getLogger(__name__)
log = get_logger(name=__name__, category="vector-io")
def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
@ -99,7 +100,7 @@ def compat_client_with_empty_stores(compat_client):
compat_client.vector_stores.delete(vector_store_id=store.id)
except Exception:
# If the API is not available or fails, just continue
logger.warning("Failed to clear vector stores")
log.warning("Failed to clear vector stores")
pass
def clear_files():
@ -109,7 +110,7 @@ def compat_client_with_empty_stores(compat_client):
compat_client.files.delete(file_id=file.id)
except Exception:
# If the API is not available or fails, just continue
logger.warning("Failed to clear files")
log.warning("Failed to clear files")
pass
clear_vector_stores()

View file

@ -6,7 +6,7 @@
import asyncio
import json
import logging
import logging # allow-direct-logging
import threading
import time
from http.server import BaseHTTPRequestHandler, HTTPServer