feat: add auto-generated CI documentation pre-commit hook (#2890)

Our CI is entirely undocumented, this commit adds a README.md file with
a table of the current CI and what is does

---------

Signed-off-by: Nathan Weinberg <nweinber@redhat.com>
This commit is contained in:
Nathan Weinberg 2025-07-25 11:57:01 -04:00 committed by Mustafa Elbehery
parent 7f834339ba
commit b381ed6d64
93 changed files with 495 additions and 477 deletions

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import importlib.resources
import logging
import sys
from pydantic import BaseModel
@ -17,10 +16,9 @@ from llama_stack.core.external import load_external_apis
from llama_stack.core.utils.exec import run_command
from llama_stack.core.utils.image_types import LlamaStackImageType
from llama_stack.distributions.template import DistributionTemplate
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__)
# These are the dependencies needed by the distribution server.
# `llama-stack` is automatically installed by the installation script.
SERVER_DEPENDENCIES = [
@ -33,6 +31,8 @@ SERVER_DEPENDENCIES = [
"opentelemetry-exporter-otlp-proto-http",
]
log = get_logger(name=__name__, category="core")
class ApiInput(BaseModel):
api: Api

View file

@ -3,7 +3,6 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import textwrap
from typing import Any
@ -21,9 +20,10 @@ from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.core.utils.prompt_for_config import prompt_for_config
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, ProviderSpec
logger = logging.getLogger(__name__)
log = get_logger(name=__name__, category="core")
def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provider) -> Provider:
@ -49,7 +49,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
is_nux = len(config.providers) == 0
if is_nux:
logger.info(
log.info(
textwrap.dedent(
"""
Llama Stack is composed of several APIs working together. For each API served by the Stack,
@ -75,12 +75,12 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
existing_providers = config.providers.get(api_str, [])
if existing_providers:
logger.info(f"Re-configuring existing providers for API `{api_str}`...")
log.info(f"Re-configuring existing providers for API `{api_str}`...")
updated_providers = []
for p in existing_providers:
logger.info(f"> Configuring provider `({p.provider_type})`")
log.info(f"> Configuring provider `({p.provider_type})`")
updated_providers.append(configure_single_provider(provider_registry[api], p))
logger.info("")
log.info("")
else:
# we are newly configuring this API
plist = build_spec.providers.get(api_str, [])
@ -89,17 +89,17 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
if not plist:
raise ValueError(f"No provider configured for API {api_str}?")
logger.info(f"Configuring API `{api_str}`...")
log.info(f"Configuring API `{api_str}`...")
updated_providers = []
for i, provider in enumerate(plist):
if i >= 1:
others = ", ".join(p.provider_type for p in plist[i:])
logger.info(
log.info(
f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n"
)
break
logger.info(f"> Configuring provider `({provider.provider_type})`")
log.info(f"> Configuring provider `({provider.provider_type})`")
pid = provider.provider_type.split("::")[-1]
updated_providers.append(
configure_single_provider(
@ -111,7 +111,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
),
)
)
logger.info("")
log.info("")
config.providers[api_str] = updated_providers
@ -169,7 +169,7 @@ def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfi
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
if "routing_table" in config_dict:
logger.info("Upgrading config...")
log.info("Upgrading config...")
config_dict = upgrade_from_routing_table(config_dict)
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION

View file

@ -23,7 +23,7 @@ from llama_stack.providers.datatypes import (
remote_provider_spec,
)
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
def stack_apis() -> list[Api]:
@ -141,18 +141,18 @@ def get_provider_registry(config=None) -> dict[Api, dict[str, ProviderSpec]]:
registry: dict[Api, dict[str, ProviderSpec]] = {}
for api in providable_apis():
name = api.name.lower()
logger.debug(f"Importing module {name}")
log.debug(f"Importing module {name}")
try:
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
registry[api] = {a.provider_type: a for a in module.available_providers()}
except ImportError as e:
logger.warning(f"Failed to import module {name}: {e}")
log.warning(f"Failed to import module {name}: {e}")
# Refresh providable APIs with external APIs if any
external_apis = load_external_apis(config)
for api, api_spec in external_apis.items():
name = api_spec.name.lower()
logger.info(f"Importing external API {name} module {api_spec.module}")
log.info(f"Importing external API {name} module {api_spec.module}")
try:
module = importlib.import_module(api_spec.module)
registry[api] = {a.provider_type: a for a in module.available_providers()}
@ -161,7 +161,7 @@ def get_provider_registry(config=None) -> dict[Api, dict[str, ProviderSpec]]:
# This assume that the in-tree provider(s) are not available for this API which means
# that users will need to use external providers for this API.
registry[api] = {}
logger.error(
log.error(
f"Failed to import external API {name}: {e}. Could not populate the in-tree provider(s) registry for {api.name}. \n"
"Install the API package to load any in-tree providers for this API."
)
@ -183,13 +183,13 @@ def get_provider_registry(config=None) -> dict[Api, dict[str, ProviderSpec]]:
def get_external_providers_from_dir(
registry: dict[Api, dict[str, ProviderSpec]], config
) -> dict[Api, dict[str, ProviderSpec]]:
logger.warning(
log.warning(
"Specifying external providers via `external_providers_dir` is being deprecated. Please specify `module:` in the provider instead."
)
external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir))
if not os.path.exists(external_providers_dir):
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
logger.info(f"Loading external providers from {external_providers_dir}")
log.info(f"Loading external providers from {external_providers_dir}")
for api in providable_apis():
api_name = api.name.lower()
@ -198,13 +198,13 @@ def get_external_providers_from_dir(
for provider_type in ["remote", "inline"]:
api_dir = os.path.join(external_providers_dir, provider_type, api_name)
if not os.path.exists(api_dir):
logger.debug(f"No {provider_type} provider directory found for {api_name}")
log.debug(f"No {provider_type} provider directory found for {api_name}")
continue
# Look for provider spec files in the API directory
for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")):
provider_name = os.path.splitext(os.path.basename(spec_path))[0]
logger.info(f"Loading {provider_type} provider spec from {spec_path}")
log.info(f"Loading {provider_type} provider spec from {spec_path}")
try:
with open(spec_path) as f:
@ -217,16 +217,16 @@ def get_external_providers_from_dir(
spec = _load_inline_provider_spec(spec_data, api, provider_name)
provider_type_key = f"inline::{provider_name}"
logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
log.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
if provider_type_key in registry[api]:
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
log.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
registry[api][provider_type_key] = spec
logger.info(f"Successfully loaded external provider {provider_type_key}")
log.info(f"Successfully loaded external provider {provider_type_key}")
except yaml.YAMLError as yaml_err:
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
log.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
raise yaml_err
except Exception as e:
logger.error(f"Failed to load provider spec from {spec_path}: {e}")
log.error(f"Failed to load provider spec from {spec_path}: {e}")
raise e
return registry
@ -241,7 +241,7 @@ def get_external_providers_from_module(
else:
provider_list = config.providers.items()
if provider_list is None:
logger.warning("Could not get list of providers from config")
log.warning("Could not get list of providers from config")
return registry
for provider_api, providers in provider_list:
for provider in providers:
@ -272,6 +272,6 @@ def get_external_providers_from_module(
"get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available"
) from exc
except Exception as e:
logger.error(f"Failed to load provider spec from module {provider.module}: {e}")
log.error(f"Failed to load provider spec from module {provider.module}: {e}")
raise e
return registry

View file

@ -11,7 +11,7 @@ from llama_stack.apis.datatypes import Api, ExternalApiSpec
from llama_stack.core.datatypes import BuildConfig, StackRunConfig
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
def load_external_apis(config: StackRunConfig | BuildConfig | None) -> dict[Api, ExternalApiSpec]:
@ -28,10 +28,10 @@ def load_external_apis(config: StackRunConfig | BuildConfig | None) -> dict[Api,
external_apis_dir = config.external_apis_dir.expanduser().resolve()
if not external_apis_dir.is_dir():
logger.error(f"External APIs directory is not a directory: {external_apis_dir}")
log.error(f"External APIs directory is not a directory: {external_apis_dir}")
return {}
logger.info(f"Loading external APIs from {external_apis_dir}")
log.info(f"Loading external APIs from {external_apis_dir}")
external_apis: dict[Api, ExternalApiSpec] = {}
# Look for YAML files in the external APIs directory
@ -42,13 +42,13 @@ def load_external_apis(config: StackRunConfig | BuildConfig | None) -> dict[Api,
spec = ExternalApiSpec(**spec_data)
api = Api.add(spec.name)
logger.info(f"Loaded external API spec for {spec.name} from {yaml_path}")
log.info(f"Loaded external API spec for {spec.name} from {yaml_path}")
external_apis[api] = spec
except yaml.YAMLError as yaml_err:
logger.error(f"Failed to parse YAML file {yaml_path}: {yaml_err}")
log.error(f"Failed to parse YAML file {yaml_path}: {yaml_err}")
raise
except Exception:
logger.exception(f"Failed to load external API spec from {yaml_path}")
log.exception(f"Failed to load external API spec from {yaml_path}")
raise
return external_apis

View file

@ -7,7 +7,6 @@
import asyncio
import inspect
import json
import logging
import os
import sys
from concurrent.futures import ThreadPoolExecutor
@ -48,6 +47,7 @@ from llama_stack.core.stack import (
from llama_stack.core.utils.config import redact_sensitive_fields
from llama_stack.core.utils.context import preserve_contexts_async_generator
from llama_stack.core.utils.exec import in_notebook
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry.tracing import (
CURRENT_TRACE_CONTEXT,
end_trace,
@ -55,7 +55,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
start_trace,
)
logger = logging.getLogger(__name__)
log = get_logger(name=__name__, category="core")
T = TypeVar("T")
@ -84,7 +84,7 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any:
try:
return [convert_to_pydantic(item_type, item) for item in value]
except Exception:
logger.error(f"Error converting list {value} into {item_type}")
log.error(f"Error converting list {value} into {item_type}")
return value
elif origin is dict:
@ -92,7 +92,7 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any:
try:
return {k: convert_to_pydantic(val_type, v) for k, v in value.items()}
except Exception:
logger.error(f"Error converting dict {value} into {val_type}")
log.error(f"Error converting dict {value} into {val_type}")
return value
try:
@ -108,7 +108,7 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any:
return convert_to_pydantic(union_type, value)
except Exception:
continue
logger.warning(
log.warning(
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
)
raise ValueError(f"Failed to convert parameter {value} into {annotation}: {e}") from e
@ -171,13 +171,15 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
def _remove_root_logger_handlers(self):
"""
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
Remove all handlers from the root log. Needed to avoid polluting the console with logs.
"""
import logging # allow-direct-logging
root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
log.info(f"Removed handler {handler.__class__.__name__} from root log")
def request(self, *args, **kwargs):
loop = self.loop

View file

@ -16,7 +16,7 @@ from llama_stack.providers.datatypes import HealthResponse, HealthStatus
from .datatypes import StackRunConfig
from .utils.config import redact_sensitive_fields
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
class ProviderImplConfig(BaseModel):
@ -38,7 +38,7 @@ class ProviderImpl(Providers):
pass
async def shutdown(self) -> None:
logger.debug("ProviderImpl.shutdown")
log.debug("ProviderImpl.shutdown")
pass
async def list_providers(self) -> ListProvidersResponse:

View file

@ -6,19 +6,19 @@
import contextvars
import json
import logging
from contextlib import AbstractContextManager
from typing import Any
from llama_stack.core.datatypes import User
from llama_stack.log import get_logger
from .utils.dynamic import instantiate_class_type
log = logging.getLogger(__name__)
# Context variable for request provider data and auth attributes
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
log = get_logger(name=__name__, category="core")
class RequestProviderDataContext(AbstractContextManager):
"""Context manager for request provider data"""

View file

@ -54,7 +54,7 @@ from llama_stack.providers.datatypes import (
VectorDBsProtocolPrivate,
)
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
class InvalidProviderError(Exception):
@ -101,7 +101,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
protocols[api] = api_class
except (ImportError, AttributeError):
logger.exception(f"Failed to load external API {api_spec.name}")
log.exception(f"Failed to load external API {api_spec.name}")
return protocols
@ -223,7 +223,7 @@ def validate_and_prepare_providers(
specs = {}
for provider in providers:
if not provider.provider_id or provider.provider_id == "__disabled__":
logger.debug(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
log.debug(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
continue
validate_provider(provider, api, provider_registry)
@ -245,10 +245,10 @@ def validate_provider(provider: Provider, api: Api, provider_registry: ProviderR
p = provider_registry[api][provider.provider_type]
if p.deprecation_error:
logger.error(p.deprecation_error)
log.error(p.deprecation_error)
raise InvalidProviderError(p.deprecation_error)
elif p.deprecation_warning:
logger.warning(
log.warning(
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
)
@ -261,9 +261,9 @@ def sort_providers_by_deps(
{k: list(v.values()) for k, v in providers_with_specs.items()}
)
logger.debug(f"Resolved {len(sorted_providers)} providers")
log.debug(f"Resolved {len(sorted_providers)} providers")
for api_str, provider in sorted_providers:
logger.debug(f" {api_str} => {provider.provider_id}")
log.debug(f" {api_str} => {provider.provider_id}")
return sorted_providers
@ -348,7 +348,7 @@ async def instantiate_provider(
if not hasattr(provider_spec, "module") or provider_spec.module is None:
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
logger.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}")
log.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}")
module = importlib.import_module(provider_spec.module)
args = []
if isinstance(provider_spec, RemoteProviderSpec):
@ -418,7 +418,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
obj_params = set(obj_sig.parameters)
obj_params.discard("self")
if not (proto_params <= obj_params):
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
log.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
missing_methods.append((name, "signature_mismatch"))
else:
# Check if the method has a concrete implementation (not just a protocol stub)

View file

@ -12,7 +12,7 @@ from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
class DatasetIORouter(DatasetIO):
@ -20,15 +20,15 @@ class DatasetIORouter(DatasetIO):
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing DatasetIORouter")
log.debug("Initializing DatasetIORouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("DatasetIORouter.initialize")
log.debug("DatasetIORouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("DatasetIORouter.shutdown")
log.debug("DatasetIORouter.shutdown")
pass
async def register_dataset(
@ -38,7 +38,7 @@ class DatasetIORouter(DatasetIO):
metadata: dict[str, Any] | None = None,
dataset_id: str | None = None,
) -> None:
logger.debug(
log.debug(
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
)
await self.routing_table.register_dataset(
@ -54,7 +54,7 @@ class DatasetIORouter(DatasetIO):
start_index: int | None = None,
limit: int | None = None,
) -> PaginatedResponse:
logger.debug(
log.debug(
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
)
provider = await self.routing_table.get_provider_impl(dataset_id)
@ -65,7 +65,7 @@ class DatasetIORouter(DatasetIO):
)
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
log.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
provider = await self.routing_table.get_provider_impl(dataset_id)
return await provider.append_rows(
dataset_id=dataset_id,

View file

@ -16,7 +16,7 @@ from llama_stack.apis.scoring import (
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
class ScoringRouter(Scoring):
@ -24,15 +24,15 @@ class ScoringRouter(Scoring):
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing ScoringRouter")
log.debug("Initializing ScoringRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("ScoringRouter.initialize")
log.debug("ScoringRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("ScoringRouter.shutdown")
log.debug("ScoringRouter.shutdown")
pass
async def score_batch(
@ -41,7 +41,7 @@ class ScoringRouter(Scoring):
scoring_functions: dict[str, ScoringFnParams | None] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
log.debug(f"ScoringRouter.score_batch: {dataset_id}")
res = {}
for fn_identifier in scoring_functions.keys():
provider = await self.routing_table.get_provider_impl(fn_identifier)
@ -63,7 +63,7 @@ class ScoringRouter(Scoring):
input_rows: list[dict[str, Any]],
scoring_functions: dict[str, ScoringFnParams | None] = None,
) -> ScoreResponse:
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
log.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
res = {}
# look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys():
@ -82,15 +82,15 @@ class EvalRouter(Eval):
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing EvalRouter")
log.debug("Initializing EvalRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("EvalRouter.initialize")
log.debug("EvalRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("EvalRouter.shutdown")
log.debug("EvalRouter.shutdown")
pass
async def run_eval(
@ -98,7 +98,7 @@ class EvalRouter(Eval):
benchmark_id: str,
benchmark_config: BenchmarkConfig,
) -> Job:
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
log.debug(f"EvalRouter.run_eval: {benchmark_id}")
provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.run_eval(
benchmark_id=benchmark_id,
@ -112,7 +112,7 @@ class EvalRouter(Eval):
scoring_functions: list[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
log.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.evaluate_rows(
benchmark_id=benchmark_id,
@ -126,7 +126,7 @@ class EvalRouter(Eval):
benchmark_id: str,
job_id: str,
) -> Job:
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
log.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.job_status(benchmark_id, job_id)
@ -135,7 +135,7 @@ class EvalRouter(Eval):
benchmark_id: str,
job_id: str,
) -> None:
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
log.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
provider = await self.routing_table.get_provider_impl(benchmark_id)
await provider.job_cancel(
benchmark_id,
@ -147,7 +147,7 @@ class EvalRouter(Eval):
benchmark_id: str,
job_id: str,
) -> EvaluateResponse:
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
log.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.job_result(
benchmark_id,

View file

@ -58,7 +58,7 @@ from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack.providers.utils.inference.stream_utils import stream_and_store_openai_completion
from llama_stack.providers.utils.telemetry.tracing import get_current_span
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
class InferenceRouter(Inference):
@ -70,7 +70,7 @@ class InferenceRouter(Inference):
telemetry: Telemetry | None = None,
store: InferenceStore | None = None,
) -> None:
logger.debug("Initializing InferenceRouter")
log.debug("Initializing InferenceRouter")
self.routing_table = routing_table
self.telemetry = telemetry
self.store = store
@ -79,10 +79,10 @@ class InferenceRouter(Inference):
self.formatter = ChatFormat(self.tokenizer)
async def initialize(self) -> None:
logger.debug("InferenceRouter.initialize")
log.debug("InferenceRouter.initialize")
async def shutdown(self) -> None:
logger.debug("InferenceRouter.shutdown")
log.debug("InferenceRouter.shutdown")
async def register_model(
self,
@ -92,7 +92,7 @@ class InferenceRouter(Inference):
metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None,
) -> None:
logger.debug(
log.debug(
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
)
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
@ -117,7 +117,7 @@ class InferenceRouter(Inference):
"""
span = get_current_span()
if span is None:
logger.warning("No span found for token usage metrics")
log.warning("No span found for token usage metrics")
return []
metrics = [
("prompt_tokens", prompt_tokens),
@ -182,7 +182,7 @@ class InferenceRouter(Inference):
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
logger.debug(
log.debug(
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
)
if sampling_params is None:
@ -288,7 +288,7 @@ class InferenceRouter(Inference):
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> BatchChatCompletionResponse:
logger.debug(
log.debug(
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
)
provider = await self.routing_table.get_provider_impl(model_id)
@ -313,7 +313,7 @@ class InferenceRouter(Inference):
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
logger.debug(
log.debug(
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
)
model = await self.routing_table.get_model(model_id)
@ -374,7 +374,7 @@ class InferenceRouter(Inference):
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> BatchCompletionResponse:
logger.debug(
log.debug(
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
)
provider = await self.routing_table.get_provider_impl(model_id)
@ -388,7 +388,7 @@ class InferenceRouter(Inference):
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
logger.debug(f"InferenceRouter.embeddings: {model_id}")
log.debug(f"InferenceRouter.embeddings: {model_id}")
model = await self.routing_table.get_model(model_id)
if model is None:
raise ModelNotFoundError(model_id)
@ -426,7 +426,7 @@ class InferenceRouter(Inference):
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
logger.debug(
log.debug(
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
)
model_obj = await self.routing_table.get_model(model)
@ -487,7 +487,7 @@ class InferenceRouter(Inference):
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
logger.debug(
log.debug(
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
)
model_obj = await self.routing_table.get_model(model)
@ -558,7 +558,7 @@ class InferenceRouter(Inference):
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
logger.debug(
log.debug(
f"InferenceRouter.openai_embeddings: {model=}, input_type={type(input)}, {encoding_format=}, {dimensions=}",
)
model_obj = await self.routing_table.get_model(model)

View file

@ -14,7 +14,7 @@ from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
class SafetyRouter(Safety):
@ -22,15 +22,15 @@ class SafetyRouter(Safety):
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing SafetyRouter")
log.debug("Initializing SafetyRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("SafetyRouter.initialize")
log.debug("SafetyRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("SafetyRouter.shutdown")
log.debug("SafetyRouter.shutdown")
pass
async def register_shield(
@ -40,7 +40,7 @@ class SafetyRouter(Safety):
provider_id: str | None = None,
params: dict[str, Any] | None = None,
) -> Shield:
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
log.debug(f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
async def unregister_shield(self, identifier: str) -> None:
@ -53,7 +53,7 @@ class SafetyRouter(Safety):
messages: list[Message],
params: dict[str, Any] = None,
) -> RunShieldResponse:
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
log.debug(f"SafetyRouter.run_shield: {shield_id}")
provider = await self.routing_table.get_provider_impl(shield_id)
return await provider.run_shield(
shield_id=shield_id,

View file

@ -22,7 +22,7 @@ from llama_stack.log import get_logger
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
class ToolRuntimeRouter(ToolRuntime):
@ -31,7 +31,7 @@ class ToolRuntimeRouter(ToolRuntime):
self,
routing_table: ToolGroupsRoutingTable,
) -> None:
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
log.debug("Initializing ToolRuntimeRouter.RagToolImpl")
self.routing_table = routing_table
async def query(
@ -40,7 +40,7 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_ids: list[str],
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
log.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
provider = await self.routing_table.get_provider_impl("knowledge_search")
return await provider.query(content, vector_db_ids, query_config)
@ -50,7 +50,7 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
logger.debug(
log.debug(
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
)
provider = await self.routing_table.get_provider_impl("insert_into_memory")
@ -60,7 +60,7 @@ class ToolRuntimeRouter(ToolRuntime):
self,
routing_table: ToolGroupsRoutingTable,
) -> None:
logger.debug("Initializing ToolRuntimeRouter")
log.debug("Initializing ToolRuntimeRouter")
self.routing_table = routing_table
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
@ -69,15 +69,15 @@ class ToolRuntimeRouter(ToolRuntime):
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
async def initialize(self) -> None:
logger.debug("ToolRuntimeRouter.initialize")
log.debug("ToolRuntimeRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("ToolRuntimeRouter.shutdown")
log.debug("ToolRuntimeRouter.shutdown")
pass
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
log.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
provider = await self.routing_table.get_provider_impl(tool_name)
return await provider.invoke_tool(
tool_name=tool_name,
@ -87,5 +87,5 @@ class ToolRuntimeRouter(ToolRuntime):
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolsResponse:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
log.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.list_tools(tool_group_id)

View file

@ -30,7 +30,7 @@ from llama_stack.apis.vector_io import (
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
class VectorIORouter(VectorIO):
@ -40,15 +40,15 @@ class VectorIORouter(VectorIO):
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing VectorIORouter")
log.debug("Initializing VectorIORouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("VectorIORouter.initialize")
log.debug("VectorIORouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("VectorIORouter.shutdown")
log.debug("VectorIORouter.shutdown")
pass
async def _get_first_embedding_model(self) -> tuple[str, int] | None:
@ -70,10 +70,10 @@ class VectorIORouter(VectorIO):
raise ValueError(f"Embedding model {embedding_models[0].identifier} has no embedding dimension")
return embedding_models[0].identifier, dimension
else:
logger.warning("No embedding models found in the routing table")
log.warning("No embedding models found in the routing table")
return None
except Exception as e:
logger.error(f"Error getting embedding models: {e}")
log.error(f"Error getting embedding models: {e}")
return None
async def register_vector_db(
@ -85,7 +85,7 @@ class VectorIORouter(VectorIO):
vector_db_name: str | None = None,
provider_vector_db_id: str | None = None,
) -> None:
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
log.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
await self.routing_table.register_vector_db(
vector_db_id,
embedding_model,
@ -101,7 +101,7 @@ class VectorIORouter(VectorIO):
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
logger.debug(
log.debug(
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
)
provider = await self.routing_table.get_provider_impl(vector_db_id)
@ -113,7 +113,7 @@ class VectorIORouter(VectorIO):
query: InterleavedContent,
params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
log.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
provider = await self.routing_table.get_provider_impl(vector_db_id)
return await provider.query_chunks(vector_db_id, query, params)
@ -129,7 +129,7 @@ class VectorIORouter(VectorIO):
embedding_dimension: int | None = None,
provider_id: str | None = None,
) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}")
log.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}")
# If no embedding model is provided, use the first available one
if embedding_model is None:
@ -137,7 +137,7 @@ class VectorIORouter(VectorIO):
if embedding_model_info is None:
raise ValueError("No embedding model provided and no embedding models available in the system")
embedding_model, embedding_dimension = embedding_model_info
logger.info(f"No embedding model specified, using first available: {embedding_model}")
log.info(f"No embedding model specified, using first available: {embedding_model}")
vector_db_id = f"vs_{uuid.uuid4()}"
registered_vector_db = await self.routing_table.register_vector_db(
@ -168,7 +168,7 @@ class VectorIORouter(VectorIO):
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
logger.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}")
log.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}")
# Route to default provider for now - could aggregate from all providers in the future
# call retrieve on each vector dbs to get list of vector stores
vector_dbs = await self.routing_table.get_all_with_type("vector_db")
@ -179,7 +179,7 @@ class VectorIORouter(VectorIO):
vector_store = await provider.openai_retrieve_vector_store(vector_db.identifier)
all_stores.append(vector_store)
except Exception as e:
logger.error(f"Error retrieving vector store {vector_db.identifier}: {e}")
log.error(f"Error retrieving vector store {vector_db.identifier}: {e}")
continue
# Sort by created_at
@ -215,7 +215,7 @@ class VectorIORouter(VectorIO):
self,
vector_store_id: str,
) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}")
log.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}")
return await self.routing_table.openai_retrieve_vector_store(vector_store_id)
async def openai_update_vector_store(
@ -225,7 +225,7 @@ class VectorIORouter(VectorIO):
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}")
log.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}")
return await self.routing_table.openai_update_vector_store(
vector_store_id=vector_store_id,
name=name,
@ -237,7 +237,7 @@ class VectorIORouter(VectorIO):
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
log.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
return await self.routing_table.openai_delete_vector_store(vector_store_id)
async def openai_search_vector_store(
@ -250,7 +250,7 @@ class VectorIORouter(VectorIO):
rewrite_query: bool | None = False,
search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage:
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
log.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
return await self.routing_table.openai_search_vector_store(
vector_store_id=vector_store_id,
query=query,
@ -268,7 +268,7 @@ class VectorIORouter(VectorIO):
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}")
log.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}")
return await self.routing_table.openai_attach_file_to_vector_store(
vector_store_id=vector_store_id,
file_id=file_id,
@ -285,7 +285,7 @@ class VectorIORouter(VectorIO):
before: str | None = None,
filter: VectorStoreFileStatus | None = None,
) -> list[VectorStoreFileObject]:
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}")
log.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}")
return await self.routing_table.openai_list_files_in_vector_store(
vector_store_id=vector_store_id,
limit=limit,
@ -300,7 +300,7 @@ class VectorIORouter(VectorIO):
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}")
log.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}")
return await self.routing_table.openai_retrieve_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
@ -311,7 +311,7 @@ class VectorIORouter(VectorIO):
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}")
log.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}")
return await self.routing_table.openai_retrieve_vector_store_file_contents(
vector_store_id=vector_store_id,
file_id=file_id,
@ -323,7 +323,7 @@ class VectorIORouter(VectorIO):
file_id: str,
attributes: dict[str, Any],
) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}")
log.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}")
return await self.routing_table.openai_update_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,
@ -335,7 +335,7 @@ class VectorIORouter(VectorIO):
vector_store_id: str,
file_id: str,
) -> VectorStoreFileDeleteResponse:
logger.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}")
log.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}")
return await self.routing_table.openai_delete_vector_store_file(
vector_store_id=vector_store_id,
file_id=file_id,

View file

@ -14,7 +14,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):

View file

@ -23,7 +23,7 @@ from llama_stack.core.store import DistributionRegistry
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, RoutingTable
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
def get_impl_api(p: Any) -> Api:
@ -177,7 +177,7 @@ class CommonRoutingTableImpl(RoutingTable):
# Check if user has permission to access this object
if not is_action_allowed(self.policy, "read", obj, get_authenticated_user()):
logger.debug(f"Access denied to {type} '{identifier}'")
log.debug(f"Access denied to {type} '{identifier}'")
return None
return obj
@ -205,7 +205,7 @@ class CommonRoutingTableImpl(RoutingTable):
raise AccessDeniedError("create", obj, creator)
if creator:
obj.owner = creator
logger.info(f"Setting owner for {obj.type} '{obj.identifier}' to {obj.owner.principal}")
log.info(f"Setting owner for {obj.type} '{obj.identifier}' to {obj.owner.principal}")
registered_obj = await register_object_with_provider(obj, p)
# TODO: This needs to be fixed for all APIs once they return the registered object
@ -250,7 +250,7 @@ async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) ->
if model is not None:
return model
logger.warning(
log.warning(
f"WARNING: model identifier '{model_id}' not found in routing table. Falling back to "
"searching in all providers. This is only for backwards compatibility and will stop working "
"soon. Migrate your calls to use fully scoped `provider_id/model_id` names."

View file

@ -26,7 +26,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):

View file

@ -17,7 +17,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
@ -33,7 +33,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
try:
models = await provider.list_models()
except Exception as e:
logger.exception(f"Model refresh failed for provider {provider_id}: {e}")
log.exception(f"Model refresh failed for provider {provider_id}: {e}")
continue
self.listed_providers.add(provider_id)
@ -132,7 +132,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
model_ids[model.provider_resource_id] = model.identifier
continue
logger.debug(f"unregistering model {model.identifier}")
log.debug(f"unregistering model {model.identifier}")
await self.unregister_object(model)
for model in models:
@ -143,7 +143,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
if model.identifier == model.provider_resource_id:
model.identifier = f"{provider_id}/{model.provider_resource_id}"
logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
log.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
await self.register_object(
ModelWithOwner(
identifier=model.identifier,

View file

@ -19,7 +19,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):

View file

@ -15,7 +15,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):

View file

@ -14,7 +14,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:

View file

@ -30,7 +30,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
@ -57,7 +57,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
if len(self.impls_by_provider_id) > 0:
provider_id = list(self.impls_by_provider_id.keys())[0]
if len(self.impls_by_provider_id) > 1:
logger.warning(
log.warning(
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
)
else:

View file

@ -15,7 +15,7 @@ from llama_stack.core.server.auth_providers import create_auth_provider
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
log = get_logger(name=__name__, category="auth")
class AuthenticationMiddleware:
@ -105,13 +105,13 @@ class AuthenticationMiddleware:
try:
validation_result = await self.auth_provider.validate_token(token, scope)
except httpx.TimeoutException:
logger.exception("Authentication request timed out")
log.exception("Authentication request timed out")
return await self._send_auth_error(send, "Authentication service timeout")
except ValueError as e:
logger.exception("Error during authentication")
log.exception("Error during authentication")
return await self._send_auth_error(send, str(e))
except Exception:
logger.exception("Error during authentication")
log.exception("Error during authentication")
return await self._send_auth_error(send, "Authentication service error")
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
@ -122,7 +122,7 @@ class AuthenticationMiddleware:
scope["principal"] = validation_result.principal
if validation_result.attributes:
scope["user_attributes"] = validation_result.attributes
logger.debug(
log.debug(
f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes"
)

View file

@ -23,7 +23,7 @@ from llama_stack.core.datatypes import (
)
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
log = get_logger(name=__name__, category="auth")
class AuthResponse(BaseModel):
@ -163,7 +163,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
timeout=10.0, # Add a reasonable timeout
)
if response.status_code != 200:
logger.warning(f"Token introspection failed with status code: {response.status_code}")
log.warning(f"Token introspection failed with status code: {response.status_code}")
raise ValueError(f"Token introspection failed: {response.status_code}")
fields = response.json()
@ -176,13 +176,13 @@ class OAuth2TokenAuthProvider(AuthProvider):
attributes=access_attributes,
)
except httpx.TimeoutException:
logger.exception("Token introspection request timed out")
log.exception("Token introspection request timed out")
raise
except ValueError:
# Re-raise ValueError exceptions to preserve their message
raise
except Exception as e:
logger.exception("Error during token introspection")
log.exception("Error during token introspection")
raise ValueError("Token introspection error") from e
async def close(self):
@ -273,7 +273,7 @@ class CustomAuthProvider(AuthProvider):
timeout=10.0, # Add a reasonable timeout
)
if response.status_code != 200:
logger.warning(f"Authentication failed with status code: {response.status_code}")
log.warning(f"Authentication failed with status code: {response.status_code}")
raise ValueError(f"Authentication failed: {response.status_code}")
# Parse and validate the auth response
@ -282,17 +282,17 @@ class CustomAuthProvider(AuthProvider):
auth_response = AuthResponse(**response_data)
return User(principal=auth_response.principal, attributes=auth_response.attributes)
except Exception as e:
logger.exception("Error parsing authentication response")
log.exception("Error parsing authentication response")
raise ValueError("Invalid authentication response format") from e
except httpx.TimeoutException:
logger.exception("Authentication request timed out")
log.exception("Authentication request timed out")
raise
except ValueError:
# Re-raise ValueError exceptions to preserve their message
raise
except Exception as e:
logger.exception("Error during authentication")
log.exception("Error during authentication")
raise ValueError("Authentication service error") from e
async def close(self):
@ -329,7 +329,7 @@ class GitHubTokenAuthProvider(AuthProvider):
try:
user_info = await _get_github_user_info(token, self.config.github_api_base_url)
except httpx.HTTPStatusError as e:
logger.warning(f"GitHub token validation failed: {e}")
log.warning(f"GitHub token validation failed: {e}")
raise ValueError("GitHub token validation failed. Please check your token and try again.") from e
principal = user_info["user"]["login"]

View file

@ -15,7 +15,7 @@ from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
logger = get_logger(name=__name__, category="quota")
log = get_logger(name=__name__, category="quota")
class QuotaMiddleware:
@ -46,7 +46,7 @@ class QuotaMiddleware:
self.window_seconds = window_seconds
if isinstance(self.kv_config, SqliteKVStoreConfig):
logger.warning(
log.warning(
"QuotaMiddleware: Using SQLite backend. Expiry/TTL is not enforced; cleanup is manual. "
f"window_seconds={self.window_seconds}"
)
@ -84,11 +84,11 @@ class QuotaMiddleware:
else:
await kv.set(key, str(count))
except Exception:
logger.exception("Failed to access KV store for quota")
log.exception("Failed to access KV store for quota")
return await self._send_error(send, 500, "Quota service error")
if count > limit:
logger.warning(
log.warning(
"Quota exceeded for client %s: %d/%d",
key_id,
count,

View file

@ -9,7 +9,7 @@ import asyncio
import functools
import inspect
import json
import logging
import logging # allow-direct-logging
import os
import ssl
import sys
@ -80,7 +80,7 @@ from .quota import QuotaMiddleware
REPO_ROOT = Path(__file__).parent.parent.parent.parent
logger = get_logger(name=__name__, category="server")
log = get_logger(name=__name__, category="server")
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
@ -157,9 +157,9 @@ async def shutdown(app):
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Starting up")
log.info("Starting up")
yield
logger.info("Shutting down")
log.info("Shutting down")
await shutdown(app)
@ -182,11 +182,11 @@ async def sse_generator(event_gen_coroutine):
yield create_sse_event(item)
await asyncio.sleep(0.01)
except asyncio.CancelledError:
logger.info("Generator cancelled")
log.info("Generator cancelled")
if event_gen:
await event_gen.aclose()
except Exception as e:
logger.exception("Error in sse_generator")
log.exception("Error in sse_generator")
yield create_sse_event(
{
"error": {
@ -206,11 +206,11 @@ async def log_request_pre_validation(request: Request):
log_output = rich.pretty.pretty_repr(parsed_body)
except (json.JSONDecodeError, UnicodeDecodeError):
log_output = repr(body_bytes)
logger.debug(f"Incoming raw request body for {request.method} {request.url.path}:\n{log_output}")
log.debug(f"Incoming raw request body for {request.method} {request.url.path}:\n{log_output}")
else:
logger.debug(f"Incoming {request.method} {request.url.path} request with empty body.")
log.debug(f"Incoming {request.method} {request.url.path} request with empty body.")
except Exception as e:
logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
log.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
@ -238,10 +238,10 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
result.url = route
return result
except Exception as e:
if logger.isEnabledFor(logging.DEBUG):
logger.exception(f"Error executing endpoint {route=} {method=}")
if log.isEnabledFor(logging.DEBUG):
log.exception(f"Error executing endpoint {route=} {method=}")
else:
logger.error(f"Error executing endpoint {route=} {method=}: {str(e)}")
log.error(f"Error executing endpoint {route=} {method=}: {str(e)}")
raise translate_exception(e) from e
sig = inspect.signature(func)
@ -291,7 +291,7 @@ class TracingMiddleware:
# Check if the path is a FastAPI built-in path
if path.startswith(self.fastapi_paths):
# Pass through to FastAPI's built-in handlers
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
log.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
return await self.app(scope, receive, send)
if not hasattr(self, "route_impls"):
@ -303,7 +303,7 @@ class TracingMiddleware:
)
except ValueError:
# If no matching endpoint is found, pass through to FastAPI
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
log.debug(f"No matching route found for path: {path}, falling back to FastAPI")
return await self.app(scope, receive, send)
trace_attributes = {"__location__": "server", "raw_path": path}
@ -404,15 +404,15 @@ def main(args: argparse.Namespace | None = None):
config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg)
logger = get_logger(name=__name__, category="server", config=logger_config)
log = get_logger(name=__name__, category="server", config=logger_config)
if args.env:
for env_pair in args.env:
try:
key, value = validate_env_pair(env_pair)
logger.info(f"Setting CLI environment variable {key} => {value}")
log.info(f"Setting CLI environment variable {key} => {value}")
os.environ[key] = value
except ValueError as e:
logger.error(f"Error: {str(e)}")
log.error(f"Error: {str(e)}")
sys.exit(1)
config = replace_env_vars(config_contents)
config = StackRunConfig(**cast_image_name_to_string(config))
@ -438,16 +438,16 @@ def main(args: argparse.Namespace | None = None):
impls = loop.run_until_complete(construct_stack(config))
except InvalidProviderError as e:
logger.error(f"Error: {str(e)}")
log.error(f"Error: {str(e)}")
sys.exit(1)
if config.server.auth:
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
log.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth, impls=impls)
else:
if config.server.quota:
quota = config.server.quota
logger.warning(
log.warning(
"Configured authenticated_max_requests (%d) but no auth is enabled; "
"falling back to anonymous_max_requests (%d) for all the requests",
quota.authenticated_max_requests,
@ -455,7 +455,7 @@ def main(args: argparse.Namespace | None = None):
)
if config.server.quota:
logger.info("Enabling quota middleware for authenticated and anonymous clients")
log.info("Enabling quota middleware for authenticated and anonymous clients")
quota = config.server.quota
anonymous_max_requests = quota.anonymous_max_requests
@ -516,7 +516,7 @@ def main(args: argparse.Namespace | None = None):
if not available_methods:
raise ValueError(f"No methods found for {route.name} on {impl}")
method = available_methods[0]
logger.debug(f"{method} {route.path}")
log.debug(f"{method} {route.path}")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
@ -528,7 +528,7 @@ def main(args: argparse.Namespace | None = None):
)
)
logger.debug(f"serving APIs: {apis_to_serve}")
log.debug(f"serving APIs: {apis_to_serve}")
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler)
@ -553,21 +553,21 @@ def main(args: argparse.Namespace | None = None):
if config.server.tls_cafile:
ssl_config["ssl_ca_certs"] = config.server.tls_cafile
ssl_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED
logger.info(
log.info(
f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}\n CA: {config.server.tls_cafile}"
)
else:
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
log.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
listen_host = config.server.host or ["::", "0.0.0.0"]
logger.info(f"Listening on {listen_host}:{port}")
log.info(f"Listening on {listen_host}:{port}")
uvicorn_config = {
"app": app,
"host": listen_host,
"port": port,
"lifespan": "on",
"log_level": logger.getEffectiveLevel(),
"log_level": log.getEffectiveLevel(),
"log_config": logger_config,
}
if ssl_config:
@ -586,19 +586,19 @@ def main(args: argparse.Namespace | None = None):
try:
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
except (KeyboardInterrupt, SystemExit):
logger.info("Received interrupt signal, shutting down gracefully...")
log.info("Received interrupt signal, shutting down gracefully...")
finally:
if not loop.is_closed():
logger.debug("Closing event loop")
log.debug("Closing event loop")
loop.close()
def _log_run_config(run_config: StackRunConfig):
"""Logs the run config with redacted fields and disabled providers removed."""
logger.info("Run configuration:")
log.info("Run configuration:")
safe_config = redact_sensitive_fields(run_config.model_dump(mode="json"))
clean_config = remove_disabled_providers(safe_config)
logger.info(yaml.dump(clean_config, indent=2))
log.info(yaml.dump(clean_config, indent=2))
def extract_path_params(route: str) -> list[str]:

View file

@ -45,7 +45,7 @@ from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
logger = get_logger(name=__name__, category="core")
log = get_logger(name=__name__, category="core")
class LlamaStack(
@ -105,11 +105,11 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
method = getattr(impls[api], register_method)
for obj in objects:
logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
log.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
# Do not register models on disabled providers
if hasattr(obj, "provider_id") and (not obj.provider_id or obj.provider_id == "__disabled__"):
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
log.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
continue
# we want to maintain the type information in arguments to method.
@ -123,7 +123,7 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
objects_to_process = response.data if hasattr(response, "data") else response
for obj in objects_to_process:
logger.debug(
log.debug(
f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}",
)
@ -160,7 +160,7 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
try:
resolved_provider_id = replace_env_vars(v["provider_id"], f"{path}[{i}].provider_id")
if resolved_provider_id == "__disabled__":
logger.debug(
log.debug(
f"Skipping config env variable expansion for disabled provider: {v.get('provider_id', '')}"
)
# Create a copy with resolved provider_id but original config
@ -315,7 +315,7 @@ async def construct_stack(
TEST_RECORDING_CONTEXT = setup_inference_recording()
if TEST_RECORDING_CONTEXT:
TEST_RECORDING_CONTEXT.__enter__()
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
log.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
policy = run_config.server.auth.access_policy if run_config.server.auth else []
@ -337,12 +337,12 @@ async def construct_stack(
import traceback
if task.cancelled():
logger.error("Model refresh task cancelled")
log.error("Model refresh task cancelled")
elif task.exception():
logger.error(f"Model refresh task failed: {task.exception()}")
log.error(f"Model refresh task failed: {task.exception()}")
traceback.print_exception(task.exception())
else:
logger.debug("Model refresh task completed")
log.debug("Model refresh task completed")
REGISTRY_REFRESH_TASK.add_done_callback(cb)
return impls
@ -351,23 +351,23 @@ async def construct_stack(
async def shutdown_stack(impls: dict[Api, Any]):
for impl in impls.values():
impl_name = impl.__class__.__name__
logger.info(f"Shutting down {impl_name}")
log.info(f"Shutting down {impl_name}")
try:
if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logger.warning(f"No shutdown method for {impl_name}")
log.warning(f"No shutdown method for {impl_name}")
except TimeoutError:
logger.exception(f"Shutdown timeout for {impl_name}")
log.exception(f"Shutdown timeout for {impl_name}")
except (Exception, asyncio.CancelledError) as e:
logger.exception(f"Failed to shutdown {impl_name}: {e}")
log.exception(f"Failed to shutdown {impl_name}: {e}")
global TEST_RECORDING_CONTEXT
if TEST_RECORDING_CONTEXT:
try:
TEST_RECORDING_CONTEXT.__exit__(None, None, None)
except Exception as e:
logger.error(f"Error during inference recording cleanup: {e}")
log.error(f"Error during inference recording cleanup: {e}")
global REGISTRY_REFRESH_TASK
if REGISTRY_REFRESH_TASK:
@ -375,14 +375,14 @@ async def shutdown_stack(impls: dict[Api, Any]):
async def refresh_registry_once(impls: dict[Api, Any]):
logger.debug("refreshing registry")
log.debug("refreshing registry")
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
for routing_table in routing_tables:
await routing_table.refresh()
async def refresh_registry_task(impls: dict[Api, Any]):
logger.info("starting registry refresh task")
log.info("starting registry refresh task")
while True:
await refresh_registry_once(impls)

View file

@ -10,7 +10,7 @@ from pathlib import Path
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="config_resolution")
log = get_logger(name=__name__, category="config_resolution")
DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions"
@ -42,25 +42,25 @@ def resolve_config_or_distro(
# Strategy 1: Try as file path first
config_path = Path(config_or_distro)
if config_path.exists() and config_path.is_file():
logger.info(f"Using file path: {config_path}")
log.info(f"Using file path: {config_path}")
return config_path.resolve()
# Strategy 2: Try as distribution name (if no .yaml extension)
if not config_or_distro.endswith(".yaml"):
distro_config = _get_distro_config_path(config_or_distro, mode)
if distro_config.exists():
logger.info(f"Using distribution: {distro_config}")
log.info(f"Using distribution: {distro_config}")
return distro_config
# Strategy 3: Try as built distribution name
distrib_config = DISTRIBS_BASE_DIR / f"llamastack-{config_or_distro}" / f"{config_or_distro}-{mode}.yaml"
if distrib_config.exists():
logger.info(f"Using built distribution: {distrib_config}")
log.info(f"Using built distribution: {distrib_config}")
return distrib_config
distrib_config = DISTRIBS_BASE_DIR / f"{config_or_distro}" / f"{config_or_distro}-{mode}.yaml"
if distrib_config.exists():
logger.info(f"Using built distribution: {distrib_config}")
log.info(f"Using built distribution: {distrib_config}")
return distrib_config
# Strategy 4: Failed - provide helpful error

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import importlib
import os
import signal
import subprocess
@ -12,9 +12,9 @@ import sys
from termcolor import cprint
log = logging.getLogger(__name__)
from llama_stack.log import get_logger
import importlib
log = get_logger(name=__name__, category="core")
def formulate_run_args(image_type: str, image_name: str) -> list:

View file

@ -6,7 +6,6 @@
import inspect
import json
import logging
from enum import Enum
from typing import Annotated, Any, Literal, Union, get_args, get_origin
@ -14,7 +13,9 @@ from pydantic import BaseModel
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefinedType
log = logging.getLogger(__name__)
from llama_stack.log import get_logger
log = get_logger(name=__name__, category="core")
def is_list_of_primitives(field_type):