feat: add auto-generated CI documentation pre-commit hook (#2890)

Our CI is entirely undocumented, this commit adds a README.md file with
a table of the current CI and what is does

---------

Signed-off-by: Nathan Weinberg <nweinber@redhat.com>
This commit is contained in:
Nathan Weinberg 2025-07-25 11:57:01 -04:00 committed by Mustafa Elbehery
parent 7f834339ba
commit b381ed6d64
93 changed files with 495 additions and 477 deletions

View file

@ -145,6 +145,24 @@ repos:
pass_filenames: false pass_filenames: false
require_serial: true require_serial: true
files: ^.github/workflows/.*$ files: ^.github/workflows/.*$
- id: check-log-usage
name: Check for proper log usage (use llama_stack.log instead)
entry: bash
language: system
types: [python]
pass_filenames: true
args:
- -c
- |
matches=$(grep -EnH '^[^#]*\b(import logging|from logging\b)' "$@" | grep -v '# allow-direct-logging' || true)
if [ -n "$matches" ]; then
# GitHub Actions annotation format
while IFS=: read -r file line_num rest; do
echo "::error file=$file,line=$line_num::Do not use 'import logging' or 'from logging import' in $file. Use the custom log instead: from llama_stack.log import get_logger; logger = get_logger(). If direct logging is truly needed, add: # allow-direct-logging"
done <<< "$matches"
exit 1
fi
exit 0
ci: ci:
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks

View file

@ -15,7 +15,7 @@ from llama_stack.log import get_logger
REPO_ROOT = Path(__file__).parent.parent.parent.parent REPO_ROOT = Path(__file__).parent.parent.parent.parent
logger = get_logger(name=__name__, category="server") log = get_logger(name=__name__, category="server")
class StackRun(Subcommand): class StackRun(Subcommand):
@ -126,7 +126,7 @@ class StackRun(Subcommand):
self.parser.error("Config file is required for venv environment") self.parser.error("Config file is required for venv environment")
if config_file: if config_file:
logger.info(f"Using run configuration: {config_file}") log.info(f"Using run configuration: {config_file}")
try: try:
config_dict = yaml.safe_load(config_file.read_text()) config_dict = yaml.safe_load(config_file.read_text())
@ -145,7 +145,7 @@ class StackRun(Subcommand):
# If neither image type nor image name is provided, assume the server should be run directly # If neither image type nor image name is provided, assume the server should be run directly
# using the current environment packages. # using the current environment packages.
if not image_type and not image_name: if not image_type and not image_name:
logger.info("No image type or image name provided. Assuming environment packages.") log.info("No image type or image name provided. Assuming environment packages.")
from llama_stack.core.server.server import main as server_main from llama_stack.core.server.server import main as server_main
# Build the server args from the current args passed to the CLI # Build the server args from the current args passed to the CLI
@ -185,11 +185,11 @@ class StackRun(Subcommand):
run_command(run_args) run_command(run_args)
def _start_ui_development_server(self, stack_server_port: int): def _start_ui_development_server(self, stack_server_port: int):
logger.info("Attempting to start UI development server...") log.info("Attempting to start UI development server...")
# Check if npm is available # Check if npm is available
npm_check = subprocess.run(["npm", "--version"], capture_output=True, text=True, check=False) npm_check = subprocess.run(["npm", "--version"], capture_output=True, text=True, check=False)
if npm_check.returncode != 0: if npm_check.returncode != 0:
logger.warning( log.warning(
f"'npm' command not found or not executable. UI development server will not be started. Error: {npm_check.stderr}" f"'npm' command not found or not executable. UI development server will not be started. Error: {npm_check.stderr}"
) )
return return
@ -214,13 +214,13 @@ class StackRun(Subcommand):
stderr=stderr_log_file, stderr=stderr_log_file,
env={**os.environ, "NEXT_PUBLIC_LLAMA_STACK_BASE_URL": f"http://localhost:{stack_server_port}"}, env={**os.environ, "NEXT_PUBLIC_LLAMA_STACK_BASE_URL": f"http://localhost:{stack_server_port}"},
) )
logger.info(f"UI development server process started in {ui_dir} with PID {process.pid}.") log.info(f"UI development server process started in {ui_dir} with PID {process.pid}.")
logger.info(f"Logs: stdout -> {ui_stdout_log_path}, stderr -> {ui_stderr_log_path}") log.info(f"Logs: stdout -> {ui_stdout_log_path}, stderr -> {ui_stderr_log_path}")
logger.info(f"UI will be available at http://localhost:{os.getenv('LLAMA_STACK_UI_PORT', 8322)}") log.info(f"UI will be available at http://localhost:{os.getenv('LLAMA_STACK_UI_PORT', 8322)}")
except FileNotFoundError: except FileNotFoundError:
logger.error( log.error(
"Failed to start UI development server: 'npm' command not found. Make sure npm is installed and in your PATH." "Failed to start UI development server: 'npm' command not found. Make sure npm is installed and in your PATH."
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to start UI development server in {ui_dir}: {e}") log.error(f"Failed to start UI development server in {ui_dir}: {e}")

View file

@ -8,7 +8,7 @@ import argparse
from llama_stack.log import get_logger from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="cli") log = get_logger(name=__name__, category="cli")
# TODO: this can probably just be inlined now? # TODO: this can probably just be inlined now?

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import importlib.resources import importlib.resources
import logging
import sys import sys
from pydantic import BaseModel from pydantic import BaseModel
@ -17,10 +16,9 @@ from llama_stack.core.external import load_external_apis
from llama_stack.core.utils.exec import run_command from llama_stack.core.utils.exec import run_command
from llama_stack.core.utils.image_types import LlamaStackImageType from llama_stack.core.utils.image_types import LlamaStackImageType
from llama_stack.distributions.template import DistributionTemplate from llama_stack.distributions.template import DistributionTemplate
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__)
# These are the dependencies needed by the distribution server. # These are the dependencies needed by the distribution server.
# `llama-stack` is automatically installed by the installation script. # `llama-stack` is automatically installed by the installation script.
SERVER_DEPENDENCIES = [ SERVER_DEPENDENCIES = [
@ -33,6 +31,8 @@ SERVER_DEPENDENCIES = [
"opentelemetry-exporter-otlp-proto-http", "opentelemetry-exporter-otlp-proto-http",
] ]
log = get_logger(name=__name__, category="core")
class ApiInput(BaseModel): class ApiInput(BaseModel):
api: Api api: Api

View file

@ -3,7 +3,6 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
import textwrap import textwrap
from typing import Any from typing import Any
@ -21,9 +20,10 @@ from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
from llama_stack.core.utils.dynamic import instantiate_class_type from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.core.utils.prompt_for_config import prompt_for_config from llama_stack.core.utils.prompt_for_config import prompt_for_config
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, ProviderSpec from llama_stack.providers.datatypes import Api, ProviderSpec
logger = logging.getLogger(__name__) log = get_logger(name=__name__, category="core")
def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provider) -> Provider: def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provider) -> Provider:
@ -49,7 +49,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
is_nux = len(config.providers) == 0 is_nux = len(config.providers) == 0
if is_nux: if is_nux:
logger.info( log.info(
textwrap.dedent( textwrap.dedent(
""" """
Llama Stack is composed of several APIs working together. For each API served by the Stack, Llama Stack is composed of several APIs working together. For each API served by the Stack,
@ -75,12 +75,12 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
existing_providers = config.providers.get(api_str, []) existing_providers = config.providers.get(api_str, [])
if existing_providers: if existing_providers:
logger.info(f"Re-configuring existing providers for API `{api_str}`...") log.info(f"Re-configuring existing providers for API `{api_str}`...")
updated_providers = [] updated_providers = []
for p in existing_providers: for p in existing_providers:
logger.info(f"> Configuring provider `({p.provider_type})`") log.info(f"> Configuring provider `({p.provider_type})`")
updated_providers.append(configure_single_provider(provider_registry[api], p)) updated_providers.append(configure_single_provider(provider_registry[api], p))
logger.info("") log.info("")
else: else:
# we are newly configuring this API # we are newly configuring this API
plist = build_spec.providers.get(api_str, []) plist = build_spec.providers.get(api_str, [])
@ -89,17 +89,17 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
if not plist: if not plist:
raise ValueError(f"No provider configured for API {api_str}?") raise ValueError(f"No provider configured for API {api_str}?")
logger.info(f"Configuring API `{api_str}`...") log.info(f"Configuring API `{api_str}`...")
updated_providers = [] updated_providers = []
for i, provider in enumerate(plist): for i, provider in enumerate(plist):
if i >= 1: if i >= 1:
others = ", ".join(p.provider_type for p in plist[i:]) others = ", ".join(p.provider_type for p in plist[i:])
logger.info( log.info(
f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n" f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n"
) )
break break
logger.info(f"> Configuring provider `({provider.provider_type})`") log.info(f"> Configuring provider `({provider.provider_type})`")
pid = provider.provider_type.split("::")[-1] pid = provider.provider_type.split("::")[-1]
updated_providers.append( updated_providers.append(
configure_single_provider( configure_single_provider(
@ -111,7 +111,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
), ),
) )
) )
logger.info("") log.info("")
config.providers[api_str] = updated_providers config.providers[api_str] = updated_providers
@ -169,7 +169,7 @@ def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfi
return StackRunConfig(**cast_image_name_to_string(processed_config_dict)) return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
if "routing_table" in config_dict: if "routing_table" in config_dict:
logger.info("Upgrading config...") log.info("Upgrading config...")
config_dict = upgrade_from_routing_table(config_dict) config_dict = upgrade_from_routing_table(config_dict)
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION

View file

@ -23,7 +23,7 @@ from llama_stack.providers.datatypes import (
remote_provider_spec, remote_provider_spec,
) )
logger = get_logger(name=__name__, category="core") log = get_logger(name=__name__, category="core")
def stack_apis() -> list[Api]: def stack_apis() -> list[Api]:
@ -141,18 +141,18 @@ def get_provider_registry(config=None) -> dict[Api, dict[str, ProviderSpec]]:
registry: dict[Api, dict[str, ProviderSpec]] = {} registry: dict[Api, dict[str, ProviderSpec]] = {}
for api in providable_apis(): for api in providable_apis():
name = api.name.lower() name = api.name.lower()
logger.debug(f"Importing module {name}") log.debug(f"Importing module {name}")
try: try:
module = importlib.import_module(f"llama_stack.providers.registry.{name}") module = importlib.import_module(f"llama_stack.providers.registry.{name}")
registry[api] = {a.provider_type: a for a in module.available_providers()} registry[api] = {a.provider_type: a for a in module.available_providers()}
except ImportError as e: except ImportError as e:
logger.warning(f"Failed to import module {name}: {e}") log.warning(f"Failed to import module {name}: {e}")
# Refresh providable APIs with external APIs if any # Refresh providable APIs with external APIs if any
external_apis = load_external_apis(config) external_apis = load_external_apis(config)
for api, api_spec in external_apis.items(): for api, api_spec in external_apis.items():
name = api_spec.name.lower() name = api_spec.name.lower()
logger.info(f"Importing external API {name} module {api_spec.module}") log.info(f"Importing external API {name} module {api_spec.module}")
try: try:
module = importlib.import_module(api_spec.module) module = importlib.import_module(api_spec.module)
registry[api] = {a.provider_type: a for a in module.available_providers()} registry[api] = {a.provider_type: a for a in module.available_providers()}
@ -161,7 +161,7 @@ def get_provider_registry(config=None) -> dict[Api, dict[str, ProviderSpec]]:
# This assume that the in-tree provider(s) are not available for this API which means # This assume that the in-tree provider(s) are not available for this API which means
# that users will need to use external providers for this API. # that users will need to use external providers for this API.
registry[api] = {} registry[api] = {}
logger.error( log.error(
f"Failed to import external API {name}: {e}. Could not populate the in-tree provider(s) registry for {api.name}. \n" f"Failed to import external API {name}: {e}. Could not populate the in-tree provider(s) registry for {api.name}. \n"
"Install the API package to load any in-tree providers for this API." "Install the API package to load any in-tree providers for this API."
) )
@ -183,13 +183,13 @@ def get_provider_registry(config=None) -> dict[Api, dict[str, ProviderSpec]]:
def get_external_providers_from_dir( def get_external_providers_from_dir(
registry: dict[Api, dict[str, ProviderSpec]], config registry: dict[Api, dict[str, ProviderSpec]], config
) -> dict[Api, dict[str, ProviderSpec]]: ) -> dict[Api, dict[str, ProviderSpec]]:
logger.warning( log.warning(
"Specifying external providers via `external_providers_dir` is being deprecated. Please specify `module:` in the provider instead." "Specifying external providers via `external_providers_dir` is being deprecated. Please specify `module:` in the provider instead."
) )
external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir)) external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir))
if not os.path.exists(external_providers_dir): if not os.path.exists(external_providers_dir):
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}") raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
logger.info(f"Loading external providers from {external_providers_dir}") log.info(f"Loading external providers from {external_providers_dir}")
for api in providable_apis(): for api in providable_apis():
api_name = api.name.lower() api_name = api.name.lower()
@ -198,13 +198,13 @@ def get_external_providers_from_dir(
for provider_type in ["remote", "inline"]: for provider_type in ["remote", "inline"]:
api_dir = os.path.join(external_providers_dir, provider_type, api_name) api_dir = os.path.join(external_providers_dir, provider_type, api_name)
if not os.path.exists(api_dir): if not os.path.exists(api_dir):
logger.debug(f"No {provider_type} provider directory found for {api_name}") log.debug(f"No {provider_type} provider directory found for {api_name}")
continue continue
# Look for provider spec files in the API directory # Look for provider spec files in the API directory
for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")): for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")):
provider_name = os.path.splitext(os.path.basename(spec_path))[0] provider_name = os.path.splitext(os.path.basename(spec_path))[0]
logger.info(f"Loading {provider_type} provider spec from {spec_path}") log.info(f"Loading {provider_type} provider spec from {spec_path}")
try: try:
with open(spec_path) as f: with open(spec_path) as f:
@ -217,16 +217,16 @@ def get_external_providers_from_dir(
spec = _load_inline_provider_spec(spec_data, api, provider_name) spec = _load_inline_provider_spec(spec_data, api, provider_name)
provider_type_key = f"inline::{provider_name}" provider_type_key = f"inline::{provider_name}"
logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}") log.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
if provider_type_key in registry[api]: if provider_type_key in registry[api]:
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}") log.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
registry[api][provider_type_key] = spec registry[api][provider_type_key] = spec
logger.info(f"Successfully loaded external provider {provider_type_key}") log.info(f"Successfully loaded external provider {provider_type_key}")
except yaml.YAMLError as yaml_err: except yaml.YAMLError as yaml_err:
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}") log.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
raise yaml_err raise yaml_err
except Exception as e: except Exception as e:
logger.error(f"Failed to load provider spec from {spec_path}: {e}") log.error(f"Failed to load provider spec from {spec_path}: {e}")
raise e raise e
return registry return registry
@ -241,7 +241,7 @@ def get_external_providers_from_module(
else: else:
provider_list = config.providers.items() provider_list = config.providers.items()
if provider_list is None: if provider_list is None:
logger.warning("Could not get list of providers from config") log.warning("Could not get list of providers from config")
return registry return registry
for provider_api, providers in provider_list: for provider_api, providers in provider_list:
for provider in providers: for provider in providers:
@ -272,6 +272,6 @@ def get_external_providers_from_module(
"get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available" "get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available"
) from exc ) from exc
except Exception as e: except Exception as e:
logger.error(f"Failed to load provider spec from module {provider.module}: {e}") log.error(f"Failed to load provider spec from module {provider.module}: {e}")
raise e raise e
return registry return registry

View file

@ -11,7 +11,7 @@ from llama_stack.apis.datatypes import Api, ExternalApiSpec
from llama_stack.core.datatypes import BuildConfig, StackRunConfig from llama_stack.core.datatypes import BuildConfig, StackRunConfig
from llama_stack.log import get_logger from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="core") log = get_logger(name=__name__, category="core")
def load_external_apis(config: StackRunConfig | BuildConfig | None) -> dict[Api, ExternalApiSpec]: def load_external_apis(config: StackRunConfig | BuildConfig | None) -> dict[Api, ExternalApiSpec]:
@ -28,10 +28,10 @@ def load_external_apis(config: StackRunConfig | BuildConfig | None) -> dict[Api,
external_apis_dir = config.external_apis_dir.expanduser().resolve() external_apis_dir = config.external_apis_dir.expanduser().resolve()
if not external_apis_dir.is_dir(): if not external_apis_dir.is_dir():
logger.error(f"External APIs directory is not a directory: {external_apis_dir}") log.error(f"External APIs directory is not a directory: {external_apis_dir}")
return {} return {}
logger.info(f"Loading external APIs from {external_apis_dir}") log.info(f"Loading external APIs from {external_apis_dir}")
external_apis: dict[Api, ExternalApiSpec] = {} external_apis: dict[Api, ExternalApiSpec] = {}
# Look for YAML files in the external APIs directory # Look for YAML files in the external APIs directory
@ -42,13 +42,13 @@ def load_external_apis(config: StackRunConfig | BuildConfig | None) -> dict[Api,
spec = ExternalApiSpec(**spec_data) spec = ExternalApiSpec(**spec_data)
api = Api.add(spec.name) api = Api.add(spec.name)
logger.info(f"Loaded external API spec for {spec.name} from {yaml_path}") log.info(f"Loaded external API spec for {spec.name} from {yaml_path}")
external_apis[api] = spec external_apis[api] = spec
except yaml.YAMLError as yaml_err: except yaml.YAMLError as yaml_err:
logger.error(f"Failed to parse YAML file {yaml_path}: {yaml_err}") log.error(f"Failed to parse YAML file {yaml_path}: {yaml_err}")
raise raise
except Exception: except Exception:
logger.exception(f"Failed to load external API spec from {yaml_path}") log.exception(f"Failed to load external API spec from {yaml_path}")
raise raise
return external_apis return external_apis

View file

@ -7,7 +7,6 @@
import asyncio import asyncio
import inspect import inspect
import json import json
import logging
import os import os
import sys import sys
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@ -48,6 +47,7 @@ from llama_stack.core.stack import (
from llama_stack.core.utils.config import redact_sensitive_fields from llama_stack.core.utils.config import redact_sensitive_fields
from llama_stack.core.utils.context import preserve_contexts_async_generator from llama_stack.core.utils.context import preserve_contexts_async_generator
from llama_stack.core.utils.exec import in_notebook from llama_stack.core.utils.exec import in_notebook
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry.tracing import ( from llama_stack.providers.utils.telemetry.tracing import (
CURRENT_TRACE_CONTEXT, CURRENT_TRACE_CONTEXT,
end_trace, end_trace,
@ -55,7 +55,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
start_trace, start_trace,
) )
logger = logging.getLogger(__name__) log = get_logger(name=__name__, category="core")
T = TypeVar("T") T = TypeVar("T")
@ -84,7 +84,7 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any:
try: try:
return [convert_to_pydantic(item_type, item) for item in value] return [convert_to_pydantic(item_type, item) for item in value]
except Exception: except Exception:
logger.error(f"Error converting list {value} into {item_type}") log.error(f"Error converting list {value} into {item_type}")
return value return value
elif origin is dict: elif origin is dict:
@ -92,7 +92,7 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any:
try: try:
return {k: convert_to_pydantic(val_type, v) for k, v in value.items()} return {k: convert_to_pydantic(val_type, v) for k, v in value.items()}
except Exception: except Exception:
logger.error(f"Error converting dict {value} into {val_type}") log.error(f"Error converting dict {value} into {val_type}")
return value return value
try: try:
@ -108,7 +108,7 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any:
return convert_to_pydantic(union_type, value) return convert_to_pydantic(union_type, value)
except Exception: except Exception:
continue continue
logger.warning( log.warning(
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}", f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
) )
raise ValueError(f"Failed to convert parameter {value} into {annotation}: {e}") from e raise ValueError(f"Failed to convert parameter {value} into {annotation}: {e}") from e
@ -171,13 +171,15 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
def _remove_root_logger_handlers(self): def _remove_root_logger_handlers(self):
""" """
Remove all handlers from the root logger. Needed to avoid polluting the console with logs. Remove all handlers from the root log. Needed to avoid polluting the console with logs.
""" """
import logging # allow-direct-logging
root_logger = logging.getLogger() root_logger = logging.getLogger()
for handler in root_logger.handlers[:]: for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler) root_logger.removeHandler(handler)
logger.info(f"Removed handler {handler.__class__.__name__} from root logger") log.info(f"Removed handler {handler.__class__.__name__} from root log")
def request(self, *args, **kwargs): def request(self, *args, **kwargs):
loop = self.loop loop = self.loop

View file

@ -16,7 +16,7 @@ from llama_stack.providers.datatypes import HealthResponse, HealthStatus
from .datatypes import StackRunConfig from .datatypes import StackRunConfig
from .utils.config import redact_sensitive_fields from .utils.config import redact_sensitive_fields
logger = get_logger(name=__name__, category="core") log = get_logger(name=__name__, category="core")
class ProviderImplConfig(BaseModel): class ProviderImplConfig(BaseModel):
@ -38,7 +38,7 @@ class ProviderImpl(Providers):
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logger.debug("ProviderImpl.shutdown") log.debug("ProviderImpl.shutdown")
pass pass
async def list_providers(self) -> ListProvidersResponse: async def list_providers(self) -> ListProvidersResponse:

View file

@ -6,19 +6,19 @@
import contextvars import contextvars
import json import json
import logging
from contextlib import AbstractContextManager from contextlib import AbstractContextManager
from typing import Any from typing import Any
from llama_stack.core.datatypes import User from llama_stack.core.datatypes import User
from llama_stack.log import get_logger
from .utils.dynamic import instantiate_class_type from .utils.dynamic import instantiate_class_type
log = logging.getLogger(__name__)
# Context variable for request provider data and auth attributes # Context variable for request provider data and auth attributes
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None) PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
log = get_logger(name=__name__, category="core")
class RequestProviderDataContext(AbstractContextManager): class RequestProviderDataContext(AbstractContextManager):
"""Context manager for request provider data""" """Context manager for request provider data"""

