feat: add a configurable category-based logger (#1352)

A self-respecting server needs good observability which starts with
configurable logging. Llama Stack had little until now. This PR adds a
`logcat` facility towards that. Callsites look like:

```python
logcat.debug("inference", f"params to ollama: {params}")
```

- the first parameter is a category. there is a static list of
categories in `llama_stack/logcat.py`
- each category can be associated with a log-level which can be
configured via the `LLAMA_STACK_LOGGING` env var.
- a value `LLAMA_STACK_LOGGING=inference=debug;server=info"` does the
obvious thing. there is a special key called `all` which is an alias for
all categories

## Test Plan

Ran with `LLAMA_STACK_LOGGING="all=debug" llama stack run fireworks` and
saw the following:


![image](https://github.com/user-attachments/assets/d24b95ab-3941-426c-9ea0-a4c62542e6f0)

Hit it with a client-sdk test case and saw this:


![image](https://github.com/user-attachments/assets/3fee8c6c-986e-4125-a09c-f5dc019682e2)
This commit is contained in:
Ashwin Bharambe 2025-03-02 18:51:14 -08:00 committed by GitHub
parent a9a7b11326
commit 754feba61f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 409 additions and 47 deletions

View file

@ -5,9 +5,9 @@
# the root directory of this source tree. # the root directory of this source tree.
import importlib import importlib
import inspect import inspect
import logging
from typing import Any, Dict, List, Set from typing import Any, Dict, List, Set
from llama_stack import logcat
from llama_stack.apis.agents import Agents from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
@ -50,8 +50,6 @@ from llama_stack.providers.datatypes import (
VectorDBsProtocolPrivate, VectorDBsProtocolPrivate,
) )
log = logging.getLogger(__name__)
class InvalidProviderError(Exception): class InvalidProviderError(Exception):
pass pass
@ -128,7 +126,7 @@ async def resolve_impls(
specs = {} specs = {}
for provider in providers: for provider in providers:
if not provider.provider_id or provider.provider_id == "__disabled__": if not provider.provider_id or provider.provider_id == "__disabled__":
log.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled") logcat.warning("core", f"Provider `{provider.provider_type}` for API `{api}` is disabled")
continue continue
if provider.provider_type not in provider_registry[api]: if provider.provider_type not in provider_registry[api]:
@ -136,11 +134,12 @@ async def resolve_impls(
p = provider_registry[api][provider.provider_type] p = provider_registry[api][provider.provider_type]
if p.deprecation_error: if p.deprecation_error:
log.error(p.deprecation_error, "red", attrs=["bold"]) logcat.error("core", p.deprecation_error)
raise InvalidProviderError(p.deprecation_error) raise InvalidProviderError(p.deprecation_error)
elif p.deprecation_warning: elif p.deprecation_warning:
log.warning( logcat.warning(
"core",
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}", f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
) )
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies] p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
@ -214,10 +213,10 @@ async def resolve_impls(
) )
) )
log.info(f"Resolved {len(sorted_providers)} providers") logcat.debug("core", f"Resolved {len(sorted_providers)} providers")
for api_str, provider in sorted_providers: for api_str, provider in sorted_providers:
log.info(f" {api_str} => {provider.provider_id}") logcat.debug("core", f" {api_str} => {provider.provider_id}")
log.info("") logcat.debug("core", "")
impls = {} impls = {}
inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis} inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis}
@ -354,7 +353,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
obj_params = set(obj_sig.parameters) obj_params = set(obj_sig.parameters)
obj_params.discard("self") obj_params.discard("self")
if not (proto_params <= obj_params): if not (proto_params <= obj_params):
log.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}") logcat.error("core", f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
missing_methods.append((name, "signature_mismatch")) missing_methods.append((name, "signature_mismatch"))
else: else:
# Check if the method is actually implemented in the class # Check if the method is actually implemented in the class

View file

