diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 30843173c..0bad75523 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -145,6 +145,24 @@ repos: pass_filenames: false require_serial: true files: ^.github/workflows/.*$ + - id: check-log-usage + name: Check for proper log usage (use llama_stack.log instead) + entry: bash + language: system + types: [python] + pass_filenames: true + args: + - -c + - | + matches=$(grep -EnH '^[^#]*\b(import logging|from logging\b)' "$@" | grep -v '# allow-direct-logging' || true) + if [ -n "$matches" ]; then + # GitHub Actions annotation format + while IFS=: read -r file line_num rest; do + echo "::error file=$file,line=$line_num::Do not use 'import logging' or 'from logging import' in $file. Use the custom log instead: from llama_stack.log import get_logger; logger = get_logger(). If direct logging is truly needed, add: # allow-direct-logging" + done <<< "$matches" + exit 1 + fi + exit 0 ci: autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks diff --git a/llama_stack/cli/stack/run.py b/llama_stack/cli/stack/run.py index c8ffce034..192a85609 100644 --- a/llama_stack/cli/stack/run.py +++ b/llama_stack/cli/stack/run.py @@ -15,7 +15,7 @@ from llama_stack.log import get_logger REPO_ROOT = Path(__file__).parent.parent.parent.parent -logger = get_logger(name=__name__, category="server") +log = get_logger(name=__name__, category="server") class StackRun(Subcommand): @@ -126,7 +126,7 @@ class StackRun(Subcommand): self.parser.error("Config file is required for venv environment") if config_file: - logger.info(f"Using run configuration: {config_file}") + log.info(f"Using run configuration: {config_file}") try: config_dict = yaml.safe_load(config_file.read_text()) @@ -145,7 +145,7 @@ class StackRun(Subcommand): # If neither image type nor image name is provided, assume the server should be run directly # using the current environment packages. if not image_type and not image_name: - logger.info("No image type or image name provided. Assuming environment packages.") + log.info("No image type or image name provided. Assuming environment packages.") from llama_stack.core.server.server import main as server_main # Build the server args from the current args passed to the CLI @@ -185,11 +185,11 @@ class StackRun(Subcommand): run_command(run_args) def _start_ui_development_server(self, stack_server_port: int): - logger.info("Attempting to start UI development server...") + log.info("Attempting to start UI development server...") # Check if npm is available npm_check = subprocess.run(["npm", "--version"], capture_output=True, text=True, check=False) if npm_check.returncode != 0: - logger.warning( + log.warning( f"'npm' command not found or not executable. UI development server will not be started. Error: {npm_check.stderr}" ) return @@ -214,13 +214,13 @@ class StackRun(Subcommand): stderr=stderr_log_file, env={**os.environ, "NEXT_PUBLIC_LLAMA_STACK_BASE_URL": f"http://localhost:{stack_server_port}"}, ) - logger.info(f"UI development server process started in {ui_dir} with PID {process.pid}.") - logger.info(f"Logs: stdout -> {ui_stdout_log_path}, stderr -> {ui_stderr_log_path}") - logger.info(f"UI will be available at http://localhost:{os.getenv('LLAMA_STACK_UI_PORT', 8322)}") + log.info(f"UI development server process started in {ui_dir} with PID {process.pid}.") + log.info(f"Logs: stdout -> {ui_stdout_log_path}, stderr -> {ui_stderr_log_path}") + log.info(f"UI will be available at http://localhost:{os.getenv('LLAMA_STACK_UI_PORT', 8322)}") except FileNotFoundError: - logger.error( + log.error( "Failed to start UI development server: 'npm' command not found. Make sure npm is installed and in your PATH." ) except Exception as e: - logger.error(f"Failed to start UI development server in {ui_dir}: {e}") + log.error(f"Failed to start UI development server in {ui_dir}: {e}") diff --git a/llama_stack/cli/utils.py b/llama_stack/cli/utils.py index c9c51d933..f216a800a 100644 --- a/llama_stack/cli/utils.py +++ b/llama_stack/cli/utils.py @@ -8,7 +8,7 @@ import argparse from llama_stack.log import get_logger -logger = get_logger(name=__name__, category="cli") +log = get_logger(name=__name__, category="cli") # TODO: this can probably just be inlined now? diff --git a/llama_stack/core/build.py b/llama_stack/core/build.py index b3e35ecef..659bc7a7b 100644 --- a/llama_stack/core/build.py +++ b/llama_stack/core/build.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import importlib.resources -import logging import sys from pydantic import BaseModel @@ -17,10 +16,9 @@ from llama_stack.core.external import load_external_apis from llama_stack.core.utils.exec import run_command from llama_stack.core.utils.image_types import LlamaStackImageType from llama_stack.distributions.template import DistributionTemplate +from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api -log = logging.getLogger(__name__) - # These are the dependencies needed by the distribution server. # `llama-stack` is automatically installed by the installation script. SERVER_DEPENDENCIES = [ @@ -33,6 +31,8 @@ SERVER_DEPENDENCIES = [ "opentelemetry-exporter-otlp-proto-http", ] +log = get_logger(name=__name__, category="core") + class ApiInput(BaseModel): api: Api diff --git a/llama_stack/core/configure.py b/llama_stack/core/configure.py index 9e18b438c..0cf506314 100644 --- a/llama_stack/core/configure.py +++ b/llama_stack/core/configure.py @@ -3,7 +3,6 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import textwrap from typing import Any @@ -21,9 +20,10 @@ from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR from llama_stack.core.utils.dynamic import instantiate_class_type from llama_stack.core.utils.prompt_for_config import prompt_for_config +from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api, ProviderSpec -logger = logging.getLogger(__name__) +log = get_logger(name=__name__, category="core") def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provider) -> Provider: @@ -49,7 +49,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec is_nux = len(config.providers) == 0 if is_nux: - logger.info( + log.info( textwrap.dedent( """ Llama Stack is composed of several APIs working together. For each API served by the Stack, @@ -75,12 +75,12 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec existing_providers = config.providers.get(api_str, []) if existing_providers: - logger.info(f"Re-configuring existing providers for API `{api_str}`...") + log.info(f"Re-configuring existing providers for API `{api_str}`...") updated_providers = [] for p in existing_providers: - logger.info(f"> Configuring provider `({p.provider_type})`") + log.info(f"> Configuring provider `({p.provider_type})`") updated_providers.append(configure_single_provider(provider_registry[api], p)) - logger.info("") + log.info("") else: # we are newly configuring this API plist = build_spec.providers.get(api_str, []) @@ -89,17 +89,17 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec if not plist: raise ValueError(f"No provider configured for API {api_str}?") - logger.info(f"Configuring API `{api_str}`...") + log.info(f"Configuring API `{api_str}`...") updated_providers = [] for i, provider in enumerate(plist): if i >= 1: others = ", ".join(p.provider_type for p in plist[i:]) - logger.info( + log.info( f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n" ) break - logger.info(f"> Configuring provider `({provider.provider_type})`") + log.info(f"> Configuring provider `({provider.provider_type})`") pid = provider.provider_type.split("::")[-1] updated_providers.append( configure_single_provider( @@ -111,7 +111,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec ), ) ) - logger.info("") + log.info("") config.providers[api_str] = updated_providers @@ -169,7 +169,7 @@ def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfi return StackRunConfig(**cast_image_name_to_string(processed_config_dict)) if "routing_table" in config_dict: - logger.info("Upgrading config...") + log.info("Upgrading config...") config_dict = upgrade_from_routing_table(config_dict) config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION diff --git a/llama_stack/core/distribution.py b/llama_stack/core/distribution.py index 977eb5393..8daf677cd 100644 --- a/llama_stack/core/distribution.py +++ b/llama_stack/core/distribution.py @@ -23,7 +23,7 @@ from llama_stack.providers.datatypes import ( remote_provider_spec, ) -logger = get_logger(name=__name__, category="core") +log = get_logger(name=__name__, category="core") def stack_apis() -> list[Api]: @@ -141,18 +141,18 @@ def get_provider_registry(config=None) -> dict[Api, dict[str, ProviderSpec]]: registry: dict[Api, dict[str, ProviderSpec]] = {} for api in providable_apis(): name = api.name.lower() - logger.debug(f"Importing module {name}") + log.debug(f"Importing module {name}") try: module = importlib.import_module(f"llama_stack.providers.registry.{name}") registry[api] = {a.provider_type: a for a in module.available_providers()} except ImportError as e: - logger.warning(f"Failed to import module {name}: {e}") + log.warning(f"Failed to import module {name}: {e}") # Refresh providable APIs with external APIs if any external_apis = load_external_apis(config) for api, api_spec in external_apis.items(): name = api_spec.name.lower() - logger.info(f"Importing external API {name} module {api_spec.module}") + log.info(f"Importing external API {name} module {api_spec.module}") try: module = importlib.import_module(api_spec.module) registry[api] = {a.provider_type: a for a in module.available_providers()} @@ -161,7 +161,7 @@ def get_provider_registry(config=None) -> dict[Api, dict[str, ProviderSpec]]: # This assume that the in-tree provider(s) are not available for this API which means # that users will need to use external providers for this API. registry[api] = {} - logger.error( + log.error( f"Failed to import external API {name}: {e}. Could not populate the in-tree provider(s) registry for {api.name}. \n" "Install the API package to load any in-tree providers for this API." ) @@ -183,13 +183,13 @@ def get_provider_registry(config=None) -> dict[Api, dict[str, ProviderSpec]]: def get_external_providers_from_dir( registry: dict[Api, dict[str, ProviderSpec]], config ) -> dict[Api, dict[str, ProviderSpec]]: - logger.warning( + log.warning( "Specifying external providers via `external_providers_dir` is being deprecated. Please specify `module:` in the provider instead." ) external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir)) if not os.path.exists(external_providers_dir): raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}") - logger.info(f"Loading external providers from {external_providers_dir}") + log.info(f"Loading external providers from {external_providers_dir}") for api in providable_apis(): api_name = api.name.lower() @@ -198,13 +198,13 @@ def get_external_providers_from_dir( for provider_type in ["remote", "inline"]: api_dir = os.path.join(external_providers_dir, provider_type, api_name) if not os.path.exists(api_dir): - logger.debug(f"No {provider_type} provider directory found for {api_name}") + log.debug(f"No {provider_type} provider directory found for {api_name}") continue # Look for provider spec files in the API directory for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")): provider_name = os.path.splitext(os.path.basename(spec_path))[0] - logger.info(f"Loading {provider_type} provider spec from {spec_path}") + log.info(f"Loading {provider_type} provider spec from {spec_path}") try: with open(spec_path) as f: @@ -217,16 +217,16 @@ def get_external_providers_from_dir( spec = _load_inline_provider_spec(spec_data, api, provider_name) provider_type_key = f"inline::{provider_name}" - logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}") + log.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}") if provider_type_key in registry[api]: - logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}") + log.warning(f"Overriding already registered provider {provider_type_key} for {api.name}") registry[api][provider_type_key] = spec - logger.info(f"Successfully loaded external provider {provider_type_key}") + log.info(f"Successfully loaded external provider {provider_type_key}") except yaml.YAMLError as yaml_err: - logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}") + log.error(f"Failed to parse YAML file {spec_path}: {yaml_err}") raise yaml_err except Exception as e: - logger.error(f"Failed to load provider spec from {spec_path}: {e}") + log.error(f"Failed to load provider spec from {spec_path}: {e}") raise e return registry @@ -241,7 +241,7 @@ def get_external_providers_from_module( else: provider_list = config.providers.items() if provider_list is None: - logger.warning("Could not get list of providers from config") + log.warning("Could not get list of providers from config") return registry for provider_api, providers in provider_list: for provider in providers: @@ -272,6 +272,6 @@ def get_external_providers_from_module( "get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available" ) from exc except Exception as e: - logger.error(f"Failed to load provider spec from module {provider.module}: {e}") + log.error(f"Failed to load provider spec from module {provider.module}: {e}") raise e return registry diff --git a/llama_stack/core/external.py b/llama_stack/core/external.py index 12e9824ad..9eca1ee16 100644 --- a/llama_stack/core/external.py +++ b/llama_stack/core/external.py @@ -11,7 +11,7 @@ from llama_stack.apis.datatypes import Api, ExternalApiSpec from llama_stack.core.datatypes import BuildConfig, StackRunConfig from llama_stack.log import get_logger -logger = get_logger(name=__name__, category="core") +log = get_logger(name=__name__, category="core") def load_external_apis(config: StackRunConfig | BuildConfig | None) -> dict[Api, ExternalApiSpec]: @@ -28,10 +28,10 @@ def load_external_apis(config: StackRunConfig | BuildConfig | None) -> dict[Api, external_apis_dir = config.external_apis_dir.expanduser().resolve() if not external_apis_dir.is_dir(): - logger.error(f"External APIs directory is not a directory: {external_apis_dir}") + log.error(f"External APIs directory is not a directory: {external_apis_dir}") return {} - logger.info(f"Loading external APIs from {external_apis_dir}") + log.info(f"Loading external APIs from {external_apis_dir}") external_apis: dict[Api, ExternalApiSpec] = {} # Look for YAML files in the external APIs directory @@ -42,13 +42,13 @@ def load_external_apis(config: StackRunConfig | BuildConfig | None) -> dict[Api, spec = ExternalApiSpec(**spec_data) api = Api.add(spec.name) - logger.info(f"Loaded external API spec for {spec.name} from {yaml_path}") + log.info(f"Loaded external API spec for {spec.name} from {yaml_path}") external_apis[api] = spec except yaml.YAMLError as yaml_err: - logger.error(f"Failed to parse YAML file {yaml_path}: {yaml_err}") + log.error(f"Failed to parse YAML file {yaml_path}: {yaml_err}") raise except Exception: - logger.exception(f"Failed to load external API spec from {yaml_path}") + log.exception(f"Failed to load external API spec from {yaml_path}") raise return external_apis diff --git a/llama_stack/core/library_client.py b/llama_stack/core/library_client.py index 5fbbf1aff..e4a771af7 100644 --- a/llama_stack/core/library_client.py +++ b/llama_stack/core/library_client.py @@ -7,7 +7,6 @@ import asyncio import inspect import json -import logging import os import sys from concurrent.futures import ThreadPoolExecutor @@ -48,6 +47,7 @@ from llama_stack.core.stack import ( from llama_stack.core.utils.config import redact_sensitive_fields from llama_stack.core.utils.context import preserve_contexts_async_generator from llama_stack.core.utils.exec import in_notebook +from llama_stack.log import get_logger from llama_stack.providers.utils.telemetry.tracing import ( CURRENT_TRACE_CONTEXT, end_trace, @@ -55,7 +55,7 @@ from llama_stack.providers.utils.telemetry.tracing import ( start_trace, ) -logger = logging.getLogger(__name__) +log = get_logger(name=__name__, category="core") T = TypeVar("T") @@ -84,7 +84,7 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any: try: return [convert_to_pydantic(item_type, item) for item in value] except Exception: - logger.error(f"Error converting list {value} into {item_type}") + log.error(f"Error converting list {value} into {item_type}") return value elif origin is dict: @@ -92,7 +92,7 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any: try: return {k: convert_to_pydantic(val_type, v) for k, v in value.items()} except Exception: - logger.error(f"Error converting dict {value} into {val_type}") + log.error(f"Error converting dict {value} into {val_type}") return value try: @@ -108,7 +108,7 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any: return convert_to_pydantic(union_type, value) except Exception: continue - logger.warning( + log.warning( f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}", ) raise ValueError(f"Failed to convert parameter {value} into {annotation}: {e}") from e @@ -171,13 +171,15 @@ class LlamaStackAsLibraryClient(LlamaStackClient): def _remove_root_logger_handlers(self): """ - Remove all handlers from the root logger. Needed to avoid polluting the console with logs. + Remove all handlers from the root log. Needed to avoid polluting the console with logs. """ + import logging # allow-direct-logging + root_logger = logging.getLogger() for handler in root_logger.handlers[:]: root_logger.removeHandler(handler) - logger.info(f"Removed handler {handler.__class__.__name__} from root logger") + log.info(f"Removed handler {handler.__class__.__name__} from root log") def request(self, *args, **kwargs): loop = self.loop diff --git a/llama_stack/core/providers.py b/llama_stack/core/providers.py index 7095ffd18..84b67d8fc 100644 --- a/llama_stack/core/providers.py +++ b/llama_stack/core/providers.py @@ -16,7 +16,7 @@ from llama_stack.providers.datatypes import HealthResponse, HealthStatus from .datatypes import StackRunConfig from .utils.config import redact_sensitive_fields -logger = get_logger(name=__name__, category="core") +log = get_logger(name=__name__, category="core") class ProviderImplConfig(BaseModel): @@ -38,7 +38,7 @@ class ProviderImpl(Providers): pass async def shutdown(self) -> None: - logger.debug("ProviderImpl.shutdown") + log.debug("ProviderImpl.shutdown") pass async def list_providers(self) -> ListProvidersResponse: diff --git a/llama_stack/core/request_headers.py b/llama_stack/core/request_headers.py index 35ac72775..e8728cc50 100644 --- a/llama_stack/core/request_headers.py +++ b/llama_stack/core/request_headers.py @@ -6,19 +6,19 @@ import contextvars import json -import logging from contextlib import AbstractContextManager from typing import Any from llama_stack.core.datatypes import User +from llama_stack.log import get_logger from .utils.dynamic import instantiate_class_type -log = logging.getLogger(__name__) - # Context variable for request provider data and auth attributes PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None) +log = get_logger(name=__name__, category="core") + class RequestProviderDataContext(AbstractContextManager): """Context manager for request provider data""" diff --git a/llama_stack/core/resolver.py b/llama_stack/core/resolver.py index 70c78fb01..c2a540428 100644 --- a/llama_stack/core/resolver.py +++ b/llama_stack/core/resolver.py @@ -54,7 +54,7 @@ from llama_stack.providers.datatypes import ( VectorDBsProtocolPrivate, ) -logger = get_logger(name=__name__, category="core") +log = get_logger(name=__name__, category="core") class InvalidProviderError(Exception): @@ -101,7 +101,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) -> protocols[api] = api_class except (ImportError, AttributeError): - logger.exception(f"Failed to load external API {api_spec.name}") + log.exception(f"Failed to load external API {api_spec.name}") return protocols @@ -223,7 +223,7 @@ def validate_and_prepare_providers( specs = {} for provider in providers: if not provider.provider_id or provider.provider_id == "__disabled__": - logger.debug(f"Provider `{provider.provider_type}` for API `{api}` is disabled") + log.debug(f"Provider `{provider.provider_type}` for API `{api}` is disabled") continue validate_provider(provider, api, provider_registry) @@ -245,10 +245,10 @@ def validate_provider(provider: Provider, api: Api, provider_registry: ProviderR p = provider_registry[api][provider.provider_type] if p.deprecation_error: - logger.error(p.deprecation_error) + log.error(p.deprecation_error) raise InvalidProviderError(p.deprecation_error) elif p.deprecation_warning: - logger.warning( + log.warning( f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}", ) @@ -261,9 +261,9 @@ def sort_providers_by_deps( {k: list(v.values()) for k, v in providers_with_specs.items()} ) - logger.debug(f"Resolved {len(sorted_providers)} providers") + log.debug(f"Resolved {len(sorted_providers)} providers") for api_str, provider in sorted_providers: - logger.debug(f" {api_str} => {provider.provider_id}") + log.debug(f" {api_str} => {provider.provider_id}") return sorted_providers @@ -348,7 +348,7 @@ async def instantiate_provider( if not hasattr(provider_spec, "module") or provider_spec.module is None: raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute") - logger.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}") + log.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}") module = importlib.import_module(provider_spec.module) args = [] if isinstance(provider_spec, RemoteProviderSpec): @@ -418,7 +418,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None: obj_params = set(obj_sig.parameters) obj_params.discard("self") if not (proto_params <= obj_params): - logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}") + log.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}") missing_methods.append((name, "signature_mismatch")) else: # Check if the method has a concrete implementation (not just a protocol stub) diff --git a/llama_stack/core/routers/datasets.py b/llama_stack/core/routers/datasets.py index d7984f729..2e02a9d7a 100644 --- a/llama_stack/core/routers/datasets.py +++ b/llama_stack/core/routers/datasets.py @@ -12,7 +12,7 @@ from llama_stack.apis.datasets import DatasetPurpose, DataSource from llama_stack.log import get_logger from llama_stack.providers.datatypes import RoutingTable -logger = get_logger(name=__name__, category="core") +log = get_logger(name=__name__, category="core") class DatasetIORouter(DatasetIO): @@ -20,15 +20,15 @@ class DatasetIORouter(DatasetIO): self, routing_table: RoutingTable, ) -> None: - logger.debug("Initializing DatasetIORouter") + log.debug("Initializing DatasetIORouter") self.routing_table = routing_table async def initialize(self) -> None: - logger.debug("DatasetIORouter.initialize") + log.debug("DatasetIORouter.initialize") pass async def shutdown(self) -> None: - logger.debug("DatasetIORouter.shutdown") + log.debug("DatasetIORouter.shutdown") pass async def register_dataset( @@ -38,7 +38,7 @@ class DatasetIORouter(DatasetIO): metadata: dict[str, Any] | None = None, dataset_id: str | None = None, ) -> None: - logger.debug( + log.debug( f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}", ) await self.routing_table.register_dataset( @@ -54,7 +54,7 @@ class DatasetIORouter(DatasetIO): start_index: int | None = None, limit: int | None = None, ) -> PaginatedResponse: - logger.debug( + log.debug( f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}", ) provider = await self.routing_table.get_provider_impl(dataset_id) @@ -65,7 +65,7 @@ class DatasetIORouter(DatasetIO): ) async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: - logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows") + log.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows") provider = await self.routing_table.get_provider_impl(dataset_id) return await provider.append_rows( dataset_id=dataset_id, diff --git a/llama_stack/core/routers/eval_scoring.py b/llama_stack/core/routers/eval_scoring.py index f7a17eecf..c68ba5504 100644 --- a/llama_stack/core/routers/eval_scoring.py +++ b/llama_stack/core/routers/eval_scoring.py @@ -16,7 +16,7 @@ from llama_stack.apis.scoring import ( from llama_stack.log import get_logger from llama_stack.providers.datatypes import RoutingTable -logger = get_logger(name=__name__, category="core") +log = get_logger(name=__name__, category="core") class ScoringRouter(Scoring): @@ -24,15 +24,15 @@ class ScoringRouter(Scoring): self, routing_table: RoutingTable, ) -> None: - logger.debug("Initializing ScoringRouter") + log.debug("Initializing ScoringRouter") self.routing_table = routing_table async def initialize(self) -> None: - logger.debug("ScoringRouter.initialize") + log.debug("ScoringRouter.initialize") pass async def shutdown(self) -> None: - logger.debug("ScoringRouter.shutdown") + log.debug("ScoringRouter.shutdown") pass async def score_batch( @@ -41,7 +41,7 @@ class ScoringRouter(Scoring): scoring_functions: dict[str, ScoringFnParams | None] = None, save_results_dataset: bool = False, ) -> ScoreBatchResponse: - logger.debug(f"ScoringRouter.score_batch: {dataset_id}") + log.debug(f"ScoringRouter.score_batch: {dataset_id}") res = {} for fn_identifier in scoring_functions.keys(): provider = await self.routing_table.get_provider_impl(fn_identifier) @@ -63,7 +63,7 @@ class ScoringRouter(Scoring): input_rows: list[dict[str, Any]], scoring_functions: dict[str, ScoringFnParams | None] = None, ) -> ScoreResponse: - logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions") + log.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions") res = {} # look up and map each scoring function to its provider impl for fn_identifier in scoring_functions.keys(): @@ -82,15 +82,15 @@ class EvalRouter(Eval): self, routing_table: RoutingTable, ) -> None: - logger.debug("Initializing EvalRouter") + log.debug("Initializing EvalRouter") self.routing_table = routing_table async def initialize(self) -> None: - logger.debug("EvalRouter.initialize") + log.debug("EvalRouter.initialize") pass async def shutdown(self) -> None: - logger.debug("EvalRouter.shutdown") + log.debug("EvalRouter.shutdown") pass async def run_eval( @@ -98,7 +98,7 @@ class EvalRouter(Eval): benchmark_id: str, benchmark_config: BenchmarkConfig, ) -> Job: - logger.debug(f"EvalRouter.run_eval: {benchmark_id}") + log.debug(f"EvalRouter.run_eval: {benchmark_id}") provider = await self.routing_table.get_provider_impl(benchmark_id) return await provider.run_eval( benchmark_id=benchmark_id, @@ -112,7 +112,7 @@ class EvalRouter(Eval): scoring_functions: list[str], benchmark_config: BenchmarkConfig, ) -> EvaluateResponse: - logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows") + log.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows") provider = await self.routing_table.get_provider_impl(benchmark_id) return await provider.evaluate_rows( benchmark_id=benchmark_id, @@ -126,7 +126,7 @@ class EvalRouter(Eval): benchmark_id: str, job_id: str, ) -> Job: - logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}") + log.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}") provider = await self.routing_table.get_provider_impl(benchmark_id) return await provider.job_status(benchmark_id, job_id) @@ -135,7 +135,7 @@ class EvalRouter(Eval): benchmark_id: str, job_id: str, ) -> None: - logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}") + log.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}") provider = await self.routing_table.get_provider_impl(benchmark_id) await provider.job_cancel( benchmark_id, @@ -147,7 +147,7 @@ class EvalRouter(Eval): benchmark_id: str, job_id: str, ) -> EvaluateResponse: - logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}") + log.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}") provider = await self.routing_table.get_provider_impl(benchmark_id) return await provider.job_result( benchmark_id, diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py index 6152acd57..45768423a 100644 --- a/llama_stack/core/routers/inference.py +++ b/llama_stack/core/routers/inference.py @@ -58,7 +58,7 @@ from llama_stack.providers.utils.inference.inference_store import InferenceStore from llama_stack.providers.utils.inference.stream_utils import stream_and_store_openai_completion from llama_stack.providers.utils.telemetry.tracing import get_current_span -logger = get_logger(name=__name__, category="core") +log = get_logger(name=__name__, category="core") class InferenceRouter(Inference): @@ -70,7 +70,7 @@ class InferenceRouter(Inference): telemetry: Telemetry | None = None, store: InferenceStore | None = None, ) -> None: - logger.debug("Initializing InferenceRouter") + log.debug("Initializing InferenceRouter") self.routing_table = routing_table self.telemetry = telemetry self.store = store @@ -79,10 +79,10 @@ class InferenceRouter(Inference): self.formatter = ChatFormat(self.tokenizer) async def initialize(self) -> None: - logger.debug("InferenceRouter.initialize") + log.debug("InferenceRouter.initialize") async def shutdown(self) -> None: - logger.debug("InferenceRouter.shutdown") + log.debug("InferenceRouter.shutdown") async def register_model( self, @@ -92,7 +92,7 @@ class InferenceRouter(Inference): metadata: dict[str, Any] | None = None, model_type: ModelType | None = None, ) -> None: - logger.debug( + log.debug( f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}", ) await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) @@ -117,7 +117,7 @@ class InferenceRouter(Inference): """ span = get_current_span() if span is None: - logger.warning("No span found for token usage metrics") + log.warning("No span found for token usage metrics") return [] metrics = [ ("prompt_tokens", prompt_tokens), @@ -182,7 +182,7 @@ class InferenceRouter(Inference): logprobs: LogProbConfig | None = None, tool_config: ToolConfig | None = None, ) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]: - logger.debug( + log.debug( f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", ) if sampling_params is None: @@ -288,7 +288,7 @@ class InferenceRouter(Inference): response_format: ResponseFormat | None = None, logprobs: LogProbConfig | None = None, ) -> BatchChatCompletionResponse: - logger.debug( + log.debug( f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}", ) provider = await self.routing_table.get_provider_impl(model_id) @@ -313,7 +313,7 @@ class InferenceRouter(Inference): ) -> AsyncGenerator: if sampling_params is None: sampling_params = SamplingParams() - logger.debug( + log.debug( f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}", ) model = await self.routing_table.get_model(model_id) @@ -374,7 +374,7 @@ class InferenceRouter(Inference): response_format: ResponseFormat | None = None, logprobs: LogProbConfig | None = None, ) -> BatchCompletionResponse: - logger.debug( + log.debug( f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}", ) provider = await self.routing_table.get_provider_impl(model_id) @@ -388,7 +388,7 @@ class InferenceRouter(Inference): output_dimension: int | None = None, task_type: EmbeddingTaskType | None = None, ) -> EmbeddingsResponse: - logger.debug(f"InferenceRouter.embeddings: {model_id}") + log.debug(f"InferenceRouter.embeddings: {model_id}") model = await self.routing_table.get_model(model_id) if model is None: raise ModelNotFoundError(model_id) @@ -426,7 +426,7 @@ class InferenceRouter(Inference): prompt_logprobs: int | None = None, suffix: str | None = None, ) -> OpenAICompletion: - logger.debug( + log.debug( f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}", ) model_obj = await self.routing_table.get_model(model) @@ -487,7 +487,7 @@ class InferenceRouter(Inference): top_p: float | None = None, user: str | None = None, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - logger.debug( + log.debug( f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}", ) model_obj = await self.routing_table.get_model(model) @@ -558,7 +558,7 @@ class InferenceRouter(Inference): dimensions: int | None = None, user: str | None = None, ) -> OpenAIEmbeddingsResponse: - logger.debug( + log.debug( f"InferenceRouter.openai_embeddings: {model=}, input_type={type(input)}, {encoding_format=}, {dimensions=}", ) model_obj = await self.routing_table.get_model(model) diff --git a/llama_stack/core/routers/safety.py b/llama_stack/core/routers/safety.py index f4273c7b5..6b85c1ae2 100644 --- a/llama_stack/core/routers/safety.py +++ b/llama_stack/core/routers/safety.py @@ -14,7 +14,7 @@ from llama_stack.apis.shields import Shield from llama_stack.log import get_logger from llama_stack.providers.datatypes import RoutingTable -logger = get_logger(name=__name__, category="core") +log = get_logger(name=__name__, category="core") class SafetyRouter(Safety): @@ -22,15 +22,15 @@ class SafetyRouter(Safety): self, routing_table: RoutingTable, ) -> None: - logger.debug("Initializing SafetyRouter") + log.debug("Initializing SafetyRouter") self.routing_table = routing_table async def initialize(self) -> None: - logger.debug("SafetyRouter.initialize") + log.debug("SafetyRouter.initialize") pass async def shutdown(self) -> None: - logger.debug("SafetyRouter.shutdown") + log.debug("SafetyRouter.shutdown") pass async def register_shield( @@ -40,7 +40,7 @@ class SafetyRouter(Safety): provider_id: str | None = None, params: dict[str, Any] | None = None, ) -> Shield: - logger.debug(f"SafetyRouter.register_shield: {shield_id}") + log.debug(f"SafetyRouter.register_shield: {shield_id}") return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) async def unregister_shield(self, identifier: str) -> None: @@ -53,7 +53,7 @@ class SafetyRouter(Safety): messages: list[Message], params: dict[str, Any] = None, ) -> RunShieldResponse: - logger.debug(f"SafetyRouter.run_shield: {shield_id}") + log.debug(f"SafetyRouter.run_shield: {shield_id}") provider = await self.routing_table.get_provider_impl(shield_id) return await provider.run_shield( shield_id=shield_id, diff --git a/llama_stack/core/routers/tool_runtime.py b/llama_stack/core/routers/tool_runtime.py index 5a40bc0c5..17de0592f 100644 --- a/llama_stack/core/routers/tool_runtime.py +++ b/llama_stack/core/routers/tool_runtime.py @@ -22,7 +22,7 @@ from llama_stack.log import get_logger from ..routing_tables.toolgroups import ToolGroupsRoutingTable -logger = get_logger(name=__name__, category="core") +log = get_logger(name=__name__, category="core") class ToolRuntimeRouter(ToolRuntime): @@ -31,7 +31,7 @@ class ToolRuntimeRouter(ToolRuntime): self, routing_table: ToolGroupsRoutingTable, ) -> None: - logger.debug("Initializing ToolRuntimeRouter.RagToolImpl") + log.debug("Initializing ToolRuntimeRouter.RagToolImpl") self.routing_table = routing_table async def query( @@ -40,7 +40,7 @@ class ToolRuntimeRouter(ToolRuntime): vector_db_ids: list[str], query_config: RAGQueryConfig | None = None, ) -> RAGQueryResult: - logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}") + log.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}") provider = await self.routing_table.get_provider_impl("knowledge_search") return await provider.query(content, vector_db_ids, query_config) @@ -50,7 +50,7 @@ class ToolRuntimeRouter(ToolRuntime): vector_db_id: str, chunk_size_in_tokens: int = 512, ) -> None: - logger.debug( + log.debug( f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}" ) provider = await self.routing_table.get_provider_impl("insert_into_memory") @@ -60,7 +60,7 @@ class ToolRuntimeRouter(ToolRuntime): self, routing_table: ToolGroupsRoutingTable, ) -> None: - logger.debug("Initializing ToolRuntimeRouter") + log.debug("Initializing ToolRuntimeRouter") self.routing_table = routing_table # HACK ALERT this should be in sync with "get_all_api_endpoints()" @@ -69,15 +69,15 @@ class ToolRuntimeRouter(ToolRuntime): setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method)) async def initialize(self) -> None: - logger.debug("ToolRuntimeRouter.initialize") + log.debug("ToolRuntimeRouter.initialize") pass async def shutdown(self) -> None: - logger.debug("ToolRuntimeRouter.shutdown") + log.debug("ToolRuntimeRouter.shutdown") pass async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any: - logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}") + log.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}") provider = await self.routing_table.get_provider_impl(tool_name) return await provider.invoke_tool( tool_name=tool_name, @@ -87,5 +87,5 @@ class ToolRuntimeRouter(ToolRuntime): async def list_runtime_tools( self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None ) -> ListToolsResponse: - logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") + log.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") return await self.routing_table.list_tools(tool_group_id) diff --git a/llama_stack/core/routers/vector_io.py b/llama_stack/core/routers/vector_io.py index 3d0996c49..07b3dc84b 100644 --- a/llama_stack/core/routers/vector_io.py +++ b/llama_stack/core/routers/vector_io.py @@ -30,7 +30,7 @@ from llama_stack.apis.vector_io import ( from llama_stack.log import get_logger from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable -logger = get_logger(name=__name__, category="core") +log = get_logger(name=__name__, category="core") class VectorIORouter(VectorIO): @@ -40,15 +40,15 @@ class VectorIORouter(VectorIO): self, routing_table: RoutingTable, ) -> None: - logger.debug("Initializing VectorIORouter") + log.debug("Initializing VectorIORouter") self.routing_table = routing_table async def initialize(self) -> None: - logger.debug("VectorIORouter.initialize") + log.debug("VectorIORouter.initialize") pass async def shutdown(self) -> None: - logger.debug("VectorIORouter.shutdown") + log.debug("VectorIORouter.shutdown") pass async def _get_first_embedding_model(self) -> tuple[str, int] | None: @@ -70,10 +70,10 @@ class VectorIORouter(VectorIO): raise ValueError(f"Embedding model {embedding_models[0].identifier} has no embedding dimension") return embedding_models[0].identifier, dimension else: - logger.warning("No embedding models found in the routing table") + log.warning("No embedding models found in the routing table") return None except Exception as e: - logger.error(f"Error getting embedding models: {e}") + log.error(f"Error getting embedding models: {e}") return None async def register_vector_db( @@ -85,7 +85,7 @@ class VectorIORouter(VectorIO): vector_db_name: str | None = None, provider_vector_db_id: str | None = None, ) -> None: - logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}") + log.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}") await self.routing_table.register_vector_db( vector_db_id, embedding_model, @@ -101,7 +101,7 @@ class VectorIORouter(VectorIO): chunks: list[Chunk], ttl_seconds: int | None = None, ) -> None: - logger.debug( + log.debug( f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", ) provider = await self.routing_table.get_provider_impl(vector_db_id) @@ -113,7 +113,7 @@ class VectorIORouter(VectorIO): query: InterleavedContent, params: dict[str, Any] | None = None, ) -> QueryChunksResponse: - logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}") + log.debug(f"VectorIORouter.query_chunks: {vector_db_id}") provider = await self.routing_table.get_provider_impl(vector_db_id) return await provider.query_chunks(vector_db_id, query, params) @@ -129,7 +129,7 @@ class VectorIORouter(VectorIO): embedding_dimension: int | None = None, provider_id: str | None = None, ) -> VectorStoreObject: - logger.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}") + log.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}") # If no embedding model is provided, use the first available one if embedding_model is None: @@ -137,7 +137,7 @@ class VectorIORouter(VectorIO): if embedding_model_info is None: raise ValueError("No embedding model provided and no embedding models available in the system") embedding_model, embedding_dimension = embedding_model_info - logger.info(f"No embedding model specified, using first available: {embedding_model}") + log.info(f"No embedding model specified, using first available: {embedding_model}") vector_db_id = f"vs_{uuid.uuid4()}" registered_vector_db = await self.routing_table.register_vector_db( @@ -168,7 +168,7 @@ class VectorIORouter(VectorIO): after: str | None = None, before: str | None = None, ) -> VectorStoreListResponse: - logger.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}") + log.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}") # Route to default provider for now - could aggregate from all providers in the future # call retrieve on each vector dbs to get list of vector stores vector_dbs = await self.routing_table.get_all_with_type("vector_db") @@ -179,7 +179,7 @@ class VectorIORouter(VectorIO): vector_store = await provider.openai_retrieve_vector_store(vector_db.identifier) all_stores.append(vector_store) except Exception as e: - logger.error(f"Error retrieving vector store {vector_db.identifier}: {e}") + log.error(f"Error retrieving vector store {vector_db.identifier}: {e}") continue # Sort by created_at @@ -215,7 +215,7 @@ class VectorIORouter(VectorIO): self, vector_store_id: str, ) -> VectorStoreObject: - logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}") + log.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}") return await self.routing_table.openai_retrieve_vector_store(vector_store_id) async def openai_update_vector_store( @@ -225,7 +225,7 @@ class VectorIORouter(VectorIO): expires_after: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None, ) -> VectorStoreObject: - logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}") + log.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}") return await self.routing_table.openai_update_vector_store( vector_store_id=vector_store_id, name=name, @@ -237,7 +237,7 @@ class VectorIORouter(VectorIO): self, vector_store_id: str, ) -> VectorStoreDeleteResponse: - logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}") + log.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}") return await self.routing_table.openai_delete_vector_store(vector_store_id) async def openai_search_vector_store( @@ -250,7 +250,7 @@ class VectorIORouter(VectorIO): rewrite_query: bool | None = False, search_mode: str | None = "vector", ) -> VectorStoreSearchResponsePage: - logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}") + log.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}") return await self.routing_table.openai_search_vector_store( vector_store_id=vector_store_id, query=query, @@ -268,7 +268,7 @@ class VectorIORouter(VectorIO): attributes: dict[str, Any] | None = None, chunking_strategy: VectorStoreChunkingStrategy | None = None, ) -> VectorStoreFileObject: - logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}") + log.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}") return await self.routing_table.openai_attach_file_to_vector_store( vector_store_id=vector_store_id, file_id=file_id, @@ -285,7 +285,7 @@ class VectorIORouter(VectorIO): before: str | None = None, filter: VectorStoreFileStatus | None = None, ) -> list[VectorStoreFileObject]: - logger.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}") + log.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}") return await self.routing_table.openai_list_files_in_vector_store( vector_store_id=vector_store_id, limit=limit, @@ -300,7 +300,7 @@ class VectorIORouter(VectorIO): vector_store_id: str, file_id: str, ) -> VectorStoreFileObject: - logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}") + log.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}") return await self.routing_table.openai_retrieve_vector_store_file( vector_store_id=vector_store_id, file_id=file_id, @@ -311,7 +311,7 @@ class VectorIORouter(VectorIO): vector_store_id: str, file_id: str, ) -> VectorStoreFileContentsResponse: - logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}") + log.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}") return await self.routing_table.openai_retrieve_vector_store_file_contents( vector_store_id=vector_store_id, file_id=file_id, @@ -323,7 +323,7 @@ class VectorIORouter(VectorIO): file_id: str, attributes: dict[str, Any], ) -> VectorStoreFileObject: - logger.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}") + log.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}") return await self.routing_table.openai_update_vector_store_file( vector_store_id=vector_store_id, file_id=file_id, @@ -335,7 +335,7 @@ class VectorIORouter(VectorIO): vector_store_id: str, file_id: str, ) -> VectorStoreFileDeleteResponse: - logger.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}") + log.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}") return await self.routing_table.openai_delete_vector_store_file( vector_store_id=vector_store_id, file_id=file_id, diff --git a/llama_stack/core/routing_tables/benchmarks.py b/llama_stack/core/routing_tables/benchmarks.py index 74bee8040..08e0d4aa8 100644 --- a/llama_stack/core/routing_tables/benchmarks.py +++ b/llama_stack/core/routing_tables/benchmarks.py @@ -14,7 +14,7 @@ from llama_stack.log import get_logger from .common import CommonRoutingTableImpl -logger = get_logger(name=__name__, category="core") +log = get_logger(name=__name__, category="core") class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks): diff --git a/llama_stack/core/routing_tables/common.py b/llama_stack/core/routing_tables/common.py index 339ff6da4..cb1f2c9e9 100644 --- a/llama_stack/core/routing_tables/common.py +++ b/llama_stack/core/routing_tables/common.py @@ -23,7 +23,7 @@ from llama_stack.core.store import DistributionRegistry from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api, RoutingTable -logger = get_logger(name=__name__, category="core") +log = get_logger(name=__name__, category="core") def get_impl_api(p: Any) -> Api: @@ -177,7 +177,7 @@ class CommonRoutingTableImpl(RoutingTable): # Check if user has permission to access this object if not is_action_allowed(self.policy, "read", obj, get_authenticated_user()): - logger.debug(f"Access denied to {type} '{identifier}'") + log.debug(f"Access denied to {type} '{identifier}'") return None return obj @@ -205,7 +205,7 @@ class CommonRoutingTableImpl(RoutingTable): raise AccessDeniedError("create", obj, creator) if creator: obj.owner = creator - logger.info(f"Setting owner for {obj.type} '{obj.identifier}' to {obj.owner.principal}") + log.info(f"Setting owner for {obj.type} '{obj.identifier}' to {obj.owner.principal}") registered_obj = await register_object_with_provider(obj, p) # TODO: This needs to be fixed for all APIs once they return the registered object @@ -250,7 +250,7 @@ async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> if model is not None: return model - logger.warning( + log.warning( f"WARNING: model identifier '{model_id}' not found in routing table. Falling back to " "searching in all providers. This is only for backwards compatibility and will stop working " "soon. Migrate your calls to use fully scoped `provider_id/model_id` names." diff --git a/llama_stack/core/routing_tables/datasets.py b/llama_stack/core/routing_tables/datasets.py index fc6a75df4..480c58c83 100644 --- a/llama_stack/core/routing_tables/datasets.py +++ b/llama_stack/core/routing_tables/datasets.py @@ -26,7 +26,7 @@ from llama_stack.log import get_logger from .common import CommonRoutingTableImpl -logger = get_logger(name=__name__, category="core") +log = get_logger(name=__name__, category="core") class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): diff --git a/llama_stack/core/routing_tables/models.py b/llama_stack/core/routing_tables/models.py index c76619271..fd26afa3c 100644 --- a/llama_stack/core/routing_tables/models.py +++ b/llama_stack/core/routing_tables/models.py @@ -17,7 +17,7 @@ from llama_stack.log import get_logger from .common import CommonRoutingTableImpl, lookup_model -logger = get_logger(name=__name__, category="core") +log = get_logger(name=__name__, category="core") class ModelsRoutingTable(CommonRoutingTableImpl, Models): @@ -33,7 +33,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): try: models = await provider.list_models() except Exception as e: - logger.exception(f"Model refresh failed for provider {provider_id}: {e}") + log.exception(f"Model refresh failed for provider {provider_id}: {e}") continue self.listed_providers.add(provider_id) @@ -132,7 +132,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): model_ids[model.provider_resource_id] = model.identifier continue - logger.debug(f"unregistering model {model.identifier}") + log.debug(f"unregistering model {model.identifier}") await self.unregister_object(model) for model in models: @@ -143,7 +143,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): if model.identifier == model.provider_resource_id: model.identifier = f"{provider_id}/{model.provider_resource_id}" - logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})") + log.debug(f"registering model {model.identifier} ({model.provider_resource_id})") await self.register_object( ModelWithOwner( identifier=model.identifier, diff --git a/llama_stack/core/routing_tables/scoring_functions.py b/llama_stack/core/routing_tables/scoring_functions.py index 5874ba941..35d3ef91d 100644 --- a/llama_stack/core/routing_tables/scoring_functions.py +++ b/llama_stack/core/routing_tables/scoring_functions.py @@ -19,7 +19,7 @@ from llama_stack.log import get_logger from .common import CommonRoutingTableImpl -logger = get_logger(name=__name__, category="core") +log = get_logger(name=__name__, category="core") class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): diff --git a/llama_stack/core/routing_tables/shields.py b/llama_stack/core/routing_tables/shields.py index e08f35bfc..2e505f502 100644 --- a/llama_stack/core/routing_tables/shields.py +++ b/llama_stack/core/routing_tables/shields.py @@ -15,7 +15,7 @@ from llama_stack.log import get_logger from .common import CommonRoutingTableImpl -logger = get_logger(name=__name__, category="core") +log = get_logger(name=__name__, category="core") class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): diff --git a/llama_stack/core/routing_tables/toolgroups.py b/llama_stack/core/routing_tables/toolgroups.py index e172af991..27f61946d 100644 --- a/llama_stack/core/routing_tables/toolgroups.py +++ b/llama_stack/core/routing_tables/toolgroups.py @@ -14,7 +14,7 @@ from llama_stack.log import get_logger from .common import CommonRoutingTableImpl -logger = get_logger(name=__name__, category="core") +log = get_logger(name=__name__, category="core") def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None: diff --git a/llama_stack/core/routing_tables/vector_dbs.py b/llama_stack/core/routing_tables/vector_dbs.py index c81a27a3b..32b83f44e 100644 --- a/llama_stack/core/routing_tables/vector_dbs.py +++ b/llama_stack/core/routing_tables/vector_dbs.py @@ -30,7 +30,7 @@ from llama_stack.log import get_logger from .common import CommonRoutingTableImpl, lookup_model -logger = get_logger(name=__name__, category="core") +log = get_logger(name=__name__, category="core") class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): @@ -57,7 +57,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): if len(self.impls_by_provider_id) > 0: provider_id = list(self.impls_by_provider_id.keys())[0] if len(self.impls_by_provider_id) > 1: - logger.warning( + log.warning( f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}." ) else: diff --git a/llama_stack/core/server/auth.py b/llama_stack/core/server/auth.py index e4fb4ff2b..b9e263f4c 100644 --- a/llama_stack/core/server/auth.py +++ b/llama_stack/core/server/auth.py @@ -15,7 +15,7 @@ from llama_stack.core.server.auth_providers import create_auth_provider from llama_stack.core.server.routes import find_matching_route, initialize_route_impls from llama_stack.log import get_logger -logger = get_logger(name=__name__, category="auth") +log = get_logger(name=__name__, category="auth") class AuthenticationMiddleware: @@ -105,13 +105,13 @@ class AuthenticationMiddleware: try: validation_result = await self.auth_provider.validate_token(token, scope) except httpx.TimeoutException: - logger.exception("Authentication request timed out") + log.exception("Authentication request timed out") return await self._send_auth_error(send, "Authentication service timeout") except ValueError as e: - logger.exception("Error during authentication") + log.exception("Error during authentication") return await self._send_auth_error(send, str(e)) except Exception: - logger.exception("Error during authentication") + log.exception("Error during authentication") return await self._send_auth_error(send, "Authentication service error") # Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware) @@ -122,7 +122,7 @@ class AuthenticationMiddleware: scope["principal"] = validation_result.principal if validation_result.attributes: scope["user_attributes"] = validation_result.attributes - logger.debug( + log.debug( f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes" ) diff --git a/llama_stack/core/server/auth_providers.py b/llama_stack/core/server/auth_providers.py index 73d5581c2..e82640ca8 100644 --- a/llama_stack/core/server/auth_providers.py +++ b/llama_stack/core/server/auth_providers.py @@ -23,7 +23,7 @@ from llama_stack.core.datatypes import ( ) from llama_stack.log import get_logger -logger = get_logger(name=__name__, category="auth") +log = get_logger(name=__name__, category="auth") class AuthResponse(BaseModel): @@ -163,7 +163,7 @@ class OAuth2TokenAuthProvider(AuthProvider): timeout=10.0, # Add a reasonable timeout ) if response.status_code != 200: - logger.warning(f"Token introspection failed with status code: {response.status_code}") + log.warning(f"Token introspection failed with status code: {response.status_code}") raise ValueError(f"Token introspection failed: {response.status_code}") fields = response.json() @@ -176,13 +176,13 @@ class OAuth2TokenAuthProvider(AuthProvider): attributes=access_attributes, ) except httpx.TimeoutException: - logger.exception("Token introspection request timed out") + log.exception("Token introspection request timed out") raise except ValueError: # Re-raise ValueError exceptions to preserve their message raise except Exception as e: - logger.exception("Error during token introspection") + log.exception("Error during token introspection") raise ValueError("Token introspection error") from e async def close(self): @@ -273,7 +273,7 @@ class CustomAuthProvider(AuthProvider): timeout=10.0, # Add a reasonable timeout ) if response.status_code != 200: - logger.warning(f"Authentication failed with status code: {response.status_code}") + log.warning(f"Authentication failed with status code: {response.status_code}") raise ValueError(f"Authentication failed: {response.status_code}") # Parse and validate the auth response @@ -282,17 +282,17 @@ class CustomAuthProvider(AuthProvider): auth_response = AuthResponse(**response_data) return User(principal=auth_response.principal, attributes=auth_response.attributes) except Exception as e: - logger.exception("Error parsing authentication response") + log.exception("Error parsing authentication response") raise ValueError("Invalid authentication response format") from e except httpx.TimeoutException: - logger.exception("Authentication request timed out") + log.exception("Authentication request timed out") raise except ValueError: # Re-raise ValueError exceptions to preserve their message raise except Exception as e: - logger.exception("Error during authentication") + log.exception("Error during authentication") raise ValueError("Authentication service error") from e async def close(self): @@ -329,7 +329,7 @@ class GitHubTokenAuthProvider(AuthProvider): try: user_info = await _get_github_user_info(token, self.config.github_api_base_url) except httpx.HTTPStatusError as e: - logger.warning(f"GitHub token validation failed: {e}") + log.warning(f"GitHub token validation failed: {e}") raise ValueError("GitHub token validation failed. Please check your token and try again.") from e principal = user_info["user"]["login"] diff --git a/llama_stack/core/server/quota.py b/llama_stack/core/server/quota.py index 1cb850cde..09d5c21cd 100644 --- a/llama_stack/core/server/quota.py +++ b/llama_stack/core/server/quota.py @@ -15,7 +15,7 @@ from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl -logger = get_logger(name=__name__, category="quota") +log = get_logger(name=__name__, category="quota") class QuotaMiddleware: @@ -46,7 +46,7 @@ class QuotaMiddleware: self.window_seconds = window_seconds if isinstance(self.kv_config, SqliteKVStoreConfig): - logger.warning( + log.warning( "QuotaMiddleware: Using SQLite backend. Expiry/TTL is not enforced; cleanup is manual. " f"window_seconds={self.window_seconds}" ) @@ -84,11 +84,11 @@ class QuotaMiddleware: else: await kv.set(key, str(count)) except Exception: - logger.exception("Failed to access KV store for quota") + log.exception("Failed to access KV store for quota") return await self._send_error(send, 500, "Quota service error") if count > limit: - logger.warning( + log.warning( "Quota exceeded for client %s: %d/%d", key_id, count, diff --git a/llama_stack/core/server/server.py b/llama_stack/core/server/server.py index fe5cc68d7..c337c4ddd 100644 --- a/llama_stack/core/server/server.py +++ b/llama_stack/core/server/server.py @@ -9,7 +9,7 @@ import asyncio import functools import inspect import json -import logging +import logging # allow-direct-logging import os import ssl import sys @@ -80,7 +80,7 @@ from .quota import QuotaMiddleware REPO_ROOT = Path(__file__).parent.parent.parent.parent -logger = get_logger(name=__name__, category="server") +log = get_logger(name=__name__, category="server") def warn_with_traceback(message, category, filename, lineno, file=None, line=None): @@ -157,9 +157,9 @@ async def shutdown(app): @asynccontextmanager async def lifespan(app: FastAPI): - logger.info("Starting up") + log.info("Starting up") yield - logger.info("Shutting down") + log.info("Shutting down") await shutdown(app) @@ -182,11 +182,11 @@ async def sse_generator(event_gen_coroutine): yield create_sse_event(item) await asyncio.sleep(0.01) except asyncio.CancelledError: - logger.info("Generator cancelled") + log.info("Generator cancelled") if event_gen: await event_gen.aclose() except Exception as e: - logger.exception("Error in sse_generator") + log.exception("Error in sse_generator") yield create_sse_event( { "error": { @@ -206,11 +206,11 @@ async def log_request_pre_validation(request: Request): log_output = rich.pretty.pretty_repr(parsed_body) except (json.JSONDecodeError, UnicodeDecodeError): log_output = repr(body_bytes) - logger.debug(f"Incoming raw request body for {request.method} {request.url.path}:\n{log_output}") + log.debug(f"Incoming raw request body for {request.method} {request.url.path}:\n{log_output}") else: - logger.debug(f"Incoming {request.method} {request.url.path} request with empty body.") + log.debug(f"Incoming {request.method} {request.url.path} request with empty body.") except Exception as e: - logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}") + log.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}") def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable: @@ -238,10 +238,10 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable: result.url = route return result except Exception as e: - if logger.isEnabledFor(logging.DEBUG): - logger.exception(f"Error executing endpoint {route=} {method=}") + if log.isEnabledFor(logging.DEBUG): + log.exception(f"Error executing endpoint {route=} {method=}") else: - logger.error(f"Error executing endpoint {route=} {method=}: {str(e)}") + log.error(f"Error executing endpoint {route=} {method=}: {str(e)}") raise translate_exception(e) from e sig = inspect.signature(func) @@ -291,7 +291,7 @@ class TracingMiddleware: # Check if the path is a FastAPI built-in path if path.startswith(self.fastapi_paths): # Pass through to FastAPI's built-in handlers - logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}") + log.debug(f"Bypassing custom routing for FastAPI built-in path: {path}") return await self.app(scope, receive, send) if not hasattr(self, "route_impls"): @@ -303,7 +303,7 @@ class TracingMiddleware: ) except ValueError: # If no matching endpoint is found, pass through to FastAPI - logger.debug(f"No matching route found for path: {path}, falling back to FastAPI") + log.debug(f"No matching route found for path: {path}, falling back to FastAPI") return await self.app(scope, receive, send) trace_attributes = {"__location__": "server", "raw_path": path} @@ -404,15 +404,15 @@ def main(args: argparse.Namespace | None = None): config_contents = yaml.safe_load(fp) if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")): logger_config = LoggingConfig(**cfg) - logger = get_logger(name=__name__, category="server", config=logger_config) + log = get_logger(name=__name__, category="server", config=logger_config) if args.env: for env_pair in args.env: try: key, value = validate_env_pair(env_pair) - logger.info(f"Setting CLI environment variable {key} => {value}") + log.info(f"Setting CLI environment variable {key} => {value}") os.environ[key] = value except ValueError as e: - logger.error(f"Error: {str(e)}") + log.error(f"Error: {str(e)}") sys.exit(1) config = replace_env_vars(config_contents) config = StackRunConfig(**cast_image_name_to_string(config)) @@ -438,16 +438,16 @@ def main(args: argparse.Namespace | None = None): impls = loop.run_until_complete(construct_stack(config)) except InvalidProviderError as e: - logger.error(f"Error: {str(e)}") + log.error(f"Error: {str(e)}") sys.exit(1) if config.server.auth: - logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}") + log.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}") app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth, impls=impls) else: if config.server.quota: quota = config.server.quota - logger.warning( + log.warning( "Configured authenticated_max_requests (%d) but no auth is enabled; " "falling back to anonymous_max_requests (%d) for all the requests", quota.authenticated_max_requests, @@ -455,7 +455,7 @@ def main(args: argparse.Namespace | None = None): ) if config.server.quota: - logger.info("Enabling quota middleware for authenticated and anonymous clients") + log.info("Enabling quota middleware for authenticated and anonymous clients") quota = config.server.quota anonymous_max_requests = quota.anonymous_max_requests @@ -516,7 +516,7 @@ def main(args: argparse.Namespace | None = None): if not available_methods: raise ValueError(f"No methods found for {route.name} on {impl}") method = available_methods[0] - logger.debug(f"{method} {route.path}") + log.debug(f"{method} {route.path}") with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields") @@ -528,7 +528,7 @@ def main(args: argparse.Namespace | None = None): ) ) - logger.debug(f"serving APIs: {apis_to_serve}") + log.debug(f"serving APIs: {apis_to_serve}") app.exception_handler(RequestValidationError)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler) @@ -553,21 +553,21 @@ def main(args: argparse.Namespace | None = None): if config.server.tls_cafile: ssl_config["ssl_ca_certs"] = config.server.tls_cafile ssl_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED - logger.info( + log.info( f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}\n CA: {config.server.tls_cafile}" ) else: - logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}") + log.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}") listen_host = config.server.host or ["::", "0.0.0.0"] - logger.info(f"Listening on {listen_host}:{port}") + log.info(f"Listening on {listen_host}:{port}") uvicorn_config = { "app": app, "host": listen_host, "port": port, "lifespan": "on", - "log_level": logger.getEffectiveLevel(), + "log_level": log.getEffectiveLevel(), "log_config": logger_config, } if ssl_config: @@ -586,19 +586,19 @@ def main(args: argparse.Namespace | None = None): try: loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve()) except (KeyboardInterrupt, SystemExit): - logger.info("Received interrupt signal, shutting down gracefully...") + log.info("Received interrupt signal, shutting down gracefully...") finally: if not loop.is_closed(): - logger.debug("Closing event loop") + log.debug("Closing event loop") loop.close() def _log_run_config(run_config: StackRunConfig): """Logs the run config with redacted fields and disabled providers removed.""" - logger.info("Run configuration:") + log.info("Run configuration:") safe_config = redact_sensitive_fields(run_config.model_dump(mode="json")) clean_config = remove_disabled_providers(safe_config) - logger.info(yaml.dump(clean_config, indent=2)) + log.info(yaml.dump(clean_config, indent=2)) def extract_path_params(route: str) -> list[str]: diff --git a/llama_stack/core/stack.py b/llama_stack/core/stack.py index 87a3978c1..89652e353 100644 --- a/llama_stack/core/stack.py +++ b/llama_stack/core/stack.py @@ -45,7 +45,7 @@ from llama_stack.core.utils.dynamic import instantiate_class_type from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api -logger = get_logger(name=__name__, category="core") +log = get_logger(name=__name__, category="core") class LlamaStack( @@ -105,11 +105,11 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]): method = getattr(impls[api], register_method) for obj in objects: - logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}") + log.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}") # Do not register models on disabled providers if hasattr(obj, "provider_id") and (not obj.provider_id or obj.provider_id == "__disabled__"): - logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.") + log.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.") continue # we want to maintain the type information in arguments to method. @@ -123,7 +123,7 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]): objects_to_process = response.data if hasattr(response, "data") else response for obj in objects_to_process: - logger.debug( + log.debug( f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}", ) @@ -160,7 +160,7 @@ def replace_env_vars(config: Any, path: str = "") -> Any: try: resolved_provider_id = replace_env_vars(v["provider_id"], f"{path}[{i}].provider_id") if resolved_provider_id == "__disabled__": - logger.debug( + log.debug( f"Skipping config env variable expansion for disabled provider: {v.get('provider_id', '')}" ) # Create a copy with resolved provider_id but original config @@ -315,7 +315,7 @@ async def construct_stack( TEST_RECORDING_CONTEXT = setup_inference_recording() if TEST_RECORDING_CONTEXT: TEST_RECORDING_CONTEXT.__enter__() - logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}") + log.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}") dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name) policy = run_config.server.auth.access_policy if run_config.server.auth else [] @@ -337,12 +337,12 @@ async def construct_stack( import traceback if task.cancelled(): - logger.error("Model refresh task cancelled") + log.error("Model refresh task cancelled") elif task.exception(): - logger.error(f"Model refresh task failed: {task.exception()}") + log.error(f"Model refresh task failed: {task.exception()}") traceback.print_exception(task.exception()) else: - logger.debug("Model refresh task completed") + log.debug("Model refresh task completed") REGISTRY_REFRESH_TASK.add_done_callback(cb) return impls @@ -351,23 +351,23 @@ async def construct_stack( async def shutdown_stack(impls: dict[Api, Any]): for impl in impls.values(): impl_name = impl.__class__.__name__ - logger.info(f"Shutting down {impl_name}") + log.info(f"Shutting down {impl_name}") try: if hasattr(impl, "shutdown"): await asyncio.wait_for(impl.shutdown(), timeout=5) else: - logger.warning(f"No shutdown method for {impl_name}") + log.warning(f"No shutdown method for {impl_name}") except TimeoutError: - logger.exception(f"Shutdown timeout for {impl_name}") + log.exception(f"Shutdown timeout for {impl_name}") except (Exception, asyncio.CancelledError) as e: - logger.exception(f"Failed to shutdown {impl_name}: {e}") + log.exception(f"Failed to shutdown {impl_name}: {e}") global TEST_RECORDING_CONTEXT if TEST_RECORDING_CONTEXT: try: TEST_RECORDING_CONTEXT.__exit__(None, None, None) except Exception as e: - logger.error(f"Error during inference recording cleanup: {e}") + log.error(f"Error during inference recording cleanup: {e}") global REGISTRY_REFRESH_TASK if REGISTRY_REFRESH_TASK: @@ -375,14 +375,14 @@ async def shutdown_stack(impls: dict[Api, Any]): async def refresh_registry_once(impls: dict[Api, Any]): - logger.debug("refreshing registry") + log.debug("refreshing registry") routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)] for routing_table in routing_tables: await routing_table.refresh() async def refresh_registry_task(impls: dict[Api, Any]): - logger.info("starting registry refresh task") + log.info("starting registry refresh task") while True: await refresh_registry_once(impls) diff --git a/llama_stack/core/utils/config_resolution.py b/llama_stack/core/utils/config_resolution.py index 30cd71e15..5e7f3d1da 100644 --- a/llama_stack/core/utils/config_resolution.py +++ b/llama_stack/core/utils/config_resolution.py @@ -10,7 +10,7 @@ from pathlib import Path from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.log import get_logger -logger = get_logger(name=__name__, category="config_resolution") +log = get_logger(name=__name__, category="config_resolution") DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions" @@ -42,25 +42,25 @@ def resolve_config_or_distro( # Strategy 1: Try as file path first config_path = Path(config_or_distro) if config_path.exists() and config_path.is_file(): - logger.info(f"Using file path: {config_path}") + log.info(f"Using file path: {config_path}") return config_path.resolve() # Strategy 2: Try as distribution name (if no .yaml extension) if not config_or_distro.endswith(".yaml"): distro_config = _get_distro_config_path(config_or_distro, mode) if distro_config.exists(): - logger.info(f"Using distribution: {distro_config}") + log.info(f"Using distribution: {distro_config}") return distro_config # Strategy 3: Try as built distribution name distrib_config = DISTRIBS_BASE_DIR / f"llamastack-{config_or_distro}" / f"{config_or_distro}-{mode}.yaml" if distrib_config.exists(): - logger.info(f"Using built distribution: {distrib_config}") + log.info(f"Using built distribution: {distrib_config}") return distrib_config distrib_config = DISTRIBS_BASE_DIR / f"{config_or_distro}" / f"{config_or_distro}-{mode}.yaml" if distrib_config.exists(): - logger.info(f"Using built distribution: {distrib_config}") + log.info(f"Using built distribution: {distrib_config}") return distrib_config # Strategy 4: Failed - provide helpful error diff --git a/llama_stack/core/utils/exec.py b/llama_stack/core/utils/exec.py index 1b2b782fe..12fb82d01 100644 --- a/llama_stack/core/utils/exec.py +++ b/llama_stack/core/utils/exec.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging +import importlib import os import signal import subprocess @@ -12,9 +12,9 @@ import sys from termcolor import cprint -log = logging.getLogger(__name__) +from llama_stack.log import get_logger -import importlib +log = get_logger(name=__name__, category="core") def formulate_run_args(image_type: str, image_name: str) -> list: diff --git a/llama_stack/core/utils/prompt_for_config.py b/llama_stack/core/utils/prompt_for_config.py index 26f6920e0..bac0531ed 100644 --- a/llama_stack/core/utils/prompt_for_config.py +++ b/llama_stack/core/utils/prompt_for_config.py @@ -6,7 +6,6 @@ import inspect import json -import logging from enum import Enum from typing import Annotated, Any, Literal, Union, get_args, get_origin @@ -14,7 +13,9 @@ from pydantic import BaseModel from pydantic.fields import FieldInfo from pydantic_core import PydanticUndefinedType -log = logging.getLogger(__name__) +from llama_stack.log import get_logger + +log = get_logger(name=__name__, category="core") def is_list_of_primitives(field_type): diff --git a/llama_stack/log.py b/llama_stack/log.py index ab53e08c0..ce3d77f57 100644 --- a/llama_stack/log.py +++ b/llama_stack/log.py @@ -4,11 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging +import logging # allow-direct-logging import os import re import sys -from logging.config import dictConfig +from logging.config import dictConfig # allow-direct-logging from rich.console import Console from rich.errors import MarkupError diff --git a/llama_stack/models/llama/llama3/multimodal/encoder_utils.py b/llama_stack/models/llama/llama3/multimodal/encoder_utils.py index 5b5969d89..99e2af8c7 100644 --- a/llama_stack/models/llama/llama3/multimodal/encoder_utils.py +++ b/llama_stack/models/llama/llama3/multimodal/encoder_utils.py @@ -13,7 +13,7 @@ # Copyright (c) Meta Platforms, Inc. and its affiliates. import math -from logging import getLogger +from logging import getLogger # allow-direct-logging import torch import torch.nn.functional as F diff --git a/llama_stack/models/llama/llama3/multimodal/image_transform.py b/llama_stack/models/llama/llama3/multimodal/image_transform.py index f2761ee47..db89048dc 100644 --- a/llama_stack/models/llama/llama3/multimodal/image_transform.py +++ b/llama_stack/models/llama/llama3/multimodal/image_transform.py @@ -13,7 +13,7 @@ import math from collections import defaultdict -from logging import getLogger +from logging import getLogger # allow-direct-logging from typing import Any import torch diff --git a/llama_stack/models/llama/llama3/multimodal/model.py b/llama_stack/models/llama/llama3/multimodal/model.py index 5f1c3605c..7817ecd39 100644 --- a/llama_stack/models/llama/llama3/multimodal/model.py +++ b/llama_stack/models/llama/llama3/multimodal/model.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import math from collections.abc import Callable from functools import partial @@ -22,6 +21,8 @@ from PIL import Image as PIL_Image from torch import Tensor, nn from torch.distributed import _functional_collectives as funcol +from llama_stack.log import get_logger + from ..model import ModelArgs, RMSNorm, apply_rotary_emb, precompute_freqs_cis from .encoder_utils import ( build_encoder_attention_mask, @@ -34,9 +35,10 @@ from .encoder_utils import ( from .image_transform import VariableSizeImageTransform from .utils import get_negative_inf_value, to_2tuple -logger = logging.getLogger(__name__) MP_SCALE = 8 +log = get_logger(name=__name__, category="core") + def reduce_from_tensor_model_parallel_region(input_): """All-reduce the input tensor across model parallel group.""" @@ -415,7 +417,7 @@ class VisionEncoder(nn.Module): ) state_dict[prefix + "gated_positional_embedding"] = global_pos_embed state_dict[prefix + "gated_positional_embedding_gate"] = torch.zeros(1, dtype=global_pos_embed.dtype) - logger.info(f"Initialized global positional embedding with size {global_pos_embed.size()}") + log.info(f"Initialized global positional embedding with size {global_pos_embed.size()}") else: global_pos_embed = resize_global_position_embedding( state_dict[prefix + "gated_positional_embedding"], @@ -423,7 +425,7 @@ class VisionEncoder(nn.Module): self.max_num_tiles, self.max_num_tiles, ) - logger.info( + log.info( f"Resized global positional embedding from {state_dict[prefix + 'gated_positional_embedding'].size()} to {global_pos_embed.size()}" ) state_dict[prefix + "gated_positional_embedding"] = global_pos_embed @@ -771,7 +773,7 @@ class TilePositionEmbedding(nn.Module): if embed is not None: # reshape the weights to the correct shape nt_old, nt_old, _, w = embed.shape - logging.info(f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}") + log.info(f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}") embed_new = TilePositionEmbedding._dynamic_resize(embed, self.num_tiles) # assign the weights to the module state_dict[prefix + "embedding"] = embed_new diff --git a/llama_stack/models/llama/llama3/tokenizer.py b/llama_stack/models/llama/llama3/tokenizer.py index e47b579e3..623d0b607 100644 --- a/llama_stack/models/llama/llama3/tokenizer.py +++ b/llama_stack/models/llama/llama3/tokenizer.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from collections.abc import Collection, Iterator, Sequence, Set -from logging import getLogger +from logging import getLogger # allow-direct-logging from pathlib import Path from typing import ( Literal, diff --git a/llama_stack/models/llama/llama3/tool_utils.py b/llama_stack/models/llama/llama3/tool_utils.py index 574080184..f62660041 100644 --- a/llama_stack/models/llama/llama3/tool_utils.py +++ b/llama_stack/models/llama/llama3/tool_utils.py @@ -11,7 +11,7 @@ from llama_stack.log import get_logger from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat -logger = get_logger(name=__name__, category="inference") +log = get_logger(name=__name__, category="inference") BUILTIN_TOOL_PATTERN = r'\b(?P\w+)\.call\(query="(?P[^"]*)"\)' CUSTOM_TOOL_CALL_PATTERN = re.compile(r"[^}]+)>(?P{.*?})") @@ -215,7 +215,7 @@ class ToolUtils: # FIXME: Enable multiple tool calls return function_calls[0] else: - logger.debug(f"Did not parse tool call from message body: {message_body}") + log.debug(f"Did not parse tool call from message body: {message_body}") return None @staticmethod diff --git a/llama_stack/models/llama/llama4/quantization/loader.py b/llama_stack/models/llama/llama4/quantization/loader.py index 223744a5f..44125f1a4 100644 --- a/llama_stack/models/llama/llama4/quantization/loader.py +++ b/llama_stack/models/llama/llama4/quantization/loader.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import os from collections.abc import Callable @@ -13,11 +12,13 @@ from fairscale.nn.model_parallel.initialize import get_model_parallel_rank from torch import Tensor, nn from torch.nn import functional as F +from llama_stack.log import get_logger + from ...datatypes import QuantizationMode from ..model import Transformer, TransformerBlock from ..moe import MoE -log = logging.getLogger(__name__) +logger = get_logger(__name__, category="core") def swiglu_wrapper_no_reduce( @@ -186,7 +187,7 @@ def logging_callbacks( if use_rich_progress: console.print(message) elif rank == 0: # Only log from rank 0 for non-rich logging - log.info(message) + logger.info(message) total_blocks = sum(1 for _, block in model.named_modules() if should_quantize_block(block)) progress = None @@ -220,6 +221,6 @@ def logging_callbacks( if completed is not None: progress.update(task_id, completed=completed) elif rank == 0 and completed and completed % 10 == 0: - log.info(f"Rank {rank}: {completed}/{total_blocks} blocks completed") + logger.info(f"Rank {rank}: {completed}/{total_blocks} blocks completed") return progress, log_status, update_status diff --git a/llama_stack/models/llama/llama4/tokenizer.py b/llama_stack/models/llama/llama4/tokenizer.py index e12b2cae0..4078d3d70 100644 --- a/llama_stack/models/llama/llama4/tokenizer.py +++ b/llama_stack/models/llama/llama4/tokenizer.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from collections.abc import Collection, Iterator, Sequence, Set -from logging import getLogger +from logging import getLogger # allow-direct-logging from pathlib import Path from typing import ( Literal, diff --git a/llama_stack/models/llama/quantize_impls.py b/llama_stack/models/llama/quantize_impls.py index a6400c5c9..73349c575 100644 --- a/llama_stack/models/llama/quantize_impls.py +++ b/llama_stack/models/llama/quantize_impls.py @@ -6,16 +6,17 @@ # type: ignore import collections -import logging -log = logging.getLogger(__name__) +from llama_stack.log import get_logger + +logger = get_logger(__name__, category="core") try: import fbgemm_gpu.experimental.gen_ai # noqa: F401 - log.info("Using efficient FP8 or INT4 operators in FBGEMM.") + logger.info("Using efficient FP8 or INT4 operators in FBGEMM.") except ImportError: - log.error("No efficient FP8 or INT4 operators. Please install FBGEMM.") + logger.error("No efficient FP8 or INT4 operators. Please install FBGEMM.") raise import torch diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 5f7c90879..4579157a2 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -84,7 +84,7 @@ MEMORY_QUERY_TOOL = "knowledge_search" WEB_SEARCH_TOOL = "web_search" RAG_TOOL_GROUP = "builtin::rag" -logger = get_logger(name=__name__, category="agents") +log = get_logger(name=__name__, category="agents") class ChatAgent(ShieldRunnerMixin): @@ -612,7 +612,7 @@ class ChatAgent(ShieldRunnerMixin): ) if n_iter >= self.agent_config.max_infer_iters: - logger.info(f"done with MAX iterations ({n_iter}), exiting.") + log.info(f"done with MAX iterations ({n_iter}), exiting.") # NOTE: mark end_of_turn to indicate to client that we are done with the turn # Do not continue the tool call loop after this point message.stop_reason = StopReason.end_of_turn @@ -620,7 +620,7 @@ class ChatAgent(ShieldRunnerMixin): break if stop_reason == StopReason.out_of_tokens: - logger.info("out of token budget, exiting.") + log.info("out of token budget, exiting.") yield message break @@ -634,7 +634,7 @@ class ChatAgent(ShieldRunnerMixin): message.content = [message.content] + output_attachments yield message else: - logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}") + log.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}") input_messages = input_messages + [message] else: input_messages = input_messages + [message] @@ -889,7 +889,7 @@ class ChatAgent(ShieldRunnerMixin): else: tool_name_str = tool_name - logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}") + log.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}") result = await self.tool_runtime_api.invoke_tool( tool_name=tool_name_str, kwargs={ @@ -899,7 +899,7 @@ class ChatAgent(ShieldRunnerMixin): **self.tool_name_to_args.get(tool_name_str, {}), }, ) - logger.debug(f"tool call {tool_name_str} completed with result: {result}") + log.debug(f"tool call {tool_name_str} completed with result: {result}") return result diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 15695ec48..661f7770e 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import uuid from collections.abc import AsyncGenerator from datetime import UTC, datetime @@ -42,6 +41,7 @@ from llama_stack.apis.safety import Safety from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO from llama_stack.core.datatypes import AccessRule +from llama_stack.log import get_logger from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl from llama_stack.providers.utils.pagination import paginate_records from llama_stack.providers.utils.responses.responses_store import ResponsesStore @@ -51,7 +51,7 @@ from .config import MetaReferenceAgentsImplConfig from .openai_responses import OpenAIResponsesImpl from .persistence import AgentInfo -logger = logging.getLogger() +log = get_logger(name=__name__, category="agents") class MetaReferenceAgentsImpl(Agents): @@ -268,7 +268,7 @@ class MetaReferenceAgentsImpl(Agents): # Get the agent info using the key agent_info_json = await self.persistence_store.get(agent_key) if not agent_info_json: - logger.error(f"Could not find agent info for key {agent_key}") + log.error(f"Could not find agent info for key {agent_key}") continue try: @@ -281,7 +281,7 @@ class MetaReferenceAgentsImpl(Agents): ) ) except Exception as e: - logger.error(f"Error parsing agent info for {agent_id}: {e}") + log.error(f"Error parsing agent info for {agent_id}: {e}") continue # Convert Agent objects to dictionaries diff --git a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py index 7eb2b3897..b3712626d 100644 --- a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py @@ -75,7 +75,7 @@ from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefiniti from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool from llama_stack.providers.utils.responses.responses_store import ResponsesStore -logger = get_logger(name=__name__, category="openai_responses") +log = get_logger(name=__name__, category="openai_responses") OPENAI_RESPONSES_PREFIX = "openai_responses:" @@ -544,12 +544,12 @@ class OpenAIResponsesImpl: break if function_tool_calls: - logger.info("Exiting inference loop since there is a function (client-side) tool call") + log.info("Exiting inference loop since there is a function (client-side) tool call") break n_iter += 1 if n_iter >= max_infer_iters: - logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {max_infer_iters=}") + log.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {max_infer_iters=}") break messages = next_turn_messages @@ -698,7 +698,7 @@ class OpenAIResponsesImpl: ) return search_response.data except Exception as e: - logger.warning(f"Failed to search vector store {vector_store_id}: {e}") + log.warning(f"Failed to search vector store {vector_store_id}: {e}") return [] # Run all searches in parallel using gather diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index 7a8d99b78..c88b7b892 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import json -import logging import uuid from datetime import UTC, datetime @@ -15,9 +14,10 @@ from llama_stack.core.access_control.access_control import AccessDeniedError, is from llama_stack.core.access_control.datatypes import AccessRule from llama_stack.core.datatypes import User from llama_stack.core.request_headers import get_authenticated_user +from llama_stack.log import get_logger from llama_stack.providers.utils.kvstore import KVStore -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="agents") class AgentSessionInfo(Session): diff --git a/llama_stack/providers/inline/agents/meta_reference/safety.py b/llama_stack/providers/inline/agents/meta_reference/safety.py index 605f387b7..b8a5d8a95 100644 --- a/llama_stack/providers/inline/agents/meta_reference/safety.py +++ b/llama_stack/providers/inline/agents/meta_reference/safety.py @@ -5,13 +5,13 @@ # the root directory of this source tree. import asyncio -import logging from llama_stack.apis.inference import Message from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel +from llama_stack.log import get_logger from llama_stack.providers.utils.telemetry import tracing -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="agents") class SafetyException(Exception): # noqa: N818 diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 88d7a98ec..eea5b8353 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -73,11 +73,12 @@ from .config import MetaReferenceInferenceConfig from .generators import LlamaGenerator from .model_parallel import LlamaModelParallelGenerator -log = get_logger(__name__, category="inference") # there's a single model parallel process running serving the model. for now, # we don't support multiple concurrent requests to this process. SEMAPHORE = asyncio.Semaphore(1) +logger = get_logger(__name__, category="inference") + def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator: return LlamaGenerator(config, model_id, llama_model) @@ -144,7 +145,7 @@ class MetaReferenceInferenceImpl( return model async def load_model(self, model_id, llama_model) -> None: - log.info(f"Loading model `{model_id}`") + logger.info(f"Loading model `{model_id}`") builder_params = [self.config, model_id, llama_model] @@ -166,7 +167,7 @@ class MetaReferenceInferenceImpl( self.model_id = model_id self.llama_model = llama_model - log.info("Warming up...") + logger.info("Warming up...") await self.completion( model_id=model_id, content="Hello, world!", @@ -177,7 +178,7 @@ class MetaReferenceInferenceImpl( messages=[UserMessage(content="Hi how are you?")], sampling_params=SamplingParams(max_tokens=20), ) - log.info("Warmed up!") + logger.info("Warmed up!") def check_model(self, request) -> None: if self.model_id is None or self.llama_model is None: diff --git a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py index 7ade75032..015a6e83f 100644 --- a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +++ b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py @@ -12,7 +12,6 @@ import copy import json -import logging import multiprocessing import os import tempfile @@ -32,13 +31,14 @@ from fairscale.nn.model_parallel.initialize import ( from pydantic import BaseModel, Field from torch.distributed.launcher.api import LaunchConfig, elastic_launch +from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import GenerationResult from llama_stack.providers.utils.inference.prompt_adapter import ( ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent, ) -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="inference") class ProcessingMessageName(str, Enum): @@ -236,7 +236,7 @@ def worker_process_entrypoint( except StopIteration: break - log.info("[debug] worker process done") + log.info("[debug] worker process done") def launch_dist_group( diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index fea8a8189..a1503e777 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging from collections.abc import AsyncGenerator from llama_stack.apis.inference import ( @@ -32,8 +31,6 @@ from llama_stack.providers.utils.inference.openai_compat import ( from .config import SentenceTransformersInferenceConfig -log = logging.getLogger(__name__) - class SentenceTransformersInferenceImpl( OpenAIChatCompletionToLlamaStackMixin, diff --git a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py index 2574b995b..3ac968cfc 100644 --- a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py +++ b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py @@ -6,7 +6,6 @@ import gc import json -import logging import multiprocessing from pathlib import Path from typing import Any @@ -28,6 +27,7 @@ from llama_stack.apis.post_training import ( LoraFinetuningConfig, TrainingConfig, ) +from llama_stack.log import get_logger from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device from ..config import HuggingFacePostTrainingConfig @@ -44,7 +44,7 @@ from ..utils import ( split_dataset, ) -logger = logging.getLogger(__name__) +log = get_logger(name=__name__, category="core") class HFFinetuningSingleDevice: @@ -69,14 +69,14 @@ class HFFinetuningSingleDevice: try: messages = json.loads(row["chat_completion_input"]) if not isinstance(messages, list) or len(messages) != 1: - logger.warning(f"Invalid chat_completion_input format: {row['chat_completion_input']}") + log.warning(f"Invalid chat_completion_input format: {row['chat_completion_input']}") return None, None if "content" not in messages[0]: - logger.warning(f"Message missing content: {messages[0]}") + log.warning(f"Message missing content: {messages[0]}") return None, None return messages[0]["content"], row["expected_answer"] except json.JSONDecodeError: - logger.warning(f"Failed to parse chat_completion_input: {row['chat_completion_input']}") + log.warning(f"Failed to parse chat_completion_input: {row['chat_completion_input']}") return None, None return None, None @@ -86,13 +86,13 @@ class HFFinetuningSingleDevice: try: dialog = json.loads(row["dialog"]) if not isinstance(dialog, list) or len(dialog) < 2: - logger.warning(f"Dialog must have at least 2 messages: {row['dialog']}") + log.warning(f"Dialog must have at least 2 messages: {row['dialog']}") return None, None if dialog[0].get("role") != "user": - logger.warning(f"First message must be from user: {dialog[0]}") + log.warning(f"First message must be from user: {dialog[0]}") return None, None if not any(msg.get("role") == "assistant" for msg in dialog): - logger.warning("Dialog must have at least one assistant message") + log.warning("Dialog must have at least one assistant message") return None, None # Convert to human/gpt format @@ -100,14 +100,14 @@ class HFFinetuningSingleDevice: conversations = [] for msg in dialog: if "role" not in msg or "content" not in msg: - logger.warning(f"Message missing role or content: {msg}") + log.warning(f"Message missing role or content: {msg}") continue conversations.append({"from": role_map[msg["role"]], "value": msg["content"]}) # Format as a single conversation return conversations[0]["value"], conversations[1]["value"] except json.JSONDecodeError: - logger.warning(f"Failed to parse dialog: {row['dialog']}") + log.warning(f"Failed to parse dialog: {row['dialog']}") return None, None return None, None @@ -198,7 +198,7 @@ class HFFinetuningSingleDevice: """ import asyncio - logger.info("Starting training process with async wrapper") + log.info("Starting training process with async wrapper") asyncio.run( self._run_training( model=model, @@ -228,14 +228,14 @@ class HFFinetuningSingleDevice: raise ValueError("DataConfig is required for training") # Load dataset - logger.info(f"Loading dataset: {config.data_config.dataset_id}") + log.info(f"Loading dataset: {config.data_config.dataset_id}") rows = await load_rows_from_dataset(self.datasetio_api, config.data_config.dataset_id) if not self.validate_dataset_format(rows): raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input") - logger.info(f"Loaded {len(rows)} rows from dataset") + log.info(f"Loaded {len(rows)} rows from dataset") # Initialize tokenizer - logger.info(f"Initializing tokenizer for model: {model}") + log.info(f"Initializing tokenizer for model: {model}") try: tokenizer = AutoTokenizer.from_pretrained(model, **provider_config.model_specific_config) @@ -257,16 +257,16 @@ class HFFinetuningSingleDevice: # This ensures consistent sequence lengths across the training process tokenizer.model_max_length = provider_config.max_seq_length - logger.info("Tokenizer initialized successfully") + log.info("Tokenizer initialized successfully") except Exception as e: raise RuntimeError(f"Failed to initialize tokenizer: {str(e)}") from e # Create and preprocess dataset - logger.info("Creating and preprocessing dataset") + log.info("Creating and preprocessing dataset") try: ds = self._create_dataset(rows, config, provider_config) ds = self._preprocess_dataset(ds, tokenizer, provider_config) - logger.info(f"Dataset created with {len(ds)} examples") + log.info(f"Dataset created with {len(ds)} examples") except Exception as e: raise ValueError(f"Failed to create dataset: {str(e)}") from e @@ -293,11 +293,11 @@ class HFFinetuningSingleDevice: Returns: Configured SFTConfig object """ - logger.info("Configuring training arguments") + log.info("Configuring training arguments") lr = 2e-5 if config.optimizer_config: lr = config.optimizer_config.lr - logger.info(f"Using custom learning rate: {lr}") + log.info(f"Using custom learning rate: {lr}") # Validate data config if not config.data_config: @@ -350,17 +350,17 @@ class HFFinetuningSingleDevice: peft_config: Optional LoRA configuration output_dir_path: Path to save the model """ - logger.info("Saving final model") + log.info("Saving final model") model_obj.config.use_cache = True if peft_config: - logger.info("Merging LoRA weights with base model") + log.info("Merging LoRA weights with base model") model_obj = trainer.model.merge_and_unload() else: model_obj = trainer.model save_path = output_dir_path / "merged_model" - logger.info(f"Saving model to {save_path}") + log.info(f"Saving model to {save_path}") model_obj.save_pretrained(save_path) async def _run_training( @@ -380,13 +380,13 @@ class HFFinetuningSingleDevice: setup_signal_handlers() # Convert config dicts back to objects - logger.info("Initializing configuration objects") + log.info("Initializing configuration objects") provider_config_obj = HuggingFacePostTrainingConfig(**provider_config) config_obj = TrainingConfig(**config) # Initialize and validate device device = setup_torch_device(provider_config_obj.device) - logger.info(f"Using device '{device}'") + log.info(f"Using device '{device}'") # Load dataset and tokenizer train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config_obj, provider_config_obj) @@ -409,7 +409,7 @@ class HFFinetuningSingleDevice: model_obj = load_model(model, device, provider_config_obj) # Initialize trainer - logger.info("Initializing SFTTrainer") + log.info("Initializing SFTTrainer") trainer = SFTTrainer( model=model_obj, train_dataset=train_dataset, @@ -420,9 +420,9 @@ class HFFinetuningSingleDevice: try: # Train - logger.info("Starting training") + log.info("Starting training") trainer.train() - logger.info("Training completed successfully") + log.info("Training completed successfully") # Save final model if output directory is provided if output_dir_path: @@ -430,12 +430,12 @@ class HFFinetuningSingleDevice: finally: # Clean up resources - logger.info("Cleaning up resources") + log.info("Cleaning up resources") if hasattr(trainer, "model"): evacuate_model_from_device(trainer.model, device.type) del trainer gc.collect() - logger.info("Cleanup completed") + log.info("Cleanup completed") async def train( self, @@ -449,7 +449,7 @@ class HFFinetuningSingleDevice: """Train a model using HuggingFace's SFTTrainer""" # Initialize and validate device device = setup_torch_device(provider_config.device) - logger.info(f"Using device '{device}'") + log.info(f"Using device '{device}'") output_dir_path = None if output_dir: @@ -479,7 +479,7 @@ class HFFinetuningSingleDevice: raise ValueError("DataConfig is required for training") # Train in a separate process - logger.info("Starting training in separate process") + log.info("Starting training in separate process") try: # Setup multiprocessing for device if device.type in ["cuda", "mps"]: diff --git a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py index a7c19faac..cede29edd 100644 --- a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py +++ b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import gc -import logging import multiprocessing from pathlib import Path from typing import Any @@ -24,6 +23,7 @@ from llama_stack.apis.post_training import ( DPOAlignmentConfig, TrainingConfig, ) +from llama_stack.log import get_logger from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device from ..config import HuggingFacePostTrainingConfig @@ -40,7 +40,7 @@ from ..utils import ( split_dataset, ) -logger = logging.getLogger(__name__) +logger = get_logger(__name__, category="core") class HFDPOAlignmentSingleDevice: diff --git a/llama_stack/providers/inline/post_training/huggingface/utils.py b/llama_stack/providers/inline/post_training/huggingface/utils.py index 3147c19ab..53c93fb7b 100644 --- a/llama_stack/providers/inline/post_training/huggingface/utils.py +++ b/llama_stack/providers/inline/post_training/huggingface/utils.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import os import signal import sys @@ -19,10 +18,11 @@ from transformers import AutoConfig, AutoModelForCausalLM from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.post_training import Checkpoint, TrainingConfig +from llama_stack.log import get_logger from .config import HuggingFacePostTrainingConfig -logger = logging.getLogger(__name__) +logger = get_logger(__name__, category="core") def setup_environment(): diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 49e1c95b8..95887196e 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import os import time from datetime import UTC, datetime @@ -19,6 +18,7 @@ from torch.utils.data import DataLoader, DistributedSampler from torchtune import modules, training from torchtune import utils as torchtune_utils from torchtune.data import padded_collate_sft +from torchtune.models.llama3._tokenizer import Llama3Tokenizer from torchtune.modules.loss import CEWithChunkedOutputLoss from torchtune.modules.peft import ( get_adapter_params, @@ -45,6 +45,7 @@ from llama_stack.apis.post_training import ( ) from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR from llama_stack.core.utils.model_utils import model_local_dir +from llama_stack.log import get_logger from llama_stack.models.llama.sku_list import resolve_model from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device from llama_stack.providers.inline.post_training.torchtune.common import utils @@ -56,9 +57,7 @@ from llama_stack.providers.inline.post_training.torchtune.config import ( ) from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset -log = logging.getLogger(__name__) - -from torchtune.models.llama3._tokenizer import Llama3Tokenizer +log = get_logger(name=__name__, category="core") class LoraFinetuningSingleDevice: diff --git a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py index be05ee436..1b9397a4d 100644 --- a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py +++ b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging from typing import Any from llama_stack.apis.inference import Message @@ -15,13 +14,14 @@ from llama_stack.apis.safety import ( ViolationLevel, ) from llama_stack.apis.shields import Shield +from llama_stack.log import get_logger from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) from .config import CodeScannerConfig -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="safety") ALLOWED_CODE_SCANNER_MODEL_IDS = [ "CodeScanner", diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index 796771ee1..3138834a0 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging from typing import Any import torch @@ -19,6 +18,7 @@ from llama_stack.apis.safety import ( ) from llama_stack.apis.shields import Shield from llama_stack.core.utils.model_utils import model_local_dir +from llama_stack.log import get_logger from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, @@ -26,10 +26,10 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import PromptGuardConfig, PromptGuardType -log = logging.getLogger(__name__) - PROMPT_GUARD_MODEL = "Prompt-Guard-86M" +log = get_logger(name=__name__, category="safety") + class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate): def __init__(self, config: PromptGuardConfig, _deps) -> None: diff --git a/llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py b/llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py index b74c3826e..eb4d20012 100644 --- a/llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py +++ b/llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py @@ -7,7 +7,6 @@ import collections import functools import json -import logging import random import re import string @@ -20,7 +19,9 @@ import nltk from pythainlp.tokenize import sent_tokenize as sent_tokenize_thai from pythainlp.tokenize import word_tokenize as word_tokenize_thai -logger = logging.getLogger() +from llama_stack.log import get_logger + +log = get_logger(name=__name__, category="core") WORD_LIST = [ "western", @@ -1726,7 +1727,7 @@ def get_langid(text: str, lid_path: str | None = None) -> str: try: line_langs.append(langdetect.detect(line)) except langdetect.LangDetectException as e: - logger.info("Unable to detect language for text %s due to %s", line, e) # refex: disable=pytotw.037 + log.info("Unable to detect language for text %s due to %s", line, e) # refex: disable=pytotw.037 if len(line_langs) == 0: return "en" @@ -1885,7 +1886,7 @@ class ResponseLanguageChecker(Instruction): return langdetect.detect(value) == self._language except langdetect.LangDetectException as e: # Count as instruction is followed. - logger.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037 + log.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037 return True @@ -3110,7 +3111,7 @@ class CapitalLettersEnglishChecker(Instruction): return value.isupper() and langdetect.detect(value) == "en" except langdetect.LangDetectException as e: # Count as instruction is followed. - logger.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037 + log.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037 return True @@ -3139,7 +3140,7 @@ class LowercaseLettersEnglishChecker(Instruction): return value.islower() and langdetect.detect(value) == "en" except langdetect.LangDetectException as e: # Count as instruction is followed. - logger.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037 + log.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037 return True diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index 6a7c7885c..c8cd49cf6 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import asyncio -import logging import secrets import string from typing import Any @@ -32,6 +31,7 @@ from llama_stack.apis.tools import ( ToolRuntime, ) from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO +from llama_stack.log import get_logger from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str from llama_stack.providers.utils.memory.vector_store import ( @@ -42,7 +42,7 @@ from llama_stack.providers.utils.memory.vector_store import ( from .config import RagToolRuntimeConfig from .context_retriever import generate_rag_query -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="tools") def make_random_string(length: int = 8): diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index 7a5373726..84913fe27 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -8,7 +8,6 @@ import asyncio import base64 import io import json -import logging from typing import Any import faiss @@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import ( QueryChunksResponse, VectorIO, ) +from llama_stack.log import get_logger from llama_stack.providers.datatypes import ( HealthResponse, HealthStatus, @@ -39,7 +39,7 @@ from llama_stack.providers.utils.memory.vector_store import ( from .config import FaissVectorIOConfig -logger = logging.getLogger(__name__) +log = get_logger(name=__name__, category="core") VERSION = "v3" VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::" @@ -83,7 +83,7 @@ class FaissIndex(EmbeddingIndex): self.index = faiss.deserialize_index(np.load(buffer, allow_pickle=False)) self.chunk_ids = [chunk.chunk_id for chunk in self.chunk_by_index.values()] except Exception as e: - logger.debug(e, exc_info=True) + log.debug(e, exc_info=True) raise ValueError( "Error deserializing Faiss index from storage. If you recently upgraded your Llama Stack, Faiss, " "or NumPy versions, you may need to delete the index and re-create it again or downgrade versions.\n" @@ -262,7 +262,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr assert self.kvstore is not None if vector_db_id not in self.cache: - logger.warning(f"Vector DB {vector_db_id} not found") + log.warning(f"Vector DB {vector_db_id} not found") return await self.cache[vector_db_id].index.delete() diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index 1fff7b484..451aa86dc 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import asyncio -import logging import re import sqlite3 import struct @@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import ( QueryChunksResponse, VectorIO, ) +from llama_stack.log import get_logger from llama_stack.providers.datatypes import VectorDBsProtocolPrivate from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore @@ -35,7 +35,7 @@ from llama_stack.providers.utils.memory.vector_store import ( VectorDBWithIndex, ) -logger = logging.getLogger(__name__) +log = get_logger(name=__name__, category="core") # Specifying search mode is dependent on the VectorIO provider. VECTOR_SEARCH = "vector" @@ -257,7 +257,7 @@ class SQLiteVecIndex(EmbeddingIndex): except sqlite3.Error as e: connection.rollback() - logger.error(f"Error inserting into {self.vector_table}: {e}") + log.error(f"Error inserting into {self.vector_table}: {e}") raise finally: @@ -306,7 +306,7 @@ class SQLiteVecIndex(EmbeddingIndex): try: chunk = Chunk.model_validate_json(chunk_json) except Exception as e: - logger.error(f"Error parsing chunk JSON for id {_id}: {e}") + log.error(f"Error parsing chunk JSON for id {_id}: {e}") continue chunks.append(chunk) scores.append(score) @@ -352,7 +352,7 @@ class SQLiteVecIndex(EmbeddingIndex): try: chunk = Chunk.model_validate_json(chunk_json) except Exception as e: - logger.error(f"Error parsing chunk JSON for id {_id}: {e}") + log.error(f"Error parsing chunk JSON for id {_id}: {e}") continue chunks.append(chunk) scores.append(score) @@ -447,7 +447,7 @@ class SQLiteVecIndex(EmbeddingIndex): connection.commit() except Exception as e: connection.rollback() - logger.error(f"Error deleting chunk {chunk_id}: {e}") + log.error(f"Error deleting chunk {chunk_id}: {e}") raise finally: cur.close() @@ -530,7 +530,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc async def unregister_vector_db(self, vector_db_id: str) -> None: if vector_db_id not in self.cache: - logger.warning(f"Vector DB {vector_db_id} not found") + log.warning(f"Vector DB {vector_db_id} not found") return await self.cache[vector_db_id].index.delete() del self.cache[vector_db_id] diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index ca4c7b578..88dab3d95 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -65,7 +65,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import FireworksImplConfig from .models import MODEL_ENTRIES -logger = get_logger(name=__name__, category="inference") +log = get_logger(name=__name__, category="inference") class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): @@ -256,7 +256,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv "stream": bool(request.stream), **self._build_options(request.sampling_params, request.response_format, request.logprobs), } - logger.debug(f"params to fireworks: {params}") + log.debug(f"params to fireworks: {params}") return params diff --git a/llama_stack/providers/remote/inference/llama_openai_compat/llama.py b/llama_stack/providers/remote/inference/llama_openai_compat/llama.py index 4857c6723..cd3e55b7f 100644 --- a/llama_stack/providers/remote/inference/llama_openai_compat/llama.py +++ b/llama_stack/providers/remote/inference/llama_openai_compat/llama.py @@ -3,7 +3,6 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin @@ -11,8 +10,6 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .models import MODEL_ENTRIES -logger = logging.getLogger(__name__) - class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): """ diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 7bc3fd0c9..fdb678e3e 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import warnings from collections.abc import AsyncIterator @@ -33,6 +32,7 @@ from llama_stack.apis.inference import ( ToolChoice, ToolConfig, ) +from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, @@ -54,7 +54,7 @@ from .openai_utils import ( ) from .utils import _is_nvidia_hosted -logger = logging.getLogger(__name__) +log = get_logger(name=__name__, category="inference") class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper): @@ -75,7 +75,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper): # TODO(mf): filter by available models ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) - logger.info(f"Initializing NVIDIAInferenceAdapter({config.url})...") + log.info(f"Initializing NVIDIAInferenceAdapter({config.url})...") if _is_nvidia_hosted(config): if not config.api_key: diff --git a/llama_stack/providers/remote/inference/nvidia/utils.py b/llama_stack/providers/remote/inference/nvidia/utils.py index 74019999e..5912e72ef 100644 --- a/llama_stack/providers/remote/inference/nvidia/utils.py +++ b/llama_stack/providers/remote/inference/nvidia/utils.py @@ -4,13 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import httpx +from llama_stack.log import get_logger + from . import NVIDIAConfig -logger = logging.getLogger(__name__) +log = get_logger(name=__name__, category="inference") def _is_nvidia_hosted(config: NVIDIAConfig) -> bool: @@ -44,7 +45,7 @@ async def check_health(config: NVIDIAConfig) -> None: RuntimeError: If the server is not running or ready """ if not _is_nvidia_hosted(config): - logger.info("Checking NVIDIA NIM health...") + log.info("Checking NVIDIA NIM health...") try: is_live, is_ready = await _get_health(config.url) if not is_live: diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 26b4dec76..49ad3f7ef 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -85,7 +85,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .models import MODEL_ENTRIES -logger = get_logger(name=__name__, category="inference") +log = get_logger(name=__name__, category="inference") class OllamaInferenceAdapter( @@ -117,10 +117,10 @@ class OllamaInferenceAdapter( return self._openai_client async def initialize(self) -> None: - logger.info(f"checking connectivity to Ollama at `{self.config.url}`...") + log.info(f"checking connectivity to Ollama at `{self.config.url}`...") health_response = await self.health() if health_response["status"] == HealthStatus.ERROR: - logger.warning( + log.warning( "Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal" ) @@ -339,7 +339,7 @@ class OllamaInferenceAdapter( "options": sampling_options, "stream": request.stream, } - logger.debug(f"params to ollama: {params}") + log.debug(f"params to ollama: {params}") return params @@ -437,7 +437,7 @@ class OllamaInferenceAdapter( if provider_resource_id not in available_models: available_models_latest = [m.model.split(":latest")[0] for m in response.models] if provider_resource_id in available_models_latest: - logger.warning( + log.warning( f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'" ) return model diff --git a/llama_stack/providers/remote/inference/openai/openai.py b/llama_stack/providers/remote/inference/openai/openai.py index 865258559..bd794c257 100644 --- a/llama_stack/providers/remote/inference/openai/openai.py +++ b/llama_stack/providers/remote/inference/openai/openai.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin @@ -12,8 +11,6 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .config import OpenAIConfig from .models import MODEL_ENTRIES -logger = logging.getLogger(__name__) - # # This OpenAI adapter implements Inference methods using two mixins - diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index a5bb079ef..2b3e7fa72 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -5,7 +5,6 @@ # the root directory of this source tree. -import logging from collections.abc import AsyncGenerator from huggingface_hub import AsyncInferenceClient, HfApi @@ -34,6 +33,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import Model +from llama_stack.log import get_logger from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ( @@ -58,7 +58,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig -log = logging.getLogger(__name__) +logger = get_logger(__name__, category="core") def build_hf_repo_model_entries(): @@ -307,7 +307,7 @@ class TGIAdapter(_HfAdapter): async def initialize(self, config: TGIImplConfig) -> None: if not config.url: raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.") - log.info(f"Initializing TGI client with url={config.url}") + logger.info(f"Initializing TGI client with url={config.url}") self.client = AsyncInferenceClient( model=config.url, ) diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index a06e4173b..8377cfe4c 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -61,7 +61,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import TogetherImplConfig from .models import MODEL_ENTRIES -logger = get_logger(name=__name__, category="inference") +log = get_logger(name=__name__, category="inference") class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): @@ -232,7 +232,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi "stream": request.stream, **self._build_options(request.sampling_params, request.logprobs, request.response_format), } - logger.debug(f"params to together: {params}") + log.debug(f"params to together: {params}") return params async def embeddings( diff --git a/llama_stack/providers/remote/post_training/nvidia/utils.py b/llama_stack/providers/remote/post_training/nvidia/utils.py index d6e1016b2..bd6b35a4c 100644 --- a/llama_stack/providers/remote/post_training/nvidia/utils.py +++ b/llama_stack/providers/remote/post_training/nvidia/utils.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import warnings from typing import Any @@ -15,8 +14,6 @@ from llama_stack.providers.remote.post_training.nvidia.config import SFTLoRADefa from .config import NvidiaPostTrainingConfig -logger = logging.getLogger(__name__) - def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None: keys = set(config_dict.__annotations__.keys()) if isinstance(config_dict, BaseModel) else config_dict.keys() diff --git a/llama_stack/providers/remote/safety/bedrock/bedrock.py b/llama_stack/providers/remote/safety/bedrock/bedrock.py index 1895e7507..94b30ed40 100644 --- a/llama_stack/providers/remote/safety/bedrock/bedrock.py +++ b/llama_stack/providers/remote/safety/bedrock/bedrock.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import json -import logging from typing import Any from llama_stack.apis.inference import Message @@ -16,12 +15,13 @@ from llama_stack.apis.safety import ( ViolationLevel, ) from llama_stack.apis.shields import Shield +from llama_stack.log import get_logger from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.utils.bedrock.client import create_bedrock_client from .config import BedrockSafetyConfig -logger = logging.getLogger(__name__) +log = get_logger(name=__name__, category="safety") class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): @@ -76,13 +76,13 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): """ shield_params = shield.params - logger.debug(f"run_shield::{shield_params}::messages={messages}") + log.debug(f"run_shield::{shield_params}::messages={messages}") # - convert the messages into format Bedrock expects content_messages = [] for message in messages: content_messages.append({"text": {"text": message.content}}) - logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:") + log.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:") response = self.bedrock_runtime_client.apply_guardrail( guardrailIdentifier=shield.provider_resource_id, diff --git a/llama_stack/providers/remote/safety/nvidia/nvidia.py b/llama_stack/providers/remote/safety/nvidia/nvidia.py index 7f17b1cb6..5e567017b 100644 --- a/llama_stack/providers/remote/safety/nvidia/nvidia.py +++ b/llama_stack/providers/remote/safety/nvidia/nvidia.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging from typing import Any import requests @@ -17,8 +16,6 @@ from llama_stack.providers.utils.inference.openai_compat import convert_message_ from .config import NVIDIASafetyConfig -logger = logging.getLogger(__name__) - class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate): def __init__(self, config: NVIDIASafetyConfig) -> None: diff --git a/llama_stack/providers/remote/safety/sambanova/sambanova.py b/llama_stack/providers/remote/safety/sambanova/sambanova.py index 6c7190afe..7cac365d5 100644 --- a/llama_stack/providers/remote/safety/sambanova/sambanova.py +++ b/llama_stack/providers/remote/safety/sambanova/sambanova.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import json -import logging from typing import Any import litellm @@ -20,12 +19,13 @@ from llama_stack.apis.safety import ( ) from llama_stack.apis.shields import Shield from llama_stack.core.request_headers import NeedsRequestProviderData +from llama_stack.log import get_logger from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new from .config import SambaNovaSafetyConfig -logger = logging.getLogger(__name__) +log = get_logger(name=__name__, category="safety") CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" @@ -66,7 +66,7 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide "guard" not in shield.provider_resource_id.lower() or shield.provider_resource_id.split("sambanova/")[-1] not in self.environment_available_models ): - logger.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}") + log.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}") async def unregister_shield(self, identifier: str) -> None: pass @@ -79,9 +79,9 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide raise ValueError(f"Shield {shield_id} not found") shield_params = shield.params - logger.debug(f"run_shield::{shield_params}::messages={messages}") + log.debug(f"run_shield::{shield_params}::messages={messages}") content_messages = [await convert_message_to_openai_dict_new(m) for m in messages] - logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:") + log.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:") response = litellm.completion( model=shield.provider_resource_id, messages=content_messages, api_key=self._get_api_key() diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index 26aeaedfb..154a19146 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import asyncio import json -import logging from typing import Any from urllib.parse import urlparse @@ -20,6 +19,7 @@ from llama_stack.apis.vector_io import ( QueryChunksResponse, VectorIO, ) +from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig from llama_stack.providers.utils.kvstore import kvstore_impl @@ -32,8 +32,6 @@ from llama_stack.providers.utils.memory.vector_store import ( from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig -log = logging.getLogger(__name__) - ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI VERSION = "v3" @@ -43,6 +41,8 @@ OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:chroma:{VERSION}::" OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:chroma:{VERSION}::" OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:chroma:{VERSION}::" +logger = get_logger(__name__, category="core") + # this is a helper to allow us to use async and non-async chroma clients interchangeably async def maybe_await(result): @@ -92,7 +92,7 @@ class ChromaIndex(EmbeddingIndex): doc = json.loads(doc) chunk = Chunk(**doc) except Exception: - log.exception(f"Failed to parse document: {doc}") + logger.exception(f"Failed to parse document: {doc}") continue score = 1.0 / float(dist) if dist != 0 else float("inf") @@ -137,7 +137,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP inference_api: Api.inference, files_api: Files | None, ) -> None: - log.info(f"Initializing ChromaVectorIOAdapter with url: {config}") + logger.info(f"Initializing ChromaVectorIOAdapter with url: {config}") self.config = config self.inference_api = inference_api self.client = None @@ -150,7 +150,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP self.vector_db_store = self.kvstore if isinstance(self.config, RemoteChromaVectorIOConfig): - log.info(f"Connecting to Chroma server at: {self.config.url}") + logger.info(f"Connecting to Chroma server at: {self.config.url}") url = self.config.url.rstrip("/") parsed = urlparse(url) @@ -159,7 +159,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP self.client = await chromadb.AsyncHttpClient(host=parsed.hostname, port=parsed.port) else: - log.info(f"Connecting to Chroma local db at: {self.config.db_path}") + logger.info(f"Connecting to Chroma local db at: {self.config.db_path}") self.client = chromadb.PersistentClient(path=self.config.db_path) self.openai_vector_stores = await self._load_openai_vector_stores() @@ -182,7 +182,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP async def unregister_vector_db(self, vector_db_id: str) -> None: if vector_db_id not in self.cache: - log.warning(f"Vector DB {vector_db_id} not found") + logger.warning(f"Vector DB {vector_db_id} not found") return await self.cache[vector_db_id].index.delete() diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index db58bf6d3..4399301b7 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import asyncio -import logging import os from typing import Any @@ -21,6 +20,7 @@ from llama_stack.apis.vector_io import ( QueryChunksResponse, VectorIO, ) +from llama_stack.log import get_logger from llama_stack.providers.datatypes import VectorDBsProtocolPrivate from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig from llama_stack.providers.utils.kvstore import kvstore_impl @@ -34,7 +34,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig -logger = logging.getLogger(__name__) +log = get_logger(name=__name__, category="core") VERSION = "v3" VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::" @@ -68,7 +68,7 @@ class MilvusIndex(EmbeddingIndex): ) if not await asyncio.to_thread(self.client.has_collection, self.collection_name): - logger.info(f"Creating new collection {self.collection_name} with nullable sparse field") + log.info(f"Creating new collection {self.collection_name} with nullable sparse field") # Create schema for vector search schema = self.client.create_schema() schema.add_field( @@ -147,7 +147,7 @@ class MilvusIndex(EmbeddingIndex): data=data, ) except Exception as e: - logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}") + log.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}") raise e async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: @@ -203,7 +203,7 @@ class MilvusIndex(EmbeddingIndex): return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores) except Exception as e: - logger.error(f"Error performing BM25 search: {e}") + log.error(f"Error performing BM25 search: {e}") # Fallback to simple text search return await self._fallback_keyword_search(query_string, k, score_threshold) @@ -247,7 +247,7 @@ class MilvusIndex(EmbeddingIndex): self.client.delete, collection_name=self.collection_name, filter=f'chunk_id == "{chunk_id}"' ) except Exception as e: - logger.error(f"Error deleting chunk {chunk_id} from Milvus collection {self.collection_name}: {e}") + log.error(f"Error deleting chunk {chunk_id} from Milvus collection {self.collection_name}: {e}") raise @@ -288,10 +288,10 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP ) self.cache[vector_db.identifier] = index if isinstance(self.config, RemoteMilvusVectorIOConfig): - logger.info(f"Connecting to Milvus server at {self.config.uri}") + log.info(f"Connecting to Milvus server at {self.config.uri}") self.client = MilvusClient(**self.config.model_dump(exclude_none=True)) else: - logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}") + log.info(f"Connecting to Milvus Lite at: {self.config.db_path}") uri = os.path.expanduser(self.config.db_path) self.client = MilvusClient(uri=uri) diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index b1645ac5a..f8b57451a 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging from typing import Any import psycopg2 @@ -22,6 +21,7 @@ from llama_stack.apis.vector_io import ( QueryChunksResponse, VectorIO, ) +from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore @@ -33,8 +33,6 @@ from llama_stack.providers.utils.memory.vector_store import ( from .config import PGVectorVectorIOConfig -log = logging.getLogger(__name__) - VERSION = "v3" VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::" VECTOR_INDEX_PREFIX = f"vector_index:pgvector:{VERSION}::" @@ -42,6 +40,8 @@ OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:pgvector:{VERSION}::" OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:pgvector:{VERSION}::" OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:pgvector:{VERSION}::" +logger = get_logger(__name__, category="core") + def check_extension_version(cur): cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'") @@ -187,7 +187,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco self.metadatadata_collection_name = "openai_vector_stores_metadata" async def initialize(self) -> None: - log.info(f"Initializing PGVector memory adapter with config: {self.config}") + logger.info(f"Initializing PGVector memory adapter with config: {self.config}") self.kvstore = await kvstore_impl(self.config.kvstore) await self.initialize_openai_vector_stores() @@ -203,7 +203,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: version = check_extension_version(cur) if version: - log.info(f"Vector extension version: {version}") + logger.info(f"Vector extension version: {version}") else: raise RuntimeError("Vector extension is not installed.") @@ -216,13 +216,13 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco """ ) except Exception as e: - log.exception("Could not connect to PGVector database server") + logger.exception("Could not connect to PGVector database server") raise RuntimeError("Could not connect to PGVector database server") from e async def shutdown(self) -> None: if self.conn is not None: self.conn.close() - log.info("Connection to PGVector database server closed") + logger.info("Connection to PGVector database server closed") async def register_vector_db(self, vector_db: VectorDB) -> None: # Persist vector DB metadata in the KV store diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index 144da0f4f..833f64a5c 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import asyncio -import logging import uuid from typing import Any @@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import ( VectorStoreChunkingStrategy, VectorStoreFileObject, ) +from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl @@ -35,13 +35,14 @@ from llama_stack.providers.utils.memory.vector_store import ( from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig -log = logging.getLogger(__name__) CHUNK_ID_KEY = "_chunk_id" # KV store prefixes for vector databases VERSION = "v3" VECTOR_DBS_PREFIX = f"vector_dbs:qdrant:{VERSION}::" +logger = get_logger(__name__, category="core") + def convert_id(_id: str) -> str: """ @@ -96,7 +97,7 @@ class QdrantIndex(EmbeddingIndex): points_selector=models.PointIdsList(points=[convert_id(chunk_id)]), ) except Exception as e: - log.error(f"Error deleting chunk {chunk_id} from Qdrant collection {self.collection_name}: {e}") + logger.error(f"Error deleting chunk {chunk_id} from Qdrant collection {self.collection_name}: {e}") raise async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: @@ -118,7 +119,7 @@ class QdrantIndex(EmbeddingIndex): try: chunk = Chunk(**point.payload["chunk_content"]) except Exception: - log.exception("Failed to parse chunk") + logger.exception("Failed to parse chunk") continue chunks.append(chunk) diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index 11da8902c..4ad5750a5 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import json -import logging from typing import Any import weaviate @@ -19,6 +18,7 @@ from llama_stack.apis.files.files import Files from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.core.request_headers import NeedsRequestProviderData +from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore @@ -33,8 +33,6 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti from .config import WeaviateVectorIOConfig -log = logging.getLogger(__name__) - VERSION = "v3" VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::" VECTOR_INDEX_PREFIX = f"vector_index:weaviate:{VERSION}::" @@ -42,6 +40,8 @@ OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:weaviate:{VERSION}::" OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:weaviate:{VERSION}::" OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:weaviate:{VERSION}::" +logger = get_logger(__name__, category="core") + class WeaviateIndex(EmbeddingIndex): def __init__( @@ -102,7 +102,7 @@ class WeaviateIndex(EmbeddingIndex): chunk_dict = json.loads(chunk_json) chunk = Chunk(**chunk_dict) except Exception: - log.exception(f"Failed to parse document: {chunk_json}") + logger.exception(f"Failed to parse document: {chunk_json}") continue score = 1.0 / doc.metadata.distance if doc.metadata.distance != 0 else float("inf") @@ -171,7 +171,7 @@ class WeaviateVectorIOAdapter( def _get_client(self) -> weaviate.Client: if "localhost" in self.config.weaviate_cluster_url: - log.info("using Weaviate locally in container") + logger.info("using Weaviate locally in container") host, port = self.config.weaviate_cluster_url.split(":") key = "local_test" client = weaviate.connect_to_local( @@ -179,7 +179,7 @@ class WeaviateVectorIOAdapter( port=port, ) else: - log.info("Using Weaviate remote cluster with URL") + logger.info("Using Weaviate remote cluster with URL") key = f"{self.config.weaviate_cluster_url}::{self.config.weaviate_api_key}" if key in self.client_cache: return self.client_cache[key] @@ -197,7 +197,7 @@ class WeaviateVectorIOAdapter( self.kvstore = await kvstore_impl(self.config.kvstore) else: self.kvstore = None - log.info("No kvstore configured, registry will not persist across restarts") + logger.info("No kvstore configured, registry will not persist across restarts") # Load existing vector DB definitions if self.kvstore is not None: @@ -254,7 +254,7 @@ class WeaviateVectorIOAdapter( client = self._get_client() sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True) if sanitized_collection_name not in self.cache or client.collections.exists(sanitized_collection_name) is False: - log.warning(f"Vector DB {sanitized_collection_name} not found") + logger.warning(f"Vector DB {sanitized_collection_name} not found") return client.collections.delete(sanitized_collection_name) await self.cache[sanitized_collection_name].index.delete() diff --git a/llama_stack/providers/utils/inference/embedding_mixin.py b/llama_stack/providers/utils/inference/embedding_mixin.py index 32e89f987..bb0af0787 100644 --- a/llama_stack/providers/utils/inference/embedding_mixin.py +++ b/llama_stack/providers/utils/inference/embedding_mixin.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import base64 -import logging import struct from typing import TYPE_CHECKING @@ -27,7 +26,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import interleaved_con EMBEDDING_MODELS = {} -log = logging.getLogger(__name__) +from llama_stack.log import get_logger + +log = get_logger(name=__name__, category="inference") class SentenceTransformerEmbeddingMixin: diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index da2e634f6..7f1cf3686 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -54,7 +54,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) -logger = get_logger(name=__name__, category="inference") +log = get_logger(name=__name__, category="inference") class LiteLLMOpenAIMixin( @@ -157,7 +157,7 @@ class LiteLLMOpenAIMixin( params = await self._get_params(request) params["model"] = self.get_litellm_model_name(params["model"]) - logger.debug(f"params to litellm (openai compat): {params}") + log.debug(f"params to litellm (openai compat): {params}") # see https://docs.litellm.ai/docs/completion/stream#async-completion response = await litellm.acompletion(**params) if stream: @@ -460,7 +460,7 @@ class LiteLLMOpenAIMixin( :return: True if the model is available dynamically, False otherwise. """ if self.litellm_provider_name not in litellm.models_by_provider: - logger.error(f"Provider {self.litellm_provider_name} is not registered in litellm.") + log.error(f"Provider {self.litellm_provider_name} is not registered in litellm.") return False return model in litellm.models_by_provider[self.litellm_provider_name] diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index ddb3bda8c..b01243582 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -17,7 +17,7 @@ from llama_stack.providers.utils.inference import ( ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR, ) -logger = get_logger(name=__name__, category="core") +log = get_logger(name=__name__, category="core") class RemoteInferenceProviderConfig(BaseModel): @@ -135,7 +135,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate): :param model: The model identifier to check. :return: True if the model is available dynamically, False otherwise. """ - logger.info( + log.info( f"check_model_availability is not implemented for {self.__class__.__name__}. Returning False by default." ) return False diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index e6e5ccc8a..a6082e7f0 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import base64 import json -import logging import struct import time import uuid @@ -116,6 +115,7 @@ from llama_stack.apis.inference import ( from llama_stack.apis.inference import ( OpenAIChoice as OpenAIChatCompletionChoice, ) +from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ( BuiltinTool, StopReason, @@ -128,7 +128,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( decode_assistant_message, ) -logger = logging.getLogger(__name__) +log = get_logger(name=__name__, category="inference") class OpenAICompatCompletionChoiceDelta(BaseModel): @@ -316,7 +316,7 @@ def process_chat_completion_response( if t.tool_name in request_tools: new_tool_calls.append(t) else: - logger.warning(f"Tool {t.tool_name} not found in request tools") + log.warning(f"Tool {t.tool_name} not found in request tools") if len(new_tool_calls) < len(raw_message.tool_calls): raw_message.tool_calls = new_tool_calls @@ -477,7 +477,7 @@ async def process_chat_completion_stream_response( ) ) else: - logger.warning(f"Tool {tool_call.tool_name} not found in request tools") + log.warning(f"Tool {tool_call.tool_name} not found in request tools") yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.progress, @@ -1198,7 +1198,7 @@ async def convert_openai_chat_completion_stream( ) for idx, buffer in tool_call_idx_to_buffer.items(): - logger.debug(f"toolcall_buffer[{idx}]: {buffer}") + log.debug(f"toolcall_buffer[{idx}]: {buffer}") if buffer["name"]: delta = ")" buffer["content"] += delta diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index 72286dffb..72b937b44 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -25,7 +25,7 @@ from llama_stack.apis.inference import ( from llama_stack.log import get_logger from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params -logger = get_logger(name=__name__, category="core") +log = get_logger(name=__name__, category="core") class OpenAIMixin(ABC): @@ -125,9 +125,9 @@ class OpenAIMixin(ABC): Direct OpenAI completion API call. """ if guided_choice is not None: - logger.warning("guided_choice is not supported by the OpenAI API. Ignoring.") + log.warning("guided_choice is not supported by the OpenAI API. Ignoring.") if prompt_logprobs is not None: - logger.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.") + log.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.") # TODO: fix openai_completion to return type compatible with OpenAI's API response return await self.client.completions.create( # type: ignore[no-any-return] @@ -267,6 +267,6 @@ class OpenAIMixin(ABC): pass except Exception as e: # All other errors (auth, rate limit, network, etc.) - logger.warning(f"Failed to check model availability for {model}: {e}") + log.warning(f"Failed to check model availability for {model}: {e}") return False diff --git a/llama_stack/providers/utils/kvstore/mongodb/mongodb.py b/llama_stack/providers/utils/kvstore/mongodb/mongodb.py index 3842773d9..639735b11 100644 --- a/llama_stack/providers/utils/kvstore/mongodb/mongodb.py +++ b/llama_stack/providers/utils/kvstore/mongodb/mongodb.py @@ -4,16 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging from datetime import datetime from pymongo import AsyncMongoClient +from llama_stack.log import get_logger from llama_stack.providers.utils.kvstore import KVStore from ..config import MongoDBKVStoreConfig -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="core") class MongoDBKVStoreImpl(KVStore): diff --git a/llama_stack/providers/utils/kvstore/postgres/postgres.py b/llama_stack/providers/utils/kvstore/postgres/postgres.py index bd35decfc..605917fb2 100644 --- a/llama_stack/providers/utils/kvstore/postgres/postgres.py +++ b/llama_stack/providers/utils/kvstore/postgres/postgres.py @@ -4,16 +4,17 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging from datetime import datetime import psycopg2 from psycopg2.extras import DictCursor +from llama_stack.log import get_logger + from ..api import KVStore from ..config import PostgresKVStoreConfig -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="core") class PostgresKVStoreImpl(KVStore): diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index 7b6e69df1..30293dabb 100644 --- a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -6,7 +6,6 @@ import asyncio import json -import logging import mimetypes import time import uuid @@ -37,10 +36,11 @@ from llama_stack.apis.vector_io import ( VectorStoreSearchResponse, VectorStoreSearchResponsePage, ) +from llama_stack.log import get_logger from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, make_overlapped_chunks -logger = logging.getLogger(__name__) +log = get_logger(name=__name__, category="core") # Constants for OpenAI vector stores CHUNK_MULTIPLIER = 5 @@ -378,7 +378,7 @@ class OpenAIVectorStoreMixin(ABC): try: await self.unregister_vector_db(vector_store_id) except Exception as e: - logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}") + log.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}") return VectorStoreDeleteResponse( id=vector_store_id, @@ -460,7 +460,7 @@ class OpenAIVectorStoreMixin(ABC): ) except Exception as e: - logger.error(f"Error searching vector store {vector_store_id}: {e}") + log.error(f"Error searching vector store {vector_store_id}: {e}") # Return empty results on error return VectorStoreSearchResponsePage( search_query=search_query, @@ -614,7 +614,7 @@ class OpenAIVectorStoreMixin(ABC): ) vector_store_file_object.status = "completed" except Exception as e: - logger.error(f"Error attaching file to vector store: {e}") + log.error(f"Error attaching file to vector store: {e}") vector_store_file_object.status = "failed" vector_store_file_object.last_error = VectorStoreFileLastError( code="server_error", diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index 484475e9d..73c3c1fac 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import base64 import io -import logging import re import time from abc import ABC, abstractmethod @@ -25,6 +24,7 @@ from llama_stack.apis.common.content_types import ( from llama_stack.apis.tools import RAGDocument from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse +from llama_stack.log import get_logger from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.providers.datatypes import Api from llama_stack.providers.utils.inference.prompt_adapter import ( @@ -32,12 +32,12 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( ) from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id -log = logging.getLogger(__name__) - # Constants for reranker types RERANKER_TYPE_RRF = "rrf" RERANKER_TYPE_WEIGHTED = "weighted" +log = get_logger(name=__name__, category="memory") + def parse_pdf(data: bytes) -> str: # For PDF and DOC/DOCX files, we can't reliably convert to string diff --git a/llama_stack/providers/utils/scheduler.py b/llama_stack/providers/utils/scheduler.py index 65c3d2898..460e86fb4 100644 --- a/llama_stack/providers/utils/scheduler.py +++ b/llama_stack/providers/utils/scheduler.py @@ -17,7 +17,7 @@ from pydantic import BaseModel from llama_stack.log import get_logger -logger = get_logger(name=__name__, category="scheduler") +log = get_logger(name=__name__, category="scheduler") # TODO: revisit the list of possible statuses when defining a more coherent @@ -186,7 +186,7 @@ class _NaiveSchedulerBackend(_SchedulerBackend): except Exception as e: on_log_message_cb(str(e)) job.status = JobStatus.failed - logger.exception(f"Job {job.id} failed.") + log.exception(f"Job {job.id} failed.") asyncio.run_coroutine_threadsafe(do(), self._loop) @@ -222,7 +222,7 @@ class Scheduler: msg = (datetime.now(UTC), message) # At least for the time being, until there's a better way to expose # logs to users, log messages on console - logger.info(f"Job {job.id}: {message}") + log.info(f"Job {job.id}: {message}") job.append_log(msg) self._backend.on_log_message_cb(job, msg) diff --git a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py index ccc835768..03380eb47 100644 --- a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py @@ -17,7 +17,7 @@ from llama_stack.log import get_logger from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore from .sqlstore import SqlStoreType -logger = get_logger(name=__name__, category="authorized_sqlstore") +log = get_logger(name=__name__, category="authorized_sqlstore") # Hardcoded copy of the default policy that our SQL filtering implements # WARNING: If default_policy() changes, this constant must be updated accordingly @@ -81,7 +81,7 @@ class AuthorizedSqlStore: actual_default = default_policy() if SQL_OPTIMIZED_POLICY != actual_default: - logger.warning( + log.warning( f"SQL_OPTIMIZED_POLICY does not match default_policy(). " f"SQL filtering will use conservative mode. " f"Expected: {SQL_OPTIMIZED_POLICY}, Got: {actual_default}", diff --git a/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py b/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py index 6414929db..2aaa050df 100644 --- a/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py @@ -29,7 +29,7 @@ from llama_stack.log import get_logger from .api import ColumnDefinition, ColumnType, SqlStore from .sqlstore import SqlAlchemySqlStoreConfig -logger = get_logger(name=__name__, category="sqlstore") +log = get_logger(name=__name__, category="sqlstore") TYPE_MAPPING: dict[ColumnType, Any] = { ColumnType.INTEGER: Integer, @@ -280,5 +280,5 @@ class SqlAlchemySqlStoreImpl(SqlStore): except Exception as e: # If any error occurs during migration, log it but don't fail # The table creation will handle adding the column - logger.error(f"Error adding column {column_name} to table {table}: {e}") + log.error(f"Error adding column {column_name} to table {table}: {e}") pass diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py index c85722bdc..bffa5f45a 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -6,7 +6,7 @@ import asyncio import contextvars -import logging +import logging # allow-direct-logging import queue import random import threading diff --git a/tests/integration/post_training/test_post_training.py b/tests/integration/post_training/test_post_training.py index f9c797593..b4b82a670 100644 --- a/tests/integration/post_training/test_post_training.py +++ b/tests/integration/post_training/test_post_training.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import sys import time import uuid @@ -19,10 +18,9 @@ from llama_stack.apis.post_training import ( LoraFinetuningConfig, TrainingConfig, ) +from llama_stack.log import get_logger -# Configure logging -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", force=True) -logger = logging.getLogger(__name__) +log = get_logger(name=__name__, category="core") skip_because_resource_intensive = pytest.mark.skip( @@ -71,14 +69,14 @@ class TestPostTraining: ) @pytest.mark.timeout(360) # 6 minutes timeout def test_supervised_fine_tune(self, llama_stack_client, purpose, source): - logger.info("Starting supervised fine-tuning test") + log.info("Starting supervised fine-tuning test") # register dataset to train dataset = llama_stack_client.datasets.register( purpose=purpose, source=source, ) - logger.info(f"Registered dataset with ID: {dataset.identifier}") + log.info(f"Registered dataset with ID: {dataset.identifier}") algorithm_config = LoraFinetuningConfig( type="LoRA", @@ -105,7 +103,7 @@ class TestPostTraining: ) job_uuid = f"test-job{uuid.uuid4()}" - logger.info(f"Starting training job with UUID: {job_uuid}") + log.info(f"Starting training job with UUID: {job_uuid}") # train with HF trl SFTTrainer as the default _ = llama_stack_client.post_training.supervised_fine_tune( @@ -121,21 +119,21 @@ class TestPostTraining: while True: status = llama_stack_client.post_training.job.status(job_uuid=job_uuid) if not status: - logger.error("Job not found") + log.error("Job not found") break - logger.info(f"Current status: {status}") + log.info(f"Current status: {status}") assert status.status in ["scheduled", "in_progress", "completed"] if status.status == "completed": break - logger.info("Waiting for job to complete...") + log.info("Waiting for job to complete...") time.sleep(10) # Increased sleep time to reduce polling frequency artifacts = llama_stack_client.post_training.job.artifacts(job_uuid=job_uuid) - logger.info(f"Job artifacts: {artifacts}") + log.info(f"Job artifacts: {artifacts}") - logger.info(f"Registered dataset with ID: {dataset.identifier}") + log.info(f"Registered dataset with ID: {dataset.identifier}") # TODO: Fix these tests to properly represent the Jobs API in training # @@ -181,17 +179,21 @@ class TestPostTraining: ) @pytest.mark.timeout(360) def test_preference_optimize(self, llama_stack_client, purpose, source): - logger.info("Starting DPO preference optimization test") + log.info("Starting DPO preference optimization test") # register preference dataset to train dataset = llama_stack_client.datasets.register( purpose=purpose, source=source, ) - logger.info(f"Registered preference dataset with ID: {dataset.identifier}") + log.info(f"Registered preference dataset with ID: {dataset.identifier}") # DPO algorithm configuration algorithm_config = DPOAlignmentConfig( + reward_scale=1.0, + reward_clip=10.0, + epsilon=1e-8, + gamma=0.99, beta=0.1, loss_type=DPOLossType.sigmoid, # Default loss type ) @@ -211,7 +213,7 @@ class TestPostTraining: ) job_uuid = f"test-dpo-job-{uuid.uuid4()}" - logger.info(f"Starting DPO training job with UUID: {job_uuid}") + log.info(f"Starting DPO training job with UUID: {job_uuid}") # train with HuggingFace DPO implementation _ = llama_stack_client.post_training.preference_optimize( @@ -226,15 +228,15 @@ class TestPostTraining: while True: status = llama_stack_client.post_training.job.status(job_uuid=job_uuid) if not status: - logger.error("DPO job not found") + log.error("DPO job not found") break - logger.info(f"Current DPO status: {status}") + log.info(f"Current DPO status: {status}") if status.status == "completed": break - logger.info("Waiting for DPO job to complete...") + log.info("Waiting for DPO job to complete...") time.sleep(10) # Increased sleep time to reduce polling frequency artifacts = llama_stack_client.post_training.job.artifacts(job_uuid=job_uuid) - logger.info(f"DPO job artifacts: {artifacts}") + log.info(f"DPO job artifacts: {artifacts}") diff --git a/tests/integration/vector_io/test_openai_vector_stores.py b/tests/integration/vector_io/test_openai_vector_stores.py index 1c9ef92b6..9ae25cbb5 100644 --- a/tests/integration/vector_io/test_openai_vector_stores.py +++ b/tests/integration/vector_io/test_openai_vector_stores.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import time from io import BytesIO @@ -13,8 +12,10 @@ from llama_stack_client import BadRequestError, LlamaStackClient from openai import BadRequestError as OpenAIBadRequestError from llama_stack.apis.vector_io import Chunk +from llama_stack.core.library_client import LlamaStackAsLibraryClient +from llama_stack.log import get_logger -logger = logging.getLogger(__name__) +log = get_logger(name=__name__, category="vector-io") def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models): @@ -99,7 +100,7 @@ def compat_client_with_empty_stores(compat_client): compat_client.vector_stores.delete(vector_store_id=store.id) except Exception: # If the API is not available or fails, just continue - logger.warning("Failed to clear vector stores") + log.warning("Failed to clear vector stores") pass def clear_files(): @@ -109,7 +110,7 @@ def compat_client_with_empty_stores(compat_client): compat_client.files.delete(file_id=file.id) except Exception: # If the API is not available or fails, just continue - logger.warning("Failed to clear files") + log.warning("Failed to clear files") pass clear_vector_stores() diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index 5c2ad03ab..ce0e930b1 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -6,7 +6,7 @@ import asyncio import json -import logging +import logging # allow-direct-logging import threading import time from http.server import BaseHTTPRequestHandler, HTTPServer