View file

@ -54,7 +54,7 @@ from llama_stack.providers.datatypes import (
VectorDBsProtocolPrivate, VectorDBsProtocolPrivate,
) )
logger = get_logger(name=__name__, category="core") log = get_logger(name=__name__, category="core")
class InvalidProviderError(Exception): class InvalidProviderError(Exception):
@ -101,7 +101,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
protocols[api] = api_class protocols[api] = api_class
except (ImportError, AttributeError): except (ImportError, AttributeError):
logger.exception(f"Failed to load external API {api_spec.name}") log.exception(f"Failed to load external API {api_spec.name}")
return protocols return protocols
@ -223,7 +223,7 @@ def validate_and_prepare_providers(
specs = {} specs = {}
for provider in providers: for provider in providers:
if not provider.provider_id or provider.provider_id == "__disabled__": if not provider.provider_id or provider.provider_id == "__disabled__":
logger.debug(f"Provider `{provider.provider_type}` for API `{api}` is disabled") log.debug(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
continue continue
validate_provider(provider, api, provider_registry) validate_provider(provider, api, provider_registry)
@ -245,10 +245,10 @@ def validate_provider(provider: Provider, api: Api, provider_registry: ProviderR
p = provider_registry[api][provider.provider_type] p = provider_registry[api][provider.provider_type]
if p.deprecation_error: if p.deprecation_error:
logger.error(p.deprecation_error) log.error(p.deprecation_error)
raise InvalidProviderError(p.deprecation_error) raise InvalidProviderError(p.deprecation_error)
elif p.deprecation_warning: elif p.deprecation_warning:
logger.warning( log.warning(
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}", f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
) )
@ -261,9 +261,9 @@ def sort_providers_by_deps(
{k: list(v.values()) for k, v in providers_with_specs.items()} {k: list(v.values()) for k, v in providers_with_specs.items()}
) )
logger.debug(f"Resolved {len(sorted_providers)} providers") log.debug(f"Resolved {len(sorted_providers)} providers")
for api_str, provider in sorted_providers: for api_str, provider in sorted_providers:
logger.debug(f" {api_str} => {provider.provider_id}") log.debug(f" {api_str} => {provider.provider_id}")
return sorted_providers return sorted_providers
@ -348,7 +348,7 @@ async def instantiate_provider(
if not hasattr(provider_spec, "module") or provider_spec.module is None: if not hasattr(provider_spec, "module") or provider_spec.module is None:
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute") raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
logger.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}") log.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}")
module = importlib.import_module(provider_spec.module) module = importlib.import_module(provider_spec.module)
args = [] args = []
if isinstance(provider_spec, RemoteProviderSpec): if isinstance(provider_spec, RemoteProviderSpec):
@ -418,7 +418,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
obj_params = set(obj_sig.parameters) obj_params = set(obj_sig.parameters)
obj_params.discard("self") obj_params.discard("self")
if not (proto_params <= obj_params): if not (proto_params <= obj_params):
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}") log.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
missing_methods.append((name, "signature_mismatch")) missing_methods.append((name, "signature_mismatch"))
else: else:
# Check if the method has a concrete implementation (not just a protocol stub) # Check if the method has a concrete implementation (not just a protocol stub)

View file

@ -12,7 +12,7 @@ from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core") log = get_logger(name=__name__, category="core")
class DatasetIORouter(DatasetIO): class DatasetIORouter(DatasetIO):
@ -20,15 +20,15 @@ class DatasetIORouter(DatasetIO):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logger.debug("Initializing DatasetIORouter") log.debug("Initializing DatasetIORouter")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
logger.debug("DatasetIORouter.initialize") log.debug("DatasetIORouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logger.debug("DatasetIORouter.shutdown") log.debug("DatasetIORouter.shutdown")
pass pass
async def register_dataset( async def register_dataset(
@ -38,7 +38,7 @@ class DatasetIORouter(DatasetIO):
metadata: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None,
dataset_id: str | None = None, dataset_id: str | None = None,
) -> None: ) -> None:
logger.debug( log.debug(
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}", f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
) )
await self.routing_table.register_dataset( await self.routing_table.register_dataset(
@ -54,7 +54,7 @@ class DatasetIORouter(DatasetIO):
start_index: int | None = None, start_index: int | None = None,
limit: int | None = None, limit: int | None = None,
) -> PaginatedResponse: ) -> PaginatedResponse:
logger.debug( log.debug(
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}", f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
) )
provider = await self.routing_table.get_provider_impl(dataset_id) provider = await self.routing_table.get_provider_impl(dataset_id)
@ -65,7 +65,7 @@ class DatasetIORouter(DatasetIO):
) )
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows") log.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
provider = await self.routing_table.get_provider_impl(dataset_id) provider = await self.routing_table.get_provider_impl(dataset_id)
return await provider.append_rows( return await provider.append_rows(
dataset_id=dataset_id, dataset_id=dataset_id,

View file

@ -16,7 +16,7 @@ from llama_stack.apis.scoring import (
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core") log = get_logger(name=__name__, category="core")
class ScoringRouter(Scoring): class ScoringRouter(Scoring):
@ -24,15 +24,15 @@ class ScoringRouter(Scoring):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logger.debug("Initializing ScoringRouter") log.debug("Initializing ScoringRouter")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
logger.debug("ScoringRouter.initialize") log.debug("ScoringRouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logger.debug("ScoringRouter.shutdown") log.debug("ScoringRouter.shutdown")
pass pass
async def score_batch( async def score_batch(
@ -41,7 +41,7 @@ class ScoringRouter(Scoring):
scoring_functions: dict[str, ScoringFnParams | None] = None, scoring_functions: dict[str, ScoringFnParams | None] = None,
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
logger.debug(f"ScoringRouter.score_batch: {dataset_id}") log.debug(f"ScoringRouter.score_batch: {dataset_id}")
res = {} res = {}
for fn_identifier in scoring_functions.keys(): for fn_identifier in scoring_functions.keys():
provider = await self.routing_table.get_provider_impl(fn_identifier) provider = await self.routing_table.get_provider_impl(fn_identifier)
@ -63,7 +63,7 @@ class ScoringRouter(Scoring):
input_rows: list[dict[str, Any]], input_rows: list[dict[str, Any]],
scoring_functions: dict[str, ScoringFnParams | None] = None, scoring_functions: dict[str, ScoringFnParams | None] = None,
) -> ScoreResponse: ) -> ScoreResponse:
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions") log.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
res = {} res = {}
# look up and map each scoring function to its provider impl # look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys(): for fn_identifier in scoring_functions.keys():
@ -82,15 +82,15 @@ class EvalRouter(Eval):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logger.debug("Initializing EvalRouter") log.debug("Initializing EvalRouter")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
logger.debug("EvalRouter.initialize") log.debug("EvalRouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logger.debug("EvalRouter.shutdown") log.debug("EvalRouter.shutdown")
pass pass
async def run_eval( async def run_eval(
@ -98,7 +98,7 @@ class EvalRouter(Eval):
benchmark_id: str, benchmark_id: str,
benchmark_config: BenchmarkConfig, benchmark_config: BenchmarkConfig,
) -> Job: ) -> Job:
logger.debug(f"EvalRouter.run_eval: {benchmark_id}") log.debug(f"EvalRouter.run_eval: {benchmark_id}")
provider = await self.routing_table.get_provider_impl(benchmark_id) provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.run_eval( return await provider.run_eval(
benchmark_id=benchmark_id, benchmark_id=benchmark_id,
@ -112,7 +112,7 @@ class EvalRouter(Eval):
scoring_functions: list[str], scoring_functions: list[str],
benchmark_config: BenchmarkConfig, benchmark_config: BenchmarkConfig,
) -> EvaluateResponse: ) -> EvaluateResponse:
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows") log.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
provider = await self.routing_table.get_provider_impl(benchmark_id) provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.evaluate_rows( return await provider.evaluate_rows(
benchmark_id=benchmark_id, benchmark_id=benchmark_id,
@ -126,7 +126,7 @@ class EvalRouter(Eval):
benchmark_id: str, benchmark_id: str,
job_id: str, job_id: str,
) -> Job: ) -> Job:
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}") log.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
provider = await self.routing_table.get_provider_impl(benchmark_id) provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.job_status(benchmark_id, job_id) return await provider.job_status(benchmark_id, job_id)
@ -135,7 +135,7 @@ class EvalRouter(Eval):
benchmark_id: str, benchmark_id: str,
job_id: str, job_id: str,
) -> None: ) -> None:
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}") log.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
provider = await self.routing_table.get_provider_impl(benchmark_id) provider = await self.routing_table.get_provider_impl(benchmark_id)
await provider.job_cancel( await provider.job_cancel(
benchmark_id, benchmark_id,
@ -147,7 +147,7 @@ class EvalRouter(Eval):
benchmark_id: str, benchmark_id: str,
job_id: str, job_id: str,
) -> EvaluateResponse: ) -> EvaluateResponse:
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}") log.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
provider = await self.routing_table.get_provider_impl(benchmark_id) provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.job_result( return await provider.job_result(
benchmark_id, benchmark_id,

View file

@ -58,7 +58,7 @@ from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack.providers.utils.inference.stream_utils import stream_and_store_openai_completion from llama_stack.providers.utils.inference.stream_utils import stream_and_store_openai_completion
from llama_stack.providers.utils.telemetry.tracing import get_current_span from llama_stack.providers.utils.telemetry.tracing import get_current_span
logger = get_logger(name=__name__, category="core") log = get_logger(name=__name__, category="core")
class InferenceRouter(Inference): class InferenceRouter(Inference):
@ -70,7 +70,7 @@ class InferenceRouter(Inference):
telemetry: Telemetry | None = None, telemetry: Telemetry | None = None,
store: InferenceStore | None = None, store: InferenceStore | None = None,
) -> None: ) -> None:
logger.debug("Initializing InferenceRouter") log.debug("Initializing InferenceRouter")
self.routing_table = routing_table self.routing_table = routing_table
self.telemetry = telemetry self.telemetry = telemetry
self.store = store self.store = store
@ -79,10 +79,10 @@ class InferenceRouter(Inference):
self.formatter = ChatFormat(self.tokenizer) self.formatter = ChatFormat(self.tokenizer)
async def initialize(self) -> None: async def initialize(self) -> None:
logger.debug("InferenceRouter.initialize") log.debug("InferenceRouter.initialize")
async def shutdown(self) -> None: async def shutdown(self) -> None:
logger.debug("InferenceRouter.shutdown") log.debug("InferenceRouter.shutdown")
async def register_model( async def register_model(
self, self,
@ -92,7 +92,7 @@ class InferenceRouter(Inference):
metadata: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None, model_type: ModelType | None = None,
) -> None: ) -> None:
logger.debug( log.debug(
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}", f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
) )
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
@ -117,7 +117,7 @@ class InferenceRouter(Inference):
""" """
span = get_current_span() span = get_current_span()
if span is None: if span is None:
logger.warning("No span found for token usage metrics") log.warning("No span found for token usage metrics")
return [] return []
metrics = [ metrics = [
("prompt_tokens", prompt_tokens), ("prompt_tokens", prompt_tokens),
@ -182,7 +182,7 @@ class InferenceRouter(Inference):
logprobs: LogProbConfig | None = None, logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None, tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]: ) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
logger.debug( log.debug(
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
) )
if sampling_params is None: if sampling_params is None:
@ -288,7 +288,7 @@ class InferenceRouter(Inference):
response_format: ResponseFormat | None = None, response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None, logprobs: LogProbConfig | None = None,
) -> BatchChatCompletionResponse: ) -> BatchChatCompletionResponse:
logger.debug( log.debug(
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}", f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
) )
provider = await self.routing_table.get_provider_impl(model_id) provider = await self.routing_table.get_provider_impl(model_id)
@ -313,7 +313,7 @@ class InferenceRouter(Inference):
) -> AsyncGenerator: ) -> AsyncGenerator:
if sampling_params is None: if sampling_params is None:
sampling_params = SamplingParams() sampling_params = SamplingParams()
logger.debug( log.debug(
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}", f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
) )
model = await self.routing_table.get_model(model_id) model = await self.routing_table.get_model(model_id)
@ -374,7 +374,7 @@ class InferenceRouter(Inference):
response_format: ResponseFormat | None = None, response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None, logprobs: LogProbConfig | None = None,
) -> BatchCompletionResponse: ) -> BatchCompletionResponse:
logger.debug( log.debug(
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}", f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
) )
provider = await self.routing_table.get_provider_impl(model_id) provider = await self.routing_table.get_provider_impl(model_id)
@ -388,7 +388,7 @@ class InferenceRouter(Inference):
output_dimension: int | None = None, output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None, task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
logger.debug(f"InferenceRouter.embeddings: {model_id}") log.debug(f"InferenceRouter.embeddings: {model_id}")
model = await self.routing_table.get_model(model_id) model = await self.routing_table.get_model(model_id)
if model is None: if model is None:
raise ModelNotFoundError(model_id) raise ModelNotFoundError(model_id)
@ -426,7 +426,7 @@ class InferenceRouter(Inference):
prompt_logprobs: int | None = None, prompt_logprobs: int | None = None,
suffix: str | None = None, suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
logger.debug( log.debug(
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}", f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
) )
model_obj = await self.routing_table.get_model(model) model_obj = await self.routing_table.get_model(model)
@ -487,7 +487,7 @@ class InferenceRouter(Inference):
top_p: float | None = None, top_p: float | None = None,
user: str | None = None, user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
logger.debug( log.debug(
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}", f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
) )
model_obj = await self.routing_table.get_model(model) model_obj = await self.routing_table.get_model(model)
@ -558,7 +558,7 @@ class InferenceRouter(Inference):
dimensions: int | None = None, dimensions: int | None = None,
user: str | None = None, user: str | None = None,
) -> OpenAIEmbeddingsResponse: ) -> OpenAIEmbeddingsResponse:
logger.debug( log.debug(
f"InferenceRouter.openai_embeddings: {model=}, input_type={type(input)}, {encoding_format=}, {dimensions=}", f"InferenceRouter.openai_embeddings: {model=}, input_type={type(input)}, {encoding_format=}, {dimensions=}",
) )
model_obj = await self.routing_table.get_model(model) model_obj = await self.routing_table.get_model(model)

View file

