feat(logging): implement category-based logging

This commit introduces a new logging system that allows loggers to be assigned
a category while retaining the logger name based on the file name. The log
format includes both the logger name and the category, producing output
like:

```
INFO     2025-03-03 21:44:11,323 llama_stack.distribution.stack:103 [core]: Tool_groups: builtin::websearch served by
         tavily-search
```

Key features include:

- Category-based logging: Loggers can be assigned a category (e.g.,
  "core", "server") when programming. The logger can be loaded like
  this: `logger = get_logger(name=__name__, category="server")`
- Environment variable control: Log levels can be configured per-category using the
  `LLAMA_STACK_LOGGING` environment variable. For example:
  `LLAMA_STACK_LOGGING="server=DEBUG;core=debug"` enables DEBUG level for the "server"
    and "core" categories.
- `LLAMA_STACK_LOGGING="all=debug"` sets DEBUG level globally for all categories and
    third-party libraries.

This provides fine-grained control over logging levels while maintaining a clean and
informative log format.

The formatter uses the rich library which provides nice colors better
stack traces like so:

```
ERROR    2025-03-03 21:49:37,124 asyncio:1758 [uncategorized]: unhandled exception during asyncio.run() shutdown
         task: <Task finished name='Task-16' coro=<handle_signal.<locals>.shutdown() done, defined at
         /Users/leseb/Documents/AI/llama-stack/llama_stack/distribution/server/server.py:146>
         exception=UnboundLocalError("local variable 'loop' referenced before assignment")>
         ╭────────────────────────────────────── Traceback (most recent call last) ───────────────────────────────────────╮
         │ /Users/leseb/Documents/AI/llama-stack/llama_stack/distribution/server/server.py:178 in shutdown                │
         │                                                                                                                │
         │   175 │   │   except asyncio.CancelledError:                                                                   │
         │   176 │   │   │   pass                                                                                         │
         │   177 │   │   finally:                                                                                         │
         │ ❱ 178 │   │   │   loop.stop()                                                                                  │
         │   179 │                                                                                                        │
         │   180 │   loop = asyncio.get_running_loop()                                                                    │
         │   181 │   loop.create_task(shutdown())                                                                         │
         ╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
         UnboundLocalError: local variable 'loop' referenced before assignment
```

Co-authored-by: Ashwin Bharambe <@ashwinb>
Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-03-03 13:59:48 +01:00 committed by Ashwin Bharambe
parent efe1772727
commit 11fffe7b95
13 changed files with 258 additions and 57 deletions

View file

@ -5,15 +5,15 @@
# the root directory of this source tree. # the root directory of this source tree.
import argparse import argparse
import logging
import os import os
from pathlib import Path from pathlib import Path
from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
from llama_stack.log import get_logger
REPO_ROOT = Path(__file__).parent.parent.parent.parent REPO_ROOT = Path(__file__).parent.parent.parent.parent
logger = logging.getLogger(__name__) logger = get_logger(name=__name__, category="server")
class StackRun(Subcommand): class StackRun(Subcommand):

View file

