Revert "feat: add a configurable category-based logger (#1352)"

This reverts commit 754feba61f.
This commit is contained in:
Sébastien Han 2025-03-03 11:58:40 +01:00 committed by Ashwin Bharambe
parent b8c519ba11
commit efe1772727
12 changed files with 54 additions and 407 deletions

View file

@ -6,8 +6,9 @@
import importlib
import inspect
from typing import Any, Dict, List, Set, Tuple
import logging
from typing import Any, Dict, List, Set
from llama_stack import logcat
from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.datasetio import DatasetIO
@ -50,6 +51,8 @@ from llama_stack.providers.datatypes import (
VectorDBsProtocolPrivate,
)
log = logging.getLogger(__name__)
class InvalidProviderError(Exception):
pass
@ -184,7 +187,7 @@ def validate_and_prepare_providers(
specs = {}
for provider in providers:
if not provider.provider_id or provider.provider_id == "__disabled__":
logcat.warning("core", f"Provider `{provider.provider_type}` for API `{api}` is disabled")
log.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
continue
validate_provider(provider, api, provider_registry)
@ -206,11 +209,10 @@ def validate_provider(provider: Provider, api: Api, provider_registry: ProviderR
p = provider_registry[api][provider.provider_type]
if p.deprecation_error:
logcat.error("core", p.deprecation_error)
log.error(p.deprecation_error)
raise InvalidProviderError(p.deprecation_error)
elif p.deprecation_warning:
logcat.warning(
"core",
log.warning(
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
)
@ -244,10 +246,10 @@ def sort_providers_by_deps(
)
)
logcat.debug("core", f"Resolved {len(sorted_providers)} providers")
log.info(f"Resolved {len(sorted_providers)} providers")
for api_str, provider in sorted_providers:
logcat.debug("core", f" {api_str} => {provider.provider_id}")
return sorted_providers
log.debug(f" {api_str} => {provider.provider_id}")
log.debug("")
async def instantiate_providers(
@ -387,7 +389,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
obj_params = set(obj_sig.parameters)
obj_params.discard("self")
if not (proto_params <= obj_params):
logcat.error("core", f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
log.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
missing_methods.append((name, "signature_mismatch"))
else:
# Check if the method is actually implemented in the class

View file

@ -6,7 +6,6 @@
from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack import logcat
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
@ -52,6 +51,7 @@ from llama_stack.apis.tools import (
ToolRuntime,
)
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
@ -62,15 +62,12 @@ class VectorIORouter(VectorIO):
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing VectorIORouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logcat.debug("core", "VectorIORouter.initialize")
pass
async def shutdown(self) -> None:
logcat.debug("core", "VectorIORouter.shutdown")
pass
async def register_vector_db(
@ -81,8 +78,7 @@ class VectorIORouter(VectorIO):
provider_id: Optional[str] = None,
provider_vector_db_id: Optional[str] = None,
) -> None:
logcat.debug(
"core",
logger.debug(
f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}",
)
await self.routing_table.register_vector_db(
@ -99,8 +95,7 @@ class VectorIORouter(VectorIO):
chunks: List[Chunk],
ttl_seconds: Optional[int] = None,
) -> None:
logcat.debug(
"core",
logger.debug(
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
)
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
@ -111,7 +106,6 @@ class VectorIORouter(VectorIO):
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryChunksResponse:
logcat.debug("core", f"VectorIORouter.query_chunks: {vector_db_id}")
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
@ -122,15 +116,12 @@ class InferenceRouter(Inference):
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing InferenceRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logcat.debug("core", "InferenceRouter.initialize")
pass
async def shutdown(self) -> None:
logcat.debug("core", "InferenceRouter.shutdown")
pass
async def register_model(
@ -141,10 +132,6 @@ class InferenceRouter(Inference):
metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> None:
logcat.debug(
"core",
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
)
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
async def chat_completion(
@ -160,8 +147,7 @@ class InferenceRouter(Inference):
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
logcat.debug(
"core",
logger.debug(
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
)
if sampling_params is None:
@ -226,8 +212,7 @@ class InferenceRouter(Inference):
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
logcat.debug(
"core",
logger.debug(
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
)
model = await self.routing_table.get_model(model_id)
@ -257,7 +242,6 @@ class InferenceRouter(Inference):
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
logcat.debug("core", f"InferenceRouter.embeddings: {model_id}")
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
@ -277,15 +261,12 @@ class SafetyRouter(Safety):
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing SafetyRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logcat.debug("core", "SafetyRouter.initialize")
pass
async def shutdown(self) -> None:
logcat.debug("core", "SafetyRouter.shutdown")
pass
async def register_shield(
@ -295,7 +276,6 @@ class SafetyRouter(Safety):
provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
) -> Shield:
logcat.debug("core", f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
async def run_shield(
@ -304,7 +284,6 @@ class SafetyRouter(Safety):
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
logcat.debug("core", f"SafetyRouter.run_shield: {shield_id}")
return await self.routing_table.get_provider_impl(shield_id).run_shield(
shield_id=shield_id,
messages=messages,
@ -317,15 +296,12 @@ class DatasetIORouter(DatasetIO):
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing DatasetIORouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logcat.debug("core", "DatasetIORouter.initialize")
pass
async def shutdown(self) -> None:
logcat.debug("core", "DatasetIORouter.shutdown")
pass
async def get_rows_paginated(
@ -335,8 +311,7 @@ class DatasetIORouter(DatasetIO):
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
logcat.debug(
"core",
logger.debug(
f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}",
)
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
@ -347,7 +322,6 @@ class DatasetIORouter(DatasetIO):
)
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
logcat.debug("core", f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
dataset_id=dataset_id,
rows=rows,
@ -359,15 +333,12 @@ class ScoringRouter(Scoring):
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing ScoringRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logcat.debug("core", "ScoringRouter.initialize")
pass
async def shutdown(self) -> None:
logcat.debug("core", "ScoringRouter.shutdown")
pass
async def score_batch(
@ -376,7 +347,6 @@ class ScoringRouter(Scoring):
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
logcat.debug("core", f"ScoringRouter.score_batch: {dataset_id}")
res = {}
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
@ -397,8 +367,7 @@ class ScoringRouter(Scoring):
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse:
logcat.debug(
"core",
logger.debug(
f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions",
)
res = {}
@ -418,15 +387,12 @@ class EvalRouter(Eval):
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing EvalRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logcat.debug("core", "EvalRouter.initialize")
pass
async def shutdown(self) -> None:
logcat.debug("core", "EvalRouter.shutdown")
pass
async def run_eval(
@ -434,7 +400,6 @@ class EvalRouter(Eval):
benchmark_id: str,
benchmark_config: BenchmarkConfig,
) -> Job:
logcat.debug("core", f"EvalRouter.run_eval: {benchmark_id}")
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
benchmark_id=benchmark_id,
benchmark_config=benchmark_config,
@ -447,7 +412,6 @@ class EvalRouter(Eval):
scoring_functions: List[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
logcat.debug("core", f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
benchmark_id=benchmark_id,
input_rows=input_rows,
@ -460,7 +424,6 @@ class EvalRouter(Eval):
benchmark_id: str,
job_id: str,
) -> Optional[JobStatus]:
logcat.debug("core", f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
async def job_cancel(
@ -468,7 +431,6 @@ class EvalRouter(Eval):
benchmark_id: str,
job_id: str,
) -> None:
logcat.debug("core", f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
benchmark_id,
job_id,
@ -479,7 +441,6 @@ class EvalRouter(Eval):
benchmark_id: str,
job_id: str,
) -> EvaluateResponse:
logcat.debug("core", f"EvalRouter.job_result: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
benchmark_id,
job_id,
@ -492,7 +453,6 @@ class ToolRuntimeRouter(ToolRuntime):
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing ToolRuntimeRouter.RagToolImpl")
self.routing_table = routing_table
async def query(
@ -501,7 +461,6 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_ids: List[str],
query_config: Optional[RAGQueryConfig] = None,
) -> RAGQueryResult:
logcat.debug("core", f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
return await self.routing_table.get_provider_impl("knowledge_search").query(
content, vector_db_ids, query_config
)
@ -512,10 +471,6 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
logcat.debug(
"core",
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}",
)
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
documents, vector_db_id, chunk_size_in_tokens
)
@ -524,7 +479,6 @@ class ToolRuntimeRouter(ToolRuntime):
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing ToolRuntimeRouter")
self.routing_table = routing_table
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
@ -533,15 +487,12 @@ class ToolRuntimeRouter(ToolRuntime):
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
async def initialize(self) -> None:
logcat.debug("core", "ToolRuntimeRouter.initialize")
pass
async def shutdown(self) -> None:
logcat.debug("core", "ToolRuntimeRouter.shutdown")
pass
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
logcat.debug("core", f"ToolRuntimeRouter.invoke_tool: {tool_name}")
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
tool_name=tool_name,
kwargs=kwargs,
@ -550,5 +501,4 @@ class ToolRuntimeRouter(ToolRuntime):
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
logcat.debug("core", f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)

View file

@ -26,9 +26,9 @@ from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack import logcat
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import set_request_provider_data
@ -55,7 +55,7 @@ from .endpoints import get_all_api_endpoints
REPO_ROOT = Path(__file__).parent.parent.parent.parent
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s")
logcat.init()
logger = logging.getLogger(__name__)
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
@ -142,23 +142,23 @@ def handle_signal(app, signum, _) -> None:
not block the current execution.
"""
signame = signal.Signals(signum).name
logcat.info("server", f"Received signal {signame} ({signum}). Exiting gracefully...")
logger.info(f"Received signal {signame} ({signum}). Exiting gracefully...")
async def shutdown():
try:
# Gracefully shut down implementations
for impl in app.__llama_stack_impls__.values():
impl_name = impl.__class__.__name__
logcat.info("server", f"Shutting down {impl_name}")
logger.info("Shutting down %s", impl_name)
try:
if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logcat.warning("server", f"No shutdown method for {impl_name}")
logger.warning("No shutdown method for %s", impl_name)
except asyncio.TimeoutError:
logcat.exception("server", f"Shutdown timeout for {impl_name}")
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
except Exception as e:
logcat.exception("server", f"Failed to shutdown {impl_name}: {e}")
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
# Gather all running tasks
loop = asyncio.get_running_loop()
@ -172,7 +172,7 @@ def handle_signal(app, signum, _) -> None:
try:
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
except asyncio.TimeoutError:
logcat.exception("server", "Timeout while waiting for tasks to finish")
logger.exception("Timeout while waiting for tasks to finish")
except asyncio.CancelledError:
pass
finally:
@ -184,9 +184,9 @@ def handle_signal(app, signum, _) -> None:
@asynccontextmanager
async def lifespan(app: FastAPI):
logcat.info("server", "Starting up")
logger.info("Starting up")
yield
logcat.info("server", "Shutting down")
logger.info("Shutting down")
for impl in app.__llama_stack_impls__.values():
await impl.shutdown()
@ -209,11 +209,11 @@ async def sse_generator(event_gen):
yield create_sse_event(item)
await asyncio.sleep(0.01)
except asyncio.CancelledError:
logcat.info("server", "Generator cancelled")
print("Generator cancelled")
await event_gen.aclose()
except Exception as e:
logcat.exception("server", f"Error in sse_generator: {e}")
logcat.exception("server", f"Traceback: {''.join(traceback.format_exception(type(e), e, e.__traceback__))}")
logger.exception(f"Error in sse_generator: {e}")
logger.exception(f"Traceback: {''.join(traceback.format_exception(type(e), e, e.__traceback__))}")
yield create_sse_event(
{
"error": {
@ -235,7 +235,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
value = func(**kwargs)
return await maybe_await(value)
except Exception as e:
logcat.exception("server", f"Error in {func.__name__}")
traceback.print_exception(e)
raise translate_exception(e) from e
sig = inspect.signature(func)
@ -314,8 +314,6 @@ class ClientVersionMiddleware:
def main():
logcat.init()
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
parser.add_argument(
@ -355,10 +353,10 @@ def main():
for env_pair in args.env:
try:
key, value = validate_env_pair(env_pair)
logcat.info("server", f"Setting CLI environment variable {key} => {value}")
logger.info(f"Setting CLI environment variable {key} => {value}")
os.environ[key] = value
except ValueError as e:
logcat.error("server", f"Error: {str(e)}")
logger.error(f"Error: {str(e)}")
sys.exit(1)
if args.yaml_config:
@ -366,12 +364,12 @@ def main():
config_file = Path(args.yaml_config)
if not config_file.exists():
raise ValueError(f"Config file {config_file} does not exist")
logcat.info("server", f"Using config file: {config_file}")
logger.info(f"Using config file: {config_file}")
elif args.template:
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
if not config_file.exists():
raise ValueError(f"Template {args.template} does not exist")
logcat.info("server", f"Using template {args.template} config file: {config_file}")
logger.info(f"Using template {args.template} config file: {config_file}")
else:
raise ValueError("Either --yaml-config or --template must be provided")
@ -379,10 +377,9 @@ def main():
config = replace_env_vars(yaml.safe_load(fp))
config = StackRunConfig(**config)
logcat.info("server", "Run configuration:")
logger.info("Run configuration:")
safe_config = redact_sensitive_fields(config.model_dump())
for log_line in yaml.dump(safe_config, indent=2).split("\n"):
logcat.info("server", log_line)
logger.info(yaml.dump(safe_config, indent=2))
app = FastAPI(lifespan=lifespan)
app.add_middleware(TracingMiddleware)
@ -392,7 +389,7 @@ def main():
try:
impls = asyncio.run(construct_stack(config))
except InvalidProviderError as e:
logcat.error("server", f"Error: {str(e)}")
logger.error(f"Error: {str(e)}")
sys.exit(1)
if Api.telemetry in impls:
@ -437,8 +434,9 @@ def main():
)
)
logcat.debug("server", f"serving APIs: {apis_to_serve}")
logger.debug(f"serving APIs: {apis_to_serve}")
print("")
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler)
signal.signal(signal.SIGINT, functools.partial(handle_signal, app))
@ -464,10 +462,10 @@ def main():
"ssl_keyfile": keyfile,
"ssl_certfile": certfile,
}
logcat.info("server", f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0"
logcat.info("server", f"Listening on {listen_host}:{port}")
logger.info(f"Listening on {listen_host}:{port}")
uvicorn_config = {
"app": app,

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import importlib.resources
import logging
import os
import re
import tempfile
@ -13,7 +14,6 @@ from typing import Any, Dict, Optional
import yaml
from termcolor import colored
from llama_stack import logcat
from llama_stack.apis.agents import Agents
from llama_stack.apis.batch_inference import BatchInference
from llama_stack.apis.benchmarks import Benchmarks
@ -41,6 +41,8 @@ from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__)
class LlamaStack(
VectorDBs,
@ -101,11 +103,12 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
objects_to_process = response.data if hasattr(response, "data") else response
for obj in objects_to_process:
logcat.debug(
"core",
log.info(
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
)
log.info("")
class EnvVarError(Exception):
def __init__(self, var_name: str, path: str = ""):

View file

@ -98,8 +98,9 @@ case "$env_type" in
*)
esac
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
set -x
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
$PYTHON_BINARY -m llama_stack.distribution.server.server \
--yaml-config "$yaml_config" \
--port "$port" \
@ -141,8 +142,6 @@ elif [[ "$env_type" == "container" ]]; then
version_tag=$(curl -s $URL | jq -r '.info.version')
fi
set -x
$CONTAINER_BINARY run $CONTAINER_OPTS -it \
-p $port:$port \
$env_vars \

View file

@ -1,204 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Category-based logging utility for llama-stack.
This module provides a wrapper over the standard Python logging module that supports
categorized logging with environment variable control.
Usage:
from llama_stack import logcat
logcat.info("server", "Starting up...")
logcat.debug("inference", "Processing request...")
Environment variable:
LLAMA_STACK_LOGGING: Semicolon-separated list of category=level pairs
Example: "server=debug;inference=warning"
"""
import datetime
import logging
import os
from typing import Dict
# ANSI color codes for terminal output
COLORS = {
"RESET": "\033[0m",
"DEBUG": "\033[36m", # Cyan
"INFO": "\033[32m", # Green
"WARNING": "\033[33m", # Yellow
"ERROR": "\033[31m", # Red
"CRITICAL": "\033[35m", # Magenta
"DIM": "\033[2m", # Dimmed text
"YELLOW_DIM": "\033[2;33m", # Dimmed yellow
}
# Static list of valid categories representing various parts of the Llama Stack
# server codebase
CATEGORIES = [
"core",
"server",
"router",
"inference",
"agents",
"safety",
"eval",
"tools",
"client",
]
_logger = logging.getLogger("llama_stack")
_logger.propagate = False
_default_level = logging.INFO
# Category-level mapping (can be modified by environment variables)
_category_levels: Dict[str, int] = {}
class TerminalStreamHandler(logging.StreamHandler):
def __init__(self, stream=None):
super().__init__(stream)
self.is_tty = hasattr(self.stream, "isatty") and self.stream.isatty()
def format(self, record):
record.is_tty = self.is_tty
return super().format(record)
class ColoredFormatter(logging.Formatter):
"""Custom formatter with colors and fixed-width level names"""
def format(self, record):
levelname = record.levelname
# Use only time with milliseconds, not date
timestamp = datetime.datetime.now().strftime("%H:%M:%S.%f")[:-3] # HH:MM:SS.mmm format
file_info = f"{record.filename}:{record.lineno}"
# Get category from extra if available
category = getattr(record, "category", None)
msg = record.getMessage()
if getattr(record, "is_tty", False):
color = COLORS.get(levelname, COLORS["RESET"])
if category:
category_formatted = f"{COLORS['YELLOW_DIM']}{category}{COLORS['RESET']} "
formatted_msg = (
f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']} "
f"{file_info:<20} {category_formatted}{msg}"
)
else:
formatted_msg = (
f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']}] "
f"{file_info:<20} {msg}"
)
else:
if category:
formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} [{category}] {msg}"
else:
formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} {msg}"
return formatted_msg
def init(default_level: int = logging.INFO) -> None:
global _default_level, _category_levels, _logger
_default_level = default_level
_logger.setLevel(logging.DEBUG)
_logger.handlers = [] # Clear existing handlers
# Add our custom handler with the colored formatter
handler = TerminalStreamHandler()
formatter = ColoredFormatter()
handler.setFormatter(formatter)
_logger.addHandler(handler)
for category in CATEGORIES:
_category_levels[category] = default_level
env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
if env_config:
for pair in env_config.split(";"):
if not pair.strip():
continue
try:
category, level = pair.split("=", 1)
category = category.strip().lower()
level = level.strip().lower()
level_value = {
"debug": logging.DEBUG,
"info": logging.INFO,
"warning": logging.WARNING,
"warn": logging.WARNING,
"error": logging.ERROR,
"critical": logging.CRITICAL,
}.get(level)
if level_value is None:
_logger.warning(f"Unknown log level '{level}' for category '{category}'")
continue
if category == "all":
for cat in CATEGORIES:
_category_levels[cat] = level_value
else:
if category in CATEGORIES:
_category_levels[category] = level_value
else:
_logger.warning(f"Unknown logging category: {category}")
except ValueError:
_logger.warning(f"Invalid logging configuration: {pair}")
def _should_log(level: int, category: str) -> bool:
category = category.lower()
if category not in _category_levels:
return False
category_level = _category_levels[category]
return level >= category_level
def _log(level: int, level_name: str, category: str, msg: str, *args, **kwargs) -> None:
if _should_log(level, category):
kwargs.setdefault("extra", {})["category"] = category.lower()
getattr(_logger, level_name)(msg, *args, stacklevel=3, **kwargs)
def debug(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.DEBUG, "debug", category, msg, *args, **kwargs)
def info(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.INFO, "info", category, msg, *args, **kwargs)
def warning(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.WARNING, "warning", category, msg, *args, **kwargs)
def warn(category: str, msg: str, *args, **kwargs) -> None:
warning(category, msg, *args, **kwargs)
def error(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.ERROR, "error", category, msg, *args, **kwargs)
def critical(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.CRITICAL, "critical", category, msg, *args, **kwargs)
def exception(category: str, msg: str, *args, **kwargs) -> None:
if _should_log(logging.ERROR, category):
kwargs.setdefault("extra", {})["category"] = category.lower()
_logger.exception(msg, *args, stacklevel=2, **kwargs)

View file

@ -8,7 +8,6 @@ from typing import AsyncGenerator, List, Optional, Union
from fireworks.client import Fireworks
from llama_stack import logcat
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
@ -231,14 +230,12 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
if input_dict["prompt"].startswith("<|begin_of_text|>"):
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
params = {
return {
"model": request.model,
**input_dict,
"stream": request.stream,
**self._build_options(request.sampling_params, request.response_format, request.logprobs),
}
logcat.debug("inference", f"params to fireworks: {params}")
return params
async def embeddings(
self,

View file

@ -10,7 +10,6 @@ from typing import AsyncGenerator, List, Optional, Union
import httpx
from ollama import AsyncClient
from llama_stack import logcat
from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
@ -208,14 +207,12 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
else:
raise ValueError(f"Unknown response format type: {fmt.type}")
params = {
return {
"model": request.model,
**input_dict,
"options": sampling_options,
"stream": request.stream,
}
logcat.debug("inference", f"params to ollama: {params}")
return params
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)

View file

@ -8,7 +8,6 @@ from typing import AsyncGenerator, List, Optional, Union
from together import Together
from llama_stack import logcat
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
@ -218,14 +217,12 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
assert not media_present, "Together does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request)
params = {
return {
"model": request.model,
**input_dict,
"stream": request.stream,
**self._build_options(request.sampling_params, request.logprobs, request.response_format),
}
logcat.debug("inference", f"params to together: {params}")
return params
async def embeddings(
self,

View file

@ -8,7 +8,6 @@ from typing import AsyncGenerator, AsyncIterator, List, Optional, Union
import litellm
from llama_stack import logcat
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
@ -109,8 +108,6 @@ class LiteLLMOpenAIMixin(
)
params = await self._get_params(request)
logcat.debug("inference", f"params to litellm (openai compat): {params}")
# unfortunately, we need to use synchronous litellm.completion here because litellm
# caches various httpx.client objects in a non-eventloop aware manner
response = litellm.completion(**params)

View file

@ -151,7 +151,6 @@ exclude = [
"llama_stack/distribution",
"llama_stack/apis",
"llama_stack/cli",
"llama_stack/logcat.py",
"llama_stack/models",
"llama_stack/strong_typing",
"llama_stack/templates",

View file

@ -1,88 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import io
import logging
import os
import unittest
from llama_stack import logcat
class TestLogcat(unittest.TestCase):
def setUp(self):
self.original_env = os.environ.get("LLAMA_STACK_LOGGING")
self.log_output = io.StringIO()
self._init_logcat()
def tearDown(self):
if self.original_env is not None:
os.environ["LLAMA_STACK_LOGGING"] = self.original_env
else:
os.environ.pop("LLAMA_STACK_LOGGING", None)
def _init_logcat(self):
logcat.init(default_level=logging.DEBUG)
self.handler = logging.StreamHandler(self.log_output)
self.handler.setFormatter(logging.Formatter("[%(category)s] %(message)s"))
logcat._logger.handlers.clear()
logcat._logger.addHandler(self.handler)
def test_basic_logging(self):
logcat.info("server", "Info message")
logcat.warning("server", "Warning message")
logcat.error("server", "Error message")
output = self.log_output.getvalue()
self.assertIn("[server] Info message", output)
self.assertIn("[server] Warning message", output)
self.assertIn("[server] Error message", output)
def test_different_categories(self):
# Log messages with different categories
logcat.info("server", "Server message")
logcat.info("inference", "Inference message")
logcat.info("router", "Router message")
output = self.log_output.getvalue()
self.assertIn("[server] Server message", output)
self.assertIn("[inference] Inference message", output)
self.assertIn("[router] Router message", output)
def test_env_var_control(self):
os.environ["LLAMA_STACK_LOGGING"] = "server=debug;inference=warning"
self._init_logcat()
# These should be visible based on the environment settings
logcat.debug("server", "Server debug message")
logcat.info("server", "Server info message")
logcat.warning("inference", "Inference warning message")
logcat.error("inference", "Inference error message")
# These should be filtered out based on the environment settings
logcat.debug("inference", "Inference debug message")
logcat.info("inference", "Inference info message")
output = self.log_output.getvalue()
self.assertIn("[server] Server debug message", output)
self.assertIn("[server] Server info message", output)
self.assertIn("[inference] Inference warning message", output)
self.assertIn("[inference] Inference error message", output)
self.assertNotIn("[inference] Inference debug message", output)
self.assertNotIn("[inference] Inference info message", output)
def test_invalid_category(self):
logcat.info("nonexistent", "This message should not be logged")
# Check that the message was not logged
output = self.log_output.getvalue()
self.assertNotIn("[nonexistent] This message should not be logged", output)
if __name__ == "__main__":
unittest.main()