@ -14,7 +14,7 @@ from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core") log = get_logger(name=__name__, category="core")
class SafetyRouter(Safety): class SafetyRouter(Safety):
@ -22,15 +22,15 @@ class SafetyRouter(Safety):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logger.debug("Initializing SafetyRouter") log.debug("Initializing SafetyRouter")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
logger.debug("SafetyRouter.initialize") log.debug("SafetyRouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logger.debug("SafetyRouter.shutdown") log.debug("SafetyRouter.shutdown")
pass pass
async def register_shield( async def register_shield(
@ -40,7 +40,7 @@ class SafetyRouter(Safety):
provider_id: str | None = None, provider_id: str | None = None,
params: dict[str, Any] | None = None, params: dict[str, Any] | None = None,
) -> Shield: ) -> Shield:
logger.debug(f"SafetyRouter.register_shield: {shield_id}") log.debug(f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
async def unregister_shield(self, identifier: str) -> None: async def unregister_shield(self, identifier: str) -> None:
@ -53,7 +53,7 @@ class SafetyRouter(Safety):
messages: list[Message], messages: list[Message],
params: dict[str, Any] = None, params: dict[str, Any] = None,
) -> RunShieldResponse: ) -> RunShieldResponse:
logger.debug(f"SafetyRouter.run_shield: {shield_id}") log.debug(f"SafetyRouter.run_shield: {shield_id}")
provider = await self.routing_table.get_provider_impl(shield_id) provider = await self.routing_table.get_provider_impl(shield_id)
return await provider.run_shield( return await provider.run_shield(
shield_id=shield_id, shield_id=shield_id,

View file

@ -22,7 +22,7 @@ from llama_stack.log import get_logger
from ..routing_tables.toolgroups import ToolGroupsRoutingTable from ..routing_tables.toolgroups import ToolGroupsRoutingTable
logger = get_logger(name=__name__, category="core") log = get_logger(name=__name__, category="core")
class ToolRuntimeRouter(ToolRuntime): class ToolRuntimeRouter(ToolRuntime):
@ -31,7 +31,7 @@ class ToolRuntimeRouter(ToolRuntime):
self, self,
routing_table: ToolGroupsRoutingTable, routing_table: ToolGroupsRoutingTable,
) -> None: ) -> None:
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl") log.debug("Initializing ToolRuntimeRouter.RagToolImpl")
self.routing_table = routing_table self.routing_table = routing_table
async def query( async def query(
@ -40,7 +40,7 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_ids: list[str], vector_db_ids: list[str],
query_config: RAGQueryConfig | None = None, query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult: ) -> RAGQueryResult:
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}") log.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
provider = await self.routing_table.get_provider_impl("knowledge_search") provider = await self.routing_table.get_provider_impl("knowledge_search")
return await provider.query(content, vector_db_ids, query_config) return await provider.query(content, vector_db_ids, query_config)
@ -50,7 +50,7 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_id: str, vector_db_id: str,
chunk_size_in_tokens: int = 512, chunk_size_in_tokens: int = 512,
) -> None: ) -> None:
logger.debug( log.debug(
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}" f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
) )
provider = await self.routing_table.get_provider_impl("insert_into_memory") provider = await self.routing_table.get_provider_impl("insert_into_memory")
@ -60,7 +60,7 @@ class ToolRuntimeRouter(ToolRuntime):
self, self,
routing_table: ToolGroupsRoutingTable, routing_table: ToolGroupsRoutingTable,
) -> None: ) -> None:
logger.debug("Initializing ToolRuntimeRouter") log.debug("Initializing ToolRuntimeRouter")
self.routing_table = routing_table self.routing_table = routing_table
# HACK ALERT this should be in sync with "get_all_api_endpoints()" # HACK ALERT this should be in sync with "get_all_api_endpoints()"
@ -69,15 +69,15 @@ class ToolRuntimeRouter(ToolRuntime):
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method)) setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
async def initialize(self) -> None: async def initialize(self) -> None:
logger.debug("ToolRuntimeRouter.initialize") log.debug("ToolRuntimeRouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logger.debug("ToolRuntimeRouter.shutdown") log.debug("ToolRuntimeRouter.shutdown")
pass pass
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any: async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}") log.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
provider = await self.routing_table.get_provider_impl(tool_name) provider = await self.routing_table.get_provider_impl(tool_name)
return await provider.invoke_tool( return await provider.invoke_tool(
tool_name=tool_name, tool_name=tool_name,
@ -87,5 +87,5 @@ class ToolRuntimeRouter(ToolRuntime):
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolsResponse: ) -> ListToolsResponse:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") log.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.list_tools(tool_group_id) return await self.routing_table.list_tools(tool_group_id)

View file

@ -30,7 +30,7 @@ from llama_stack.apis.vector_io import (
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
logger = get_logger(name=__name__, category="core") log = get_logger(name=__name__, category="core")
class VectorIORouter(VectorIO): class VectorIORouter(VectorIO):
@ -40,15 +40,15 @@ class VectorIORouter(VectorIO):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logger.debug("Initializing VectorIORouter") log.debug("Initializing VectorIORouter")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
logger.debug("VectorIORouter.initialize") log.debug("VectorIORouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logger.debug("VectorIORouter.shutdown") log.debug("VectorIORouter.shutdown")
pass pass
async def _get_first_embedding_model(self) -> tuple[str, int] | None: async def _get_first_embedding_model(self) -> tuple[str, int] | None:
@ -70,10 +70,10 @@ class VectorIORouter(VectorIO):
raise ValueError(f"Embedding model {embedding_models[0].identifier} has no embedding dimension") raise ValueError(f"Embedding model {embedding_models[0].identifier} has no embedding dimension")
return embedding_models[0].identifier, dimension return embedding_models[0].identifier, dimension
else: else:
logger.warning("No embedding models found in the routing table") log.warning("No embedding models found in the routing table")
return None return None
except Exception as e: except Exception as e:
logger.error(f"Error getting embedding models: {e}") log.error(f"Error getting embedding models: {e}")
return None return None
async def register_vector_db( async def register_vector_db(
@ -85,7 +85,7 @@ class VectorIORouter(VectorIO):
vector_db_name: str | None = None, vector_db_name: str | None = None,
provider_vector_db_id: str | None = None, provider_vector_db_id: str | None = None,
) -> None: ) -> None:
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}") log.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
await self.routing_table.register_vector_db( await self.routing_table.register_vector_db(
vector_db_id, vector_db_id,
embedding_model, embedding_model,
@ -101,7 +101,7 @@ class VectorIORouter(VectorIO):
chunks: list[Chunk], chunks: list[Chunk],
ttl_seconds: int | None = None, ttl_seconds: int | None = None,
) -> None: ) -> None:
logger.debug( log.debug(
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
) )
provider = await self.routing_table.get_provider_impl(vector_db_id) provider = await self.routing_table.get_provider_impl(vector_db_id)
@ -113,7 +113,7 @@ class VectorIORouter(VectorIO):
query: InterleavedContent, query: InterleavedContent,
params: dict[str, Any] | None = None, params: dict[str, Any] | None = None,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}") log.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
provider = await self.routing_table.get_provider_impl(vector_db_id) provider = await self.routing_table.get_provider_impl(vector_db_id)
return await provider.query_chunks(vector_db_id, query, params) return await provider.query_chunks(vector_db_id, query, params)
@ -129,7 +129,7 @@ class VectorIORouter(VectorIO):
embedding_dimension: int | None = None, embedding_dimension: int | None = None,
provider_id: str | None = None, provider_id: str | None = None,
) -> VectorStoreObject: ) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}") log.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}")
# If no embedding model is provided, use the first available one # If no embedding model is provided, use the first available one
if embedding_model is None: if embedding_model is None:
@ -137,7 +137,7 @@ class VectorIORouter(VectorIO):
if embedding_model_info is None: if embedding_model_info is None:
raise ValueError("No embedding model provided and no embedding models available in the system") raise ValueError("No embedding model provided and no embedding models available in the system")
embedding_model, embedding_dimension = embedding_model_info embedding_model, embedding_dimension = embedding_model_info
logger.info(f"No embedding model specified, using first available: {embedding_model}") log.info(f"No embedding model specified, using first available: {embedding_model}")
vector_db_id = f"vs_{uuid.uuid4()}" vector_db_id = f"vs_{uuid.uuid4()}"
registered_vector_db = await self.routing_table.register_vector_db( registered_vector_db = await self.routing_table.register_vector_db(
@ -168,7 +168,7 @@ class VectorIORouter(VectorIO):
after: str | None = None, after: str | None = None,
before: str | None = None, before: str | None = None,
) -> VectorStoreListResponse: ) -> VectorStoreListResponse:
logger.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}") log.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}")
# Route to default provider for now - could aggregate from all providers in the future # Route to default provider for now - could aggregate from all providers in the future
# call retrieve on each vector dbs to get list of vector stores # call retrieve on each vector dbs to get list of vector stores
vector_dbs = await self.routing_table.get_all_with_type("vector_db") vector_dbs = await self.routing_table.get_all_with_type("vector_db")
@ -179,7 +179,7 @@ class VectorIORouter(VectorIO):
vector_store = await provider.openai_retrieve_vector_store(vector_db.identifier) vector_store = await provider.openai_retrieve_vector_store(vector_db.identifier)
all_stores.append(vector_store) all_stores.append(vector_store)
except Exception as e: except Exception as e:
logger.error(f"Error retrieving vector store {vector_db.identifier}: {e}") log.error(f"Error retrieving vector store {vector_db.identifier}: {e}")
continue continue
# Sort by created_at # Sort by created_at
@ -215,7 +215,7 @@ class VectorIORouter(VectorIO):
self, self,
vector_store_id: str, vector_store_id: str,
) -> VectorStoreObject: ) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}") log.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}")
return await self.routing_table.openai_retrieve_vector_store(vector_store_id) return await self.routing_table.openai_retrieve_vector_store(vector_store_id)
async def openai_update_vector_store( async def openai_update_vector_store(
@ -225,7 +225,7 @@ class VectorIORouter(VectorIO):
expires_after: dict[str, Any] | None = None, expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None,
) -> VectorStoreObject: ) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}") log.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}")
return await self.routing_table.openai_update_vector_store( return await self.routing_table.openai_update_vector_store(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
name=name, name=name,
@ -237,7 +237,7 @@ class VectorIORouter(VectorIO):
self, self,
vector_store_id: str, vector_store_id: str,
) -> VectorStoreDeleteResponse: ) -> VectorStoreDeleteResponse:
logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}") log.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
return await self.routing_table.openai_delete_vector_store(vector_store_id) return await self.routing_table.openai_delete_vector_store(vector_store_id)
async def openai_search_vector_store( async def openai_search_vector_store(
@ -250,7 +250,7 @@ class VectorIORouter(VectorIO):
rewrite_query: bool | None = False, rewrite_query: bool | None = False,
search_mode: str | None = "vector", search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage: ) -> VectorStoreSearchResponsePage:
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}") log.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
return await self.routing_table.openai_search_vector_store( return await self.routing_table.openai_search_vector_store(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
query=query, query=query,
@ -268,7 +268,7 @@ class VectorIORouter(VectorIO):
attributes: dict[str, Any] | None = None, attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None, chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject: ) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}") log.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}")
return await self.routing_table.openai_attach_file_to_vector_store( return await self.routing_table.openai_attach_file_to_vector_store(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
file_id=file_id, file_id=file_id,
@ -285,7 +285,7 @@ class VectorIORouter(VectorIO):
before: str | None = None, before: str | None = None,
filter: VectorStoreFileStatus | None = None, filter: VectorStoreFileStatus | None = None,
) -> list[VectorStoreFileObject]: ) -> list[VectorStoreFileObject]:
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}") log.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}")
return await self.routing_table.openai_list_files_in_vector_store( return await self.routing_table.openai_list_files_in_vector_store(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
limit=limit, limit=limit,
@ -300,7 +300,7 @@ class VectorIORouter(VectorIO):
vector_store_id: str, vector_store_id: str,
file_id: str, file_id: str,
) -> VectorStoreFileObject: ) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}") log.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}")
return await self.routing_table.openai_retrieve_vector_store_file( return await self.routing_table.openai_retrieve_vector_store_file(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
file_id=file_id, file_id=file_id,
@ -311,7 +311,7 @@ class VectorIORouter(VectorIO):
vector_store_id: str, vector_store_id: str,
file_id: str, file_id: str,
) -> VectorStoreFileContentsResponse: ) -> VectorStoreFileContentsResponse:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}") log.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}")
return await self.routing_table.openai_retrieve_vector_store_file_contents( return await self.routing_table.openai_retrieve_vector_store_file_contents(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
file_id=file_id, file_id=file_id,
@ -323,7 +323,7 @@ class VectorIORouter(VectorIO):
file_id: str, file_id: str,
attributes: dict[str, Any], attributes: dict[str, Any],
) -> VectorStoreFileObject: ) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}") log.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}")
return await self.routing_table.openai_update_vector_store_file( return await self.routing_table.openai_update_vector_store_file(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
file_id=file_id, file_id=file_id,
@ -335,7 +335,7 @@ class VectorIORouter(VectorIO):
vector_store_id: str, vector_store_id: str,
file_id: str, file_id: str,
) -> VectorStoreFileDeleteResponse: ) -> VectorStoreFileDeleteResponse:
logger.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}") log.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}")
return await self.routing_table.openai_delete_vector_store_file( return await self.routing_table.openai_delete_vector_store_file(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
file_id=file_id, file_id=file_id,

View file

@ -14,7 +14,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core") log = get_logger(name=__name__, category="core")
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks): class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):

View file

@ -23,7 +23,7 @@ from llama_stack.core.store import DistributionRegistry
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, RoutingTable from llama_stack.providers.datatypes import Api, RoutingTable
logger = get_logger(name=__name__, category="core") log = get_logger(name=__name__, category="core")
def get_impl_api(p: Any) -> Api: def get_impl_api(p: Any) -> Api:
@ -177,7 +177,7 @@ class CommonRoutingTableImpl(RoutingTable):
# Check if user has permission to access this object # Check if user has permission to access this object
if not is_action_allowed(self.policy, "read", obj, get_authenticated_user()): if not is_action_allowed(self.policy, "read", obj, get_authenticated_user()):
logger.debug(f"Access denied to {type} '{identifier}'") log.debug(f"Access denied to {type} '{identifier}'")
return None return None
return obj return obj
@ -205,7 +205,7 @@ class CommonRoutingTableImpl(RoutingTable):
raise AccessDeniedError("create", obj, creator) raise AccessDeniedError("create", obj, creator)
if creator: if creator:
obj.owner = creator obj.owner = creator
logger.info(f"Setting owner for {obj.type} '{obj.identifier}' to {obj.owner.principal}") log.info(f"Setting owner for {obj.type} '{obj.identifier}' to {obj.owner.principal}")
registered_obj = await register_object_with_provider(obj, p) registered_obj = await register_object_with_provider(obj, p)
# TODO: This needs to be fixed for all APIs once they return the registered object # TODO: This needs to be fixed for all APIs once they return the registered object
@ -250,7 +250,7 @@ async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) ->
if model is not None: if model is not None:
return model return model
logger.warning( log.warning(
f"WARNING: model identifier '{model_id}' not found in routing table. Falling back to " f"WARNING: model identifier '{model_id}' not found in routing table. Falling back to "
"searching in all providers. This is only for backwards compatibility and will stop working " "searching in all providers. This is only for backwards compatibility and will stop working "
"soon. Migrate your calls to use fully scoped `provider_id/model_id` names." "soon. Migrate your calls to use fully scoped `provider_id/model_id` names."

View file

@ -26,7 +26,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core") log = get_logger(name=__name__, category="core")
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):

View file

@ -17,7 +17,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core") log = get_logger(name=__name__, category="core")
class ModelsRoutingTable(CommonRoutingTableImpl, Models): class ModelsRoutingTable(CommonRoutingTableImpl, Models):
@ -33,7 +33,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
try: try:
models = await provider.list_models() models = await provider.list_models()
except Exception as e: except Exception as e:
logger.exception(f"Model refresh failed for provider {provider_id}: {e}") log.exception(f"Model refresh failed for provider {provider_id}: {e}")
continue continue
self.listed_providers.add(provider_id) self.listed_providers.add(provider_id)
@ -132,7 +132,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
model_ids[model.provider_resource_id] = model.identifier model_ids[model.provider_resource_id] = model.identifier
continue continue
logger.debug(f"unregistering model {model.identifier}") log.debug(f"unregistering model {model.identifier}")
await self.unregister_object(model) await self.unregister_object(model)
for model in models: for model in models:
@ -143,7 +143,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
if model.identifier == model.provider_resource_id: if model.identifier == model.provider_resource_id:
model.identifier = f"{provider_id}/{model.provider_resource_id}" model.identifier = f"{provider_id}/{model.provider_resource_id}"
logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})") log.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
await self.register_object( await self.register_object(
ModelWithOwner( ModelWithOwner(
identifier=model.identifier, identifier=model.identifier,

View file

@ -19,7 +19,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core") log = get_logger(name=__name__, category="core")
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):

View file

@ -15,7 +15,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core") log = get_logger(name=__name__, category="core")
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):

View file

@ -14,7 +14,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core") log = get_logger(name=__name__, category="core")
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None: def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:

View file

@ -30,7 +30,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core") log = get_logger(name=__name__, category="core")
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
@ -57,7 +57,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
if len(self.impls_by_provider_id) > 0: if len(self.impls_by_provider_id) > 0:
provider_id = list(self.impls_by_provider_id.keys())[0] provider_id = list(self.impls_by_provider_id.keys())[0]
if len(self.impls_by_provider_id) > 1: if len(self.impls_by_provider_id) > 1:
logger.warning( log.warning(
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}." f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
) )
else: else:

View file

@ -15,7 +15,7 @@ from llama_stack.core.server.auth_providers import create_auth_provider
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
from llama_stack.log import get_logger from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth") log = get_logger(name=__name__, category="auth")
class AuthenticationMiddleware: class AuthenticationMiddleware:
@ -105,13 +105,13 @@ class AuthenticationMiddleware:
try: try:
validation_result = await self.auth_provider.validate_token(token, scope) validation_result = await self.auth_provider.validate_token(token, scope)
except httpx.TimeoutException: except httpx.TimeoutException:
logger.exception("Authentication request timed out") log.exception("Authentication request timed out")
return await self._send_auth_error(send, "Authentication service timeout") return await self._send_auth_error(send, "Authentication service timeout")
except ValueError as e: except ValueError as e:
logger.exception("Error during authentication") log.exception("Error during authentication")
return await self._send_auth_error(send, str(e)) return await self._send_auth_error(send, str(e))
except Exception: except Exception:
logger.exception("Error during authentication") log.exception("Error during authentication")
return await self._send_auth_error(send, "Authentication service error") return await self._send_auth_error(send, "Authentication service error")
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware) # Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
@ -122,7 +122,7 @@ class AuthenticationMiddleware:
scope["principal"] = validation_result.principal scope["principal"] = validation_result.principal
if validation_result.attributes: if validation_result.attributes:
scope["user_attributes"] = validation_result.attributes scope["user_attributes"] = validation_result.attributes
logger.debug( log.debug(
f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes" f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes"
) )

View file

@ -23,7 +23,7 @@ from llama_stack.core.datatypes import (
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth") log = get_logger(name=__name__, category="auth")
class AuthResponse(BaseModel): class AuthResponse(BaseModel):
@ -163,7 +163,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
timeout=10.0, # Add a reasonable timeout timeout=10.0, # Add a reasonable timeout
) )
if response.status_code != 200: if response.status_code != 200:
logger.warning(f"Token introspection failed with status code: {response.status_code}") log.warning(f"Token introspection failed with status code: {response.status_code}")
raise ValueError(f"Token introspection failed: {response.status_code}") raise ValueError(f"Token introspection failed: {response.status_code}")
fields = response.json() fields = response.json()
@ -176,13 +176,13 @@ class OAuth2TokenAuthProvider(AuthProvider):
attributes=access_attributes, attributes=access_attributes,
) )
except httpx.TimeoutException: except httpx.TimeoutException:
logger.exception("Token introspection request timed out") log.exception("Token introspection request timed out")
raise raise
except ValueError: except ValueError:
# Re-raise ValueError exceptions to preserve their message # Re-raise ValueError exceptions to preserve their message
raise raise
except Exception as e: except Exception as e:
logger.exception("Error during token introspection") log.exception("Error during token introspection")
raise ValueError("Token introspection error") from e raise ValueError("Token introspection error") from e
async def close(self): async def close(self):
@ -273,7 +273,7 @@ class CustomAuthProvider(AuthProvider):
timeout=10.0, # Add a reasonable timeout timeout=10.0, # Add a reasonable timeout
) )
if response.status_code != 200: if response.status_code != 200:
logger.warning(f"Authentication failed with status code: {response.status_code}") log.warning(f"Authentication failed with status code: {response.status_code}")
raise ValueError(f"Authentication failed: {response.status_code}") raise ValueError(f"Authentication failed: {response.status_code}")
# Parse and validate the auth response # Parse and validate the auth response
@ -282,17 +282,17 @@ class CustomAuthProvider(AuthProvider):
auth_response = AuthResponse(**response_data) auth_response = AuthResponse(**response_data)
return User(principal=auth_response.principal, attributes=auth_response.attributes) return User(principal=auth_response.principal, attributes=auth_response.attributes)
except Exception as e: except Exception as e:
logger.exception("Error parsing authentication response") log.exception("Error parsing authentication response")
raise ValueError("Invalid authentication response format") from e raise ValueError("Invalid authentication response format") from e
except httpx.TimeoutException: except httpx.TimeoutException:
logger.exception("Authentication request timed out") log.exception("Authentication request timed out")
raise raise
except ValueError: except ValueError:
# Re-raise ValueError exceptions to preserve their message # Re-raise ValueError exceptions to preserve their message
raise raise
except Exception as e: except Exception as e:
logger.exception("Error during authentication") log.exception("Error during authentication")
raise ValueError("Authentication service error") from e raise ValueError("Authentication service error") from e
async def close(self): async def close(self):
@ -329,7 +329,7 @@ class GitHubTokenAuthProvider(AuthProvider):
try: try:
user_info = await _get_github_user_info(token, self.config.github_api_base_url) user_info = await _get_github_user_info(token, self.config.github_api_base_url)
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
logger.warning(f"GitHub token validation failed: {e}") log.warning(f"GitHub token validation failed: {e}")
raise ValueError("GitHub token validation failed. Please check your token and try again.") from e raise ValueError("GitHub token validation failed. Please check your token and try again.") from e
principal = user_info["user"]["login"] principal = user_info["user"]["login"]

View file

