feat(logging): implement category-based logging (#1362)

# What does this PR do?

This commit introduces a new logging system that allows loggers to be
assigned
a category while retaining the logger name based on the file name. The
log
format includes both the logger name and the category, producing output
like:

```
INFO     2025-03-03 21:44:11,323 llama_stack.distribution.stack:103 [core]: Tool_groups: builtin::websearch served by
         tavily-search
```

Key features include:

- Category-based logging: Loggers can be assigned a category (e.g.,
  "core", "server") when programming. The logger can be loaded like
  this: `logger = get_logger(name=__name__, category="server")`
- Environment variable control: Log levels can be configured
per-category using the
  `LLAMA_STACK_LOGGING` environment variable. For example:
`LLAMA_STACK_LOGGING="server=DEBUG;core=debug"` enables DEBUG level for
the "server"
    and "core" categories.
- `LLAMA_STACK_LOGGING="all=debug"` sets DEBUG level globally for all
categories and
    third-party libraries.

This provides fine-grained control over logging levels while maintaining
a clean and
informative log format.

The formatter uses the rich library which provides nice colors better
stack traces like so:

```
ERROR    2025-03-03 21:49:37,124 asyncio:1758 [uncategorized]: unhandled exception during asyncio.run() shutdown
         task: <Task finished name='Task-16' coro=<handle_signal.<locals>.shutdown() done, defined at
         /Users/leseb/Documents/AI/llama-stack/llama_stack/distribution/server/server.py:146>
         exception=UnboundLocalError("local variable 'loop' referenced before assignment")>
         ╭────────────────────────────────────── Traceback (most recent call last) ───────────────────────────────────────╮
         │ /Users/leseb/Documents/AI/llama-stack/llama_stack/distribution/server/server.py:178 in shutdown                │
         │                                                                                                                │
         │   175 │   │   except asyncio.CancelledError:                                                                   │
         │   176 │   │   │   pass                                                                                         │
         │   177 │   │   finally:                                                                                         │
         │ ❱ 178 │   │   │   loop.stop()                                                                                  │
         │   179 │                                                                                                        │
         │   180 │   loop = asyncio.get_running_loop()                                                                    │
         │   181 │   loop.create_task(shutdown())                                                                         │
         ╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
         UnboundLocalError: local variable 'loop' referenced before assignment
```

Co-authored-by: Ashwin Bharambe <@ashwinb>
Signed-off-by: Sébastien Han <seb@redhat.com>

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan

```
python -m llama_stack.distribution.server.server --yaml-config ./llama_stack/templates/ollama/run.yaml
INFO     2025-03-03 21:55:35,918 __main__:365 [server]: Using config file: llama_stack/templates/ollama/run.yaml           
INFO     2025-03-03 21:55:35,925 __main__:378 [server]: Run configuration:                                                 
INFO     2025-03-03 21:55:35,928 __main__:380 [server]: apis:                                                              
         - agents                                                     
``` 
[//]: # (## Documentation)

---------

Signed-off-by: Sébastien Han <seb@redhat.com>
Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
Sébastien Han 2025-03-07 20:34:30 +01:00 committed by GitHub
parent bad12ee21f
commit 7cf1e24c4e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 296 additions and 431 deletions

View file

@ -5,15 +5,15 @@
# the root directory of this source tree. # the root directory of this source tree.
import argparse import argparse
import logging
import os import os
from pathlib import Path from pathlib import Path
from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
from llama_stack.log import get_logger
REPO_ROOT = Path(__file__).parent.parent.parent.parent REPO_ROOT = Path(__file__).parent.parent.parent.parent
logger = logging.getLogger(__name__) logger = get_logger(name=__name__, category="server")
class StackRun(Subcommand): class StackRun(Subcommand):

View file

@ -7,7 +7,6 @@ import importlib
import inspect import inspect
from typing import Any, Dict, List, Set, Tuple from typing import Any, Dict, List, Set, Tuple
from llama_stack import logcat
from llama_stack.apis.agents import Agents from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
@ -35,6 +34,7 @@ from llama_stack.distribution.datatypes import (
from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ( from llama_stack.providers.datatypes import (
Api, Api,
BenchmarksProtocolPrivate, BenchmarksProtocolPrivate,
@ -50,6 +50,8 @@ from llama_stack.providers.datatypes import (
VectorDBsProtocolPrivate, VectorDBsProtocolPrivate,
) )
logger = get_logger(name=__name__, category="core")
class InvalidProviderError(Exception): class InvalidProviderError(Exception):
pass pass
@ -184,7 +186,7 @@ def validate_and_prepare_providers(
specs = {} specs = {}
for provider in providers: for provider in providers:
if not provider.provider_id or provider.provider_id == "__disabled__": if not provider.provider_id or provider.provider_id == "__disabled__":
logcat.warning("core", f"Provider `{provider.provider_type}` for API `{api}` is disabled") logger.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
continue continue
validate_provider(provider, api, provider_registry) validate_provider(provider, api, provider_registry)
@ -206,11 +208,10 @@ def validate_provider(provider: Provider, api: Api, provider_registry: ProviderR
p = provider_registry[api][provider.provider_type] p = provider_registry[api][provider.provider_type]
if p.deprecation_error: if p.deprecation_error:
logcat.error("core", p.deprecation_error) logger.error(p.deprecation_error)
raise InvalidProviderError(p.deprecation_error) raise InvalidProviderError(p.deprecation_error)
elif p.deprecation_warning: elif p.deprecation_warning:
logcat.warning( logger.warning(
"core",
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}", f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
) )
@ -244,9 +245,10 @@ def sort_providers_by_deps(
) )
) )
logcat.debug("core", f"Resolved {len(sorted_providers)} providers") logger.debug(f"Resolved {len(sorted_providers)} providers")
for api_str, provider in sorted_providers: for api_str, provider in sorted_providers:
logcat.debug("core", f" {api_str} => {provider.provider_id}") logger.debug(f" {api_str} => {provider.provider_id}")
logger.debug("")
return sorted_providers return sorted_providers
@ -387,7 +389,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
obj_params = set(obj_sig.parameters) obj_params = set(obj_sig.parameters)
obj_params.discard("self") obj_params.discard("self")
if not (proto_params <= obj_params): if not (proto_params <= obj_params):
logcat.error("core", f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}") logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
missing_methods.append((name, "signature_mismatch")) missing_methods.append((name, "signature_mismatch"))
else: else:
# Check if the method is actually implemented in the class # Check if the method is actually implemented in the class

View file

@ -6,7 +6,6 @@
from typing import Any, AsyncGenerator, Dict, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack import logcat
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
URL, URL,
InterleavedContent, InterleavedContent,
@ -52,8 +51,11 @@ from llama_stack.apis.tools import (
ToolRuntime, ToolRuntime,
) )
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
class VectorIORouter(VectorIO): class VectorIORouter(VectorIO):
"""Routes to an provider based on the vector db identifier""" """Routes to an provider based on the vector db identifier"""
@ -62,15 +64,15 @@ class VectorIORouter(VectorIO):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logcat.debug("core", "Initializing VectorIORouter") logger.debug("Initializing VectorIORouter")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
logcat.debug("core", "VectorIORouter.initialize") logger.debug("VectorIORouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logcat.debug("core", "VectorIORouter.shutdown") logger.debug("VectorIORouter.shutdown")
pass pass
async def register_vector_db( async def register_vector_db(
@ -81,10 +83,7 @@ class VectorIORouter(VectorIO):
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
provider_vector_db_id: Optional[str] = None, provider_vector_db_id: Optional[str] = None,
) -> None: ) -> None:
logcat.debug( logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
"core",
f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}",
)
await self.routing_table.register_vector_db( await self.routing_table.register_vector_db(
vector_db_id, vector_db_id,
embedding_model, embedding_model,
@ -99,8 +98,7 @@ class VectorIORouter(VectorIO):
chunks: List[Chunk], chunks: List[Chunk],
ttl_seconds: Optional[int] = None, ttl_seconds: Optional[int] = None,
) -> None: ) -> None:
logcat.debug( logger.debug(
"core",
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
) )
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds) return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
@ -111,7 +109,7 @@ class VectorIORouter(VectorIO):
query: InterleavedContent, query: InterleavedContent,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
logcat.debug("core", f"VectorIORouter.query_chunks: {vector_db_id}") logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params) return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
@ -122,15 +120,15 @@ class InferenceRouter(Inference):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logcat.debug("core", "Initializing InferenceRouter") logger.debug("Initializing InferenceRouter")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
logcat.debug("core", "InferenceRouter.initialize") logger.debug("InferenceRouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logcat.debug("core", "InferenceRouter.shutdown") logger.debug("InferenceRouter.shutdown")
pass pass
async def register_model( async def register_model(
@ -141,8 +139,7 @@ class InferenceRouter(Inference):
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None, model_type: Optional[ModelType] = None,
) -> None: ) -> None:
logcat.debug( logger.debug(
"core",
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}", f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
) )
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
@ -160,8 +157,7 @@ class InferenceRouter(Inference):
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
logcat.debug( logger.debug(
"core",
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
) )
if sampling_params is None: if sampling_params is None:
@ -226,8 +222,7 @@ class InferenceRouter(Inference):
) -> AsyncGenerator: ) -> AsyncGenerator:
if sampling_params is None: if sampling_params is None:
sampling_params = SamplingParams() sampling_params = SamplingParams()
logcat.debug( logger.debug(
"core",
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}", f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
) )
model = await self.routing_table.get_model(model_id) model = await self.routing_table.get_model(model_id)
@ -257,7 +252,7 @@ class InferenceRouter(Inference):
output_dimension: Optional[int] = None, output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None, task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
logcat.debug("core", f"InferenceRouter.embeddings: {model_id}") logger.debug(f"InferenceRouter.embeddings: {model_id}")
model = await self.routing_table.get_model(model_id) model = await self.routing_table.get_model(model_id)
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
@ -277,15 +272,15 @@ class SafetyRouter(Safety):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logcat.debug("core", "Initializing SafetyRouter") logger.debug("Initializing SafetyRouter")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
logcat.debug("core", "SafetyRouter.initialize") logger.debug("SafetyRouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logcat.debug("core", "SafetyRouter.shutdown") logger.debug("SafetyRouter.shutdown")
pass pass
async def register_shield( async def register_shield(
@ -295,7 +290,7 @@ class SafetyRouter(Safety):
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> Shield: ) -> Shield:
logcat.debug("core", f"SafetyRouter.register_shield: {shield_id}") logger.debug(f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
async def run_shield( async def run_shield(
@ -304,7 +299,7 @@ class SafetyRouter(Safety):
messages: List[Message], messages: List[Message],
params: Dict[str, Any] = None, params: Dict[str, Any] = None,
) -> RunShieldResponse: ) -> RunShieldResponse:
logcat.debug("core", f"SafetyRouter.run_shield: {shield_id}") logger.debug(f"SafetyRouter.run_shield: {shield_id}")
return await self.routing_table.get_provider_impl(shield_id).run_shield( return await self.routing_table.get_provider_impl(shield_id).run_shield(
shield_id=shield_id, shield_id=shield_id,
messages=messages, messages=messages,
@ -317,15 +312,15 @@ class DatasetIORouter(DatasetIO):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logcat.debug("core", "Initializing DatasetIORouter") logger.debug("Initializing DatasetIORouter")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
logcat.debug("core", "DatasetIORouter.initialize") logger.debug("DatasetIORouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logcat.debug("core", "DatasetIORouter.shutdown") logger.debug("DatasetIORouter.shutdown")
pass pass
async def get_rows_paginated( async def get_rows_paginated(
@ -335,8 +330,7 @@ class DatasetIORouter(DatasetIO):
page_token: Optional[str] = None, page_token: Optional[str] = None,
filter_condition: Optional[str] = None, filter_condition: Optional[str] = None,
) -> PaginatedRowsResult: ) -> PaginatedRowsResult:
logcat.debug( logger.debug(
"core",
f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}", f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}",
) )
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated( return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
@ -347,7 +341,7 @@ class DatasetIORouter(DatasetIO):
) )
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
logcat.debug("core", f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows") logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
return await self.routing_table.get_provider_impl(dataset_id).append_rows( return await self.routing_table.get_provider_impl(dataset_id).append_rows(
dataset_id=dataset_id, dataset_id=dataset_id,
rows=rows, rows=rows,
@ -359,15 +353,15 @@ class ScoringRouter(Scoring):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logcat.debug("core", "Initializing ScoringRouter") logger.debug("Initializing ScoringRouter")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
logcat.debug("core", "ScoringRouter.initialize") logger.debug("ScoringRouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logcat.debug("core", "ScoringRouter.shutdown") logger.debug("ScoringRouter.shutdown")
pass pass
async def score_batch( async def score_batch(
@ -376,7 +370,7 @@ class ScoringRouter(Scoring):
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
logcat.debug("core", f"ScoringRouter.score_batch: {dataset_id}") logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
res = {} res = {}
for fn_identifier in scoring_functions.keys(): for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch( score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
@ -397,10 +391,7 @@ class ScoringRouter(Scoring):
input_rows: List[Dict[str, Any]], input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse: ) -> ScoreResponse:
logcat.debug( logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
"core",
f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions",
)
res = {} res = {}
# look up and map each scoring function to its provider impl # look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys(): for fn_identifier in scoring_functions.keys():
@ -418,15 +409,15 @@ class EvalRouter(Eval):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logcat.debug("core", "Initializing EvalRouter") logger.debug("Initializing EvalRouter")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
logcat.debug("core", "EvalRouter.initialize") logger.debug("EvalRouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logcat.debug("core", "EvalRouter.shutdown") logger.debug("EvalRouter.shutdown")
pass pass
async def run_eval( async def run_eval(
@ -434,7 +425,7 @@ class EvalRouter(Eval):
benchmark_id: str, benchmark_id: str,
benchmark_config: BenchmarkConfig, benchmark_config: BenchmarkConfig,
) -> Job: ) -> Job:
logcat.debug("core", f"EvalRouter.run_eval: {benchmark_id}") logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
return await self.routing_table.get_provider_impl(benchmark_id).run_eval( return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
benchmark_id=benchmark_id, benchmark_id=benchmark_id,
benchmark_config=benchmark_config, benchmark_config=benchmark_config,
@ -447,7 +438,7 @@ class EvalRouter(Eval):
scoring_functions: List[str], scoring_functions: List[str],
benchmark_config: BenchmarkConfig, benchmark_config: BenchmarkConfig,
) -> EvaluateResponse: ) -> EvaluateResponse:
logcat.debug("core", f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows") logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows( return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
benchmark_id=benchmark_id, benchmark_id=benchmark_id,
input_rows=input_rows, input_rows=input_rows,
@ -460,7 +451,7 @@ class EvalRouter(Eval):
benchmark_id: str, benchmark_id: str,
job_id: str, job_id: str,
) -> Optional[JobStatus]: ) -> Optional[JobStatus]:
logcat.debug("core", f"EvalRouter.job_status: {benchmark_id}, {job_id}") logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id) return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
async def job_cancel( async def job_cancel(
@ -468,7 +459,7 @@ class EvalRouter(Eval):
benchmark_id: str, benchmark_id: str,
job_id: str, job_id: str,
) -> None: ) -> None:
logcat.debug("core", f"EvalRouter.job_cancel: {benchmark_id}, {job_id}") logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
await self.routing_table.get_provider_impl(benchmark_id).job_cancel( await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
benchmark_id, benchmark_id,
job_id, job_id,
@ -479,7 +470,7 @@ class EvalRouter(Eval):
benchmark_id: str, benchmark_id: str,
job_id: str, job_id: str,
) -> EvaluateResponse: ) -> EvaluateResponse:
logcat.debug("core", f"EvalRouter.job_result: {benchmark_id}, {job_id}") logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_result( return await self.routing_table.get_provider_impl(benchmark_id).job_result(
benchmark_id, benchmark_id,
job_id, job_id,
@ -492,7 +483,7 @@ class ToolRuntimeRouter(ToolRuntime):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logcat.debug("core", "Initializing ToolRuntimeRouter.RagToolImpl") logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
self.routing_table = routing_table self.routing_table = routing_table
async def query( async def query(
@ -501,7 +492,7 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_ids: List[str], vector_db_ids: List[str],
query_config: Optional[RAGQueryConfig] = None, query_config: Optional[RAGQueryConfig] = None,
) -> RAGQueryResult: ) -> RAGQueryResult:
logcat.debug("core", f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}") logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
return await self.routing_table.get_provider_impl("knowledge_search").query( return await self.routing_table.get_provider_impl("knowledge_search").query(
content, vector_db_ids, query_config content, vector_db_ids, query_config
) )
@ -512,9 +503,8 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_id: str, vector_db_id: str,
chunk_size_in_tokens: int = 512, chunk_size_in_tokens: int = 512,
) -> None: ) -> None:
logcat.debug( logger.debug(
"core", f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}",
) )
return await self.routing_table.get_provider_impl("insert_into_memory").insert( return await self.routing_table.get_provider_impl("insert_into_memory").insert(
documents, vector_db_id, chunk_size_in_tokens documents, vector_db_id, chunk_size_in_tokens
@ -524,7 +514,7 @@ class ToolRuntimeRouter(ToolRuntime):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logcat.debug("core", "Initializing ToolRuntimeRouter") logger.debug("Initializing ToolRuntimeRouter")
self.routing_table = routing_table self.routing_table = routing_table
# HACK ALERT this should be in sync with "get_all_api_endpoints()" # HACK ALERT this should be in sync with "get_all_api_endpoints()"
@ -533,15 +523,15 @@ class ToolRuntimeRouter(ToolRuntime):
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method)) setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
async def initialize(self) -> None: async def initialize(self) -> None:
logcat.debug("core", "ToolRuntimeRouter.initialize") logger.debug("ToolRuntimeRouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logcat.debug("core", "ToolRuntimeRouter.shutdown") logger.debug("ToolRuntimeRouter.shutdown")
pass pass
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any: async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
logcat.debug("core", f"ToolRuntimeRouter.invoke_tool: {tool_name}") logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
return await self.routing_table.get_provider_impl(tool_name).invoke_tool( return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
tool_name=tool_name, tool_name=tool_name,
kwargs=kwargs, kwargs=kwargs,
@ -550,5 +540,5 @@ class ToolRuntimeRouter(ToolRuntime):
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]: ) -> List[ToolDef]:
logcat.debug("core", f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint) return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)

View file

@ -9,7 +9,6 @@ import asyncio
import functools import functools
import inspect import inspect
import json import json
import logging
import os import os
import signal import signal
import sys import sys
@ -28,7 +27,6 @@ from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack import logcat
from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.request_headers import set_request_provider_data
@ -39,6 +37,7 @@ from llama_stack.distribution.stack import (
replace_env_vars, replace_env_vars,
validate_env_pair, validate_env_pair,
) )
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import ( from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
@ -54,8 +53,7 @@ from .endpoints import get_all_api_endpoints
REPO_ROOT = Path(__file__).parent.parent.parent.parent REPO_ROOT = Path(__file__).parent.parent.parent.parent
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s") logger = get_logger(name=__name__, category="server")
logcat.init()
def warn_with_traceback(message, category, filename, lineno, file=None, line=None): def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
@ -142,23 +140,23 @@ def handle_signal(app, signum, _) -> None:
not block the current execution. not block the current execution.
""" """
signame = signal.Signals(signum).name signame = signal.Signals(signum).name
logcat.info("server", f"Received signal {signame} ({signum}). Exiting gracefully...") logger.info(f"Received signal {signame} ({signum}). Exiting gracefully...")
async def shutdown(): async def shutdown():
try: try:
# Gracefully shut down implementations # Gracefully shut down implementations
for impl in app.__llama_stack_impls__.values(): for impl in app.__llama_stack_impls__.values():
impl_name = impl.__class__.__name__ impl_name = impl.__class__.__name__
logcat.info("server", f"Shutting down {impl_name}") logger.info("Shutting down %s", impl_name)
try: try:
if hasattr(impl, "shutdown"): if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5) await asyncio.wait_for(impl.shutdown(), timeout=5)
else: else:
logcat.warning("server", f"No shutdown method for {impl_name}") logger.warning("No shutdown method for %s", impl_name)
except asyncio.TimeoutError: except asyncio.TimeoutError:
logcat.exception("server", f"Shutdown timeout for {impl_name}") logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
except Exception as e: except Exception as e:
logcat.exception("server", f"Failed to shutdown {impl_name}: {e}") logger.exception("Failed to shutdown %s: %s", impl_name, {e})
# Gather all running tasks # Gather all running tasks
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
@ -172,7 +170,7 @@ def handle_signal(app, signum, _) -> None:
try: try:
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10) await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
except asyncio.TimeoutError: except asyncio.TimeoutError:
logcat.exception("server", "Timeout while waiting for tasks to finish") logger.exception("Timeout while waiting for tasks to finish")
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
finally: finally:
@ -184,9 +182,9 @@ def handle_signal(app, signum, _) -> None:
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
logcat.info("server", "Starting up") logger.info("Starting up")
yield yield
logcat.info("server", "Shutting down") logger.info("Shutting down")
for impl in app.__llama_stack_impls__.values(): for impl in app.__llama_stack_impls__.values():
await impl.shutdown() await impl.shutdown()
@ -209,11 +207,11 @@ async def sse_generator(event_gen):
yield create_sse_event(item) yield create_sse_event(item)
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
except asyncio.CancelledError: except asyncio.CancelledError:
logcat.info("server", "Generator cancelled") logger.info("Generator cancelled")
await event_gen.aclose() await event_gen.aclose()
except Exception as e: except Exception as e:
logcat.exception("server", f"Error in sse_generator: {e}") logger.exception(f"Error in sse_generator: {e}")
logcat.exception("server", f"Traceback: {''.join(traceback.format_exception(type(e), e, e.__traceback__))}") logger.exception(f"Traceback: {''.join(traceback.format_exception(type(e), e, e.__traceback__))}")
yield create_sse_event( yield create_sse_event(
{ {
"error": { "error": {
@ -235,7 +233,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
value = func(**kwargs) value = func(**kwargs)
return await maybe_await(value) return await maybe_await(value)
except Exception as e: except Exception as e:
logcat.exception("server", f"Error in {func.__name__}") traceback.print_exception(e)
raise translate_exception(e) from e raise translate_exception(e) from e
sig = inspect.signature(func) sig = inspect.signature(func)
@ -314,8 +312,6 @@ class ClientVersionMiddleware:
def main(): def main():
logcat.init()
"""Start the LlamaStack server.""" """Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.") parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
parser.add_argument( parser.add_argument(
@ -355,10 +351,10 @@ def main():
for env_pair in args.env: for env_pair in args.env:
try: try:
key, value = validate_env_pair(env_pair) key, value = validate_env_pair(env_pair)
logcat.info("server", f"Setting CLI environment variable {key} => {value}") logger.info(f"Setting CLI environment variable {key} => {value}")
os.environ[key] = value os.environ[key] = value
except ValueError as e: except ValueError as e:
logcat.error("server", f"Error: {str(e)}") logger.error(f"Error: {str(e)}")
sys.exit(1) sys.exit(1)
if args.yaml_config: if args.yaml_config:
@ -366,12 +362,12 @@ def main():
config_file = Path(args.yaml_config) config_file = Path(args.yaml_config)
if not config_file.exists(): if not config_file.exists():
raise ValueError(f"Config file {config_file} does not exist") raise ValueError(f"Config file {config_file} does not exist")
logcat.info("server", f"Using config file: {config_file}") logger.info(f"Using config file: {config_file}")
elif args.template: elif args.template:
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml" config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
if not config_file.exists(): if not config_file.exists():
raise ValueError(f"Template {args.template} does not exist") raise ValueError(f"Template {args.template} does not exist")
logcat.info("server", f"Using template {args.template} config file: {config_file}") logger.info(f"Using template {args.template} config file: {config_file}")
else: else:
raise ValueError("Either --yaml-config or --template must be provided") raise ValueError("Either --yaml-config or --template must be provided")
@ -379,10 +375,9 @@ def main():
config = replace_env_vars(yaml.safe_load(fp)) config = replace_env_vars(yaml.safe_load(fp))
config = StackRunConfig(**config) config = StackRunConfig(**config)
logcat.info("server", "Run configuration:") logger.info("Run configuration:")
safe_config = redact_sensitive_fields(config.model_dump()) safe_config = redact_sensitive_fields(config.model_dump())
for log_line in yaml.dump(safe_config, indent=2).split("\n"): logger.info(yaml.dump(safe_config, indent=2))
logcat.info("server", log_line)
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
app.add_middleware(TracingMiddleware) app.add_middleware(TracingMiddleware)
@ -392,7 +387,7 @@ def main():
try: try:
impls = asyncio.run(construct_stack(config)) impls = asyncio.run(construct_stack(config))
except InvalidProviderError as e: except InvalidProviderError as e:
logcat.error("server", f"Error: {str(e)}") logger.error(f"Error: {str(e)}")
sys.exit(1) sys.exit(1)
if Api.telemetry in impls: if Api.telemetry in impls:
@ -437,7 +432,7 @@ def main():
) )
) )
logcat.debug("server", f"serving APIs: {apis_to_serve}") logger.debug(f"serving APIs: {apis_to_serve}")
app.exception_handler(RequestValidationError)(global_exception_handler) app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler)
@ -464,10 +459,10 @@ def main():
"ssl_keyfile": keyfile, "ssl_keyfile": keyfile,
"ssl_certfile": certfile, "ssl_certfile": certfile,
} }
logcat.info("server", f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}") logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0" listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0"
logcat.info("server", f"Listening on {listen_host}:{port}") logger.info(f"Listening on {listen_host}:{port}")
uvicorn_config = { uvicorn_config = {
"app": app, "app": app,

View file

@ -11,9 +11,7 @@ import tempfile
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import yaml import yaml
from termcolor import colored
from llama_stack import logcat
from llama_stack.apis.agents import Agents from llama_stack.apis.agents import Agents
from llama_stack.apis.batch_inference import BatchInference from llama_stack.apis.batch_inference import BatchInference
from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.benchmarks import Benchmarks
@ -39,8 +37,11 @@ from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
from llama_stack.distribution.store.registry import create_dist_registry from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
logger = get_logger(name=__name__, category="core")
class LlamaStack( class LlamaStack(
VectorDBs, VectorDBs,
@ -101,9 +102,8 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
objects_to_process = response.data if hasattr(response, "data") else response objects_to_process = response.data if hasattr(response, "data") else response
for obj in objects_to_process: for obj in objects_to_process:
logcat.debug( logger.debug(
"core", f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}",
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
) )

View file

@ -100,12 +100,15 @@ esac
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
set -x set -x
$PYTHON_BINARY -m llama_stack.distribution.server.server \ $PYTHON_BINARY -m llama_stack.distribution.server.server \
--yaml-config "$yaml_config" \ --yaml-config "$yaml_config" \
--port "$port" \ --port "$port" \
$env_vars \ $env_vars \
$other_args $other_args
elif [[ "$env_type" == "container" ]]; then elif [[ "$env_type" == "container" ]]; then
set -x
# Check if container command is available # Check if container command is available
if ! is_command_available $CONTAINER_BINARY; then if ! is_command_available $CONTAINER_BINARY; then
printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2 printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2
@ -141,8 +144,6 @@ elif [[ "$env_type" == "container" ]]; then
version_tag=$(curl -s $URL | jq -r '.info.version') version_tag=$(curl -s $URL | jq -r '.info.version')
fi fi
set -x
$CONTAINER_BINARY run $CONTAINER_OPTS -it \ $CONTAINER_BINARY run $CONTAINER_OPTS -it \
-p $port:$port \ -p $port:$port \
$env_vars \ $env_vars \

169
llama_stack/log.py Normal file
View file

@ -0,0 +1,169 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import os
from logging.config import dictConfig
from typing import Dict
from rich.console import Console
from rich.logging import RichHandler
# Default log level
DEFAULT_LOG_LEVEL = logging.INFO
# Predefined categories
CATEGORIES = [
"core",
"server",
"router",
"inference",
"agents",
"safety",
"eval",
"tools",
"client",
]
# Initialize category levels with default level
_category_levels: Dict[str, int] = {category: DEFAULT_LOG_LEVEL for category in CATEGORIES}
def parse_environment_config(env_config: str) -> Dict[str, int]:
"""
Parse the LLAMA_STACK_LOGGING environment variable and return a dictionary of category log levels.
Parameters:
env_config (str): The value of the LLAMA_STACK_LOGGING environment variable.
Returns:
Dict[str, int]: A dictionary mapping categories to their log levels.
"""
category_levels = {}
for pair in env_config.split(";"):
if not pair.strip():
continue
try:
category, level = pair.split("=", 1)
category = category.strip().lower()
level = level.strip().upper() # Convert to uppercase for logging._nameToLevel
level_value = logging._nameToLevel.get(level)
if level_value is None:
logging.warning(
f"Unknown log level '{level}' for category '{category}'. Falling back to default 'INFO'."
)
continue
if category == "all":
# Apply the log level to all categories and the root logger
for cat in CATEGORIES:
category_levels[cat] = level_value
# Set the root logger's level to the specified level
category_levels["root"] = level_value
elif category in CATEGORIES:
category_levels[category] = level_value
logging.info(f"Setting '{category}' category to level '{level}'.")
else:
logging.warning(f"Unknown logging category: {category}. No changes made.")
except ValueError:
logging.warning(f"Invalid logging configuration: '{pair}'. Expected format: 'category=level'.")
return category_levels
class CustomRichHandler(RichHandler):
def __init__(self, *args, **kwargs):
kwargs["console"] = Console(width=120)
super().__init__(*args, **kwargs)
def setup_logging(category_levels: Dict[str, int]) -> None:
"""
Configure logging based on the provided category log levels.
Parameters:
category_levels (Dict[str, int]): A dictionary mapping categories to their log levels.
"""
log_format = "[dim]%(asctime)s %(name)s:%(lineno)d[/] [yellow dim]%(category)s[/]: %(message)s"
class CategoryFilter(logging.Filter):
"""Ensure category is always present in log records."""
def filter(self, record):
if not hasattr(record, "category"):
record.category = "uncategorized" # Default to 'uncategorized' if no category found
return True
# Determine the root logger's level (default to WARNING if not specified)
root_level = category_levels.get("root", logging.WARNING)
logging_config = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"rich": {
"()": logging.Formatter,
"format": log_format,
}
},
"handlers": {
"console": {
"()": CustomRichHandler, # Use our custom handler class
"formatter": "rich",
"rich_tracebacks": True,
"show_time": False,
"show_path": False,
"markup": True,
"filters": ["category_filter"],
}
},
"filters": {
"category_filter": {
"()": CategoryFilter,
}
},
"loggers": {
category: {
"handlers": ["console"],
"level": category_levels.get(category, DEFAULT_LOG_LEVEL),
"propagate": False, # Disable propagation to root logger
}
for category in CATEGORIES
},
"root": {
"handlers": ["console"],
"level": root_level, # Set root logger's level dynamically
},
}
dictConfig(logging_config)
def get_logger(name: str, category: str = "uncategorized") -> logging.LoggerAdapter:
"""
Returns a logger with the specified name and category.
If no category is provided, defaults to 'uncategorized'.
Parameters:
name (str): The name of the logger (e.g., module or filename).
category (str): The category of the logger (default 'uncategorized').
Returns:
logging.LoggerAdapter: Configured logger with category support.
"""
logger = logging.getLogger(name)
logger.setLevel(_category_levels.get(category, DEFAULT_LOG_LEVEL))
return logging.LoggerAdapter(logger, {"category": category})
env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
if env_config:
print(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}")
_category_levels.update(parse_environment_config(env_config))
setup_logging(_category_levels)

View file

@ -1,204 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Category-based logging utility for llama-stack.
This module provides a wrapper over the standard Python logging module that supports
categorized logging with environment variable control.
Usage:
from llama_stack import logcat
logcat.info("server", "Starting up...")
logcat.debug("inference", "Processing request...")
Environment variable:
LLAMA_STACK_LOGGING: Semicolon-separated list of category=level pairs
Example: "server=debug;inference=warning"
"""
import datetime
import logging
import os
from typing import Dict
# ANSI color codes for terminal output
COLORS = {
"RESET": "\033[0m",
"DEBUG": "\033[36m", # Cyan
"INFO": "\033[32m", # Green
"WARNING": "\033[33m", # Yellow
"ERROR": "\033[31m", # Red
"CRITICAL": "\033[35m", # Magenta
"DIM": "\033[2m", # Dimmed text
"YELLOW_DIM": "\033[2;33m", # Dimmed yellow
}
# Static list of valid categories representing various parts of the Llama Stack
# server codebase
CATEGORIES = [
"core",
"server",
"router",
"inference",
"agents",
"safety",
"eval",
"tools",
"client",
]
_logger = logging.getLogger("llama_stack")
_logger.propagate = False
_default_level = logging.INFO
# Category-level mapping (can be modified by environment variables)
_category_levels: Dict[str, int] = {}
class TerminalStreamHandler(logging.StreamHandler):
def __init__(self, stream=None):
super().__init__(stream)
self.is_tty = hasattr(self.stream, "isatty") and self.stream.isatty()
def format(self, record):
record.is_tty = self.is_tty
return super().format(record)
class ColoredFormatter(logging.Formatter):
"""Custom formatter with colors and fixed-width level names"""
def format(self, record):
levelname = record.levelname
# Use only time with milliseconds, not date
timestamp = datetime.datetime.now().strftime("%H:%M:%S.%f")[:-3] # HH:MM:SS.mmm format
file_info = f"{record.filename}:{record.lineno}"
# Get category from extra if available
category = getattr(record, "category", None)
msg = record.getMessage()
if getattr(record, "is_tty", False):
color = COLORS.get(levelname, COLORS["RESET"])
if category:
category_formatted = f"{COLORS['YELLOW_DIM']}{category}{COLORS['RESET']} "
formatted_msg = (
f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']} "
f"{file_info:<20} {category_formatted}{msg}"
)
else:
formatted_msg = (
f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']}] "
f"{file_info:<20} {msg}"
)
else:
if category:
formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} [{category}] {msg}"
else:
formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} {msg}"
return formatted_msg
def init(default_level: int = logging.INFO) -> None:
global _default_level, _category_levels, _logger
_default_level = default_level
_logger.setLevel(logging.DEBUG)
_logger.handlers = [] # Clear existing handlers
# Add our custom handler with the colored formatter
handler = TerminalStreamHandler()
formatter = ColoredFormatter()
handler.setFormatter(formatter)
_logger.addHandler(handler)
for category in CATEGORIES:
_category_levels[category] = default_level
env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
if env_config:
for pair in env_config.split(";"):
if not pair.strip():
continue
try:
category, level = pair.split("=", 1)
category = category.strip().lower()
level = level.strip().lower()
level_value = {
"debug": logging.DEBUG,
"info": logging.INFO,
"warning": logging.WARNING,
"warn": logging.WARNING,
"error": logging.ERROR,
"critical": logging.CRITICAL,
}.get(level)
if level_value is None:
_logger.warning(f"Unknown log level '{level}' for category '{category}'")
continue
if category == "all":
for cat in CATEGORIES:
_category_levels[cat] = level_value
else:
if category in CATEGORIES:
_category_levels[category] = level_value
else:
_logger.warning(f"Unknown logging category: {category}")
except ValueError:
_logger.warning(f"Invalid logging configuration: {pair}")
def _should_log(level: int, category: str) -> bool:
category = category.lower()
if category not in _category_levels:
return False
category_level = _category_levels[category]
return level >= category_level
def _log(level: int, level_name: str, category: str, msg: str, *args, **kwargs) -> None:
if _should_log(level, category):
kwargs.setdefault("extra", {})["category"] = category.lower()
getattr(_logger, level_name)(msg, *args, stacklevel=3, **kwargs)
def debug(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.DEBUG, "debug", category, msg, *args, **kwargs)
def info(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.INFO, "info", category, msg, *args, **kwargs)
def warning(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.WARNING, "warning", category, msg, *args, **kwargs)
def warn(category: str, msg: str, *args, **kwargs) -> None:
warning(category, msg, *args, **kwargs)
def error(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.ERROR, "error", category, msg, *args, **kwargs)
def critical(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.CRITICAL, "critical", category, msg, *args, **kwargs)
def exception(category: str, msg: str, *args, **kwargs) -> None:
if _should_log(logging.ERROR, category):
kwargs.setdefault("extra", {})["category"] = category.lower()
_logger.exception(msg, *args, stacklevel=2, **kwargs)

View file

@ -17,7 +17,6 @@ from urllib.parse import urlparse
import httpx import httpx
from llama_stack import logcat
from llama_stack.apis.agents import ( from llama_stack.apis.agents import (
AgentConfig, AgentConfig,
AgentToolGroup, AgentToolGroup,
@ -67,6 +66,7 @@ from llama_stack.apis.tools import (
ToolRuntime, ToolRuntime,
) )
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
BuiltinTool, BuiltinTool,
ToolCall, ToolCall,
@ -88,6 +88,8 @@ MEMORY_QUERY_TOOL = "knowledge_search"
WEB_SEARCH_TOOL = "web_search" WEB_SEARCH_TOOL = "web_search"
RAG_TOOL_GROUP = "builtin::rag" RAG_TOOL_GROUP = "builtin::rag"
logger = get_logger(name=__name__, category="agents")
class ChatAgent(ShieldRunnerMixin): class ChatAgent(ShieldRunnerMixin):
def __init__( def __init__(
@ -609,7 +611,7 @@ class ChatAgent(ShieldRunnerMixin):
) )
if n_iter >= self.agent_config.max_infer_iters: if n_iter >= self.agent_config.max_infer_iters:
logcat.info("agents", f"done with MAX iterations ({n_iter}), exiting.") logger.info(f"done with MAX iterations ({n_iter}), exiting.")
# NOTE: mark end_of_turn to indicate to client that we are done with the turn # NOTE: mark end_of_turn to indicate to client that we are done with the turn
# Do not continue the tool call loop after this point # Do not continue the tool call loop after this point
message.stop_reason = StopReason.end_of_turn message.stop_reason = StopReason.end_of_turn
@ -617,7 +619,7 @@ class ChatAgent(ShieldRunnerMixin):
break break
if stop_reason == StopReason.out_of_tokens: if stop_reason == StopReason.out_of_tokens:
logcat.info("agents", "out of token budget, exiting.") logger.info("out of token budget, exiting.")
yield message yield message
break break
@ -631,16 +633,10 @@ class ChatAgent(ShieldRunnerMixin):
message.content = [message.content] + output_attachments message.content = [message.content] + output_attachments
yield message yield message
else: else:
logcat.debug( logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
"agents",
f"completion message with EOM (iter: {n_iter}): {str(message)}",
)
input_messages = input_messages + [message] input_messages = input_messages + [message]
else: else:
logcat.debug( logger.debug(f"completion message (iter: {n_iter}) from the model: {str(message)}")
"agents",
f"completion message (iter: {n_iter}) from the model: {str(message)}",
)
# 1. Start the tool execution step and progress # 1. Start the tool execution step and progress
step_id = str(uuid.uuid4()) step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
@ -983,7 +979,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
path = urlparse(uri).path path = urlparse(uri).path
basename = os.path.basename(path) basename = os.path.basename(path)
filepath = f"{tempdir}/{make_random_string() + basename}" filepath = f"{tempdir}/{make_random_string() + basename}"
logcat.info("agents", f"Downloading {url} -> {filepath}") logger.info(f"Downloading {url} -> {filepath}")
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
r = await client.get(uri) r = await client.get(uri)
@ -1023,7 +1019,7 @@ async def execute_tool_call_maybe(
else: else:
name = name.value name = name.value
logcat.info("agents", f"executing tool call: {name} with args: {tool_call.arguments}") logger.info(f"executing tool call: {name} with args: {tool_call.arguments}")
result = await tool_runtime_api.invoke_tool( result = await tool_runtime_api.invoke_tool(
tool_name=name, tool_name=name,
kwargs={ kwargs={
@ -1033,7 +1029,7 @@ async def execute_tool_call_maybe(
**toolgroup_args.get(group_name, {}), **toolgroup_args.get(group_name, {}),
}, },
) )
logcat.debug("agents", f"tool call {name} completed with result: {result}") logger.info(f"tool call {name} completed with result: {result}")
return result return result

View file

@ -8,7 +8,6 @@ from typing import AsyncGenerator, List, Optional, Union
from fireworks.client import Fireworks from fireworks.client import Fireworks
from llama_stack import logcat
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
InterleavedContentItem, InterleavedContentItem,
@ -33,6 +32,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
) )
@ -55,6 +55,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import FireworksImplConfig from .config import FireworksImplConfig
from .models import MODEL_ENTRIES from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: FireworksImplConfig) -> None: def __init__(self, config: FireworksImplConfig) -> None:
@ -237,7 +239,8 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
"stream": request.stream, "stream": request.stream,
**self._build_options(request.sampling_params, request.response_format, request.logprobs), **self._build_options(request.sampling_params, request.response_format, request.logprobs),
} }
logcat.debug("inference", f"params to fireworks: {params}") logger.debug(f"params to fireworks: {params}")
return params return params
async def embeddings( async def embeddings(

View file

@ -4,13 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
from typing import AsyncGenerator, List, Optional, Union from typing import AsyncGenerator, List, Optional, Union
import httpx import httpx
from ollama import AsyncClient from ollama import AsyncClient
from llama_stack import logcat
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
ImageContentItem, ImageContentItem,
InterleavedContent, InterleavedContent,
@ -35,6 +34,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
@ -59,7 +59,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .models import model_entries from .models import model_entries
log = logging.getLogger(__name__) logger = get_logger(name=__name__, category="inference")
class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
@ -72,7 +72,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return AsyncClient(host=self.url) return AsyncClient(host=self.url)
async def initialize(self) -> None: async def initialize(self) -> None:
log.info(f"checking connectivity to Ollama at `{self.url}`...") logger.info(f"checking connectivity to Ollama at `{self.url}`...")
try: try:
await self.client.ps() await self.client.ps()
except httpx.ConnectError as e: except httpx.ConnectError as e:
@ -214,7 +214,8 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
"options": sampling_options, "options": sampling_options,
"stream": request.stream, "stream": request.stream,
} }
logcat.debug("inference", f"params to ollama: {params}") logger.debug(f"params to ollama: {params}")
return params return params
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
@ -290,7 +291,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def register_model(self, model: Model) -> Model: async def register_model(self, model: Model) -> Model:
model = await self.register_helper.register_model(model) model = await self.register_helper.register_model(model)
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
log.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...") logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
await self.client.pull(model.provider_resource_id) await self.client.pull(model.provider_resource_id)
response = await self.client.list() response = await self.client.list()
else: else:

View file

@ -8,7 +8,6 @@ from typing import AsyncGenerator, List, Optional, Union
from together import Together from together import Together
from llama_stack import logcat
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
InterleavedContentItem, InterleavedContentItem,
@ -32,6 +31,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
) )
@ -54,6 +54,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import TogetherImplConfig from .config import TogetherImplConfig
from .models import MODEL_ENTRIES from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: TogetherImplConfig) -> None: def __init__(self, config: TogetherImplConfig) -> None:
@ -224,8 +226,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
"stream": request.stream, "stream": request.stream,
**self._build_options(request.sampling_params, request.logprobs, request.response_format), **self._build_options(request.sampling_params, request.logprobs, request.response_format),
} }
logcat.debug("inference", f"params to together: {params}") logger.debug(f"params to together: {params}")
return params
async def embeddings( async def embeddings(
self, self,

View file

@ -8,7 +8,6 @@ from typing import AsyncGenerator, AsyncIterator, List, Optional, Union
import litellm import litellm
from llama_stack import logcat
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
InterleavedContentItem, InterleavedContentItem,
@ -33,6 +32,7 @@ from llama_stack.apis.inference import (
) )
from llama_stack.apis.models.models import Model from llama_stack.apis.models.models import Model
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
) )
@ -47,6 +47,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,
) )
logger = get_logger(name=__name__, category="inference")
class LiteLLMOpenAIMixin( class LiteLLMOpenAIMixin(
ModelRegistryHelper, ModelRegistryHelper,
@ -109,8 +111,7 @@ class LiteLLMOpenAIMixin(
) )
params = await self._get_params(request) params = await self._get_params(request)
logcat.debug("inference", f"params to litellm (openai compat): {params}") logger.debug(f"params to litellm (openai compat): {params}")
# unfortunately, we need to use synchronous litellm.completion here because litellm # unfortunately, we need to use synchronous litellm.completion here because litellm
# caches various httpx.client objects in a non-eventloop aware manner # caches various httpx.client objects in a non-eventloop aware manner
response = litellm.completion(**params) response = litellm.completion(**params)

View file

@ -8,14 +8,12 @@ import asyncio
import base64 import base64
import io import io
import json import json
import logging
import re import re
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import httpx import httpx
from PIL import Image as PIL_Image from PIL import Image as PIL_Image
from llama_stack import logcat
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
ImageContentItem, ImageContentItem,
InterleavedContent, InterleavedContent,
@ -34,6 +32,7 @@ from llama_stack.apis.inference import (
ToolDefinition, ToolDefinition,
UserMessage, UserMessage,
) )
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
ModelFamily, ModelFamily,
RawContent, RawContent,
@ -58,7 +57,7 @@ from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.utils.inference import supported_inference_models from llama_stack.providers.utils.inference import supported_inference_models
log = logging.getLogger(__name__) log = get_logger(name=__name__, category="inference")
class ChatCompletionRequestWithRawContent(ChatCompletionRequest): class ChatCompletionRequestWithRawContent(ChatCompletionRequest):
@ -464,7 +463,7 @@ def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefin
def get_default_tool_prompt_format(model: str) -> ToolPromptFormat: def get_default_tool_prompt_format(model: str) -> ToolPromptFormat:
llama_model = resolve_model(model) llama_model = resolve_model(model)
if llama_model is None: if llama_model is None:
logcat.warning("inference", f"Could not resolve model {model}, defaulting to json tool prompt format") log.warning(f"Could not resolve model {model}, defaulting to json tool prompt format")
return ToolPromptFormat.json return ToolPromptFormat.json
if llama_model.model_family == ModelFamily.llama3_1 or ( if llama_model.model_family == ModelFamily.llama3_1 or (

View file

@ -151,7 +151,6 @@ exclude = [
"llama_stack/distribution", "llama_stack/distribution",
"llama_stack/apis", "llama_stack/apis",
"llama_stack/cli", "llama_stack/cli",
"llama_stack/logcat.py",
"llama_stack/models", "llama_stack/models",
"llama_stack/strong_typing", "llama_stack/strong_typing",
"llama_stack/templates", "llama_stack/templates",
@ -163,5 +162,5 @@ module = ["yaml", "fire"]
ignore_missing_imports = true ignore_missing_imports = true
[[tool.mypy.overrides]] [[tool.mypy.overrides]]
module = "llama_stack.distribution.resolver" module = ["llama_stack.distribution.resolver", "llama_stack.log"]
follow_imports = "normal" # This will force type checking on this module follow_imports = "normal" # This will force type checking on this module

View file

@ -1,88 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import io
import logging
import os
import unittest
from llama_stack import logcat
class TestLogcat(unittest.TestCase):
def setUp(self):
self.original_env = os.environ.get("LLAMA_STACK_LOGGING")
self.log_output = io.StringIO()
self._init_logcat()
def tearDown(self):
if self.original_env is not None:
os.environ["LLAMA_STACK_LOGGING"] = self.original_env
else:
os.environ.pop("LLAMA_STACK_LOGGING", None)
def _init_logcat(self):
logcat.init(default_level=logging.DEBUG)
self.handler = logging.StreamHandler(self.log_output)
self.handler.setFormatter(logging.Formatter("[%(category)s] %(message)s"))
logcat._logger.handlers.clear()
logcat._logger.addHandler(self.handler)
def test_basic_logging(self):
logcat.info("server", "Info message")
logcat.warning("server", "Warning message")
logcat.error("server", "Error message")
output = self.log_output.getvalue()
self.assertIn("[server] Info message", output)
self.assertIn("[server] Warning message", output)
self.assertIn("[server] Error message", output)
def test_different_categories(self):
# Log messages with different categories
logcat.info("server", "Server message")
logcat.info("inference", "Inference message")
logcat.info("router", "Router message")
output = self.log_output.getvalue()
self.assertIn("[server] Server message", output)
self.assertIn("[inference] Inference message", output)
self.assertIn("[router] Router message", output)
def test_env_var_control(self):
os.environ["LLAMA_STACK_LOGGING"] = "server=debug;inference=warning"
self._init_logcat()
# These should be visible based on the environment settings
logcat.debug("server", "Server debug message")
logcat.info("server", "Server info message")
logcat.warning("inference", "Inference warning message")
logcat.error("inference", "Inference error message")
# These should be filtered out based on the environment settings
logcat.debug("inference", "Inference debug message")
logcat.info("inference", "Inference info message")
output = self.log_output.getvalue()
self.assertIn("[server] Server debug message", output)
self.assertIn("[server] Server info message", output)
self.assertIn("[inference] Inference warning message", output)
self.assertIn("[inference] Inference error message", output)
self.assertNotIn("[inference] Inference debug message", output)
self.assertNotIn("[inference] Inference info message", output)
def test_invalid_category(self):
logcat.info("nonexistent", "This message should not be logged")
# Check that the message was not logged
output = self.log_output.getvalue()
self.assertNotIn("[nonexistent] This message should not be logged", output)
if __name__ == "__main__":
unittest.main()