@ -7,6 +7,7 @@
import copy import copy
from typing import Any, AsyncGenerator, Dict, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack import logcat
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
URL, URL,
InterleavedContent, InterleavedContent,
@ -63,12 +64,15 @@ class VectorIORouter(VectorIO):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logcat.debug("core", "Initializing VectorIORouter")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
logcat.debug("core", "VectorIORouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logcat.debug("core", "VectorIORouter.shutdown")
pass pass
async def register_vector_db( async def register_vector_db(
@ -79,6 +83,7 @@ class VectorIORouter(VectorIO):
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
provider_vector_db_id: Optional[str] = None, provider_vector_db_id: Optional[str] = None,
) -> None: ) -> None:
logcat.debug("core", f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
await self.routing_table.register_vector_db( await self.routing_table.register_vector_db(
vector_db_id, vector_db_id,
embedding_model, embedding_model,
@ -93,6 +98,10 @@ class VectorIORouter(VectorIO):
chunks: List[Chunk], chunks: List[Chunk],
ttl_seconds: Optional[int] = None, ttl_seconds: Optional[int] = None,
) -> None: ) -> None:
logcat.debug(
"core",
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.id for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
)
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds) return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
async def query_chunks( async def query_chunks(
@ -101,6 +110,7 @@ class VectorIORouter(VectorIO):
query: InterleavedContent, query: InterleavedContent,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
logcat.debug("core", f"VectorIORouter.query_chunks: {vector_db_id}")
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params) return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
@ -111,12 +121,15 @@ class InferenceRouter(Inference):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logcat.debug("core", "Initializing InferenceRouter")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
logcat.debug("core", "InferenceRouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logcat.debug("core", "InferenceRouter.shutdown")
pass pass
async def register_model( async def register_model(
@ -127,6 +140,10 @@ class InferenceRouter(Inference):
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None, model_type: Optional[ModelType] = None,
) -> None: ) -> None:
logcat.debug(
"core",
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
)
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
async def chat_completion( async def chat_completion(
@ -142,6 +159,10 @@ class InferenceRouter(Inference):
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
logcat.debug(
"core",
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
)
model = await self.routing_table.get_model(model_id) model = await self.routing_table.get_model(model_id)
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
@ -203,6 +224,10 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
logcat.debug(
"core",
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
)
model = await self.routing_table.get_model(model_id) model = await self.routing_table.get_model(model_id)
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
@ -230,6 +255,7 @@ class InferenceRouter(Inference):
output_dimension: Optional[int] = None, output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None, task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
logcat.debug("core", f"InferenceRouter.embeddings: {model_id}")
model = await self.routing_table.get_model(model_id) model = await self.routing_table.get_model(model_id)
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
@ -249,12 +275,15 @@ class SafetyRouter(Safety):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logcat.debug("core", "Initializing SafetyRouter")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
logcat.debug("core", "SafetyRouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logcat.debug("core", "SafetyRouter.shutdown")
pass pass
async def register_shield( async def register_shield(
@ -264,6 +293,7 @@ class SafetyRouter(Safety):
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> Shield: ) -> Shield:
logcat.debug("core", f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
async def run_shield( async def run_shield(
@ -272,6 +302,7 @@ class SafetyRouter(Safety):
messages: List[Message], messages: List[Message],
params: Dict[str, Any] = None, params: Dict[str, Any] = None,
) -> RunShieldResponse: ) -> RunShieldResponse:
logcat.debug("core", f"SafetyRouter.run_shield: {shield_id}")
return await self.routing_table.get_provider_impl(shield_id).run_shield( return await self.routing_table.get_provider_impl(shield_id).run_shield(
shield_id=shield_id, shield_id=shield_id,
messages=messages, messages=messages,
@ -284,12 +315,15 @@ class DatasetIORouter(DatasetIO):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logcat.debug("core", "Initializing DatasetIORouter")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
logcat.debug("core", "DatasetIORouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logcat.debug("core", "DatasetIORouter.shutdown")
pass pass
async def get_rows_paginated( async def get_rows_paginated(
@ -299,6 +333,7 @@ class DatasetIORouter(DatasetIO):
page_token: Optional[str] = None, page_token: Optional[str] = None,
filter_condition: Optional[str] = None, filter_condition: Optional[str] = None,
) -> PaginatedRowsResult: ) -> PaginatedRowsResult:
logcat.debug("core", f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}")
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated( return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=rows_in_page, rows_in_page=rows_in_page,
@ -307,6 +342,7 @@ class DatasetIORouter(DatasetIO):
) )
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
logcat.debug("core", f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
return await self.routing_table.get_provider_impl(dataset_id).append_rows( return await self.routing_table.get_provider_impl(dataset_id).append_rows(
dataset_id=dataset_id, dataset_id=dataset_id,
rows=rows, rows=rows,
@ -318,12 +354,15 @@ class ScoringRouter(Scoring):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logcat.debug("core", "Initializing ScoringRouter")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
logcat.debug("core", "ScoringRouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logcat.debug("core", "ScoringRouter.shutdown")
pass pass
async def score_batch( async def score_batch(
@ -332,6 +371,7 @@ class ScoringRouter(Scoring):
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
logcat.debug("core", f"ScoringRouter.score_batch: {dataset_id}")
res = {} res = {}
for fn_identifier in scoring_functions.keys(): for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch( score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
@ -352,6 +392,7 @@ class ScoringRouter(Scoring):
input_rows: List[Dict[str, Any]], input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse: ) -> ScoreResponse:
logcat.debug("core", f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
res = {} res = {}
# look up and map each scoring function to its provider impl # look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys(): for fn_identifier in scoring_functions.keys():
@ -369,12 +410,15 @@ class EvalRouter(Eval):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logcat.debug("core", "Initializing EvalRouter")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
logcat.debug("core", "EvalRouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logcat.debug("core", "EvalRouter.shutdown")
pass pass
async def run_eval( async def run_eval(
@ -382,6 +426,7 @@ class EvalRouter(Eval):
benchmark_id: str, benchmark_id: str,
task_config: BenchmarkConfig, task_config: BenchmarkConfig,
) -> Job: ) -> Job:
logcat.debug("core", f"EvalRouter.run_eval: {benchmark_id}")
return await self.routing_table.get_provider_impl(benchmark_id).run_eval( return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
benchmark_id=benchmark_id, benchmark_id=benchmark_id,
task_config=task_config, task_config=task_config,
@ -394,6 +439,7 @@ class EvalRouter(Eval):
scoring_functions: List[str], scoring_functions: List[str],
task_config: BenchmarkConfig, task_config: BenchmarkConfig,
) -> EvaluateResponse: ) -> EvaluateResponse:
logcat.debug("core", f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows( return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
benchmark_id=benchmark_id, benchmark_id=benchmark_id,
input_rows=input_rows, input_rows=input_rows,
@ -406,6 +452,7 @@ class EvalRouter(Eval):
benchmark_id: str, benchmark_id: str,
job_id: str, job_id: str,
) -> Optional[JobStatus]: ) -> Optional[JobStatus]:
logcat.debug("core", f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id) return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
async def job_cancel( async def job_cancel(
@ -413,6 +460,7 @@ class EvalRouter(Eval):
benchmark_id: str, benchmark_id: str,
job_id: str, job_id: str,
) -> None: ) -> None:
logcat.debug("core", f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
await self.routing_table.get_provider_impl(benchmark_id).job_cancel( await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
benchmark_id, benchmark_id,
job_id, job_id,
@ -423,6 +471,7 @@ class EvalRouter(Eval):
benchmark_id: str, benchmark_id: str,
job_id: str, job_id: str,
) -> EvaluateResponse: ) -> EvaluateResponse:
logcat.debug("core", f"EvalRouter.job_result: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_result( return await self.routing_table.get_provider_impl(benchmark_id).job_result(
benchmark_id, benchmark_id,
job_id, job_id,
@ -435,6 +484,7 @@ class ToolRuntimeRouter(ToolRuntime):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logcat.debug("core", "Initializing ToolRuntimeRouter.RagToolImpl")
self.routing_table = routing_table self.routing_table = routing_table
async def query( async def query(
@ -443,6 +493,7 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_ids: List[str], vector_db_ids: List[str],
query_config: Optional[RAGQueryConfig] = None, query_config: Optional[RAGQueryConfig] = None,
) -> RAGQueryResult: ) -> RAGQueryResult:
logcat.debug("core", f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
return await self.routing_table.get_provider_impl("knowledge_search").query( return await self.routing_table.get_provider_impl("knowledge_search").query(
content, vector_db_ids, query_config content, vector_db_ids, query_config
) )
@ -453,6 +504,10 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_id: str, vector_db_id: str,
chunk_size_in_tokens: int = 512, chunk_size_in_tokens: int = 512,
) -> None: ) -> None:
logcat.debug(
"core",
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}",
)
return await self.routing_table.get_provider_impl("insert_into_memory").insert( return await self.routing_table.get_provider_impl("insert_into_memory").insert(
documents, vector_db_id, chunk_size_in_tokens documents, vector_db_id, chunk_size_in_tokens
) )
@ -461,6 +516,7 @@ class ToolRuntimeRouter(ToolRuntime):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logcat.debug("core", "Initializing ToolRuntimeRouter")
self.routing_table = routing_table self.routing_table = routing_table
# HACK ALERT this should be in sync with "get_all_api_endpoints()" # HACK ALERT this should be in sync with "get_all_api_endpoints()"
@ -469,12 +525,15 @@ class ToolRuntimeRouter(ToolRuntime):
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method)) setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
async def initialize(self) -> None: async def initialize(self) -> None:
logcat.debug("core", "ToolRuntimeRouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logcat.debug("core", "ToolRuntimeRouter.shutdown")
pass pass
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any: async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
logcat.debug("core", f"ToolRuntimeRouter.invoke_tool: {tool_name}")
return await self.routing_table.get_provider_impl(tool_name).invoke_tool( return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
tool_name=tool_name, tool_name=tool_name,
kwargs=kwargs, kwargs=kwargs,
@ -483,4 +542,5 @@ class ToolRuntimeRouter(ToolRuntime):
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]: ) -> List[ToolDef]:
logcat.debug("core", f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint) return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)

View file

@ -26,9 +26,9 @@ from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack import logcat
from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.request_headers import set_request_provider_data
@ -55,7 +55,7 @@ from .endpoints import get_all_api_endpoints
REPO_ROOT = Path(__file__).parent.parent.parent.parent REPO_ROOT = Path(__file__).parent.parent.parent.parent
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s") logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s")
logger = logging.getLogger(__name__) logcat.init()
def warn_with_traceback(message, category, filename, lineno, file=None, line=None): def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
@ -142,23 +142,23 @@ def handle_signal(app, signum, _) -> None:
not block the current execution. not block the current execution.
""" """
signame = signal.Signals(signum).name signame = signal.Signals(signum).name
logger.info(f"Received signal {signame} ({signum}). Exiting gracefully...") logcat.info("server", f"Received signal {signame} ({signum}). Exiting gracefully...")
async def shutdown(): async def shutdown():
try: try:
# Gracefully shut down implementations # Gracefully shut down implementations
for impl in app.__llama_stack_impls__.values(): for impl in app.__llama_stack_impls__.values():
impl_name = impl.__class__.__name__ impl_name = impl.__class__.__name__
logger.info("Shutting down %s", impl_name) logcat.info("server", f"Shutting down {impl_name}")
try: try:
if hasattr(impl, "shutdown"): if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5) await asyncio.wait_for(impl.shutdown(), timeout=5)
else: else:
logger.warning("No shutdown method for %s", impl_name) logcat.warning("server", f"No shutdown method for {impl_name}")
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True) logcat.exception("server", f"Shutdown timeout for {impl_name}")
except Exception as e: except Exception as e:
logger.exception("Failed to shutdown %s: %s", impl_name, {e}) logcat.exception("server", f"Failed to shutdown {impl_name}: {e}")
# Gather all running tasks # Gather all running tasks
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
@ -172,7 +172,7 @@ def handle_signal(app, signum, _) -> None:
try: try:
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10) await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.exception("Timeout while waiting for tasks to finish") logcat.exception("server", "Timeout while waiting for tasks to finish")
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
finally: finally:
@ -184,9 +184,9 @@ def handle_signal(app, signum, _) -> None:
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
logger.info("Starting up") logcat.info("server", "Starting up")
yield yield
logger.info("Shutting down") logcat.info("server", "Shutting down")
for impl in app.__llama_stack_impls__.values(): for impl in app.__llama_stack_impls__.values():
await impl.shutdown() await impl.shutdown()
@ -209,10 +209,10 @@ async def sse_generator(event_gen):
yield create_sse_event(item) yield create_sse_event(item)
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
except asyncio.CancelledError: except asyncio.CancelledError:
print("Generator cancelled") logcat.info("server", "Generator cancelled")
await event_gen.aclose() await event_gen.aclose()
except Exception as e: except Exception as e:
traceback.print_exception(e) logcat.exception("server", "Error in sse_generator")
yield create_sse_event( yield create_sse_event(
{ {
"error": { "error": {
@ -234,7 +234,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
value = func(**kwargs) value = func(**kwargs)
return await maybe_await(value) return await maybe_await(value)
except Exception as e: except Exception as e:
traceback.print_exception(e) logcat.exception("server", f"Error in {func.__name__}")
raise translate_exception(e) from e raise translate_exception(e) from e
sig = inspect.signature(func) sig = inspect.signature(func)
@ -313,6 +313,8 @@ class ClientVersionMiddleware:
def main(): def main():
logcat.init()
"""Start the LlamaStack server.""" """Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.") parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
parser.add_argument( parser.add_argument(
@ -352,10 +354,10 @@ def main():
for env_pair in args.env: for env_pair in args.env:
try: try:
key, value = validate_env_pair(env_pair) key, value = validate_env_pair(env_pair)
logger.info(f"Setting CLI environment variable {key} => {value}") logcat.info("server", f"Setting CLI environment variable {key} => {value}")
os.environ[key] = value os.environ[key] = value
except ValueError as e: except ValueError as e:
logger.error(f"Error: {str(e)}") logcat.error("server", f"Error: {str(e)}")
sys.exit(1) sys.exit(1)
if args.yaml_config: if args.yaml_config:
@ -363,12 +365,12 @@ def main():
config_file = Path(args.yaml_config) config_file = Path(args.yaml_config)
if not config_file.exists(): if not config_file.exists():
raise ValueError(f"Config file {config_file} does not exist") raise ValueError(f"Config file {config_file} does not exist")
logger.info(f"Using config file: {config_file}") logcat.info("server", f"Using config file: {config_file}")
elif args.template: elif args.template:
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml" config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
if not config_file.exists(): if not config_file.exists():
raise ValueError(f"Template {args.template} does not exist") raise ValueError(f"Template {args.template} does not exist")
logger.info(f"Using template {args.template} config file: {config_file}") logcat.info("server", f"Using template {args.template} config file: {config_file}")
else: else:
raise ValueError("Either --yaml-config or --template must be provided") raise ValueError("Either --yaml-config or --template must be provided")
@ -376,9 +378,10 @@ def main():
config = replace_env_vars(yaml.safe_load(fp)) config = replace_env_vars(yaml.safe_load(fp))
config = StackRunConfig(**config) config = StackRunConfig(**config)
logger.info("Run configuration:") logcat.info("server", "Run configuration:")
safe_config = redact_sensitive_fields(config.model_dump()) safe_config = redact_sensitive_fields(config.model_dump())
logger.info(yaml.dump(safe_config, indent=2)) for log_line in yaml.dump(safe_config, indent=2).split("\n"):
logcat.info("server", log_line)
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
app.add_middleware(TracingMiddleware) app.add_middleware(TracingMiddleware)
@ -388,7 +391,7 @@ def main():
try: try:
impls = asyncio.run(construct_stack(config)) impls = asyncio.run(construct_stack(config))
except InvalidProviderError as e: except InvalidProviderError as e:
logger.error(f"Error: {str(e)}") logcat.error("server", f"Error: {str(e)}")
sys.exit(1) sys.exit(1)
if Api.telemetry in impls: if Api.telemetry in impls:
@ -433,11 +436,8 @@ def main():
) )
) )
logger.info(f"Serving API {api_str}") logcat.debug("server", f"Serving API {api_str}")
for endpoint in endpoints:
cprint(f" {endpoint.method.upper()} {endpoint.route}", "white")
print("")
app.exception_handler(RequestValidationError)(global_exception_handler) app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler)
signal.signal(signal.SIGINT, functools.partial(handle_signal, app)) signal.signal(signal.SIGINT, functools.partial(handle_signal, app))
@ -463,10 +463,10 @@ def main():
"ssl_keyfile": keyfile, "ssl_keyfile": keyfile,
"ssl_certfile": certfile, "ssl_certfile": certfile,
} }
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}") logcat.info("server", f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0" listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0"
logger.info(f"Listening on {listen_host}:{port}") logcat.info("server", f"Listening on {listen_host}:{port}")
uvicorn_config = { uvicorn_config = {
"app": app, "app": app,

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import importlib.resources import importlib.resources
import logging
import os import os
import re import re
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
@ -13,6 +12,7 @@ from typing import Any, Dict, Optional
import yaml import yaml
from termcolor import colored from termcolor import colored
from llama_stack import logcat
from llama_stack.apis.agents import Agents from llama_stack.apis.agents import Agents
from llama_stack.apis.batch_inference import BatchInference from llama_stack.apis.batch_inference import BatchInference
from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.benchmarks import Benchmarks
@ -39,8 +39,6 @@ from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
from llama_stack.distribution.store.registry import create_dist_registry from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__)
class LlamaStack( class LlamaStack(
VectorDBs, VectorDBs,
@ -101,12 +99,11 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
objects_to_process = response.data if hasattr(response, "data") else response objects_to_process = response.data if hasattr(response, "data") else response
for obj in objects_to_process: for obj in objects_to_process:
log.info( logcat.debug(
"core",
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}", f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
) )
log.info("")
class EnvVarError(Exception): class EnvVarError(Exception):
def __init__(self, var_name: str, path: str = ""): def __init__(self, var_name: str, path: str = ""):

View file

@ -98,9 +98,8 @@ case "$env_type" in
*) *)
esac esac
set -x
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
set -x
$PYTHON_BINARY -m llama_stack.distribution.server.server \ $PYTHON_BINARY -m llama_stack.distribution.server.server \
--yaml-config "$yaml_config" \ --yaml-config "$yaml_config" \
--port "$port" \ --port "$port" \
@ -142,6 +141,8 @@ elif [[ "$env_type" == "container" ]]; then
version_tag=$(curl -s $URL | jq -r '.info.version') version_tag=$(curl -s $URL | jq -r '.info.version')
fi fi
set -x
$CONTAINER_BINARY run $CONTAINER_OPTS -it \ $CONTAINER_BINARY run $CONTAINER_OPTS -it \
-p $port:$port \ -p $port:$port \
$env_vars \ $env_vars \

204
llama_stack/logcat.py Normal file
View file

@ -0,0 +1,204 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Category-based logging utility for llama-stack.
This module provides a wrapper over the standard Python logging module that supports
categorized logging with environment variable control.
Usage:
from llama_stack import logcat
logcat.info("server", "Starting up...")
logcat.debug("inference", "Processing request...")
Environment variable:
LLAMA_STACK_LOGGING: Semicolon-separated list of category=level pairs
Example: "server=debug;inference=warning"
"""
import datetime
import logging
import os
from typing import Dict
# ANSI color codes for terminal output
COLORS = {
"RESET": "\033[0m",
"DEBUG": "\033[36m", # Cyan
"INFO": "\033[32m", # Green
"WARNING": "\033[33m", # Yellow
"ERROR": "\033[31m", # Red
"CRITICAL": "\033[35m", # Magenta
"DIM": "\033[2m", # Dimmed text
"YELLOW_DIM": "\033[2;33m", # Dimmed yellow
}
# Static list of valid categories representing various parts of the Llama Stack
# server codebase
CATEGORIES = [
"core",
"server",
"router",
"inference",
"agents",
"safety",
"eval",
"tools",
"client",
]
_logger = logging.getLogger("llama_stack")
_logger.propagate = False
_default_level = logging.INFO
# Category-level mapping (can be modified by environment variables)
_category_levels: Dict[str, int] = {}
class TerminalStreamHandler(logging.StreamHandler):
def __init__(self, stream=None):
super().__init__(stream)
self.is_tty = hasattr(self.stream, "isatty") and self.stream.isatty()
def format(self, record):
record.is_tty = self.is_tty
return super().format(record)
class ColoredFormatter(logging.Formatter):
"""Custom formatter with colors and fixed-width level names"""
def format(self, record):
levelname = record.levelname
# Use only time with milliseconds, not date
timestamp = datetime.datetime.now().strftime("%H:%M:%S.%f")[:-3] # HH:MM:SS.mmm format
file_info = f"{record.filename}:{record.lineno}"
# Get category from extra if available
category = getattr(record, "category", None)
msg = record.getMessage()
if getattr(record, "is_tty", False):
color = COLORS.get(levelname, COLORS["RESET"])
if category:
category_formatted = f"{COLORS['YELLOW_DIM']}{category}{COLORS['RESET']} "
formatted_msg = (
f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']} "
f"{file_info:<20} {category_formatted}{msg}"
)
else:
formatted_msg = (
f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']}] "
f"{file_info:<20} {msg}"
)
else:
if category:
formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} [{category}] {msg}"
else:
formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} {msg}"
return formatted_msg
def init(default_level: int = logging.INFO) -> None:
global _default_level, _category_levels, _logger
_default_level = default_level
_logger.setLevel(logging.DEBUG)
_logger.handlers = [] # Clear existing handlers
# Add our custom handler with the colored formatter
handler = TerminalStreamHandler()
formatter = ColoredFormatter()
handler.setFormatter(formatter)
_logger.addHandler(handler)
for category in CATEGORIES:
_category_levels[category] = default_level
env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
if env_config:
for pair in env_config.split(";"):
if not pair.strip():
continue
try:
category, level = pair.split("=", 1)
category = category.strip().lower()
level = level.strip().lower()
level_value = {
"debug": logging.DEBUG,
"info": logging.INFO,
"warning": logging.WARNING,
"warn": logging.WARNING,
"error": logging.ERROR,
"critical": logging.CRITICAL,
}.get(level)
if level_value is None:
_logger.warning(f"Unknown log level '{level}' for category '{category}'")
continue
if category == "all":
for cat in CATEGORIES:
_category_levels[cat] = level_value
else:
if category in CATEGORIES:
_category_levels[category] = level_value
else:
_logger.warning(f"Unknown logging category: {category}")
except ValueError:
_logger.warning(f"Invalid logging configuration: {pair}")
def _should_log(level: int, category: str) -> bool:
category = category.lower()
if category not in _category_levels:
return False
category_level = _category_levels[category]
return level >= category_level
def _log(level: int, level_name: str, category: str, msg: str, *args, **kwargs) -> None:
if _should_log(level, category):
kwargs.setdefault("extra", {})["category"] = category.lower()
getattr(_logger, level_name)(msg, *args, stacklevel=3, **kwargs)
def debug(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.DEBUG, "debug", category, msg, *args, **kwargs)
def info(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.INFO, "info", category, msg, *args, **kwargs)
def warning(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.WARNING, "warning", category, msg, *args, **kwargs)
def warn(category: str, msg: str, *args, **kwargs) -> None:
warning(category, msg, *args, **kwargs)
def error(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.ERROR, "error", category, msg, *args, **kwargs)
def critical(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.CRITICAL, "critical", category, msg, *args, **kwargs)
def exception(category: str, msg: str, *args, **kwargs) -> None:
if _should_log(logging.ERROR, category):
kwargs.setdefault("extra", {})["category"] = category.lower()
_logger.exception(msg, *args, stacklevel=2, **kwargs)

View file

@ -8,6 +8,7 @@ from typing import AsyncGenerator, List, Optional, Union
from fireworks.client import Fireworks from fireworks.client import Fireworks
from llama_stack import logcat
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
InterleavedContentItem, InterleavedContentItem,
@ -226,12 +227,14 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
if input_dict["prompt"].startswith("<|begin_of_text|>"): if input_dict["prompt"].startswith("<|begin_of_text|>"):
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :] input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
return { params = {
"model": request.model, "model": request.model,
**input_dict, **input_dict,
"stream": request.stream, "stream": request.stream,
**self._build_options(request.sampling_params, request.response_format, request.logprobs), **self._build_options(request.sampling_params, request.response_format, request.logprobs),
} }
logcat.debug("inference", f"params to fireworks: {params}")
return params
async def embeddings( async def embeddings(
self, self,

View file

@ -10,6 +10,7 @@ from typing import AsyncGenerator, List, Optional, Union
import httpx import httpx
from ollama import AsyncClient from ollama import AsyncClient
from llama_stack import logcat
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
ImageContentItem, ImageContentItem,
InterleavedContent, InterleavedContent,
@ -203,12 +204,14 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
else: else:
raise ValueError(f"Unknown response format type: {fmt.type}") raise ValueError(f"Unknown response format type: {fmt.type}")
return { params = {
"model": request.model, "model": request.model,
**input_dict, **input_dict,
"options": sampling_options, "options": sampling_options,
"stream": request.stream, "stream": request.stream,
} }
logcat.debug("inference", f"params to ollama: {params}")
return params
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request) params = await self._get_params(request)

View file

@ -8,6 +8,7 @@ from typing import AsyncGenerator, List, Optional, Union
from together import Together from together import Together
from llama_stack import logcat
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
InterleavedContentItem, InterleavedContentItem,
@ -213,12 +214,14 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
assert not media_present, "Together does not support media for Completion requests" assert not media_present, "Together does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request) input_dict["prompt"] = await completion_request_to_prompt(request)
return { params = {
"model": request.model, "model": request.model,
**input_dict, **input_dict,
"stream": request.stream, "stream": request.stream,
**self._build_options(request.sampling_params, request.logprobs, request.response_format), **self._build_options(request.sampling_params, request.logprobs, request.response_format),
} }
logcat.debug("inference", f"params to together: {params}")
return params
async def embeddings( async def embeddings(
self, self,

View file

@ -8,6 +8,7 @@ from typing import AsyncGenerator, AsyncIterator, List, Optional, Union
import litellm import litellm
from llama_stack import logcat
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
InterleavedContentItem, InterleavedContentItem,
@ -106,6 +107,8 @@ class LiteLLMOpenAIMixin(
) )
params = await self._get_params(request) params = await self._get_params(request)
logcat.debug("inference", f"params to litellm (openai compat): {params}")
# unfortunately, we need to use synchronous litellm.completion here because litellm # unfortunately, we need to use synchronous litellm.completion here because litellm
# caches various httpx.client objects in a non-eventloop aware manner # caches various httpx.client objects in a non-eventloop aware manner
response = litellm.completion(**params) response = litellm.completion(**params)

View file

@ -153,6 +153,7 @@ exclude = [
"llama_stack/distribution", "llama_stack/distribution",
"llama_stack/apis", "llama_stack/apis",
"llama_stack/cli", "llama_stack/cli",
"llama_stack/logcat.py",
"llama_stack/models", "llama_stack/models",
"llama_stack/strong_typing", "llama_stack/strong_typing",
"llama_stack/templates", "llama_stack/templates",

88
tests/test_logcat.py Normal file
View file

@ -0,0 +1,88 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import io
import logging
import os
import unittest
from llama_stack import logcat
class TestLogcat(unittest.TestCase):
def setUp(self):
self.original_env = os.environ.get("LLAMA_STACK_LOGGING")
self.log_output = io.StringIO()
self._init_logcat()
def tearDown(self):
if self.original_env is not None:
os.environ["LLAMA_STACK_LOGGING"] = self.original_env
else:
os.environ.pop("LLAMA_STACK_LOGGING", None)
def _init_logcat(self):
logcat.init(default_level=logging.DEBUG)
self.handler = logging.StreamHandler(self.log_output)
self.handler.setFormatter(logging.Formatter("[%(category)s] %(message)s"))
logcat._logger.handlers.clear()
logcat._logger.addHandler(self.handler)
def test_basic_logging(self):
logcat.info("server", "Info message")
logcat.warning("server", "Warning message")
logcat.error("server", "Error message")
output = self.log_output.getvalue()
self.assertIn("[server] Info message", output)
self.assertIn("[server] Warning message", output)
self.assertIn("[server] Error message", output)
def test_different_categories(self):
# Log messages with different categories
logcat.info("server", "Server message")
logcat.info("inference", "Inference message")
logcat.info("router", "Router message")
output = self.log_output.getvalue()
self.assertIn("[server] Server message", output)
self.assertIn("[inference] Inference message", output)
self.assertIn("[router] Router message", output)
def test_env_var_control(self):
os.environ["LLAMA_STACK_LOGGING"] = "server=debug;inference=warning"
self._init_logcat()
# These should be visible based on the environment settings
logcat.debug("server", "Server debug message")
logcat.info("server", "Server info message")
logcat.warning("inference", "Inference warning message")
logcat.error("inference", "Inference error message")
# These should be filtered out based on the environment settings
logcat.debug("inference", "Inference debug message")
logcat.info("inference", "Inference info message")
output = self.log_output.getvalue()
self.assertIn("[server] Server debug message", output)
self.assertIn("[server] Server info message", output)
self.assertIn("[inference] Inference warning message", output)
self.assertIn("[inference] Inference error message", output)
self.assertNotIn("[inference] Inference debug message", output)
self.assertNotIn("[inference] Inference info message", output)
def test_invalid_category(self):
logcat.info("nonexistent", "This message should not be logged")
# Check that the message was not logged
output = self.log_output.getvalue()
self.assertNotIn("[nonexistent] This message should not be logged", output)
if __name__ == "__main__":
unittest.main()