@ -15,7 +15,7 @@ from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
logger = get_logger(name=__name__, category="quota") log = get_logger(name=__name__, category="quota")
class QuotaMiddleware: class QuotaMiddleware:
@ -46,7 +46,7 @@ class QuotaMiddleware:
self.window_seconds = window_seconds self.window_seconds = window_seconds
if isinstance(self.kv_config, SqliteKVStoreConfig): if isinstance(self.kv_config, SqliteKVStoreConfig):
logger.warning( log.warning(
"QuotaMiddleware: Using SQLite backend. Expiry/TTL is not enforced; cleanup is manual. " "QuotaMiddleware: Using SQLite backend. Expiry/TTL is not enforced; cleanup is manual. "
f"window_seconds={self.window_seconds}" f"window_seconds={self.window_seconds}"
) )
@ -84,11 +84,11 @@ class QuotaMiddleware:
else: else:
await kv.set(key, str(count)) await kv.set(key, str(count))
except Exception: except Exception:
logger.exception("Failed to access KV store for quota") log.exception("Failed to access KV store for quota")
return await self._send_error(send, 500, "Quota service error") return await self._send_error(send, 500, "Quota service error")
if count > limit: if count > limit:
logger.warning( log.warning(
"Quota exceeded for client %s: %d/%d", "Quota exceeded for client %s: %d/%d",
key_id, key_id,
count, count,

View file

@ -9,7 +9,7 @@ import asyncio
import functools import functools
import inspect import inspect
import json import json
import logging import logging # allow-direct-logging
import os import os
import ssl import ssl
import sys import sys
@ -80,7 +80,7 @@ from .quota import QuotaMiddleware
REPO_ROOT = Path(__file__).parent.parent.parent.parent REPO_ROOT = Path(__file__).parent.parent.parent.parent
logger = get_logger(name=__name__, category="server") log = get_logger(name=__name__, category="server")
def warn_with_traceback(message, category, filename, lineno, file=None, line=None): def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
@ -157,9 +157,9 @@ async def shutdown(app):
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
logger.info("Starting up") log.info("Starting up")
yield yield
logger.info("Shutting down") log.info("Shutting down")
await shutdown(app) await shutdown(app)
@ -182,11 +182,11 @@ async def sse_generator(event_gen_coroutine):
yield create_sse_event(item) yield create_sse_event(item)
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info("Generator cancelled") log.info("Generator cancelled")
if event_gen: if event_gen:
await event_gen.aclose() await event_gen.aclose()
except Exception as e: except Exception as e:
logger.exception("Error in sse_generator") log.exception("Error in sse_generator")
yield create_sse_event( yield create_sse_event(
{ {
"error": { "error": {
@ -206,11 +206,11 @@ async def log_request_pre_validation(request: Request):
log_output = rich.pretty.pretty_repr(parsed_body) log_output = rich.pretty.pretty_repr(parsed_body)
except (json.JSONDecodeError, UnicodeDecodeError): except (json.JSONDecodeError, UnicodeDecodeError):
log_output = repr(body_bytes) log_output = repr(body_bytes)
logger.debug(f"Incoming raw request body for {request.method} {request.url.path}:\n{log_output}") log.debug(f"Incoming raw request body for {request.method} {request.url.path}:\n{log_output}")
else: else:
logger.debug(f"Incoming {request.method} {request.url.path} request with empty body.") log.debug(f"Incoming {request.method} {request.url.path} request with empty body.")
except Exception as e: except Exception as e:
logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}") log.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable: def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
@ -238,10 +238,10 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
result.url = route result.url = route
return result return result
except Exception as e: except Exception as e:
if logger.isEnabledFor(logging.DEBUG): if log.isEnabledFor(logging.DEBUG):
logger.exception(f"Error executing endpoint {route=} {method=}") log.exception(f"Error executing endpoint {route=} {method=}")
else: else:
logger.error(f"Error executing endpoint {route=} {method=}: {str(e)}") log.error(f"Error executing endpoint {route=} {method=}: {str(e)}")
raise translate_exception(e) from e raise translate_exception(e) from e
sig = inspect.signature(func) sig = inspect.signature(func)
@ -291,7 +291,7 @@ class TracingMiddleware:
# Check if the path is a FastAPI built-in path # Check if the path is a FastAPI built-in path
if path.startswith(self.fastapi_paths): if path.startswith(self.fastapi_paths):
# Pass through to FastAPI's built-in handlers # Pass through to FastAPI's built-in handlers
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}") log.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
return await self.app(scope, receive, send) return await self.app(scope, receive, send)
if not hasattr(self, "route_impls"): if not hasattr(self, "route_impls"):
@ -303,7 +303,7 @@ class TracingMiddleware:
) )
except ValueError: except ValueError:
# If no matching endpoint is found, pass through to FastAPI # If no matching endpoint is found, pass through to FastAPI
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI") log.debug(f"No matching route found for path: {path}, falling back to FastAPI")
return await self.app(scope, receive, send) return await self.app(scope, receive, send)
trace_attributes = {"__location__": "server", "raw_path": path} trace_attributes = {"__location__": "server", "raw_path": path}
@ -404,15 +404,15 @@ def main(args: argparse.Namespace | None = None):
config_contents = yaml.safe_load(fp) config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")): if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg) logger_config = LoggingConfig(**cfg)
logger = get_logger(name=__name__, category="server", config=logger_config) log = get_logger(name=__name__, category="server", config=logger_config)
if args.env: if args.env:
for env_pair in args.env: for env_pair in args.env:
try: try:
key, value = validate_env_pair(env_pair) key, value = validate_env_pair(env_pair)
logger.info(f"Setting CLI environment variable {key} => {value}") log.info(f"Setting CLI environment variable {key} => {value}")
os.environ[key] = value os.environ[key] = value
except ValueError as e: except ValueError as e:
logger.error(f"Error: {str(e)}") log.error(f"Error: {str(e)}")
sys.exit(1) sys.exit(1)
config = replace_env_vars(config_contents) config = replace_env_vars(config_contents)
config = StackRunConfig(**cast_image_name_to_string(config)) config = StackRunConfig(**cast_image_name_to_string(config))
@ -438,16 +438,16 @@ def main(args: argparse.Namespace | None = None):
impls = loop.run_until_complete(construct_stack(config)) impls = loop.run_until_complete(construct_stack(config))
except InvalidProviderError as e: except InvalidProviderError as e:
logger.error(f"Error: {str(e)}") log.error(f"Error: {str(e)}")
sys.exit(1) sys.exit(1)
if config.server.auth: if config.server.auth:
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}") log.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth, impls=impls) app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth, impls=impls)
else: else:
if config.server.quota: if config.server.quota:
quota = config.server.quota quota = config.server.quota
logger.warning( log.warning(
"Configured authenticated_max_requests (%d) but no auth is enabled; " "Configured authenticated_max_requests (%d) but no auth is enabled; "
"falling back to anonymous_max_requests (%d) for all the requests", "falling back to anonymous_max_requests (%d) for all the requests",
quota.authenticated_max_requests, quota.authenticated_max_requests,
@ -455,7 +455,7 @@ def main(args: argparse.Namespace | None = None):
) )
if config.server.quota: if config.server.quota:
logger.info("Enabling quota middleware for authenticated and anonymous clients") log.info("Enabling quota middleware for authenticated and anonymous clients")
quota = config.server.quota quota = config.server.quota
anonymous_max_requests = quota.anonymous_max_requests anonymous_max_requests = quota.anonymous_max_requests
@ -516,7 +516,7 @@ def main(args: argparse.Namespace | None = None):
if not available_methods: if not available_methods:
raise ValueError(f"No methods found for {route.name} on {impl}") raise ValueError(f"No methods found for {route.name} on {impl}")
method = available_methods[0] method = available_methods[0]
logger.debug(f"{method} {route.path}") log.debug(f"{method} {route.path}")
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields") warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
@ -528,7 +528,7 @@ def main(args: argparse.Namespace | None = None):
) )
) )
logger.debug(f"serving APIs: {apis_to_serve}") log.debug(f"serving APIs: {apis_to_serve}")
app.exception_handler(RequestValidationError)(global_exception_handler) app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler)
@ -553,21 +553,21 @@ def main(args: argparse.Namespace | None = None):
if config.server.tls_cafile: if config.server.tls_cafile:
ssl_config["ssl_ca_certs"] = config.server.tls_cafile ssl_config["ssl_ca_certs"] = config.server.tls_cafile
ssl_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED ssl_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED
logger.info( log.info(
f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}\n CA: {config.server.tls_cafile}" f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}\n CA: {config.server.tls_cafile}"
) )
else: else:
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}") log.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
listen_host = config.server.host or ["::", "0.0.0.0"] listen_host = config.server.host or ["::", "0.0.0.0"]
logger.info(f"Listening on {listen_host}:{port}") log.info(f"Listening on {listen_host}:{port}")
uvicorn_config = { uvicorn_config = {
"app": app, "app": app,
"host": listen_host, "host": listen_host,
"port": port, "port": port,
"lifespan": "on", "lifespan": "on",
"log_level": logger.getEffectiveLevel(), "log_level": log.getEffectiveLevel(),
"log_config": logger_config, "log_config": logger_config,
} }
if ssl_config: if ssl_config:
@ -586,19 +586,19 @@ def main(args: argparse.Namespace | None = None):
try: try:
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve()) loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
except (KeyboardInterrupt, SystemExit): except (KeyboardInterrupt, SystemExit):
logger.info("Received interrupt signal, shutting down gracefully...") log.info("Received interrupt signal, shutting down gracefully...")
finally: finally:
if not loop.is_closed(): if not loop.is_closed():
logger.debug("Closing event loop") log.debug("Closing event loop")
loop.close() loop.close()
def _log_run_config(run_config: StackRunConfig): def _log_run_config(run_config: StackRunConfig):
"""Logs the run config with redacted fields and disabled providers removed.""" """Logs the run config with redacted fields and disabled providers removed."""
logger.info("Run configuration:") log.info("Run configuration:")
safe_config = redact_sensitive_fields(run_config.model_dump(mode="json")) safe_config = redact_sensitive_fields(run_config.model_dump(mode="json"))
clean_config = remove_disabled_providers(safe_config) clean_config = remove_disabled_providers(safe_config)
logger.info(yaml.dump(clean_config, indent=2)) log.info(yaml.dump(clean_config, indent=2))
def extract_path_params(route: str) -> list[str]: def extract_path_params(route: str) -> list[str]:

View file

@ -45,7 +45,7 @@ from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
logger = get_logger(name=__name__, category="core") log = get_logger(name=__name__, category="core")
class LlamaStack( class LlamaStack(
@ -105,11 +105,11 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
method = getattr(impls[api], register_method) method = getattr(impls[api], register_method)
for obj in objects: for obj in objects:
logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}") log.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
# Do not register models on disabled providers # Do not register models on disabled providers
if hasattr(obj, "provider_id") and (not obj.provider_id or obj.provider_id == "__disabled__"): if hasattr(obj, "provider_id") and (not obj.provider_id or obj.provider_id == "__disabled__"):
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.") log.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
continue continue
# we want to maintain the type information in arguments to method. # we want to maintain the type information in arguments to method.
@ -123,7 +123,7 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
objects_to_process = response.data if hasattr(response, "data") else response objects_to_process = response.data if hasattr(response, "data") else response
for obj in objects_to_process: for obj in objects_to_process:
logger.debug( log.debug(
f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}", f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}",
) )
@ -160,7 +160,7 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
try: try:
resolved_provider_id = replace_env_vars(v["provider_id"], f"{path}[{i}].provider_id") resolved_provider_id = replace_env_vars(v["provider_id"], f"{path}[{i}].provider_id")
if resolved_provider_id == "__disabled__": if resolved_provider_id == "__disabled__":
logger.debug( log.debug(
f"Skipping config env variable expansion for disabled provider: {v.get('provider_id', '')}" f"Skipping config env variable expansion for disabled provider: {v.get('provider_id', '')}"
) )
# Create a copy with resolved provider_id but original config # Create a copy with resolved provider_id but original config
@ -315,7 +315,7 @@ async def construct_stack(
TEST_RECORDING_CONTEXT = setup_inference_recording() TEST_RECORDING_CONTEXT = setup_inference_recording()
if TEST_RECORDING_CONTEXT: if TEST_RECORDING_CONTEXT:
TEST_RECORDING_CONTEXT.__enter__() TEST_RECORDING_CONTEXT.__enter__()
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}") log.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name) dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
policy = run_config.server.auth.access_policy if run_config.server.auth else [] policy = run_config.server.auth.access_policy if run_config.server.auth else []
@ -337,12 +337,12 @@ async def construct_stack(
import traceback import traceback
if task.cancelled(): if task.cancelled():
logger.error("Model refresh task cancelled") log.error("Model refresh task cancelled")
elif task.exception(): elif task.exception():
logger.error(f"Model refresh task failed: {task.exception()}") log.error(f"Model refresh task failed: {task.exception()}")
traceback.print_exception(task.exception()) traceback.print_exception(task.exception())
else: else:
logger.debug("Model refresh task completed") log.debug("Model refresh task completed")
REGISTRY_REFRESH_TASK.add_done_callback(cb) REGISTRY_REFRESH_TASK.add_done_callback(cb)
return impls return impls
@ -351,23 +351,23 @@ async def construct_stack(
async def shutdown_stack(impls: dict[Api, Any]): async def shutdown_stack(impls: dict[Api, Any]):
for impl in impls.values(): for impl in impls.values():
impl_name = impl.__class__.__name__ impl_name = impl.__class__.__name__
logger.info(f"Shutting down {impl_name}") log.info(f"Shutting down {impl_name}")
try: try:
if hasattr(impl, "shutdown"): if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5) await asyncio.wait_for(impl.shutdown(), timeout=5)
else: else:
logger.warning(f"No shutdown method for {impl_name}") log.warning(f"No shutdown method for {impl_name}")
except TimeoutError: except TimeoutError:
logger.exception(f"Shutdown timeout for {impl_name}") log.exception(f"Shutdown timeout for {impl_name}")
except (Exception, asyncio.CancelledError) as e: except (Exception, asyncio.CancelledError) as e:
logger.exception(f"Failed to shutdown {impl_name}: {e}") log.exception(f"Failed to shutdown {impl_name}: {e}")
global TEST_RECORDING_CONTEXT global TEST_RECORDING_CONTEXT
if TEST_RECORDING_CONTEXT: if TEST_RECORDING_CONTEXT:
try: try:
TEST_RECORDING_CONTEXT.__exit__(None, None, None) TEST_RECORDING_CONTEXT.__exit__(None, None, None)
except Exception as e: except Exception as e:
logger.error(f"Error during inference recording cleanup: {e}") log.error(f"Error during inference recording cleanup: {e}")
global REGISTRY_REFRESH_TASK global REGISTRY_REFRESH_TASK
if REGISTRY_REFRESH_TASK: if REGISTRY_REFRESH_TASK:
@ -375,14 +375,14 @@ async def shutdown_stack(impls: dict[Api, Any]):
async def refresh_registry_once(impls: dict[Api, Any]): async def refresh_registry_once(impls: dict[Api, Any]):
logger.debug("refreshing registry") log.debug("refreshing registry")
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)] routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
for routing_table in routing_tables: for routing_table in routing_tables:
await routing_table.refresh() await routing_table.refresh()
async def refresh_registry_task(impls: dict[Api, Any]): async def refresh_registry_task(impls: dict[Api, Any]):
logger.info("starting registry refresh task") log.info("starting registry refresh task")
while True: while True:
await refresh_registry_once(impls) await refresh_registry_once(impls)

View file