@ -6,8 +6,6 @@
import importlib import importlib
import inspect import inspect
from typing import Any, Dict, List, Set, Tuple from typing import Any, Dict, List, Set, Tuple
import logging
from typing import Any, Dict, List, Set
from llama_stack.apis.agents import Agents from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.benchmarks import Benchmarks
@ -36,6 +34,7 @@ from llama_stack.distribution.datatypes import (
from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ( from llama_stack.providers.datatypes import (
Api, Api,
BenchmarksProtocolPrivate, BenchmarksProtocolPrivate,
@ -51,7 +50,7 @@ from llama_stack.providers.datatypes import (
VectorDBsProtocolPrivate, VectorDBsProtocolPrivate,
) )
log = logging.getLogger(__name__) logger = get_logger(name=__name__, category="core")
class InvalidProviderError(Exception): class InvalidProviderError(Exception):
@ -187,7 +186,7 @@ def validate_and_prepare_providers(
specs = {} specs = {}
for provider in providers: for provider in providers:
if not provider.provider_id or provider.provider_id == "__disabled__": if not provider.provider_id or provider.provider_id == "__disabled__":
log.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled") logger.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
continue continue
validate_provider(provider, api, provider_registry) validate_provider(provider, api, provider_registry)
@ -209,10 +208,10 @@ def validate_provider(provider: Provider, api: Api, provider_registry: ProviderR
p = provider_registry[api][provider.provider_type] p = provider_registry[api][provider.provider_type]
if p.deprecation_error: if p.deprecation_error:
log.error(p.deprecation_error) logger.error(p.deprecation_error)
raise InvalidProviderError(p.deprecation_error) raise InvalidProviderError(p.deprecation_error)
elif p.deprecation_warning: elif p.deprecation_warning:
log.warning( logger.warning(
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}", f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
) )
@ -246,10 +245,11 @@ def sort_providers_by_deps(
) )
) )
log.info(f"Resolved {len(sorted_providers)} providers") logger.debug(f"Resolved {len(sorted_providers)} providers")
for api_str, provider in sorted_providers: for api_str, provider in sorted_providers:
log.debug(f" {api_str} => {provider.provider_id}") logger.debug(f" {api_str} => {provider.provider_id}")
log.debug("") logger.debug("")
return sorted_providers
async def instantiate_providers( async def instantiate_providers(
@ -389,7 +389,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
obj_params = set(obj_sig.parameters) obj_params = set(obj_sig.parameters)
obj_params.discard("self") obj_params.discard("self")
if not (proto_params <= obj_params): if not (proto_params <= obj_params):
log.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}") logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
missing_methods.append((name, "signature_mismatch")) missing_methods.append((name, "signature_mismatch"))
else: else:
# Check if the method is actually implemented in the class # Check if the method is actually implemented in the class

View file

@ -54,6 +54,8 @@ from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
class VectorIORouter(VectorIO): class VectorIORouter(VectorIO):
"""Routes to an provider based on the vector db identifier""" """Routes to an provider based on the vector db identifier"""
@ -62,12 +64,15 @@ class VectorIORouter(VectorIO):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logger.debug("Initializing VectorIORouter")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
logger.debug("VectorIORouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logger.debug("VectorIORouter.shutdown")
pass pass
async def register_vector_db( async def register_vector_db(
@ -78,9 +83,7 @@ class VectorIORouter(VectorIO):
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
provider_vector_db_id: Optional[str] = None, provider_vector_db_id: Optional[str] = None,
) -> None: ) -> None:
logger.debug( logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}",
)
await self.routing_table.register_vector_db( await self.routing_table.register_vector_db(
vector_db_id, vector_db_id,
embedding_model, embedding_model,
@ -106,6 +109,7 @@ class VectorIORouter(VectorIO):
query: InterleavedContent, query: InterleavedContent,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params) return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
@ -116,12 +120,15 @@ class InferenceRouter(Inference):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logger.debug("Initializing InferenceRouter")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
logger.debug("InferenceRouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logger.debug("InferenceRouter.shutdown")
pass pass
async def register_model( async def register_model(
@ -132,6 +139,9 @@ class InferenceRouter(Inference):
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None, model_type: Optional[ModelType] = None,
) -> None: ) -> None:
logger.debug(
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
)
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
async def chat_completion( async def chat_completion(
@ -242,6 +252,7 @@ class InferenceRouter(Inference):
output_dimension: Optional[int] = None, output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None, task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
logger.debug(f"InferenceRouter.embeddings: {model_id}")
model = await self.routing_table.get_model(model_id) model = await self.routing_table.get_model(model_id)
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
@ -261,12 +272,15 @@ class SafetyRouter(Safety):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logger.debug("Initializing SafetyRouter")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
logger.debug("SafetyRouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logger.debug("SafetyRouter.shutdown")
pass pass
async def register_shield( async def register_shield(
@ -276,6 +290,7 @@ class SafetyRouter(Safety):
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> Shield: ) -> Shield:
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
async def run_shield( async def run_shield(
@ -284,6 +299,7 @@ class SafetyRouter(Safety):
messages: List[Message], messages: List[Message],
params: Dict[str, Any] = None, params: Dict[str, Any] = None,
) -> RunShieldResponse: ) -> RunShieldResponse:
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
return await self.routing_table.get_provider_impl(shield_id).run_shield( return await self.routing_table.get_provider_impl(shield_id).run_shield(
shield_id=shield_id, shield_id=shield_id,
messages=messages, messages=messages,
@ -296,12 +312,15 @@ class DatasetIORouter(DatasetIO):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logger.debug("Initializing DatasetIORouter")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
logger.debug("DatasetIORouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logger.debug("DatasetIORouter.shutdown")
pass pass
async def get_rows_paginated( async def get_rows_paginated(
@ -322,6 +341,7 @@ class DatasetIORouter(DatasetIO):
) )
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
return await self.routing_table.get_provider_impl(dataset_id).append_rows( return await self.routing_table.get_provider_impl(dataset_id).append_rows(
dataset_id=dataset_id, dataset_id=dataset_id,
rows=rows, rows=rows,
@ -333,12 +353,15 @@ class ScoringRouter(Scoring):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logger.debug("Initializing ScoringRouter")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
logger.debug("ScoringRouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logger.debug("ScoringRouter.shutdown")
pass pass
async def score_batch( async def score_batch(
@ -347,6 +370,7 @@ class ScoringRouter(Scoring):
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
res = {} res = {}
for fn_identifier in scoring_functions.keys(): for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch( score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
@ -367,9 +391,7 @@ class ScoringRouter(Scoring):
input_rows: List[Dict[str, Any]], input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse: ) -> ScoreResponse:
logger.debug( logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions",
)
res = {} res = {}
# look up and map each scoring function to its provider impl # look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys(): for fn_identifier in scoring_functions.keys():
@ -387,12 +409,15 @@ class EvalRouter(Eval):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logger.debug("Initializing EvalRouter")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
logger.debug("EvalRouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logger.debug("EvalRouter.shutdown")
pass pass
async def run_eval( async def run_eval(
@ -400,6 +425,7 @@ class EvalRouter(Eval):
benchmark_id: str, benchmark_id: str,
benchmark_config: BenchmarkConfig, benchmark_config: BenchmarkConfig,
) -> Job: ) -> Job:
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
return await self.routing_table.get_provider_impl(benchmark_id).run_eval( return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
benchmark_id=benchmark_id, benchmark_id=benchmark_id,
benchmark_config=benchmark_config, benchmark_config=benchmark_config,
@ -412,6 +438,7 @@ class EvalRouter(Eval):
scoring_functions: List[str], scoring_functions: List[str],
benchmark_config: BenchmarkConfig, benchmark_config: BenchmarkConfig,
) -> EvaluateResponse: ) -> EvaluateResponse:
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows( return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
benchmark_id=benchmark_id, benchmark_id=benchmark_id,
input_rows=input_rows, input_rows=input_rows,
@ -424,6 +451,7 @@ class EvalRouter(Eval):
benchmark_id: str, benchmark_id: str,
job_id: str, job_id: str,
) -> Optional[JobStatus]: ) -> Optional[JobStatus]:
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id) return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
async def job_cancel( async def job_cancel(
@ -431,6 +459,7 @@ class EvalRouter(Eval):
benchmark_id: str, benchmark_id: str,
job_id: str, job_id: str,
) -> None: ) -> None:
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
await self.routing_table.get_provider_impl(benchmark_id).job_cancel( await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
benchmark_id, benchmark_id,
job_id, job_id,
@ -441,6 +470,7 @@ class EvalRouter(Eval):
benchmark_id: str, benchmark_id: str,
job_id: str, job_id: str,
) -> EvaluateResponse: ) -> EvaluateResponse:
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_result( return await self.routing_table.get_provider_impl(benchmark_id).job_result(
benchmark_id, benchmark_id,
job_id, job_id,
@ -453,6 +483,7 @@ class ToolRuntimeRouter(ToolRuntime):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
self.routing_table = routing_table self.routing_table = routing_table
async def query( async def query(
@ -461,6 +492,7 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_ids: List[str], vector_db_ids: List[str],
query_config: Optional[RAGQueryConfig] = None, query_config: Optional[RAGQueryConfig] = None,
) -> RAGQueryResult: ) -> RAGQueryResult:
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
return await self.routing_table.get_provider_impl("knowledge_search").query( return await self.routing_table.get_provider_impl("knowledge_search").query(
content, vector_db_ids, query_config content, vector_db_ids, query_config
) )
@ -471,6 +503,9 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_id: str, vector_db_id: str,
chunk_size_in_tokens: int = 512, chunk_size_in_tokens: int = 512,
) -> None: ) -> None:
logger.debug(
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
)
return await self.routing_table.get_provider_impl("insert_into_memory").insert( return await self.routing_table.get_provider_impl("insert_into_memory").insert(
documents, vector_db_id, chunk_size_in_tokens documents, vector_db_id, chunk_size_in_tokens
) )
@ -479,6 +514,7 @@ class ToolRuntimeRouter(ToolRuntime):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logger.debug("Initializing ToolRuntimeRouter")
self.routing_table = routing_table self.routing_table = routing_table
# HACK ALERT this should be in sync with "get_all_api_endpoints()" # HACK ALERT this should be in sync with "get_all_api_endpoints()"
@ -487,12 +523,15 @@ class ToolRuntimeRouter(ToolRuntime):
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method)) setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
async def initialize(self) -> None: async def initialize(self) -> None:
logger.debug("ToolRuntimeRouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logger.debug("ToolRuntimeRouter.shutdown")
pass pass
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any: async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
return await self.routing_table.get_provider_impl(tool_name).invoke_tool( return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
tool_name=tool_name, tool_name=tool_name,
kwargs=kwargs, kwargs=kwargs,
@ -501,4 +540,5 @@ class ToolRuntimeRouter(ToolRuntime):
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]: ) -> List[ToolDef]:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint) return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)

View file

@ -9,7 +9,6 @@ import asyncio
import functools import functools
import inspect import inspect
import json import json
import logging
import os import os
import signal import signal
import sys import sys
@ -26,7 +25,6 @@ from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.datatypes import StackRunConfig
@ -39,6 +37,7 @@ from llama_stack.distribution.stack import (
replace_env_vars, replace_env_vars,
validate_env_pair, validate_env_pair,
) )
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import ( from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
@ -54,8 +53,7 @@ from .endpoints import get_all_api_endpoints
REPO_ROOT = Path(__file__).parent.parent.parent.parent REPO_ROOT = Path(__file__).parent.parent.parent.parent
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s") logger = get_logger(name=__name__, category="server")
logger = logging.getLogger(__name__)
def warn_with_traceback(message, category, filename, lineno, file=None, line=None): def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
@ -209,7 +207,7 @@ async def sse_generator(event_gen):
yield create_sse_event(item) yield create_sse_event(item)
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
except asyncio.CancelledError: except asyncio.CancelledError:
print("Generator cancelled") logger.info("Generator cancelled")
await event_gen.aclose() await event_gen.aclose()
except Exception as e: except Exception as e:
logger.exception(f"Error in sse_generator: {e}") logger.exception(f"Error in sse_generator: {e}")

View file

@ -5,14 +5,12 @@
# the root directory of this source tree. # the root directory of this source tree.
import importlib.resources import importlib.resources
import logging
import os import os
import re import re
import tempfile import tempfile
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import yaml import yaml
from termcolor import colored
from llama_stack.apis.agents import Agents from llama_stack.apis.agents import Agents
from llama_stack.apis.batch_inference import BatchInference from llama_stack.apis.batch_inference import BatchInference
@ -39,9 +37,10 @@ from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
from llama_stack.distribution.store.registry import create_dist_registry from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__) logger = get_logger(name=__name__, category="core")
class LlamaStack( class LlamaStack(
@ -103,12 +102,10 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
objects_to_process = response.data if hasattr(response, "data") else response objects_to_process = response.data if hasattr(response, "data") else response
for obj in objects_to_process: for obj in objects_to_process:
log.info( logger.debug(
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}", f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}",
) )
log.info("")
class EnvVarError(Exception): class EnvVarError(Exception):
def __init__(self, var_name: str, path: str = ""): def __init__(self, var_name: str, path: str = ""):

153
llama_stack/log.py Normal file
View file

@ -0,0 +1,153 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import os
from logging.config import dictConfig
from typing import Dict
# Default log level
DEFAULT_LOG_LEVEL = logging.INFO
# Predefined categories
CATEGORIES = ["core", "server", "router", "inference", "agents", "safety", "eval", "tools", "client"]
# Initialize category levels with default level
_category_levels: Dict[str, int] = {category: DEFAULT_LOG_LEVEL for category in CATEGORIES}
def parse_environment_config(env_config: str) -> Dict[str, int]:
"""
Parse the LLAMA_STACK_LOGGING environment variable and return a dictionary of category log levels.
Parameters:
env_config (str): The value of the LLAMA_STACK_LOGGING environment variable.
Returns:
Dict[str, int]: A dictionary mapping categories to their log levels.
"""
category_levels = {}
for pair in env_config.split(";"):
if not pair.strip():
continue
try:
category, level = pair.split("=", 1)
category = category.strip().lower()
level = level.strip().upper() # Convert to uppercase for logging._nameToLevel
level_value = logging._nameToLevel.get(level)
if level_value is None:
logging.warning(
f"Unknown log level '{level}' for category '{category}'. Falling back to default 'INFO'."
)
continue
if category == "all":
# Apply the log level to all categories and the root logger
for cat in CATEGORIES:
category_levels[cat] = level_value
# Set the root logger's level to the specified level
category_levels["root"] = level_value
elif category in CATEGORIES:
category_levels[category] = level_value
logging.info(f"Setting '{category}' category to level '{level}'.")
else:
logging.warning(f"Unknown logging category: {category}. No changes made.")
except ValueError:
logging.warning(f"Invalid logging configuration: '{pair}'. Expected format: 'category=level'.")
return category_levels
def setup_logging(category_levels: Dict[str, int]) -> None:
"""
Configure logging based on the provided category log levels.
Parameters:
category_levels (Dict[str, int]): A dictionary mapping categories to their log levels.
"""
log_format = "%(asctime)s %(name)s:%(lineno)d [%(category)s]: %(message)s"
class CategoryFilter(logging.Filter):
"""Ensure category is always present in log records."""
def filter(self, record):
if not hasattr(record, "category"):
record.category = "uncategorized" # Default to 'uncategorized' if no category found
return True
# Determine the root logger's level (default to WARNING if not specified)
root_level = category_levels.get("root", logging.WARNING)
logging_config = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"rich": {
"()": logging.Formatter, # Standard formatter (RichHandler handles colors)
"format": log_format,
}
},
"handlers": {
"console": {
"class": "rich.logging.RichHandler",
"formatter": "rich",
"rich_tracebacks": True,
"show_time": False, # We handle timestamps ourselves in the log_format
"show_path": False,
"filters": ["category_filter"], # Ensures category is included
}
},
"filters": {
"category_filter": {
"()": CategoryFilter,
}
},
"loggers": {
category: {
"handlers": ["console"],
"level": category_levels.get(category, DEFAULT_LOG_LEVEL),
"propagate": False, # Disable propagation to root logger
}
for category in CATEGORIES
},
"root": {
"handlers": ["console"],
"level": root_level, # Set root logger's level dynamically
},
}
dictConfig(logging_config)
def get_logger(name: str, category: str = "uncategorized") -> logging.LoggerAdapter:
"""
Returns a logger with the specified name and category.
If no category is provided, defaults to 'uncategorized'.
Parameters:
name (str): The name of the logger (e.g., module or filename).
category (str): The category of the logger (default 'uncategorized').
Returns:
logging.LoggerAdapter: Configured logger with category support.
"""
# Use the name as the logger's name
logger = logging.getLogger(name)
# Apply the category's log level to the logger
logger.setLevel(_category_levels.get(category, DEFAULT_LOG_LEVEL))
# Attach the category as extra context
return logging.LoggerAdapter(logger, {"category": category})
# Parse environment variable and configure logging
env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
if env_config:
print(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}")
_category_levels.update(parse_environment_config(env_config))
setup_logging(_category_levels)

View file

@ -17,7 +17,6 @@ from urllib.parse import urlparse
import httpx import httpx
from llama_stack import logcat
from llama_stack.apis.agents import ( from llama_stack.apis.agents import (
AgentConfig, AgentConfig,
AgentToolGroup, AgentToolGroup,
@ -67,6 +66,7 @@ from llama_stack.apis.tools import (
ToolRuntime, ToolRuntime,
) )
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
BuiltinTool, BuiltinTool,
ToolCall, ToolCall,
@ -88,6 +88,8 @@ MEMORY_QUERY_TOOL = "knowledge_search"
WEB_SEARCH_TOOL = "web_search" WEB_SEARCH_TOOL = "web_search"
RAG_TOOL_GROUP = "builtin::rag" RAG_TOOL_GROUP = "builtin::rag"
logger = get_logger(name=__name__, category="agents")
class ChatAgent(ShieldRunnerMixin): class ChatAgent(ShieldRunnerMixin):
def __init__( def __init__(
@ -609,7 +611,7 @@ class ChatAgent(ShieldRunnerMixin):
) )
if n_iter >= self.agent_config.max_infer_iters: if n_iter >= self.agent_config.max_infer_iters:
logcat.info("agents", f"done with MAX iterations ({n_iter}), exiting.") logger.info(f"done with MAX iterations ({n_iter}), exiting.")
# NOTE: mark end_of_turn to indicate to client that we are done with the turn # NOTE: mark end_of_turn to indicate to client that we are done with the turn
# Do not continue the tool call loop after this point # Do not continue the tool call loop after this point
message.stop_reason = StopReason.end_of_turn message.stop_reason = StopReason.end_of_turn
@ -617,7 +619,7 @@ class ChatAgent(ShieldRunnerMixin):
break break
if stop_reason == StopReason.out_of_tokens: if stop_reason == StopReason.out_of_tokens:
logcat.info("agents", "out of token budget, exiting.") logger.info("out of token budget, exiting.")
yield message yield message
break break
@ -631,16 +633,10 @@ class ChatAgent(ShieldRunnerMixin):
message.content = [message.content] + output_attachments message.content = [message.content] + output_attachments
yield message yield message
else: else:
logcat.debug( logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
"agents",
f"completion message with EOM (iter: {n_iter}): {str(message)}",
)
input_messages = input_messages + [message] input_messages = input_messages + [message]
else: else:
logcat.debug( logger.debug(f"completion message (iter: {n_iter}) from the model: {str(message)}")
"agents",
f"completion message (iter: {n_iter}) from the model: {str(message)}",
)
# 1. Start the tool execution step and progress # 1. Start the tool execution step and progress
step_id = str(uuid.uuid4()) step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
@ -983,7 +979,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
path = urlparse(uri).path path = urlparse(uri).path
basename = os.path.basename(path) basename = os.path.basename(path)
filepath = f"{tempdir}/{make_random_string() + basename}" filepath = f"{tempdir}/{make_random_string() + basename}"
logcat.info("agents", f"Downloading {url} -> {filepath}") logger.info(f"Downloading {url} -> {filepath}")
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
r = await client.get(uri) r = await client.get(uri)
@ -1023,7 +1019,7 @@ async def execute_tool_call_maybe(
else: else:
name = name.value name = name.value
logcat.info("agents", f"executing tool call: {name} with args: {tool_call.arguments}") logger.info(f"executing tool call: {name} with args: {tool_call.arguments}")
result = await tool_runtime_api.invoke_tool( result = await tool_runtime_api.invoke_tool(
tool_name=name, tool_name=name,
kwargs={ kwargs={
@ -1033,7 +1029,7 @@ async def execute_tool_call_maybe(
**toolgroup_args.get(group_name, {}), **toolgroup_args.get(group_name, {}),
}, },
) )
logcat.debug("agents", f"tool call {name} completed with result: {result}") logger.info(f"tool call {name} completed with result: {result}")
return result return result

View file

@ -32,6 +32,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
) )
@ -54,6 +55,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import FireworksImplConfig from .config import FireworksImplConfig
from .models import MODEL_ENTRIES from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: FireworksImplConfig) -> None: def __init__(self, config: FireworksImplConfig) -> None:
@ -230,12 +233,15 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
if input_dict["prompt"].startswith("<|begin_of_text|>"): if input_dict["prompt"].startswith("<|begin_of_text|>"):
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :] input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
return { params = {
"model": request.model, "model": request.model,
**input_dict, **input_dict,
"stream": request.stream, "stream": request.stream,
**self._build_options(request.sampling_params, request.response_format, request.logprobs), **self._build_options(request.sampling_params, request.response_format, request.logprobs),
} }
logger.debug(f"params to fireworks: {params}")
return params
async def embeddings( async def embeddings(
self, self,

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
from typing import AsyncGenerator, List, Optional, Union from typing import AsyncGenerator, List, Optional, Union
import httpx import httpx
@ -34,6 +34,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
@ -58,7 +59,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .models import model_entries from .models import model_entries
log = logging.getLogger(__name__) logger = get_logger(name=__name__, category="inference")
class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
@ -71,7 +72,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return AsyncClient(host=self.url) return AsyncClient(host=self.url)
async def initialize(self) -> None: async def initialize(self) -> None:
log.info(f"checking connectivity to Ollama at `{self.url}`...") logger.info(f"checking connectivity to Ollama at `{self.url}`...")
try: try:
await self.client.ps() await self.client.ps()
except httpx.ConnectError as e: except httpx.ConnectError as e:
@ -207,12 +208,15 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
else: else:
raise ValueError(f"Unknown response format type: {fmt.type}") raise ValueError(f"Unknown response format type: {fmt.type}")
return { params = {
"model": request.model, "model": request.model,
**input_dict, **input_dict,
"options": sampling_options, "options": sampling_options,
"stream": request.stream, "stream": request.stream,
} }
logger.debug(f"params to ollama: {params}")
return params
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request) params = await self._get_params(request)
@ -287,7 +291,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def register_model(self, model: Model) -> Model: async def register_model(self, model: Model) -> Model:
model = await self.register_helper.register_model(model) model = await self.register_helper.register_model(model)
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
log.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...") logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
await self.client.pull(model.provider_resource_id) await self.client.pull(model.provider_resource_id)
response = await self.client.list() response = await self.client.list()
else: else:

View file

@ -31,6 +31,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
) )
@ -53,6 +54,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import TogetherImplConfig from .config import TogetherImplConfig
from .models import MODEL_ENTRIES from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: TogetherImplConfig) -> None: def __init__(self, config: TogetherImplConfig) -> None:
@ -217,12 +220,13 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
assert not media_present, "Together does not support media for Completion requests" assert not media_present, "Together does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request) input_dict["prompt"] = await completion_request_to_prompt(request)
return { params = {
"model": request.model, "model": request.model,
**input_dict, **input_dict,
"stream": request.stream, "stream": request.stream,
**self._build_options(request.sampling_params, request.logprobs, request.response_format), **self._build_options(request.sampling_params, request.logprobs, request.response_format),
} }
logger.debug(f"params to together: {params}")
async def embeddings( async def embeddings(
self, self,

View file

@ -32,6 +32,7 @@ from llama_stack.apis.inference import (
) )
from llama_stack.apis.models.models import Model from llama_stack.apis.models.models import Model
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
) )
@ -46,6 +47,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,
) )
logger = get_logger(name=__name__, category="inference")
class LiteLLMOpenAIMixin( class LiteLLMOpenAIMixin(
ModelRegistryHelper, ModelRegistryHelper,
@ -108,6 +111,7 @@ class LiteLLMOpenAIMixin(
) )
params = await self._get_params(request) params = await self._get_params(request)
logger.debug(f"params to litellm (openai compat): {params}")
# unfortunately, we need to use synchronous litellm.completion here because litellm # unfortunately, we need to use synchronous litellm.completion here because litellm
# caches various httpx.client objects in a non-eventloop aware manner # caches various httpx.client objects in a non-eventloop aware manner
response = litellm.completion(**params) response = litellm.completion(**params)

View file

@ -8,14 +8,12 @@ import asyncio
import base64 import base64
import io import io
import json import json
import logging
import re import re
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import httpx import httpx
from PIL import Image as PIL_Image from PIL import Image as PIL_Image
from llama_stack import logcat
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
ImageContentItem, ImageContentItem,
InterleavedContent, InterleavedContent,
@ -34,6 +32,7 @@ from llama_stack.apis.inference import (
ToolDefinition, ToolDefinition,
UserMessage, UserMessage,
) )
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
ModelFamily, ModelFamily,
RawContent, RawContent,
@ -58,7 +57,7 @@ from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.utils.inference import supported_inference_models from llama_stack.providers.utils.inference import supported_inference_models
log = logging.getLogger(__name__) log = get_logger(name=__name__, category="inference")
class ChatCompletionRequestWithRawContent(ChatCompletionRequest): class ChatCompletionRequestWithRawContent(ChatCompletionRequest):
@ -464,7 +463,7 @@ def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefin
def get_default_tool_prompt_format(model: str) -> ToolPromptFormat: def get_default_tool_prompt_format(model: str) -> ToolPromptFormat:
llama_model = resolve_model(model) llama_model = resolve_model(model)
if llama_model is None: if llama_model is None:
logcat.warning("inference", f"Could not resolve model {model}, defaulting to json tool prompt format") log.warning(f"Could not resolve model {model}, defaulting to json tool prompt format")
return ToolPromptFormat.json return ToolPromptFormat.json
if llama_model.model_family == ModelFamily.llama3_1 or ( if llama_model.model_family == ModelFamily.llama3_1 or (

View file

@ -162,5 +162,5 @@ module = ["yaml", "fire"]
ignore_missing_imports = true ignore_missing_imports = true
[[tool.mypy.overrides]] [[tool.mypy.overrides]]
module = "llama_stack.distribution.resolver" module = ["llama_stack.distribution.resolver", "llama_stack.log"]
follow_imports = "normal" # This will force type checking on this module follow_imports = "normal" # This will force type checking on this module