feat: add a configurable category-based logger (#1352)

A self-respecting server needs good observability which starts with
configurable logging. Llama Stack had little until now. This PR adds a
`logcat` facility towards that. Callsites look like:

```python
logcat.debug("inference", f"params to ollama: {params}")
```

- the first parameter is a category. there is a static list of
categories in `llama_stack/logcat.py`
- each category can be associated with a log-level which can be
configured via the `LLAMA_STACK_LOGGING` env var.
- a value `LLAMA_STACK_LOGGING=inference=debug;server=info"` does the
obvious thing. there is a special key called `all` which is an alias for
all categories

## Test Plan

Ran with `LLAMA_STACK_LOGGING="all=debug" llama stack run fireworks` and
saw the following:


![image](https://github.com/user-attachments/assets/d24b95ab-3941-426c-9ea0-a4c62542e6f0)

Hit it with a client-sdk test case and saw this:


![image](https://github.com/user-attachments/assets/3fee8c6c-986e-4125-a09c-f5dc019682e2)
This commit is contained in:
Ashwin Bharambe 2025-03-02 18:51:14 -08:00 committed by GitHub
parent a9a7b11326
commit 754feba61f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 409 additions and 47 deletions

View file

@ -5,9 +5,9 @@
# the root directory of this source tree.
import importlib
import inspect
import logging
from typing import Any, Dict, List, Set
from llama_stack import logcat
from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.datasetio import DatasetIO
@ -50,8 +50,6 @@ from llama_stack.providers.datatypes import (
VectorDBsProtocolPrivate,
)
log = logging.getLogger(__name__)
class InvalidProviderError(Exception):
pass
@ -128,7 +126,7 @@ async def resolve_impls(
specs = {}
for provider in providers:
if not provider.provider_id or provider.provider_id == "__disabled__":
log.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
logcat.warning("core", f"Provider `{provider.provider_type}` for API `{api}` is disabled")
continue
if provider.provider_type not in provider_registry[api]:
@ -136,11 +134,12 @@ async def resolve_impls(
p = provider_registry[api][provider.provider_type]
if p.deprecation_error:
log.error(p.deprecation_error, "red", attrs=["bold"])
logcat.error("core", p.deprecation_error)
raise InvalidProviderError(p.deprecation_error)
elif p.deprecation_warning:
log.warning(
logcat.warning(
"core",
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
)
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
@ -214,10 +213,10 @@ async def resolve_impls(
)
)
log.info(f"Resolved {len(sorted_providers)} providers")
logcat.debug("core", f"Resolved {len(sorted_providers)} providers")
for api_str, provider in sorted_providers:
log.info(f" {api_str} => {provider.provider_id}")
log.info("")
logcat.debug("core", f" {api_str} => {provider.provider_id}")
logcat.debug("core", "")
impls = {}
inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis}
@ -354,7 +353,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
obj_params = set(obj_sig.parameters)
obj_params.discard("self")
if not (proto_params <= obj_params):
log.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
logcat.error("core", f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
missing_methods.append((name, "signature_mismatch"))
else:
# Check if the method is actually implemented in the class

View file

@ -7,6 +7,7 @@
import copy
from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack import logcat
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
@ -63,12 +64,15 @@ class VectorIORouter(VectorIO):
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing VectorIORouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logcat.debug("core", "VectorIORouter.initialize")
pass
async def shutdown(self) -> None:
logcat.debug("core", "VectorIORouter.shutdown")
pass
async def register_vector_db(
@ -79,6 +83,7 @@ class VectorIORouter(VectorIO):
provider_id: Optional[str] = None,
provider_vector_db_id: Optional[str] = None,
) -> None:
logcat.debug("core", f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
await self.routing_table.register_vector_db(
vector_db_id,
embedding_model,
@ -93,6 +98,10 @@ class VectorIORouter(VectorIO):
chunks: List[Chunk],
ttl_seconds: Optional[int] = None,
) -> None:
logcat.debug(
"core",
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.id for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
)
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
async def query_chunks(
@ -101,6 +110,7 @@ class VectorIORouter(VectorIO):
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryChunksResponse:
logcat.debug("core", f"VectorIORouter.query_chunks: {vector_db_id}")
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
@ -111,12 +121,15 @@ class InferenceRouter(Inference):
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing InferenceRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logcat.debug("core", "InferenceRouter.initialize")
pass
async def shutdown(self) -> None:
logcat.debug("core", "InferenceRouter.shutdown")
pass
async def register_model(
@ -127,6 +140,10 @@ class InferenceRouter(Inference):
metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> None:
logcat.debug(
"core",
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
)
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
async def chat_completion(
@ -142,6 +159,10 @@ class InferenceRouter(Inference):
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
logcat.debug(
"core",
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
)
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
@ -203,6 +224,10 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
logcat.debug(
"core",
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
)
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
@ -230,6 +255,7 @@ class InferenceRouter(Inference):
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
logcat.debug("core", f"InferenceRouter.embeddings: {model_id}")
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
@ -249,12 +275,15 @@ class SafetyRouter(Safety):
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing SafetyRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logcat.debug("core", "SafetyRouter.initialize")
pass
async def shutdown(self) -> None:
logcat.debug("core", "SafetyRouter.shutdown")
pass
async def register_shield(
@ -264,6 +293,7 @@ class SafetyRouter(Safety):
provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
) -> Shield:
logcat.debug("core", f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
async def run_shield(
@ -272,6 +302,7 @@ class SafetyRouter(Safety):
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
logcat.debug("core", f"SafetyRouter.run_shield: {shield_id}")
return await self.routing_table.get_provider_impl(shield_id).run_shield(
shield_id=shield_id,
messages=messages,
@ -284,12 +315,15 @@ class DatasetIORouter(DatasetIO):
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing DatasetIORouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logcat.debug("core", "DatasetIORouter.initialize")
pass
async def shutdown(self) -> None:
logcat.debug("core", "DatasetIORouter.shutdown")
pass
async def get_rows_paginated(
@ -299,6 +333,7 @@ class DatasetIORouter(DatasetIO):
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
logcat.debug("core", f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}")
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=rows_in_page,
@ -307,6 +342,7 @@ class DatasetIORouter(DatasetIO):
)
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
logcat.debug("core", f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
dataset_id=dataset_id,
rows=rows,
@ -318,12 +354,15 @@ class ScoringRouter(Scoring):
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing ScoringRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logcat.debug("core", "ScoringRouter.initialize")
pass
async def shutdown(self) -> None:
logcat.debug("core", "ScoringRouter.shutdown")
pass
async def score_batch(
@ -332,6 +371,7 @@ class ScoringRouter(Scoring):
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
logcat.debug("core", f"ScoringRouter.score_batch: {dataset_id}")
res = {}
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
@ -352,6 +392,7 @@ class ScoringRouter(Scoring):
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse:
logcat.debug("core", f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
res = {}
# look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys():
@ -369,12 +410,15 @@ class EvalRouter(Eval):
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing EvalRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logcat.debug("core", "EvalRouter.initialize")
pass
async def shutdown(self) -> None:
logcat.debug("core", "EvalRouter.shutdown")
pass
async def run_eval(
@ -382,6 +426,7 @@ class EvalRouter(Eval):
benchmark_id: str,
task_config: BenchmarkConfig,
) -> Job:
logcat.debug("core", f"EvalRouter.run_eval: {benchmark_id}")
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
benchmark_id=benchmark_id,
task_config=task_config,
@ -394,6 +439,7 @@ class EvalRouter(Eval):
scoring_functions: List[str],
task_config: BenchmarkConfig,
) -> EvaluateResponse:
logcat.debug("core", f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
benchmark_id=benchmark_id,
input_rows=input_rows,
@ -406,6 +452,7 @@ class EvalRouter(Eval):
benchmark_id: str,
job_id: str,
) -> Optional[JobStatus]:
logcat.debug("core", f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
async def job_cancel(
@ -413,6 +460,7 @@ class EvalRouter(Eval):
benchmark_id: str,
job_id: str,
) -> None:
logcat.debug("core", f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
benchmark_id,
job_id,
@ -423,6 +471,7 @@ class EvalRouter(Eval):
benchmark_id: str,
job_id: str,
) -> EvaluateResponse:
logcat.debug("core", f"EvalRouter.job_result: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
benchmark_id,
job_id,
@ -435,6 +484,7 @@ class ToolRuntimeRouter(ToolRuntime):
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing ToolRuntimeRouter.RagToolImpl")
self.routing_table = routing_table
async def query(
@ -443,6 +493,7 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_ids: List[str],
query_config: Optional[RAGQueryConfig] = None,
) -> RAGQueryResult:
logcat.debug("core", f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
return await self.routing_table.get_provider_impl("knowledge_search").query(
content, vector_db_ids, query_config
)
@ -453,6 +504,10 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
logcat.debug(
"core",
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}",
)
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
documents, vector_db_id, chunk_size_in_tokens
)
@ -461,6 +516,7 @@ class ToolRuntimeRouter(ToolRuntime):
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing ToolRuntimeRouter")
self.routing_table = routing_table
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
@ -469,12 +525,15 @@ class ToolRuntimeRouter(ToolRuntime):
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
async def initialize(self) -> None:
logcat.debug("core", "ToolRuntimeRouter.initialize")
pass
async def shutdown(self) -> None:
logcat.debug("core", "ToolRuntimeRouter.shutdown")
pass
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
logcat.debug("core", f"ToolRuntimeRouter.invoke_tool: {tool_name}")
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
tool_name=tool_name,
kwargs=kwargs,
@ -483,4 +542,5 @@ class ToolRuntimeRouter(ToolRuntime):
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
logcat.debug("core", f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)

View file

@ -26,9 +26,9 @@ from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack import logcat
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import set_request_provider_data
@ -55,7 +55,7 @@ from .endpoints import get_all_api_endpoints
REPO_ROOT = Path(__file__).parent.parent.parent.parent
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s")
logger = logging.getLogger(__name__)
logcat.init()
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
@ -142,23 +142,23 @@ def handle_signal(app, signum, _) -> None:
not block the current execution.
"""
signame = signal.Signals(signum).name
logger.info(f"Received signal {signame} ({signum}). Exiting gracefully...")
logcat.info("server", f"Received signal {signame} ({signum}). Exiting gracefully...")
async def shutdown():
try:
# Gracefully shut down implementations
for impl in app.__llama_stack_impls__.values():
impl_name = impl.__class__.__name__
logger.info("Shutting down %s", impl_name)
logcat.info("server", f"Shutting down {impl_name}")
try:
if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logger.warning("No shutdown method for %s", impl_name)
logcat.warning("server", f"No shutdown method for {impl_name}")
except asyncio.TimeoutError:
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
logcat.exception("server", f"Shutdown timeout for {impl_name}")
except Exception as e:
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
logcat.exception("server", f"Failed to shutdown {impl_name}: {e}")
# Gather all running tasks
loop = asyncio.get_running_loop()
@ -172,7 +172,7 @@ def handle_signal(app, signum, _) -> None:
try:
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
except asyncio.TimeoutError:
logger.exception("Timeout while waiting for tasks to finish")
logcat.exception("server", "Timeout while waiting for tasks to finish")
except asyncio.CancelledError:
pass
finally:
@ -184,9 +184,9 @@ def handle_signal(app, signum, _) -> None:
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Starting up")
logcat.info("server", "Starting up")
yield
logger.info("Shutting down")
logcat.info("server", "Shutting down")
for impl in app.__llama_stack_impls__.values():
await impl.shutdown()
@ -209,10 +209,10 @@ async def sse_generator(event_gen):
yield create_sse_event(item)
await asyncio.sleep(0.01)
except asyncio.CancelledError:
print("Generator cancelled")
logcat.info("server", "Generator cancelled")
await event_gen.aclose()
except Exception as e:
traceback.print_exception(e)
logcat.exception("server", "Error in sse_generator")
yield create_sse_event(
{
"error": {
@ -234,7 +234,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
value = func(**kwargs)
return await maybe_await(value)
except Exception as e:
traceback.print_exception(e)
logcat.exception("server", f"Error in {func.__name__}")
raise translate_exception(e) from e
sig = inspect.signature(func)
@ -313,6 +313,8 @@ class ClientVersionMiddleware:
def main():
logcat.init()
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
parser.add_argument(
@ -352,10 +354,10 @@ def main():
for env_pair in args.env:
try:
key, value = validate_env_pair(env_pair)
logger.info(f"Setting CLI environment variable {key} => {value}")
logcat.info("server", f"Setting CLI environment variable {key} => {value}")
os.environ[key] = value
except ValueError as e:
logger.error(f"Error: {str(e)}")
logcat.error("server", f"Error: {str(e)}")
sys.exit(1)
if args.yaml_config:
@ -363,12 +365,12 @@ def main():
config_file = Path(args.yaml_config)
if not config_file.exists():
raise ValueError(f"Config file {config_file} does not exist")
logger.info(f"Using config file: {config_file}")
logcat.info("server", f"Using config file: {config_file}")
elif args.template:
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
if not config_file.exists():
raise ValueError(f"Template {args.template} does not exist")
logger.info(f"Using template {args.template} config file: {config_file}")
logcat.info("server", f"Using template {args.template} config file: {config_file}")
else:
raise ValueError("Either --yaml-config or --template must be provided")
@ -376,9 +378,10 @@ def main():
config = replace_env_vars(yaml.safe_load(fp))
config = StackRunConfig(**config)
logger.info("Run configuration:")
logcat.info("server", "Run configuration:")
safe_config = redact_sensitive_fields(config.model_dump())
logger.info(yaml.dump(safe_config, indent=2))
for log_line in yaml.dump(safe_config, indent=2).split("\n"):
logcat.info("server", log_line)
app = FastAPI(lifespan=lifespan)
app.add_middleware(TracingMiddleware)
@ -388,7 +391,7 @@ def main():
try:
impls = asyncio.run(construct_stack(config))
except InvalidProviderError as e:
logger.error(f"Error: {str(e)}")
logcat.error("server", f"Error: {str(e)}")
sys.exit(1)
if Api.telemetry in impls:
@ -433,11 +436,8 @@ def main():
)
)
logger.info(f"Serving API {api_str}")
for endpoint in endpoints:
cprint(f" {endpoint.method.upper()} {endpoint.route}", "white")
logcat.debug("server", f"Serving API {api_str}")
print("")
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler)
signal.signal(signal.SIGINT, functools.partial(handle_signal, app))
@ -463,10 +463,10 @@ def main():
"ssl_keyfile": keyfile,
"ssl_certfile": certfile,
}
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
logcat.info("server", f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0"
logger.info(f"Listening on {listen_host}:{port}")
logcat.info("server", f"Listening on {listen_host}:{port}")
uvicorn_config = {
"app": app,

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import importlib.resources
import logging
import os
import re
from typing import Any, Dict, Optional
@ -13,6 +12,7 @@ from typing import Any, Dict, Optional
import yaml
from termcolor import colored
from llama_stack import logcat
from llama_stack.apis.agents import Agents
from llama_stack.apis.batch_inference import BatchInference
from llama_stack.apis.benchmarks import Benchmarks
@ -39,8 +39,6 @@ from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__)
class LlamaStack(
VectorDBs,
@ -101,12 +99,11 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
objects_to_process = response.data if hasattr(response, "data") else response
for obj in objects_to_process:
log.info(
logcat.debug(
"core",
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
)
log.info("")
class EnvVarError(Exception):
def __init__(self, var_name: str, path: str = ""):

View file

@ -98,9 +98,8 @@ case "$env_type" in
*)
esac
set -x
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
set -x
$PYTHON_BINARY -m llama_stack.distribution.server.server \
--yaml-config "$yaml_config" \
--port "$port" \
@ -142,6 +141,8 @@ elif [[ "$env_type" == "container" ]]; then
version_tag=$(curl -s $URL | jq -r '.info.version')
fi
set -x
$CONTAINER_BINARY run $CONTAINER_OPTS -it \
-p $port:$port \
$env_vars \

204
llama_stack/logcat.py Normal file
View file

@ -0,0 +1,204 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Category-based logging utility for llama-stack.
This module provides a wrapper over the standard Python logging module that supports
categorized logging with environment variable control.
Usage:
from llama_stack import logcat
logcat.info("server", "Starting up...")
logcat.debug("inference", "Processing request...")
Environment variable:
LLAMA_STACK_LOGGING: Semicolon-separated list of category=level pairs
Example: "server=debug;inference=warning"
"""
import datetime
import logging
import os
from typing import Dict
# ANSI color codes for terminal output
COLORS = {
"RESET": "\033[0m",
"DEBUG": "\033[36m", # Cyan
"INFO": "\033[32m", # Green
"WARNING": "\033[33m", # Yellow
"ERROR": "\033[31m", # Red
"CRITICAL": "\033[35m", # Magenta
"DIM": "\033[2m", # Dimmed text
"YELLOW_DIM": "\033[2;33m", # Dimmed yellow
}
# Static list of valid categories representing various parts of the Llama Stack
# server codebase
CATEGORIES = [
"core",
"server",
"router",
"inference",
"agents",
"safety",
"eval",
"tools",
"client",
]
_logger = logging.getLogger("llama_stack")
_logger.propagate = False
_default_level = logging.INFO
# Category-level mapping (can be modified by environment variables)
_category_levels: Dict[str, int] = {}
class TerminalStreamHandler(logging.StreamHandler):
def __init__(self, stream=None):
super().__init__(stream)
self.is_tty = hasattr(self.stream, "isatty") and self.stream.isatty()
def format(self, record):
record.is_tty = self.is_tty
return super().format(record)
class ColoredFormatter(logging.Formatter):
"""Custom formatter with colors and fixed-width level names"""
def format(self, record):
levelname = record.levelname
# Use only time with milliseconds, not date
timestamp = datetime.datetime.now().strftime("%H:%M:%S.%f")[:-3] # HH:MM:SS.mmm format
file_info = f"{record.filename}:{record.lineno}"
# Get category from extra if available
category = getattr(record, "category", None)
msg = record.getMessage()
if getattr(record, "is_tty", False):
color = COLORS.get(levelname, COLORS["RESET"])
if category:
category_formatted = f"{COLORS['YELLOW_DIM']}{category}{COLORS['RESET']} "
formatted_msg = (
f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']} "
f"{file_info:<20} {category_formatted}{msg}"
)
else:
formatted_msg = (
f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']}] "
f"{file_info:<20} {msg}"
)
else:
if category:
formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} [{category}] {msg}"
else:
formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} {msg}"
return formatted_msg
def init(default_level: int = logging.INFO) -> None:
global _default_level, _category_levels, _logger
_default_level = default_level
_logger.setLevel(logging.DEBUG)
_logger.handlers = [] # Clear existing handlers
# Add our custom handler with the colored formatter
handler = TerminalStreamHandler()
formatter = ColoredFormatter()
handler.setFormatter(formatter)
_logger.addHandler(handler)
for category in CATEGORIES:
_category_levels[category] = default_level
env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
if env_config:
for pair in env_config.split(";"):
if not pair.strip():
continue
try:
category, level = pair.split("=", 1)
category = category.strip().lower()
level = level.strip().lower()
level_value = {
"debug": logging.DEBUG,
"info": logging.INFO,
"warning": logging.WARNING,
"warn": logging.WARNING,
"error": logging.ERROR,
"critical": logging.CRITICAL,
}.get(level)
if level_value is None:
_logger.warning(f"Unknown log level '{level}' for category '{category}'")
continue
if category == "all":
for cat in CATEGORIES:
_category_levels[cat] = level_value
else:
if category in CATEGORIES:
_category_levels[category] = level_value
else:
_logger.warning(f"Unknown logging category: {category}")
except ValueError:
_logger.warning(f"Invalid logging configuration: {pair}")
def _should_log(level: int, category: str) -> bool:
category = category.lower()
if category not in _category_levels:
return False
category_level = _category_levels[category]
return level >= category_level
def _log(level: int, level_name: str, category: str, msg: str, *args, **kwargs) -> None:
if _should_log(level, category):
kwargs.setdefault("extra", {})["category"] = category.lower()
getattr(_logger, level_name)(msg, *args, stacklevel=3, **kwargs)
def debug(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.DEBUG, "debug", category, msg, *args, **kwargs)
def info(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.INFO, "info", category, msg, *args, **kwargs)
def warning(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.WARNING, "warning", category, msg, *args, **kwargs)
def warn(category: str, msg: str, *args, **kwargs) -> None:
warning(category, msg, *args, **kwargs)
def error(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.ERROR, "error", category, msg, *args, **kwargs)
def critical(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.CRITICAL, "critical", category, msg, *args, **kwargs)
def exception(category: str, msg: str, *args, **kwargs) -> None:
if _should_log(logging.ERROR, category):
kwargs.setdefault("extra", {})["category"] = category.lower()
_logger.exception(msg, *args, stacklevel=2, **kwargs)

View file

@ -8,6 +8,7 @@ from typing import AsyncGenerator, List, Optional, Union
from fireworks.client import Fireworks
from llama_stack import logcat
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
@ -226,12 +227,14 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
if input_dict["prompt"].startswith("<|begin_of_text|>"):
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
return {
params = {
"model": request.model,
**input_dict,
"stream": request.stream,
**self._build_options(request.sampling_params, request.response_format, request.logprobs),
}
logcat.debug("inference", f"params to fireworks: {params}")
return params
async def embeddings(
self,

View file

@ -10,6 +10,7 @@ from typing import AsyncGenerator, List, Optional, Union
import httpx
from ollama import AsyncClient
from llama_stack import logcat
from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
@ -203,12 +204,14 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
else:
raise ValueError(f"Unknown response format type: {fmt.type}")
return {
params = {
"model": request.model,
**input_dict,
"options": sampling_options,
"stream": request.stream,
}
logcat.debug("inference", f"params to ollama: {params}")
return params
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)

View file

@ -8,6 +8,7 @@ from typing import AsyncGenerator, List, Optional, Union
from together import Together
from llama_stack import logcat
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
@ -213,12 +214,14 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
assert not media_present, "Together does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request)
return {
params = {
"model": request.model,
**input_dict,
"stream": request.stream,
**self._build_options(request.sampling_params, request.logprobs, request.response_format),
}
logcat.debug("inference", f"params to together: {params}")
return params
async def embeddings(
self,

View file

@ -8,6 +8,7 @@ from typing import AsyncGenerator, AsyncIterator, List, Optional, Union
import litellm
from llama_stack import logcat
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
@ -106,6 +107,8 @@ class LiteLLMOpenAIMixin(
)
params = await self._get_params(request)
logcat.debug("inference", f"params to litellm (openai compat): {params}")
# unfortunately, we need to use synchronous litellm.completion here because litellm
# caches various httpx.client objects in a non-eventloop aware manner
response = litellm.completion(**params)

View file

@ -153,6 +153,7 @@ exclude = [
"llama_stack/distribution",
"llama_stack/apis",
"llama_stack/cli",
"llama_stack/logcat.py",
"llama_stack/models",
"llama_stack/strong_typing",
"llama_stack/templates",

88
tests/test_logcat.py Normal file
View file

@ -0,0 +1,88 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import io
import logging
import os
import unittest
from llama_stack import logcat
class TestLogcat(unittest.TestCase):
def setUp(self):
self.original_env = os.environ.get("LLAMA_STACK_LOGGING")
self.log_output = io.StringIO()
self._init_logcat()
def tearDown(self):
if self.original_env is not None:
os.environ["LLAMA_STACK_LOGGING"] = self.original_env
else:
os.environ.pop("LLAMA_STACK_LOGGING", None)
def _init_logcat(self):
logcat.init(default_level=logging.DEBUG)
self.handler = logging.StreamHandler(self.log_output)
self.handler.setFormatter(logging.Formatter("[%(category)s] %(message)s"))
logcat._logger.handlers.clear()
logcat._logger.addHandler(self.handler)
def test_basic_logging(self):
logcat.info("server", "Info message")
logcat.warning("server", "Warning message")
logcat.error("server", "Error message")
output = self.log_output.getvalue()
self.assertIn("[server] Info message", output)
self.assertIn("[server] Warning message", output)
self.assertIn("[server] Error message", output)
def test_different_categories(self):
# Log messages with different categories
logcat.info("server", "Server message")
logcat.info("inference", "Inference message")
logcat.info("router", "Router message")
output = self.log_output.getvalue()
self.assertIn("[server] Server message", output)
self.assertIn("[inference] Inference message", output)
self.assertIn("[router] Router message", output)
def test_env_var_control(self):
os.environ["LLAMA_STACK_LOGGING"] = "server=debug;inference=warning"
self._init_logcat()
# These should be visible based on the environment settings
logcat.debug("server", "Server debug message")
logcat.info("server", "Server info message")
logcat.warning("inference", "Inference warning message")
logcat.error("inference", "Inference error message")
# These should be filtered out based on the environment settings
logcat.debug("inference", "Inference debug message")
logcat.info("inference", "Inference info message")
output = self.log_output.getvalue()
self.assertIn("[server] Server debug message", output)
self.assertIn("[server] Server info message", output)
self.assertIn("[inference] Inference warning message", output)
self.assertIn("[inference] Inference error message", output)
self.assertNotIn("[inference] Inference debug message", output)
self.assertNotIn("[inference] Inference info message", output)
def test_invalid_category(self):
logcat.info("nonexistent", "This message should not be logged")
# Check that the message was not logged
output = self.log_output.getvalue()
self.assertNotIn("[nonexistent] This message should not be logged", output)
if __name__ == "__main__":
unittest.main()