@ -10,7 +10,7 @@ from pathlib import Path
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.log import get_logger from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="config_resolution") log = get_logger(name=__name__, category="config_resolution")
DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions" DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions"
@ -42,25 +42,25 @@ def resolve_config_or_distro(
# Strategy 1: Try as file path first # Strategy 1: Try as file path first
config_path = Path(config_or_distro) config_path = Path(config_or_distro)
if config_path.exists() and config_path.is_file(): if config_path.exists() and config_path.is_file():
logger.info(f"Using file path: {config_path}") log.info(f"Using file path: {config_path}")
return config_path.resolve() return config_path.resolve()
# Strategy 2: Try as distribution name (if no .yaml extension) # Strategy 2: Try as distribution name (if no .yaml extension)
if not config_or_distro.endswith(".yaml"): if not config_or_distro.endswith(".yaml"):
distro_config = _get_distro_config_path(config_or_distro, mode) distro_config = _get_distro_config_path(config_or_distro, mode)
if distro_config.exists(): if distro_config.exists():
logger.info(f"Using distribution: {distro_config}") log.info(f"Using distribution: {distro_config}")
return distro_config return distro_config
# Strategy 3: Try as built distribution name # Strategy 3: Try as built distribution name
distrib_config = DISTRIBS_BASE_DIR / f"llamastack-{config_or_distro}" / f"{config_or_distro}-{mode}.yaml" distrib_config = DISTRIBS_BASE_DIR / f"llamastack-{config_or_distro}" / f"{config_or_distro}-{mode}.yaml"
if distrib_config.exists(): if distrib_config.exists():
logger.info(f"Using built distribution: {distrib_config}") log.info(f"Using built distribution: {distrib_config}")
return distrib_config return distrib_config
distrib_config = DISTRIBS_BASE_DIR / f"{config_or_distro}" / f"{config_or_distro}-{mode}.yaml" distrib_config = DISTRIBS_BASE_DIR / f"{config_or_distro}" / f"{config_or_distro}-{mode}.yaml"
if distrib_config.exists(): if distrib_config.exists():
logger.info(f"Using built distribution: {distrib_config}") log.info(f"Using built distribution: {distrib_config}")
return distrib_config return distrib_config
# Strategy 4: Failed - provide helpful error # Strategy 4: Failed - provide helpful error

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging import importlib
import os import os
import signal import signal
import subprocess import subprocess
@ -12,9 +12,9 @@ import sys
from termcolor import cprint from termcolor import cprint
log = logging.getLogger(__name__) from llama_stack.log import get_logger
import importlib log = get_logger(name=__name__, category="core")
def formulate_run_args(image_type: str, image_name: str) -> list: def formulate_run_args(image_type: str, image_name: str) -> list:

View file

@ -6,7 +6,6 @@
import inspect import inspect
import json import json
import logging
from enum import Enum from enum import Enum
from typing import Annotated, Any, Literal, Union, get_args, get_origin from typing import Annotated, Any, Literal, Union, get_args, get_origin
@ -14,7 +13,9 @@ from pydantic import BaseModel
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefinedType from pydantic_core import PydanticUndefinedType
log = logging.getLogger(__name__) from llama_stack.log import get_logger
log = get_logger(name=__name__, category="core")
def is_list_of_primitives(field_type): def is_list_of_primitives(field_type):

View file

@ -4,11 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging import logging # allow-direct-logging
import os import os
import re import re
import sys import sys
from logging.config import dictConfig from logging.config import dictConfig # allow-direct-logging
from rich.console import Console from rich.console import Console
from rich.errors import MarkupError from rich.errors import MarkupError

View file

@ -13,7 +13,7 @@
# Copyright (c) Meta Platforms, Inc. and its affiliates. # Copyright (c) Meta Platforms, Inc. and its affiliates.
import math import math
from logging import getLogger from logging import getLogger # allow-direct-logging
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F

View file

@ -13,7 +13,7 @@
import math import math
from collections import defaultdict from collections import defaultdict
from logging import getLogger from logging import getLogger # allow-direct-logging
from typing import Any from typing import Any
import torch import torch

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
import math import math
from collections.abc import Callable from collections.abc import Callable
from functools import partial from functools import partial
@ -22,6 +21,8 @@ from PIL import Image as PIL_Image
from torch import Tensor, nn from torch import Tensor, nn
from torch.distributed import _functional_collectives as funcol from torch.distributed import _functional_collectives as funcol
from llama_stack.log import get_logger
from ..model import ModelArgs, RMSNorm, apply_rotary_emb, precompute_freqs_cis from ..model import ModelArgs, RMSNorm, apply_rotary_emb, precompute_freqs_cis
from .encoder_utils import ( from .encoder_utils import (
build_encoder_attention_mask, build_encoder_attention_mask,
@ -34,9 +35,10 @@ from .encoder_utils import (
from .image_transform import VariableSizeImageTransform from .image_transform import VariableSizeImageTransform
from .utils import get_negative_inf_value, to_2tuple from .utils import get_negative_inf_value, to_2tuple
logger = logging.getLogger(__name__)
MP_SCALE = 8 MP_SCALE = 8
log = get_logger(name=__name__, category="core")
def reduce_from_tensor_model_parallel_region(input_): def reduce_from_tensor_model_parallel_region(input_):
"""All-reduce the input tensor across model parallel group.""" """All-reduce the input tensor across model parallel group."""
@ -415,7 +417,7 @@ class VisionEncoder(nn.Module):
) )
state_dict[prefix + "gated_positional_embedding"] = global_pos_embed state_dict[prefix + "gated_positional_embedding"] = global_pos_embed
state_dict[prefix + "gated_positional_embedding_gate"] = torch.zeros(1, dtype=global_pos_embed.dtype) state_dict[prefix + "gated_positional_embedding_gate"] = torch.zeros(1, dtype=global_pos_embed.dtype)
logger.info(f"Initialized global positional embedding with size {global_pos_embed.size()}") log.info(f"Initialized global positional embedding with size {global_pos_embed.size()}")
else: else:
global_pos_embed = resize_global_position_embedding( global_pos_embed = resize_global_position_embedding(
state_dict[prefix + "gated_positional_embedding"], state_dict[prefix + "gated_positional_embedding"],
@ -423,7 +425,7 @@ class VisionEncoder(nn.Module):
self.max_num_tiles, self.max_num_tiles,
self.max_num_tiles, self.max_num_tiles,
) )
logger.info( log.info(
f"Resized global positional embedding from {state_dict[prefix + 'gated_positional_embedding'].size()} to {global_pos_embed.size()}" f"Resized global positional embedding from {state_dict[prefix + 'gated_positional_embedding'].size()} to {global_pos_embed.size()}"
) )
state_dict[prefix + "gated_positional_embedding"] = global_pos_embed state_dict[prefix + "gated_positional_embedding"] = global_pos_embed
@ -771,7 +773,7 @@ class TilePositionEmbedding(nn.Module):
if embed is not None: if embed is not None:
# reshape the weights to the correct shape # reshape the weights to the correct shape
nt_old, nt_old, _, w = embed.shape nt_old, nt_old, _, w = embed.shape
logging.info(f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}") log.info(f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}")
embed_new = TilePositionEmbedding._dynamic_resize(embed, self.num_tiles) embed_new = TilePositionEmbedding._dynamic_resize(embed, self.num_tiles)
# assign the weights to the module # assign the weights to the module
state_dict[prefix + "embedding"] = embed_new state_dict[prefix + "embedding"] = embed_new

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import Collection, Iterator, Sequence, Set from collections.abc import Collection, Iterator, Sequence, Set
from logging import getLogger from logging import getLogger # allow-direct-logging
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
Literal, Literal,

View file

@ -11,7 +11,7 @@ from llama_stack.log import get_logger
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
logger = get_logger(name=__name__, category="inference") log = get_logger(name=__name__, category="inference")
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)' BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})") CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
@ -215,7 +215,7 @@ class ToolUtils:
# FIXME: Enable multiple tool calls # FIXME: Enable multiple tool calls
return function_calls[0] return function_calls[0]
else: else:
logger.debug(f"Did not parse tool call from message body: {message_body}") log.debug(f"Did not parse tool call from message body: {message_body}")
return None return None
@staticmethod @staticmethod

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
import os import os
from collections.abc import Callable from collections.abc import Callable
@ -13,11 +12,13 @@ from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
from torch import Tensor, nn from torch import Tensor, nn
from torch.nn import functional as F from torch.nn import functional as F
from llama_stack.log import get_logger
from ...datatypes import QuantizationMode from ...datatypes import QuantizationMode
from ..model import Transformer, TransformerBlock from ..model import Transformer, TransformerBlock
from ..moe import MoE from ..moe import MoE
log = logging.getLogger(__name__) logger = get_logger(__name__, category="core")
def swiglu_wrapper_no_reduce( def swiglu_wrapper_no_reduce(
@ -186,7 +187,7 @@ def logging_callbacks(
if use_rich_progress: if use_rich_progress:
console.print(message) console.print(message)
elif rank == 0: # Only log from rank 0 for non-rich logging elif rank == 0: # Only log from rank 0 for non-rich logging
log.info(message) logger.info(message)
total_blocks = sum(1 for _, block in model.named_modules() if should_quantize_block(block)) total_blocks = sum(1 for _, block in model.named_modules() if should_quantize_block(block))
progress = None progress = None
@ -220,6 +221,6 @@ def logging_callbacks(
if completed is not None: if completed is not None:
progress.update(task_id, completed=completed) progress.update(task_id, completed=completed)
elif rank == 0 and completed and completed % 10 == 0: elif rank == 0 and completed and completed % 10 == 0:
log.info(f"Rank {rank}: {completed}/{total_blocks} blocks completed") logger.info(f"Rank {rank}: {completed}/{total_blocks} blocks completed")
return progress, log_status, update_status return progress, log_status, update_status

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import Collection, Iterator, Sequence, Set from collections.abc import Collection, Iterator, Sequence, Set
from logging import getLogger from logging import getLogger # allow-direct-logging
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
Literal, Literal,

View file

@ -6,16 +6,17 @@
# type: ignore # type: ignore
import collections import collections
import logging
log = logging.getLogger(__name__) from llama_stack.log import get_logger
logger = get_logger(__name__, category="core")
try: try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401 import fbgemm_gpu.experimental.gen_ai # noqa: F401
log.info("Using efficient FP8 or INT4 operators in FBGEMM.") logger.info("Using efficient FP8 or INT4 operators in FBGEMM.")
except ImportError: except ImportError:
log.error("No efficient FP8 or INT4 operators. Please install FBGEMM.") logger.error("No efficient FP8 or INT4 operators. Please install FBGEMM.")
raise raise
import torch import torch

View file

@ -84,7 +84,7 @@ MEMORY_QUERY_TOOL = "knowledge_search"
WEB_SEARCH_TOOL = "web_search" WEB_SEARCH_TOOL = "web_search"
RAG_TOOL_GROUP = "builtin::rag" RAG_TOOL_GROUP = "builtin::rag"
logger = get_logger(name=__name__, category="agents") log = get_logger(name=__name__, category="agents")
class ChatAgent(ShieldRunnerMixin): class ChatAgent(ShieldRunnerMixin):
@ -612,7 +612,7 @@ class ChatAgent(ShieldRunnerMixin):
) )
if n_iter >= self.agent_config.max_infer_iters: if n_iter >= self.agent_config.max_infer_iters:
logger.info(f"done with MAX iterations ({n_iter}), exiting.") log.info(f"done with MAX iterations ({n_iter}), exiting.")
# NOTE: mark end_of_turn to indicate to client that we are done with the turn # NOTE: mark end_of_turn to indicate to client that we are done with the turn
# Do not continue the tool call loop after this point # Do not continue the tool call loop after this point
message.stop_reason = StopReason.end_of_turn message.stop_reason = StopReason.end_of_turn
@ -620,7 +620,7 @@ class ChatAgent(ShieldRunnerMixin):
break break
if stop_reason == StopReason.out_of_tokens: if stop_reason == StopReason.out_of_tokens:
logger.info("out of token budget, exiting.") log.info("out of token budget, exiting.")
yield message yield message
break break
@ -634,7 +634,7 @@ class ChatAgent(ShieldRunnerMixin):
message.content = [message.content] + output_attachments message.content = [message.content] + output_attachments
yield message yield message
else: else:
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}") log.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
input_messages = input_messages + [message] input_messages = input_messages + [message]
else: else:
input_messages = input_messages + [message] input_messages = input_messages + [message]
@ -889,7 +889,7 @@ class ChatAgent(ShieldRunnerMixin):
else: else:
tool_name_str = tool_name tool_name_str = tool_name
logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}") log.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}")
result = await self.tool_runtime_api.invoke_tool( result = await self.tool_runtime_api.invoke_tool(
tool_name=tool_name_str, tool_name=tool_name_str,
kwargs={ kwargs={
@ -899,7 +899,7 @@ class ChatAgent(ShieldRunnerMixin):
**self.tool_name_to_args.get(tool_name_str, {}), **self.tool_name_to_args.get(tool_name_str, {}),
}, },
) )
logger.debug(f"tool call {tool_name_str} completed with result: {result}") log.debug(f"tool call {tool_name_str} completed with result: {result}")
return result return result

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
import uuid import uuid
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from datetime import UTC, datetime from datetime import UTC, datetime
@ -42,6 +41,7 @@ from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.datatypes import AccessRule from llama_stack.core.datatypes import AccessRule
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from llama_stack.providers.utils.pagination import paginate_records from llama_stack.providers.utils.pagination import paginate_records
from llama_stack.providers.utils.responses.responses_store import ResponsesStore from llama_stack.providers.utils.responses.responses_store import ResponsesStore
@ -51,7 +51,7 @@ from .config import MetaReferenceAgentsImplConfig
from .openai_responses import OpenAIResponsesImpl from .openai_responses import OpenAIResponsesImpl
from .persistence import AgentInfo from .persistence import AgentInfo
logger = logging.getLogger() log = get_logger(name=__name__, category="agents")
class MetaReferenceAgentsImpl(Agents): class MetaReferenceAgentsImpl(Agents):
@ -268,7 +268,7 @@ class MetaReferenceAgentsImpl(Agents):
# Get the agent info using the key # Get the agent info using the key
agent_info_json = await self.persistence_store.get(agent_key) agent_info_json = await self.persistence_store.get(agent_key)
if not agent_info_json: if not agent_info_json:
logger.error(f"Could not find agent info for key {agent_key}") log.error(f"Could not find agent info for key {agent_key}")
continue continue
try: try:
@ -281,7 +281,7 @@ class MetaReferenceAgentsImpl(Agents):
) )
) )
except Exception as e: except Exception as e:
logger.error(f"Error parsing agent info for {agent_id}: {e}") log.error(f"Error parsing agent info for {agent_id}: {e}")
continue continue
# Convert Agent objects to dictionaries # Convert Agent objects to dictionaries

View file

@ -75,7 +75,7 @@ from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefiniti
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
from llama_stack.providers.utils.responses.responses_store import ResponsesStore from llama_stack.providers.utils.responses.responses_store import ResponsesStore
logger = get_logger(name=__name__, category="openai_responses") log = get_logger(name=__name__, category="openai_responses")
OPENAI_RESPONSES_PREFIX = "openai_responses:" OPENAI_RESPONSES_PREFIX = "openai_responses:"
@ -544,12 +544,12 @@ class OpenAIResponsesImpl:
break break
if function_tool_calls: if function_tool_calls:
logger.info("Exiting inference loop since there is a function (client-side) tool call") log.info("Exiting inference loop since there is a function (client-side) tool call")
break break
n_iter += 1 n_iter += 1
if n_iter >= max_infer_iters: if n_iter >= max_infer_iters:
logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {max_infer_iters=}") log.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {max_infer_iters=}")
break break
messages = next_turn_messages messages = next_turn_messages
@ -698,7 +698,7 @@ class OpenAIResponsesImpl:
) )
return search_response.data return search_response.data
except Exception as e: except Exception as e:
logger.warning(f"Failed to search vector store {vector_store_id}: {e}") log.warning(f"Failed to search vector store {vector_store_id}: {e}")
return [] return []
# Run all searches in parallel using gather # Run all searches in parallel using gather

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
import logging
import uuid import uuid
from datetime import UTC, datetime from datetime import UTC, datetime
@ -15,9 +14,10 @@ from llama_stack.core.access_control.access_control import AccessDeniedError, is
from llama_stack.core.access_control.datatypes import AccessRule from llama_stack.core.access_control.datatypes import AccessRule
from llama_stack.core.datatypes import User from llama_stack.core.datatypes import User
from llama_stack.core.request_headers import get_authenticated_user from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__) log = get_logger(name=__name__, category="agents")
class AgentSessionInfo(Session): class AgentSessionInfo(Session):

View file

@ -5,13 +5,13 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import logging
from llama_stack.apis.inference import Message from llama_stack.apis.inference import Message
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry import tracing from llama_stack.providers.utils.telemetry import tracing
log = logging.getLogger(__name__) log = get_logger(name=__name__, category="agents")
class SafetyException(Exception): # noqa: N818 class SafetyException(Exception): # noqa: N818

View file

@ -73,11 +73,12 @@ from .config import MetaReferenceInferenceConfig
from .generators import LlamaGenerator from .generators import LlamaGenerator
from .model_parallel import LlamaModelParallelGenerator from .model_parallel import LlamaModelParallelGenerator
log = get_logger(__name__, category="inference")
# there's a single model parallel process running serving the model. for now, # there's a single model parallel process running serving the model. for now,
# we don't support multiple concurrent requests to this process. # we don't support multiple concurrent requests to this process.
SEMAPHORE = asyncio.Semaphore(1) SEMAPHORE = asyncio.Semaphore(1)
logger = get_logger(__name__, category="inference")
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator: def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
return LlamaGenerator(config, model_id, llama_model) return LlamaGenerator(config, model_id, llama_model)
@ -144,7 +145,7 @@ class MetaReferenceInferenceImpl(
return model return model
async def load_model(self, model_id, llama_model) -> None: async def load_model(self, model_id, llama_model) -> None:
log.info(f"Loading model `{model_id}`") logger.info(f"Loading model `{model_id}`")
builder_params = [self.config, model_id, llama_model] builder_params = [self.config, model_id, llama_model]
@ -166,7 +167,7 @@ class MetaReferenceInferenceImpl(
self.model_id = model_id self.model_id = model_id
self.llama_model = llama_model self.llama_model = llama_model
log.info("Warming up...") logger.info("Warming up...")
await self.completion( await self.completion(
model_id=model_id, model_id=model_id,
content="Hello, world!", content="Hello, world!",
@ -177,7 +178,7 @@ class MetaReferenceInferenceImpl(
messages=[UserMessage(content="Hi how are you?")], messages=[UserMessage(content="Hi how are you?")],
sampling_params=SamplingParams(max_tokens=20), sampling_params=SamplingParams(max_tokens=20),
) )
log.info("Warmed up!") logger.info("Warmed up!")
def check_model(self, request) -> None: def check_model(self, request) -> None:
if self.model_id is None or self.llama_model is None: if self.model_id is None or self.llama_model is None:

View file

@ -12,7 +12,6 @@
import copy import copy
import json import json
import logging
import multiprocessing import multiprocessing
import os import os
import tempfile import tempfile
@ -32,13 +31,14 @@ from fairscale.nn.model_parallel.initialize import (
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from torch.distributed.launcher.api import LaunchConfig, elastic_launch from torch.distributed.launcher.api import LaunchConfig, elastic_launch
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import GenerationResult from llama_stack.models.llama.datatypes import GenerationResult
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent, ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent, CompletionRequestWithRawContent,
) )
log = logging.getLogger(__name__) log = get_logger(name=__name__, category="inference")
class ProcessingMessageName(str, Enum): class ProcessingMessageName(str, Enum):
@ -236,7 +236,7 @@ def worker_process_entrypoint(
except StopIteration: except StopIteration:
break break
log.info("[debug] worker process done") log.info("[debug] worker process done")
def launch_dist_group( def launch_dist_group(

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
@ -32,8 +31,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
from .config import SentenceTransformersInferenceConfig from .config import SentenceTransformersInferenceConfig
log = logging.getLogger(__name__)
class SentenceTransformersInferenceImpl( class SentenceTransformersInferenceImpl(
OpenAIChatCompletionToLlamaStackMixin, OpenAIChatCompletionToLlamaStackMixin,

View file

@ -6,7 +6,6 @@
import gc import gc
import json import json
import logging
import multiprocessing import multiprocessing
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@ -28,6 +27,7 @@ from llama_stack.apis.post_training import (
LoraFinetuningConfig, LoraFinetuningConfig,
TrainingConfig, TrainingConfig,
) )
from llama_stack.log import get_logger
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from ..config import HuggingFacePostTrainingConfig from ..config import HuggingFacePostTrainingConfig
@ -44,7 +44,7 @@ from ..utils import (
split_dataset, split_dataset,
) )
logger = logging.getLogger(__name__) log = get_logger(name=__name__, category="core")
class HFFinetuningSingleDevice: class HFFinetuningSingleDevice:
@ -69,14 +69,14 @@ class HFFinetuningSingleDevice:
try: try:
messages = json.loads(row["chat_completion_input"]) messages = json.loads(row["chat_completion_input"])
if not isinstance(messages, list) or len(messages) != 1: if not isinstance(messages, list) or len(messages) != 1:
logger.warning(f"Invalid chat_completion_input format: {row['chat_completion_input']}") log.warning(f"Invalid chat_completion_input format: {row['chat_completion_input']}")
return None, None return None, None
if "content" not in messages[0]: if "content" not in messages[0]:
logger.warning(f"Message missing content: {messages[0]}") log.warning(f"Message missing content: {messages[0]}")
return None, None return None, None
return messages[0]["content"], row["expected_answer"] return messages[0]["content"], row["expected_answer"]
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning(f"Failed to parse chat_completion_input: {row['chat_completion_input']}") log.warning(f"Failed to parse chat_completion_input: {row['chat_completion_input']}")
return None, None return None, None
return None, None return None, None
@ -86,13 +86,13 @@ class HFFinetuningSingleDevice:
try: try:
dialog = json.loads(row["dialog"]) dialog = json.loads(row["dialog"])
if not isinstance(dialog, list) or len(dialog) < 2: if not isinstance(dialog, list) or len(dialog) < 2:
logger.warning(f"Dialog must have at least 2 messages: {row['dialog']}") log.warning(f"Dialog must have at least 2 messages: {row['dialog']}")
return None, None return None, None
if dialog[0].get("role") != "user": if dialog[0].get("role") != "user":
logger.warning(f"First message must be from user: {dialog[0]}") log.warning(f"First message must be from user: {dialog[0]}")
return None, None return None, None
if not any(msg.get("role") == "assistant" for msg in dialog): if not any(msg.get("role") == "assistant" for msg in dialog):
logger.warning("Dialog must have at least one assistant message") log.warning("Dialog must have at least one assistant message")
return None, None return None, None
# Convert to human/gpt format # Convert to human/gpt format
@ -100,14 +100,14 @@ class HFFinetuningSingleDevice:
conversations = [] conversations = []
for msg in dialog: for msg in dialog:
if "role" not in msg or "content" not in msg: if "role" not in msg or "content" not in msg:
logger.warning(f"Message missing role or content: {msg}") log.warning(f"Message missing role or content: {msg}")
continue continue
conversations.append({"from": role_map[msg["role"]], "value": msg["content"]}) conversations.append({"from": role_map[msg["role"]], "value": msg["content"]})
# Format as a single conversation # Format as a single conversation
return conversations[0]["value"], conversations[1]["value"] return conversations[0]["value"], conversations[1]["value"]
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning(f"Failed to parse dialog: {row['dialog']}") log.warning(f"Failed to parse dialog: {row['dialog']}")
return None, None return None, None
return None, None return None, None
@ -198,7 +198,7 @@ class HFFinetuningSingleDevice:
""" """
import asyncio import asyncio
logger.info("Starting training process with async wrapper") log.info("Starting training process with async wrapper")
asyncio.run( asyncio.run(
self._run_training( self._run_training(
model=model, model=model,
@ -228,14 +228,14 @@ class HFFinetuningSingleDevice:
raise ValueError("DataConfig is required for training") raise ValueError("DataConfig is required for training")
# Load dataset # Load dataset
logger.info(f"Loading dataset: {config.data_config.dataset_id}") log.info(f"Loading dataset: {config.data_config.dataset_id}")
rows = await load_rows_from_dataset(self.datasetio_api, config.data_config.dataset_id) rows = await load_rows_from_dataset(self.datasetio_api, config.data_config.dataset_id)
if not self.validate_dataset_format(rows): if not self.validate_dataset_format(rows):
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input") raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
logger.info(f"Loaded {len(rows)} rows from dataset") log.info(f"Loaded {len(rows)} rows from dataset")
# Initialize tokenizer # Initialize tokenizer
logger.info(f"Initializing tokenizer for model: {model}") log.info(f"Initializing tokenizer for model: {model}")
try: try:
tokenizer = AutoTokenizer.from_pretrained(model, **provider_config.model_specific_config) tokenizer = AutoTokenizer.from_pretrained(model, **provider_config.model_specific_config)
@ -257,16 +257,16 @@ class HFFinetuningSingleDevice:
# This ensures consistent sequence lengths across the training process # This ensures consistent sequence lengths across the training process
tokenizer.model_max_length = provider_config.max_seq_length tokenizer.model_max_length = provider_config.max_seq_length
logger.info("Tokenizer initialized successfully") log.info("Tokenizer initialized successfully")
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to initialize tokenizer: {str(e)}") from e raise RuntimeError(f"Failed to initialize tokenizer: {str(e)}") from e
# Create and preprocess dataset # Create and preprocess dataset
logger.info("Creating and preprocessing dataset") log.info("Creating and preprocessing dataset")
try: try:
ds = self._create_dataset(rows, config, provider_config) ds = self._create_dataset(rows, config, provider_config)
ds = self._preprocess_dataset(ds, tokenizer, provider_config) ds = self._preprocess_dataset(ds, tokenizer, provider_config)
logger.info(f"Dataset created with {len(ds)} examples") log.info(f"Dataset created with {len(ds)} examples")
except Exception as e: except Exception as e:
raise ValueError(f"Failed to create dataset: {str(e)}") from e raise ValueError(f"Failed to create dataset: {str(e)}") from e
@ -293,11 +293,11 @@ class HFFinetuningSingleDevice:
Returns: Returns:
Configured SFTConfig object Configured SFTConfig object
""" """
logger.info("Configuring training arguments") log.info("Configuring training arguments")
lr = 2e-5 lr = 2e-5
if config.optimizer_config: if config.optimizer_config:
lr = config.optimizer_config.lr lr = config.optimizer_config.lr
logger.info(f"Using custom learning rate: {lr}") log.info(f"Using custom learning rate: {lr}")
# Validate data config # Validate data config
if not config.data_config: if not config.data_config:
@ -350,17 +350,17 @@ class HFFinetuningSingleDevice:
peft_config: Optional LoRA configuration peft_config: Optional LoRA configuration
output_dir_path: Path to save the model output_dir_path: Path to save the model
""" """
logger.info("Saving final model") log.info("Saving final model")
model_obj.config.use_cache = True model_obj.config.use_cache = True
if peft_config: if peft_config:
logger.info("Merging LoRA weights with base model") log.info("Merging LoRA weights with base model")
model_obj = trainer.model.merge_and_unload() model_obj = trainer.model.merge_and_unload()
else: else:
model_obj = trainer.model model_obj = trainer.model
save_path = output_dir_path / "merged_model" save_path = output_dir_path / "merged_model"
logger.info(f"Saving model to {save_path}") log.info(f"Saving model to {save_path}")
model_obj.save_pretrained(save_path) model_obj.save_pretrained(save_path)
async def _run_training( async def _run_training(
@ -380,13 +380,13 @@ class HFFinetuningSingleDevice:
setup_signal_handlers() setup_signal_handlers()
# Convert config dicts back to objects # Convert config dicts back to objects
logger.info("Initializing configuration objects") log.info("Initializing configuration objects")
provider_config_obj = HuggingFacePostTrainingConfig(**provider_config) provider_config_obj = HuggingFacePostTrainingConfig(**provider_config)
config_obj = TrainingConfig(**config) config_obj = TrainingConfig(**config)
# Initialize and validate device # Initialize and validate device
device = setup_torch_device(provider_config_obj.device) device = setup_torch_device(provider_config_obj.device)
logger.info(f"Using device '{device}'") log.info(f"Using device '{device}'")
# Load dataset and tokenizer # Load dataset and tokenizer
train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config_obj, provider_config_obj) train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config_obj, provider_config_obj)
@ -409,7 +409,7 @@ class HFFinetuningSingleDevice:
model_obj = load_model(model, device, provider_config_obj) model_obj = load_model(model, device, provider_config_obj)
# Initialize trainer # Initialize trainer
logger.info("Initializing SFTTrainer") log.info("Initializing SFTTrainer")
trainer = SFTTrainer( trainer = SFTTrainer(
model=model_obj, model=model_obj,
train_dataset=train_dataset, train_dataset=train_dataset,
@ -420,9 +420,9 @@ class HFFinetuningSingleDevice:
try: try:
# Train # Train
logger.info("Starting training") log.info("Starting training")
trainer.train() trainer.train()
logger.info("Training completed successfully") log.info("Training completed successfully")
# Save final model if output directory is provided # Save final model if output directory is provided
if output_dir_path: if output_dir_path:
@ -430,12 +430,12 @@ class HFFinetuningSingleDevice:
finally: finally:
# Clean up resources # Clean up resources
logger.info("Cleaning up resources") log.info("Cleaning up resources")
if hasattr(trainer, "model"): if hasattr(trainer, "model"):
evacuate_model_from_device(trainer.model, device.type) evacuate_model_from_device(trainer.model, device.type)
del trainer del trainer
gc.collect() gc.collect()
logger.info("Cleanup completed") log.info("Cleanup completed")
async def train( async def train(
self, self,
@ -449,7 +449,7 @@ class HFFinetuningSingleDevice:
"""Train a model using HuggingFace's SFTTrainer""" """Train a model using HuggingFace's SFTTrainer"""
# Initialize and validate device # Initialize and validate device
device = setup_torch_device(provider_config.device) device = setup_torch_device(provider_config.device)
logger.info(f"Using device '{device}'") log.info(f"Using device '{device}'")
output_dir_path = None output_dir_path = None
if output_dir: if output_dir:
@ -479,7 +479,7 @@ class HFFinetuningSingleDevice:
raise ValueError("DataConfig is required for training") raise ValueError("DataConfig is required for training")
# Train in a separate process # Train in a separate process
logger.info("Starting training in separate process") log.info("Starting training in separate process")
try: try:
# Setup multiprocessing for device # Setup multiprocessing for device
if device.type in ["cuda", "mps"]: if device.type in ["cuda", "mps"]:

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import gc import gc
import logging
import multiprocessing import multiprocessing
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@ -24,6 +23,7 @@ from llama_stack.apis.post_training import (
DPOAlignmentConfig, DPOAlignmentConfig,
TrainingConfig, TrainingConfig,
) )
from llama_stack.log import get_logger
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from ..config import HuggingFacePostTrainingConfig from ..config import HuggingFacePostTrainingConfig
@ -40,7 +40,7 @@ from ..utils import (
split_dataset, split_dataset,
) )
logger = logging.getLogger(__name__) logger = get_logger(__name__, category="core")
class HFDPOAlignmentSingleDevice: class HFDPOAlignmentSingleDevice:

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
import os import os
import signal import signal
import sys import sys
@ -19,10 +18,11 @@ from transformers import AutoConfig, AutoModelForCausalLM
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.post_training import Checkpoint, TrainingConfig from llama_stack.apis.post_training import Checkpoint, TrainingConfig
from llama_stack.log import get_logger
from .config import HuggingFacePostTrainingConfig from .config import HuggingFacePostTrainingConfig
logger = logging.getLogger(__name__) logger = get_logger(__name__, category="core")
def setup_environment(): def setup_environment():

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
import os import os
import time import time
from datetime import UTC, datetime from datetime import UTC, datetime
@ -19,6 +18,7 @@ from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training from torchtune import modules, training
from torchtune import utils as torchtune_utils from torchtune import utils as torchtune_utils
from torchtune.data import padded_collate_sft from torchtune.data import padded_collate_sft
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
from torchtune.modules.loss import CEWithChunkedOutputLoss from torchtune.modules.loss import CEWithChunkedOutputLoss
from torchtune.modules.peft import ( from torchtune.modules.peft import (
get_adapter_params, get_adapter_params,
@ -45,6 +45,7 @@ from llama_stack.apis.post_training import (
) )
from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.core.utils.model_utils import model_local_dir from llama_stack.core.utils.model_utils import model_local_dir
from llama_stack.log import get_logger
from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from llama_stack.providers.inline.post_training.torchtune.common import utils from llama_stack.providers.inline.post_training.torchtune.common import utils
@ -56,9 +57,7 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
) )
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
log = logging.getLogger(__name__) log = get_logger(name=__name__, category="core")
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
class LoraFinetuningSingleDevice: class LoraFinetuningSingleDevice:

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
from typing import Any from typing import Any
from llama_stack.apis.inference import Message from llama_stack.apis.inference import Message
@ -15,13 +14,14 @@ from llama_stack.apis.safety import (
ViolationLevel, ViolationLevel,
) )
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,
) )
from .config import CodeScannerConfig from .config import CodeScannerConfig
log = logging.getLogger(__name__) log = get_logger(name=__name__, category="safety")
ALLOWED_CODE_SCANNER_MODEL_IDS = [ ALLOWED_CODE_SCANNER_MODEL_IDS = [
"CodeScanner", "CodeScanner",

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
from typing import Any from typing import Any
import torch import torch
@ -19,6 +18,7 @@ from llama_stack.apis.safety import (
) )
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.core.utils.model_utils import model_local_dir from llama_stack.core.utils.model_utils import model_local_dir
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,
@ -26,10 +26,10 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import PromptGuardConfig, PromptGuardType from .config import PromptGuardConfig, PromptGuardType
log = logging.getLogger(__name__)
PROMPT_GUARD_MODEL = "Prompt-Guard-86M" PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
log = get_logger(name=__name__, category="safety")
class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate): class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
def __init__(self, config: PromptGuardConfig, _deps) -> None: def __init__(self, config: PromptGuardConfig, _deps) -> None:

View file

@ -7,7 +7,6 @@
import collections import collections
import functools import functools
import json import json
import logging
import random import random
import re import re
import string import string
@ -20,7 +19,9 @@ import nltk
from pythainlp.tokenize import sent_tokenize as sent_tokenize_thai from pythainlp.tokenize import sent_tokenize as sent_tokenize_thai
from pythainlp.tokenize import word_tokenize as word_tokenize_thai from pythainlp.tokenize import word_tokenize as word_tokenize_thai
logger = logging.getLogger() from llama_stack.log import get_logger
log = get_logger(name=__name__, category="core")
WORD_LIST = [ WORD_LIST = [
"western", "western",
@ -1726,7 +1727,7 @@ def get_langid(text: str, lid_path: str | None = None) -> str:
try: try:
line_langs.append(langdetect.detect(line)) line_langs.append(langdetect.detect(line))
except langdetect.LangDetectException as e: except langdetect.LangDetectException as e:
logger.info("Unable to detect language for text %s due to %s", line, e) # refex: disable=pytotw.037 log.info("Unable to detect language for text %s due to %s", line, e) # refex: disable=pytotw.037
if len(line_langs) == 0: if len(line_langs) == 0:
return "en" return "en"
@ -1885,7 +1886,7 @@ class ResponseLanguageChecker(Instruction):
return langdetect.detect(value) == self._language return langdetect.detect(value) == self._language
except langdetect.LangDetectException as e: except langdetect.LangDetectException as e:
# Count as instruction is followed. # Count as instruction is followed.
logger.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037 log.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037
return True return True
@ -3110,7 +3111,7 @@ class CapitalLettersEnglishChecker(Instruction):
return value.isupper() and langdetect.detect(value) == "en" return value.isupper() and langdetect.detect(value) == "en"
except langdetect.LangDetectException as e: except langdetect.LangDetectException as e:
# Count as instruction is followed. # Count as instruction is followed.
logger.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037 log.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037
return True return True
@ -3139,7 +3140,7 @@ class LowercaseLettersEnglishChecker(Instruction):
return value.islower() and langdetect.detect(value) == "en" return value.islower() and langdetect.detect(value) == "en"
except langdetect.LangDetectException as e: except langdetect.LangDetectException as e:
# Count as instruction is followed. # Count as instruction is followed.
logger.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037 log.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037
return True return True

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import logging
import secrets import secrets
import string import string
from typing import Any from typing import Any
@ -32,6 +31,7 @@ from llama_stack.apis.tools import (
ToolRuntime, ToolRuntime,
) )
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
@ -42,7 +42,7 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import RagToolRuntimeConfig from .config import RagToolRuntimeConfig
from .context_retriever import generate_rag_query from .context_retriever import generate_rag_query
log = logging.getLogger(__name__) log = get_logger(name=__name__, category="tools")
def make_random_string(length: int = 8): def make_random_string(length: int = 8):

View file

@ -8,7 +8,6 @@ import asyncio
import base64 import base64
import io import io
import json import json
import logging
from typing import Any from typing import Any
import faiss import faiss
@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import (
QueryChunksResponse, QueryChunksResponse,
VectorIO, VectorIO,
) )
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ( from llama_stack.providers.datatypes import (
HealthResponse, HealthResponse,
HealthStatus, HealthStatus,
@ -39,7 +39,7 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import FaissVectorIOConfig from .config import FaissVectorIOConfig
logger = logging.getLogger(__name__) log = get_logger(name=__name__, category="core")
VERSION = "v3" VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::" VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::"
@ -83,7 +83,7 @@ class FaissIndex(EmbeddingIndex):
self.index = faiss.deserialize_index(np.load(buffer, allow_pickle=False)) self.index = faiss.deserialize_index(np.load(buffer, allow_pickle=False))
self.chunk_ids = [chunk.chunk_id for chunk in self.chunk_by_index.values()] self.chunk_ids = [chunk.chunk_id for chunk in self.chunk_by_index.values()]
except Exception as e: except Exception as e:
logger.debug(e, exc_info=True) log.debug(e, exc_info=True)
raise ValueError( raise ValueError(
"Error deserializing Faiss index from storage. If you recently upgraded your Llama Stack, Faiss, " "Error deserializing Faiss index from storage. If you recently upgraded your Llama Stack, Faiss, "
"or NumPy versions, you may need to delete the index and re-create it again or downgrade versions.\n" "or NumPy versions, you may need to delete the index and re-create it again or downgrade versions.\n"
@ -262,7 +262,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
assert self.kvstore is not None assert self.kvstore is not None
if vector_db_id not in self.cache: if vector_db_id not in self.cache:
logger.warning(f"Vector DB {vector_db_id} not found") log.warning(f"Vector DB {vector_db_id} not found")
return return
await self.cache[vector_db_id].index.delete() await self.cache[vector_db_id].index.delete()

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import logging
import re import re
import sqlite3 import sqlite3
import struct import struct
@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import (
QueryChunksResponse, QueryChunksResponse,
VectorIO, VectorIO,
) )
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.kvstore.api import KVStore
@ -35,7 +35,7 @@ from llama_stack.providers.utils.memory.vector_store import (
VectorDBWithIndex, VectorDBWithIndex,
) )
logger = logging.getLogger(__name__) log = get_logger(name=__name__, category="core")
# Specifying search mode is dependent on the VectorIO provider. # Specifying search mode is dependent on the VectorIO provider.
VECTOR_SEARCH = "vector" VECTOR_SEARCH = "vector"
@ -257,7 +257,7 @@ class SQLiteVecIndex(EmbeddingIndex):
except sqlite3.Error as e: except sqlite3.Error as e:
connection.rollback() connection.rollback()
logger.error(f"Error inserting into {self.vector_table}: {e}") log.error(f"Error inserting into {self.vector_table}: {e}")
raise raise
finally: finally:
@ -306,7 +306,7 @@ class SQLiteVecIndex(EmbeddingIndex):
try: try:
chunk = Chunk.model_validate_json(chunk_json) chunk = Chunk.model_validate_json(chunk_json)
except Exception as e: except Exception as e:
logger.error(f"Error parsing chunk JSON for id {_id}: {e}") log.error(f"Error parsing chunk JSON for id {_id}: {e}")
continue continue
chunks.append(chunk) chunks.append(chunk)
scores.append(score) scores.append(score)
@ -352,7 +352,7 @@ class SQLiteVecIndex(EmbeddingIndex):
try: try:
chunk = Chunk.model_validate_json(chunk_json) chunk = Chunk.model_validate_json(chunk_json)
except Exception as e: except Exception as e:
logger.error(f"Error parsing chunk JSON for id {_id}: {e}") log.error(f"Error parsing chunk JSON for id {_id}: {e}")
continue continue
chunks.append(chunk) chunks.append(chunk)
scores.append(score) scores.append(score)
@ -447,7 +447,7 @@ class SQLiteVecIndex(EmbeddingIndex):
connection.commit() connection.commit()
except Exception as e: except Exception as e:
connection.rollback() connection.rollback()
logger.error(f"Error deleting chunk {chunk_id}: {e}") log.error(f"Error deleting chunk {chunk_id}: {e}")
raise raise
finally: finally:
cur.close() cur.close()
@ -530,7 +530,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
async def unregister_vector_db(self, vector_db_id: str) -> None: async def unregister_vector_db(self, vector_db_id: str) -> None:
if vector_db_id not in self.cache: if vector_db_id not in self.cache:
logger.warning(f"Vector DB {vector_db_id} not found") log.warning(f"Vector DB {vector_db_id} not found")
return return
await self.cache[vector_db_id].index.delete() await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id] del self.cache[vector_db_id]

View file

@ -65,7 +65,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import FireworksImplConfig from .config import FireworksImplConfig
from .models import MODEL_ENTRIES from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference") log = get_logger(name=__name__, category="inference")
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
@ -256,7 +256,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
"stream": bool(request.stream), "stream": bool(request.stream),
**self._build_options(request.sampling_params, request.response_format, request.logprobs), **self._build_options(request.sampling_params, request.response_format, request.logprobs),
} }
logger.debug(f"params to fireworks: {params}") log.debug(f"params to fireworks: {params}")
return params return params

View file

@ -3,7 +3,6 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
@ -11,8 +10,6 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .models import MODEL_ENTRIES from .models import MODEL_ENTRIES
logger = logging.getLogger(__name__)
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
""" """

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
import warnings import warnings
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
@ -33,6 +32,7 @@ from llama_stack.apis.inference import (
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
) )
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
@ -54,7 +54,7 @@ from .openai_utils import (
) )
from .utils import _is_nvidia_hosted from .utils import _is_nvidia_hosted
logger = logging.getLogger(__name__) log = get_logger(name=__name__, category="inference")
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper): class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
@ -75,7 +75,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
# TODO(mf): filter by available models # TODO(mf): filter by available models
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
logger.info(f"Initializing NVIDIAInferenceAdapter({config.url})...") log.info(f"Initializing NVIDIAInferenceAdapter({config.url})...")
if _is_nvidia_hosted(config): if _is_nvidia_hosted(config):
if not config.api_key: if not config.api_key:

View file

@ -4,13 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
import httpx import httpx
from llama_stack.log import get_logger
from . import NVIDIAConfig from . import NVIDIAConfig
logger = logging.getLogger(__name__) log = get_logger(name=__name__, category="inference")
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool: def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
@ -44,7 +45,7 @@ async def check_health(config: NVIDIAConfig) -> None:
RuntimeError: If the server is not running or ready RuntimeError: If the server is not running or ready
""" """
if not _is_nvidia_hosted(config): if not _is_nvidia_hosted(config):
logger.info("Checking NVIDIA NIM health...") log.info("Checking NVIDIA NIM health...")
try: try:
is_live, is_ready = await _get_health(config.url) is_live, is_ready = await _get_health(config.url)
if not is_live: if not is_live:

View file

@ -85,7 +85,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .models import MODEL_ENTRIES from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference") log = get_logger(name=__name__, category="inference")
class OllamaInferenceAdapter( class OllamaInferenceAdapter(
@ -117,10 +117,10 @@ class OllamaInferenceAdapter(
return self._openai_client return self._openai_client
async def initialize(self) -> None: async def initialize(self) -> None:
logger.info(f"checking connectivity to Ollama at `{self.config.url}`...") log.info(f"checking connectivity to Ollama at `{self.config.url}`...")
health_response = await self.health() health_response = await self.health()
if health_response["status"] == HealthStatus.ERROR: if health_response["status"] == HealthStatus.ERROR:
logger.warning( log.warning(
"Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal" "Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal"
) )
@ -339,7 +339,7 @@ class OllamaInferenceAdapter(
"options": sampling_options, "options": sampling_options,
"stream": request.stream, "stream": request.stream,
} }
logger.debug(f"params to ollama: {params}") log.debug(f"params to ollama: {params}")
return params return params
@ -437,7 +437,7 @@ class OllamaInferenceAdapter(
if provider_resource_id not in available_models: if provider_resource_id not in available_models:
available_models_latest = [m.model.split(":latest")[0] for m in response.models] available_models_latest = [m.model.split(":latest")[0] for m in response.models]
if provider_resource_id in available_models_latest: if provider_resource_id in available_models_latest:
logger.warning( log.warning(
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'" f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
) )
return model return model

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -12,8 +11,6 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import OpenAIConfig from .config import OpenAIConfig
from .models import MODEL_ENTRIES from .models import MODEL_ENTRIES
logger = logging.getLogger(__name__)
# #
# This OpenAI adapter implements Inference methods using two mixins - # This OpenAI adapter implements Inference methods using two mixins -

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import logging
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from huggingface_hub import AsyncInferenceClient, HfApi from huggingface_hub import AsyncInferenceClient, HfApi
@ -34,6 +33,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.log import get_logger
from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
@ -58,7 +58,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
log = logging.getLogger(__name__) logger = get_logger(__name__, category="core")
def build_hf_repo_model_entries(): def build_hf_repo_model_entries():
@ -307,7 +307,7 @@ class TGIAdapter(_HfAdapter):
async def initialize(self, config: TGIImplConfig) -> None: async def initialize(self, config: TGIImplConfig) -> None:
if not config.url: if not config.url:
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.") raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
log.info(f"Initializing TGI client with url={config.url}") logger.info(f"Initializing TGI client with url={config.url}")
self.client = AsyncInferenceClient( self.client = AsyncInferenceClient(
model=config.url, model=config.url,
) )

View file

@ -61,7 +61,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import TogetherImplConfig from .config import TogetherImplConfig
from .models import MODEL_ENTRIES from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference") log = get_logger(name=__name__, category="inference")
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
@ -232,7 +232,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
"stream": request.stream, "stream": request.stream,
**self._build_options(request.sampling_params, request.logprobs, request.response_format), **self._build_options(request.sampling_params, request.logprobs, request.response_format),
} }
logger.debug(f"params to together: {params}") log.debug(f"params to together: {params}")
return params return params
async def embeddings( async def embeddings(

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
import warnings import warnings
from typing import Any from typing import Any
@ -15,8 +14,6 @@ from llama_stack.providers.remote.post_training.nvidia.config import SFTLoRADefa
from .config import NvidiaPostTrainingConfig from .config import NvidiaPostTrainingConfig
logger = logging.getLogger(__name__)
def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None: def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None:
keys = set(config_dict.__annotations__.keys()) if isinstance(config_dict, BaseModel) else config_dict.keys() keys = set(config_dict.__annotations__.keys()) if isinstance(config_dict, BaseModel) else config_dict.keys()

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
import logging
from typing import Any from typing import Any
from llama_stack.apis.inference import Message from llama_stack.apis.inference import Message
@ -16,12 +15,13 @@ from llama_stack.apis.safety import (
ViolationLevel, ViolationLevel,
) )
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.bedrock.client import create_bedrock_client from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from .config import BedrockSafetyConfig from .config import BedrockSafetyConfig
logger = logging.getLogger(__name__) log = get_logger(name=__name__, category="safety")
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
@ -76,13 +76,13 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
""" """
shield_params = shield.params shield_params = shield.params
logger.debug(f"run_shield::{shield_params}::messages={messages}") log.debug(f"run_shield::{shield_params}::messages={messages}")
# - convert the messages into format Bedrock expects # - convert the messages into format Bedrock expects
content_messages = [] content_messages = []
for message in messages: for message in messages:
content_messages.append({"text": {"text": message.content}}) content_messages.append({"text": {"text": message.content}})
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:") log.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
response = self.bedrock_runtime_client.apply_guardrail( response = self.bedrock_runtime_client.apply_guardrail(
guardrailIdentifier=shield.provider_resource_id, guardrailIdentifier=shield.provider_resource_id,

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
from typing import Any from typing import Any
import requests import requests
@ -17,8 +16,6 @@ from llama_stack.providers.utils.inference.openai_compat import convert_message_
from .config import NVIDIASafetyConfig from .config import NVIDIASafetyConfig
logger = logging.getLogger(__name__)
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate): class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
def __init__(self, config: NVIDIASafetyConfig) -> None: def __init__(self, config: NVIDIASafetyConfig) -> None:

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
import logging
from typing import Any from typing import Any
import litellm import litellm
@ -20,12 +19,13 @@ from llama_stack.apis.safety import (
) )
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
from .config import SambaNovaSafetyConfig from .config import SambaNovaSafetyConfig
logger = logging.getLogger(__name__) log = get_logger(name=__name__, category="safety")
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
@ -66,7 +66,7 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
"guard" not in shield.provider_resource_id.lower() "guard" not in shield.provider_resource_id.lower()
or shield.provider_resource_id.split("sambanova/")[-1] not in self.environment_available_models or shield.provider_resource_id.split("sambanova/")[-1] not in self.environment_available_models
): ):
logger.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}") log.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}")
async def unregister_shield(self, identifier: str) -> None: async def unregister_shield(self, identifier: str) -> None:
pass pass
@ -79,9 +79,9 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
raise ValueError(f"Shield {shield_id} not found") raise ValueError(f"Shield {shield_id} not found")
shield_params = shield.params shield_params = shield.params
logger.debug(f"run_shield::{shield_params}::messages={messages}") log.debug(f"run_shield::{shield_params}::messages={messages}")
content_messages = [await convert_message_to_openai_dict_new(m) for m in messages] content_messages = [await convert_message_to_openai_dict_new(m) for m in messages]
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:") log.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
response = litellm.completion( response = litellm.completion(
model=shield.provider_resource_id, messages=content_messages, api_key=self._get_api_key() model=shield.provider_resource_id, messages=content_messages, api_key=self._get_api_key()

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import json import json
import logging
from typing import Any from typing import Any
from urllib.parse import urlparse from urllib.parse import urlparse
@ -20,6 +19,7 @@ from llama_stack.apis.vector_io import (
QueryChunksResponse, QueryChunksResponse,
VectorIO, VectorIO,
) )
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
@ -32,8 +32,6 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
log = logging.getLogger(__name__)
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI
VERSION = "v3" VERSION = "v3"
@ -43,6 +41,8 @@ OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:chroma:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:chroma:{VERSION}::" OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:chroma:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:chroma:{VERSION}::" OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:chroma:{VERSION}::"
logger = get_logger(__name__, category="core")
# this is a helper to allow us to use async and non-async chroma clients interchangeably # this is a helper to allow us to use async and non-async chroma clients interchangeably
async def maybe_await(result): async def maybe_await(result):
@ -92,7 +92,7 @@ class ChromaIndex(EmbeddingIndex):
doc = json.loads(doc) doc = json.loads(doc)
chunk = Chunk(**doc) chunk = Chunk(**doc)
except Exception: except Exception:
log.exception(f"Failed to parse document: {doc}") logger.exception(f"Failed to parse document: {doc}")
continue continue
score = 1.0 / float(dist) if dist != 0 else float("inf") score = 1.0 / float(dist) if dist != 0 else float("inf")
@ -137,7 +137,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
inference_api: Api.inference, inference_api: Api.inference,
files_api: Files | None, files_api: Files | None,
) -> None: ) -> None:
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}") logger.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
self.client = None self.client = None
@ -150,7 +150,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.vector_db_store = self.kvstore self.vector_db_store = self.kvstore
if isinstance(self.config, RemoteChromaVectorIOConfig): if isinstance(self.config, RemoteChromaVectorIOConfig):
log.info(f"Connecting to Chroma server at: {self.config.url}") logger.info(f"Connecting to Chroma server at: {self.config.url}")
url = self.config.url.rstrip("/") url = self.config.url.rstrip("/")
parsed = urlparse(url) parsed = urlparse(url)
@ -159,7 +159,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.client = await chromadb.AsyncHttpClient(host=parsed.hostname, port=parsed.port) self.client = await chromadb.AsyncHttpClient(host=parsed.hostname, port=parsed.port)
else: else:
log.info(f"Connecting to Chroma local db at: {self.config.db_path}") logger.info(f"Connecting to Chroma local db at: {self.config.db_path}")
self.client = chromadb.PersistentClient(path=self.config.db_path) self.client = chromadb.PersistentClient(path=self.config.db_path)
self.openai_vector_stores = await self._load_openai_vector_stores() self.openai_vector_stores = await self._load_openai_vector_stores()
@ -182,7 +182,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
async def unregister_vector_db(self, vector_db_id: str) -> None: async def unregister_vector_db(self, vector_db_id: str) -> None:
if vector_db_id not in self.cache: if vector_db_id not in self.cache:
log.warning(f"Vector DB {vector_db_id} not found") logger.warning(f"Vector DB {vector_db_id} not found")
return return
await self.cache[vector_db_id].index.delete() await self.cache[vector_db_id].index.delete()

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import logging
import os import os
from typing import Any from typing import Any
@ -21,6 +20,7 @@ from llama_stack.apis.vector_io import (
QueryChunksResponse, QueryChunksResponse,
VectorIO, VectorIO,
) )
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
@ -34,7 +34,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
logger = logging.getLogger(__name__) log = get_logger(name=__name__, category="core")
VERSION = "v3" VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::" VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::"
@ -68,7 +68,7 @@ class MilvusIndex(EmbeddingIndex):
) )
if not await asyncio.to_thread(self.client.has_collection, self.collection_name): if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field") log.info(f"Creating new collection {self.collection_name} with nullable sparse field")
# Create schema for vector search # Create schema for vector search
schema = self.client.create_schema() schema = self.client.create_schema()
schema.add_field( schema.add_field(
@ -147,7 +147,7 @@ class MilvusIndex(EmbeddingIndex):
data=data, data=data,
) )
except Exception as e: except Exception as e:
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}") log.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
raise e raise e
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
@ -203,7 +203,7 @@ class MilvusIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores) return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores)
except Exception as e: except Exception as e:
logger.error(f"Error performing BM25 search: {e}") log.error(f"Error performing BM25 search: {e}")
# Fallback to simple text search # Fallback to simple text search
return await self._fallback_keyword_search(query_string, k, score_threshold) return await self._fallback_keyword_search(query_string, k, score_threshold)
@ -247,7 +247,7 @@ class MilvusIndex(EmbeddingIndex):
self.client.delete, collection_name=self.collection_name, filter=f'chunk_id == "{chunk_id}"' self.client.delete, collection_name=self.collection_name, filter=f'chunk_id == "{chunk_id}"'
) )
except Exception as e: except Exception as e:
logger.error(f"Error deleting chunk {chunk_id} from Milvus collection {self.collection_name}: {e}") log.error(f"Error deleting chunk {chunk_id} from Milvus collection {self.collection_name}: {e}")
raise raise
@ -288,10 +288,10 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
) )
self.cache[vector_db.identifier] = index self.cache[vector_db.identifier] = index
if isinstance(self.config, RemoteMilvusVectorIOConfig): if isinstance(self.config, RemoteMilvusVectorIOConfig):
logger.info(f"Connecting to Milvus server at {self.config.uri}") log.info(f"Connecting to Milvus server at {self.config.uri}")
self.client = MilvusClient(**self.config.model_dump(exclude_none=True)) self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
else: else:
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}") log.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
uri = os.path.expanduser(self.config.db_path) uri = os.path.expanduser(self.config.db_path)
self.client = MilvusClient(uri=uri) self.client = MilvusClient(uri=uri)

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
from typing import Any from typing import Any
import psycopg2 import psycopg2
@ -22,6 +21,7 @@ from llama_stack.apis.vector_io import (
QueryChunksResponse, QueryChunksResponse,
VectorIO, VectorIO,
) )
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.kvstore.api import KVStore
@ -33,8 +33,6 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import PGVectorVectorIOConfig from .config import PGVectorVectorIOConfig
log = logging.getLogger(__name__)
VERSION = "v3" VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::" VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::"
VECTOR_INDEX_PREFIX = f"vector_index:pgvector:{VERSION}::" VECTOR_INDEX_PREFIX = f"vector_index:pgvector:{VERSION}::"
@ -42,6 +40,8 @@ OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:pgvector:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:pgvector:{VERSION}::" OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:pgvector:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:pgvector:{VERSION}::" OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:pgvector:{VERSION}::"
logger = get_logger(__name__, category="core")
def check_extension_version(cur): def check_extension_version(cur):
cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'") cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'")
@ -187,7 +187,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
self.metadatadata_collection_name = "openai_vector_stores_metadata" self.metadatadata_collection_name = "openai_vector_stores_metadata"
async def initialize(self) -> None: async def initialize(self) -> None:
log.info(f"Initializing PGVector memory adapter with config: {self.config}") logger.info(f"Initializing PGVector memory adapter with config: {self.config}")
self.kvstore = await kvstore_impl(self.config.kvstore) self.kvstore = await kvstore_impl(self.config.kvstore)
await self.initialize_openai_vector_stores() await self.initialize_openai_vector_stores()
@ -203,7 +203,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
version = check_extension_version(cur) version = check_extension_version(cur)
if version: if version:
log.info(f"Vector extension version: {version}") logger.info(f"Vector extension version: {version}")
else: else:
raise RuntimeError("Vector extension is not installed.") raise RuntimeError("Vector extension is not installed.")
@ -216,13 +216,13 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
""" """
) )
except Exception as e: except Exception as e:
log.exception("Could not connect to PGVector database server") logger.exception("Could not connect to PGVector database server")
raise RuntimeError("Could not connect to PGVector database server") from e raise RuntimeError("Could not connect to PGVector database server") from e
async def shutdown(self) -> None: async def shutdown(self) -> None:
if self.conn is not None: if self.conn is not None:
self.conn.close() self.conn.close()
log.info("Connection to PGVector database server closed") logger.info("Connection to PGVector database server closed")
async def register_vector_db(self, vector_db: VectorDB) -> None: async def register_vector_db(self, vector_db: VectorDB) -> None:
# Persist vector DB metadata in the KV store # Persist vector DB metadata in the KV store

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import logging
import uuid import uuid
from typing import Any from typing import Any
@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import (
VectorStoreChunkingStrategy, VectorStoreChunkingStrategy,
VectorStoreFileObject, VectorStoreFileObject,
) )
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
@ -35,13 +35,14 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
log = logging.getLogger(__name__)
CHUNK_ID_KEY = "_chunk_id" CHUNK_ID_KEY = "_chunk_id"
# KV store prefixes for vector databases # KV store prefixes for vector databases
VERSION = "v3" VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:qdrant:{VERSION}::" VECTOR_DBS_PREFIX = f"vector_dbs:qdrant:{VERSION}::"
logger = get_logger(__name__, category="core")
def convert_id(_id: str) -> str: def convert_id(_id: str) -> str:
""" """
@ -96,7 +97,7 @@ class QdrantIndex(EmbeddingIndex):
points_selector=models.PointIdsList(points=[convert_id(chunk_id)]), points_selector=models.PointIdsList(points=[convert_id(chunk_id)]),
) )
except Exception as e: except Exception as e:
log.error(f"Error deleting chunk {chunk_id} from Qdrant collection {self.collection_name}: {e}") logger.error(f"Error deleting chunk {chunk_id} from Qdrant collection {self.collection_name}: {e}")
raise raise
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
@ -118,7 +119,7 @@ class QdrantIndex(EmbeddingIndex):
try: try:
chunk = Chunk(**point.payload["chunk_content"]) chunk = Chunk(**point.payload["chunk_content"])
except Exception: except Exception:
log.exception("Failed to parse chunk") logger.exception("Failed to parse chunk")
continue continue
chunks.append(chunk) chunks.append(chunk)

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json import json
import logging
from typing import Any from typing import Any
import weaviate import weaviate
@ -19,6 +18,7 @@ from llama_stack.apis.files.files import Files
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.kvstore.api import KVStore
@ -33,8 +33,6 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
from .config import WeaviateVectorIOConfig from .config import WeaviateVectorIOConfig
log = logging.getLogger(__name__)
VERSION = "v3" VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::" VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"
VECTOR_INDEX_PREFIX = f"vector_index:weaviate:{VERSION}::" VECTOR_INDEX_PREFIX = f"vector_index:weaviate:{VERSION}::"
@ -42,6 +40,8 @@ OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:weaviate:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:weaviate:{VERSION}::" OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:weaviate:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:weaviate:{VERSION}::" OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:weaviate:{VERSION}::"
logger = get_logger(__name__, category="core")
class WeaviateIndex(EmbeddingIndex): class WeaviateIndex(EmbeddingIndex):
def __init__( def __init__(
@ -102,7 +102,7 @@ class WeaviateIndex(EmbeddingIndex):
chunk_dict = json.loads(chunk_json) chunk_dict = json.loads(chunk_json)
chunk = Chunk(**chunk_dict) chunk = Chunk(**chunk_dict)
except Exception: except Exception:
log.exception(f"Failed to parse document: {chunk_json}") logger.exception(f"Failed to parse document: {chunk_json}")
continue continue
score = 1.0 / doc.metadata.distance if doc.metadata.distance != 0 else float("inf") score = 1.0 / doc.metadata.distance if doc.metadata.distance != 0 else float("inf")
@ -171,7 +171,7 @@ class WeaviateVectorIOAdapter(
def _get_client(self) -> weaviate.Client: def _get_client(self) -> weaviate.Client:
if "localhost" in self.config.weaviate_cluster_url: if "localhost" in self.config.weaviate_cluster_url:
log.info("using Weaviate locally in container") logger.info("using Weaviate locally in container")
host, port = self.config.weaviate_cluster_url.split(":") host, port = self.config.weaviate_cluster_url.split(":")
key = "local_test" key = "local_test"
client = weaviate.connect_to_local( client = weaviate.connect_to_local(
@ -179,7 +179,7 @@ class WeaviateVectorIOAdapter(
port=port, port=port,
) )
else: else:
log.info("Using Weaviate remote cluster with URL") logger.info("Using Weaviate remote cluster with URL")
key = f"{self.config.weaviate_cluster_url}::{self.config.weaviate_api_key}" key = f"{self.config.weaviate_cluster_url}::{self.config.weaviate_api_key}"
if key in self.client_cache: if key in self.client_cache:
return self.client_cache[key] return self.client_cache[key]
@ -197,7 +197,7 @@ class WeaviateVectorIOAdapter(
self.kvstore = await kvstore_impl(self.config.kvstore) self.kvstore = await kvstore_impl(self.config.kvstore)
else: else:
self.kvstore = None self.kvstore = None
log.info("No kvstore configured, registry will not persist across restarts") logger.info("No kvstore configured, registry will not persist across restarts")
# Load existing vector DB definitions # Load existing vector DB definitions
if self.kvstore is not None: if self.kvstore is not None:
@ -254,7 +254,7 @@ class WeaviateVectorIOAdapter(
client = self._get_client() client = self._get_client()
sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True) sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True)
if sanitized_collection_name not in self.cache or client.collections.exists(sanitized_collection_name) is False: if sanitized_collection_name not in self.cache or client.collections.exists(sanitized_collection_name) is False:
log.warning(f"Vector DB {sanitized_collection_name} not found") logger.warning(f"Vector DB {sanitized_collection_name} not found")
return return
client.collections.delete(sanitized_collection_name) client.collections.delete(sanitized_collection_name)
await self.cache[sanitized_collection_name].index.delete() await self.cache[sanitized_collection_name].index.delete()

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import base64 import base64
import logging
import struct import struct
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@ -27,7 +26,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import interleaved_con
EMBEDDING_MODELS = {} EMBEDDING_MODELS = {}
log = logging.getLogger(__name__) from llama_stack.log import get_logger
log = get_logger(name=__name__, category="inference")
class SentenceTransformerEmbeddingMixin: class SentenceTransformerEmbeddingMixin:

View file

@ -54,7 +54,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,
) )
logger = get_logger(name=__name__, category="inference") log = get_logger(name=__name__, category="inference")
class LiteLLMOpenAIMixin( class LiteLLMOpenAIMixin(
@ -157,7 +157,7 @@ class LiteLLMOpenAIMixin(
params = await self._get_params(request) params = await self._get_params(request)
params["model"] = self.get_litellm_model_name(params["model"]) params["model"] = self.get_litellm_model_name(params["model"])
logger.debug(f"params to litellm (openai compat): {params}") log.debug(f"params to litellm (openai compat): {params}")
# see https://docs.litellm.ai/docs/completion/stream#async-completion # see https://docs.litellm.ai/docs/completion/stream#async-completion
response = await litellm.acompletion(**params) response = await litellm.acompletion(**params)
if stream: if stream:
@ -460,7 +460,7 @@ class LiteLLMOpenAIMixin(
:return: True if the model is available dynamically, False otherwise. :return: True if the model is available dynamically, False otherwise.
""" """
if self.litellm_provider_name not in litellm.models_by_provider: if self.litellm_provider_name not in litellm.models_by_provider:
logger.error(f"Provider {self.litellm_provider_name} is not registered in litellm.") log.error(f"Provider {self.litellm_provider_name} is not registered in litellm.")
return False return False
return model in litellm.models_by_provider[self.litellm_provider_name] return model in litellm.models_by_provider[self.litellm_provider_name]

View file

@ -17,7 +17,7 @@ from llama_stack.providers.utils.inference import (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR, ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
) )
logger = get_logger(name=__name__, category="core") log = get_logger(name=__name__, category="core")
class RemoteInferenceProviderConfig(BaseModel): class RemoteInferenceProviderConfig(BaseModel):
@ -135,7 +135,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
:param model: The model identifier to check. :param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise. :return: True if the model is available dynamically, False otherwise.
""" """
logger.info( log.info(
f"check_model_availability is not implemented for {self.__class__.__name__}. Returning False by default." f"check_model_availability is not implemented for {self.__class__.__name__}. Returning False by default."
) )
return False return False

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import base64 import base64
import json import json
import logging
import struct import struct
import time import time
import uuid import uuid
@ -116,6 +115,7 @@ from llama_stack.apis.inference import (
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
OpenAIChoice as OpenAIChatCompletionChoice, OpenAIChoice as OpenAIChatCompletionChoice,
) )
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
BuiltinTool, BuiltinTool,
StopReason, StopReason,
@ -128,7 +128,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
decode_assistant_message, decode_assistant_message,
) )
logger = logging.getLogger(__name__) log = get_logger(name=__name__, category="inference")
class OpenAICompatCompletionChoiceDelta(BaseModel): class OpenAICompatCompletionChoiceDelta(BaseModel):
@ -316,7 +316,7 @@ def process_chat_completion_response(
if t.tool_name in request_tools: if t.tool_name in request_tools:
new_tool_calls.append(t) new_tool_calls.append(t)
else: else:
logger.warning(f"Tool {t.tool_name} not found in request tools") log.warning(f"Tool {t.tool_name} not found in request tools")
if len(new_tool_calls) < len(raw_message.tool_calls): if len(new_tool_calls) < len(raw_message.tool_calls):
raw_message.tool_calls = new_tool_calls raw_message.tool_calls = new_tool_calls
@ -477,7 +477,7 @@ async def process_chat_completion_stream_response(
) )
) )
else: else:
logger.warning(f"Tool {tool_call.tool_name} not found in request tools") log.warning(f"Tool {tool_call.tool_name} not found in request tools")
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress, event_type=ChatCompletionResponseEventType.progress,
@ -1198,7 +1198,7 @@ async def convert_openai_chat_completion_stream(
) )
for idx, buffer in tool_call_idx_to_buffer.items(): for idx, buffer in tool_call_idx_to_buffer.items():
logger.debug(f"toolcall_buffer[{idx}]: {buffer}") log.debug(f"toolcall_buffer[{idx}]: {buffer}")
if buffer["name"]: if buffer["name"]:
delta = ")" delta = ")"
buffer["content"] += delta buffer["content"] += delta

View file

@ -25,7 +25,7 @@ from llama_stack.apis.inference import (
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
logger = get_logger(name=__name__, category="core") log = get_logger(name=__name__, category="core")
class OpenAIMixin(ABC): class OpenAIMixin(ABC):
@ -125,9 +125,9 @@ class OpenAIMixin(ABC):
Direct OpenAI completion API call. Direct OpenAI completion API call.
""" """
if guided_choice is not None: if guided_choice is not None:
logger.warning("guided_choice is not supported by the OpenAI API. Ignoring.") log.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
if prompt_logprobs is not None: if prompt_logprobs is not None:
logger.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.") log.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")
# TODO: fix openai_completion to return type compatible with OpenAI's API response # TODO: fix openai_completion to return type compatible with OpenAI's API response
return await self.client.completions.create( # type: ignore[no-any-return] return await self.client.completions.create( # type: ignore[no-any-return]
@ -267,6 +267,6 @@ class OpenAIMixin(ABC):
pass pass
except Exception as e: except Exception as e:
# All other errors (auth, rate limit, network, etc.) # All other errors (auth, rate limit, network, etc.)
logger.warning(f"Failed to check model availability for {model}: {e}") log.warning(f"Failed to check model availability for {model}: {e}")
return False return False

View file

@ -4,16 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
from datetime import datetime from datetime import datetime
from pymongo import AsyncMongoClient from pymongo import AsyncMongoClient
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.kvstore import KVStore
from ..config import MongoDBKVStoreConfig from ..config import MongoDBKVStoreConfig
log = logging.getLogger(__name__) log = get_logger(name=__name__, category="core")
class MongoDBKVStoreImpl(KVStore): class MongoDBKVStoreImpl(KVStore):

View file

@ -4,16 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
from datetime import datetime from datetime import datetime
import psycopg2 import psycopg2
from psycopg2.extras import DictCursor from psycopg2.extras import DictCursor
from llama_stack.log import get_logger
from ..api import KVStore from ..api import KVStore
from ..config import PostgresKVStoreConfig from ..config import PostgresKVStoreConfig
log = logging.getLogger(__name__) log = get_logger(name=__name__, category="core")
class PostgresKVStoreImpl(KVStore): class PostgresKVStoreImpl(KVStore):

View file

@ -6,7 +6,6 @@
import asyncio import asyncio
import json import json
import logging
import mimetypes import mimetypes
import time import time
import uuid import uuid
@ -37,10 +36,11 @@ from llama_stack.apis.vector_io import (
VectorStoreSearchResponse, VectorStoreSearchResponse,
VectorStoreSearchResponsePage, VectorStoreSearchResponsePage,
) )
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, make_overlapped_chunks from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, make_overlapped_chunks
logger = logging.getLogger(__name__) log = get_logger(name=__name__, category="core")
# Constants for OpenAI vector stores # Constants for OpenAI vector stores
CHUNK_MULTIPLIER = 5 CHUNK_MULTIPLIER = 5
@ -378,7 +378,7 @@ class OpenAIVectorStoreMixin(ABC):
try: try:
await self.unregister_vector_db(vector_store_id) await self.unregister_vector_db(vector_store_id)
except Exception as e: except Exception as e:
logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}") log.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}")
return VectorStoreDeleteResponse( return VectorStoreDeleteResponse(
id=vector_store_id, id=vector_store_id,
@ -460,7 +460,7 @@ class OpenAIVectorStoreMixin(ABC):
) )
except Exception as e: except Exception as e:
logger.error(f"Error searching vector store {vector_store_id}: {e}") log.error(f"Error searching vector store {vector_store_id}: {e}")
# Return empty results on error # Return empty results on error
return VectorStoreSearchResponsePage( return VectorStoreSearchResponsePage(
search_query=search_query, search_query=search_query,
@ -614,7 +614,7 @@ class OpenAIVectorStoreMixin(ABC):
) )
vector_store_file_object.status = "completed" vector_store_file_object.status = "completed"
except Exception as e: except Exception as e:
logger.error(f"Error attaching file to vector store: {e}") log.error(f"Error attaching file to vector store: {e}")
vector_store_file_object.status = "failed" vector_store_file_object.status = "failed"
vector_store_file_object.last_error = VectorStoreFileLastError( vector_store_file_object.last_error = VectorStoreFileLastError(
code="server_error", code="server_error",

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import base64 import base64
import io import io
import logging
import re import re
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -25,6 +24,7 @@ from llama_stack.apis.common.content_types import (
from llama_stack.apis.tools import RAGDocument from llama_stack.apis.tools import RAGDocument
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
@ -32,12 +32,12 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
) )
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
log = logging.getLogger(__name__)
# Constants for reranker types # Constants for reranker types
RERANKER_TYPE_RRF = "rrf" RERANKER_TYPE_RRF = "rrf"
RERANKER_TYPE_WEIGHTED = "weighted" RERANKER_TYPE_WEIGHTED = "weighted"
log = get_logger(name=__name__, category="memory")
def parse_pdf(data: bytes) -> str: def parse_pdf(data: bytes) -> str:
# For PDF and DOC/DOCX files, we can't reliably convert to string # For PDF and DOC/DOCX files, we can't reliably convert to string

View file

@ -17,7 +17,7 @@ from pydantic import BaseModel
from llama_stack.log import get_logger from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="scheduler") log = get_logger(name=__name__, category="scheduler")
# TODO: revisit the list of possible statuses when defining a more coherent # TODO: revisit the list of possible statuses when defining a more coherent
@ -186,7 +186,7 @@ class _NaiveSchedulerBackend(_SchedulerBackend):
except Exception as e: except Exception as e:
on_log_message_cb(str(e)) on_log_message_cb(str(e))
job.status = JobStatus.failed job.status = JobStatus.failed
logger.exception(f"Job {job.id} failed.") log.exception(f"Job {job.id} failed.")
asyncio.run_coroutine_threadsafe(do(), self._loop) asyncio.run_coroutine_threadsafe(do(), self._loop)
@ -222,7 +222,7 @@ class Scheduler:
msg = (datetime.now(UTC), message) msg = (datetime.now(UTC), message)
# At least for the time being, until there's a better way to expose # At least for the time being, until there's a better way to expose
# logs to users, log messages on console # logs to users, log messages on console
logger.info(f"Job {job.id}: {message}") log.info(f"Job {job.id}: {message}")
job.append_log(msg) job.append_log(msg)
self._backend.on_log_message_cb(job, msg) self._backend.on_log_message_cb(job, msg)

View file

@ -17,7 +17,7 @@ from llama_stack.log import get_logger
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
from .sqlstore import SqlStoreType from .sqlstore import SqlStoreType
logger = get_logger(name=__name__, category="authorized_sqlstore") log = get_logger(name=__name__, category="authorized_sqlstore")
# Hardcoded copy of the default policy that our SQL filtering implements # Hardcoded copy of the default policy that our SQL filtering implements
# WARNING: If default_policy() changes, this constant must be updated accordingly # WARNING: If default_policy() changes, this constant must be updated accordingly
@ -81,7 +81,7 @@ class AuthorizedSqlStore:
actual_default = default_policy() actual_default = default_policy()
if SQL_OPTIMIZED_POLICY != actual_default: if SQL_OPTIMIZED_POLICY != actual_default:
logger.warning( log.warning(
f"SQL_OPTIMIZED_POLICY does not match default_policy(). " f"SQL_OPTIMIZED_POLICY does not match default_policy(). "
f"SQL filtering will use conservative mode. " f"SQL filtering will use conservative mode. "
f"Expected: {SQL_OPTIMIZED_POLICY}, Got: {actual_default}", f"Expected: {SQL_OPTIMIZED_POLICY}, Got: {actual_default}",

View file

@ -29,7 +29,7 @@ from llama_stack.log import get_logger
from .api import ColumnDefinition, ColumnType, SqlStore from .api import ColumnDefinition, ColumnType, SqlStore
from .sqlstore import SqlAlchemySqlStoreConfig from .sqlstore import SqlAlchemySqlStoreConfig
logger = get_logger(name=__name__, category="sqlstore") log = get_logger(name=__name__, category="sqlstore")
TYPE_MAPPING: dict[ColumnType, Any] = { TYPE_MAPPING: dict[ColumnType, Any] = {
ColumnType.INTEGER: Integer, ColumnType.INTEGER: Integer,
@ -280,5 +280,5 @@ class SqlAlchemySqlStoreImpl(SqlStore):
except Exception as e: except Exception as e:
# If any error occurs during migration, log it but don't fail # If any error occurs during migration, log it but don't fail
# The table creation will handle adding the column # The table creation will handle adding the column
logger.error(f"Error adding column {column_name} to table {table}: {e}") log.error(f"Error adding column {column_name} to table {table}: {e}")
pass pass

View file

@ -6,7 +6,7 @@
import asyncio import asyncio
import contextvars import contextvars
import logging import logging # allow-direct-logging
import queue import queue
import random import random
import threading import threading

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
import sys import sys
import time import time
import uuid import uuid
@ -19,10 +18,9 @@ from llama_stack.apis.post_training import (
LoraFinetuningConfig, LoraFinetuningConfig,
TrainingConfig, TrainingConfig,
) )
from llama_stack.log import get_logger
# Configure logging log = get_logger(name=__name__, category="core")
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", force=True)
logger = logging.getLogger(__name__)
skip_because_resource_intensive = pytest.mark.skip( skip_because_resource_intensive = pytest.mark.skip(
@ -71,14 +69,14 @@ class TestPostTraining:
) )
@pytest.mark.timeout(360) # 6 minutes timeout @pytest.mark.timeout(360) # 6 minutes timeout
def test_supervised_fine_tune(self, llama_stack_client, purpose, source): def test_supervised_fine_tune(self, llama_stack_client, purpose, source):
logger.info("Starting supervised fine-tuning test") log.info("Starting supervised fine-tuning test")
# register dataset to train # register dataset to train
dataset = llama_stack_client.datasets.register( dataset = llama_stack_client.datasets.register(
purpose=purpose, purpose=purpose,
source=source, source=source,
) )
logger.info(f"Registered dataset with ID: {dataset.identifier}") log.info(f"Registered dataset with ID: {dataset.identifier}")
algorithm_config = LoraFinetuningConfig( algorithm_config = LoraFinetuningConfig(
type="LoRA", type="LoRA",
@ -105,7 +103,7 @@ class TestPostTraining:
) )
job_uuid = f"test-job{uuid.uuid4()}" job_uuid = f"test-job{uuid.uuid4()}"
logger.info(f"Starting training job with UUID: {job_uuid}") log.info(f"Starting training job with UUID: {job_uuid}")
# train with HF trl SFTTrainer as the default # train with HF trl SFTTrainer as the default
_ = llama_stack_client.post_training.supervised_fine_tune( _ = llama_stack_client.post_training.supervised_fine_tune(
@ -121,21 +119,21 @@ class TestPostTraining:
while True: while True:
status = llama_stack_client.post_training.job.status(job_uuid=job_uuid) status = llama_stack_client.post_training.job.status(job_uuid=job_uuid)
if not status: if not status:
logger.error("Job not found") log.error("Job not found")
break break
logger.info(f"Current status: {status}") log.info(f"Current status: {status}")
assert status.status in ["scheduled", "in_progress", "completed"] assert status.status in ["scheduled", "in_progress", "completed"]
if status.status == "completed": if status.status == "completed":
break break
logger.info("Waiting for job to complete...") log.info("Waiting for job to complete...")
time.sleep(10) # Increased sleep time to reduce polling frequency time.sleep(10) # Increased sleep time to reduce polling frequency
artifacts = llama_stack_client.post_training.job.artifacts(job_uuid=job_uuid) artifacts = llama_stack_client.post_training.job.artifacts(job_uuid=job_uuid)
logger.info(f"Job artifacts: {artifacts}") log.info(f"Job artifacts: {artifacts}")
logger.info(f"Registered dataset with ID: {dataset.identifier}") log.info(f"Registered dataset with ID: {dataset.identifier}")
# TODO: Fix these tests to properly represent the Jobs API in training # TODO: Fix these tests to properly represent the Jobs API in training
# #
@ -181,17 +179,21 @@ class TestPostTraining:
) )
@pytest.mark.timeout(360) @pytest.mark.timeout(360)
def test_preference_optimize(self, llama_stack_client, purpose, source): def test_preference_optimize(self, llama_stack_client, purpose, source):
logger.info("Starting DPO preference optimization test") log.info("Starting DPO preference optimization test")
# register preference dataset to train # register preference dataset to train
dataset = llama_stack_client.datasets.register( dataset = llama_stack_client.datasets.register(
purpose=purpose, purpose=purpose,
source=source, source=source,
) )
logger.info(f"Registered preference dataset with ID: {dataset.identifier}") log.info(f"Registered preference dataset with ID: {dataset.identifier}")
# DPO algorithm configuration # DPO algorithm configuration
algorithm_config = DPOAlignmentConfig( algorithm_config = DPOAlignmentConfig(
reward_scale=1.0,
reward_clip=10.0,
epsilon=1e-8,
gamma=0.99,
beta=0.1, beta=0.1,
loss_type=DPOLossType.sigmoid, # Default loss type loss_type=DPOLossType.sigmoid, # Default loss type
) )
@ -211,7 +213,7 @@ class TestPostTraining:
) )
job_uuid = f"test-dpo-job-{uuid.uuid4()}" job_uuid = f"test-dpo-job-{uuid.uuid4()}"
logger.info(f"Starting DPO training job with UUID: {job_uuid}") log.info(f"Starting DPO training job with UUID: {job_uuid}")
# train with HuggingFace DPO implementation # train with HuggingFace DPO implementation
_ = llama_stack_client.post_training.preference_optimize( _ = llama_stack_client.post_training.preference_optimize(
@ -226,15 +228,15 @@ class TestPostTraining:
while True: while True:
status = llama_stack_client.post_training.job.status(job_uuid=job_uuid) status = llama_stack_client.post_training.job.status(job_uuid=job_uuid)
if not status: if not status:
logger.error("DPO job not found") log.error("DPO job not found")
break break
logger.info(f"Current DPO status: {status}") log.info(f"Current DPO status: {status}")
if status.status == "completed": if status.status == "completed":
break break
logger.info("Waiting for DPO job to complete...") log.info("Waiting for DPO job to complete...")
time.sleep(10) # Increased sleep time to reduce polling frequency time.sleep(10) # Increased sleep time to reduce polling frequency
artifacts = llama_stack_client.post_training.job.artifacts(job_uuid=job_uuid) artifacts = llama_stack_client.post_training.job.artifacts(job_uuid=job_uuid)
logger.info(f"DPO job artifacts: {artifacts}") log.info(f"DPO job artifacts: {artifacts}")

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
import time import time
from io import BytesIO from io import BytesIO
@ -13,8 +12,10 @@ from llama_stack_client import BadRequestError, LlamaStackClient
from openai import BadRequestError as OpenAIBadRequestError from openai import BadRequestError as OpenAIBadRequestError
from llama_stack.apis.vector_io import Chunk from llama_stack.apis.vector_io import Chunk
from llama_stack.core.library_client import LlamaStackAsLibraryClient
from llama_stack.log import get_logger
logger = logging.getLogger(__name__) log = get_logger(name=__name__, category="vector-io")
def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models): def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
@ -99,7 +100,7 @@ def compat_client_with_empty_stores(compat_client):
compat_client.vector_stores.delete(vector_store_id=store.id) compat_client.vector_stores.delete(vector_store_id=store.id)
except Exception: except Exception:
# If the API is not available or fails, just continue # If the API is not available or fails, just continue
logger.warning("Failed to clear vector stores") log.warning("Failed to clear vector stores")
pass pass
def clear_files(): def clear_files():
@ -109,7 +110,7 @@ def compat_client_with_empty_stores(compat_client):
compat_client.files.delete(file_id=file.id) compat_client.files.delete(file_id=file.id)
except Exception: except Exception:
# If the API is not available or fails, just continue # If the API is not available or fails, just continue
logger.warning("Failed to clear files") log.warning("Failed to clear files")
pass pass
clear_vector_stores() clear_vector_stores()

View file

@ -6,7 +6,7 @@
import asyncio import asyncio
import json import json
import logging import logging # allow-direct-logging
import threading import threading
import time import time
from http.server import BaseHTTPRequestHandler, HTTPServer from http.server import BaseHTTPRequestHandler, HTTPServer