diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 3abcc3772..b3a24bb7b 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -5,9 +5,9 @@ # the root directory of this source tree. import importlib import inspect -import logging from typing import Any, Dict, List, Set +from llama_stack import logcat from llama_stack.apis.agents import Agents from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.datasetio import DatasetIO @@ -50,8 +50,6 @@ from llama_stack.providers.datatypes import ( VectorDBsProtocolPrivate, ) -log = logging.getLogger(__name__) - class InvalidProviderError(Exception): pass @@ -128,7 +126,7 @@ async def resolve_impls( specs = {} for provider in providers: if not provider.provider_id or provider.provider_id == "__disabled__": - log.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled") + logcat.warning("core", f"Provider `{provider.provider_type}` for API `{api}` is disabled") continue if provider.provider_type not in provider_registry[api]: @@ -136,11 +134,12 @@ async def resolve_impls( p = provider_registry[api][provider.provider_type] if p.deprecation_error: - log.error(p.deprecation_error, "red", attrs=["bold"]) + logcat.error("core", p.deprecation_error) raise InvalidProviderError(p.deprecation_error) elif p.deprecation_warning: - log.warning( + logcat.warning( + "core", f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}", ) p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies] @@ -214,10 +213,10 @@ async def resolve_impls( ) ) - log.info(f"Resolved {len(sorted_providers)} providers") + logcat.debug("core", f"Resolved {len(sorted_providers)} providers") for api_str, provider in sorted_providers: - log.info(f" {api_str} => {provider.provider_id}") - log.info("") + logcat.debug("core", f" {api_str} => {provider.provider_id}") + logcat.debug("core", "") impls = {} inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis} @@ -354,7 +353,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None: obj_params = set(obj_sig.parameters) obj_params.discard("self") if not (proto_params <= obj_params): - log.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}") + logcat.error("core", f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}") missing_methods.append((name, "signature_mismatch")) else: # Check if the method is actually implemented in the class diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index f85bbff97..e6d9b8060 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -7,6 +7,7 @@ import copy from typing import Any, AsyncGenerator, Dict, List, Optional +from llama_stack import logcat from llama_stack.apis.common.content_types import ( URL, InterleavedContent, @@ -63,12 +64,15 @@ class VectorIORouter(VectorIO): self, routing_table: RoutingTable, ) -> None: + logcat.debug("core", "Initializing VectorIORouter") self.routing_table = routing_table async def initialize(self) -> None: + logcat.debug("core", "VectorIORouter.initialize") pass async def shutdown(self) -> None: + logcat.debug("core", "VectorIORouter.shutdown") pass async def register_vector_db( @@ -79,6 +83,7 @@ class VectorIORouter(VectorIO): provider_id: Optional[str] = None, provider_vector_db_id: Optional[str] = None, ) -> None: + logcat.debug("core", f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}") await self.routing_table.register_vector_db( vector_db_id, embedding_model, @@ -93,6 +98,10 @@ class VectorIORouter(VectorIO): chunks: List[Chunk], ttl_seconds: Optional[int] = None, ) -> None: + logcat.debug( + "core", + f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.id for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", + ) return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds) async def query_chunks( @@ -101,6 +110,7 @@ class VectorIORouter(VectorIO): query: InterleavedContent, params: Optional[Dict[str, Any]] = None, ) -> QueryChunksResponse: + logcat.debug("core", f"VectorIORouter.query_chunks: {vector_db_id}") return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params) @@ -111,12 +121,15 @@ class InferenceRouter(Inference): self, routing_table: RoutingTable, ) -> None: + logcat.debug("core", "Initializing InferenceRouter") self.routing_table = routing_table async def initialize(self) -> None: + logcat.debug("core", "InferenceRouter.initialize") pass async def shutdown(self) -> None: + logcat.debug("core", "InferenceRouter.shutdown") pass async def register_model( @@ -127,6 +140,10 @@ class InferenceRouter(Inference): metadata: Optional[Dict[str, Any]] = None, model_type: Optional[ModelType] = None, ) -> None: + logcat.debug( + "core", + f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}", + ) await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) async def chat_completion( @@ -142,6 +159,10 @@ class InferenceRouter(Inference): logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: + logcat.debug( + "core", + f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", + ) model = await self.routing_table.get_model(model_id) if model is None: raise ValueError(f"Model '{model_id}' not found") @@ -203,6 +224,10 @@ class InferenceRouter(Inference): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + logcat.debug( + "core", + f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}", + ) model = await self.routing_table.get_model(model_id) if model is None: raise ValueError(f"Model '{model_id}' not found") @@ -230,6 +255,7 @@ class InferenceRouter(Inference): output_dimension: Optional[int] = None, task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: + logcat.debug("core", f"InferenceRouter.embeddings: {model_id}") model = await self.routing_table.get_model(model_id) if model is None: raise ValueError(f"Model '{model_id}' not found") @@ -249,12 +275,15 @@ class SafetyRouter(Safety): self, routing_table: RoutingTable, ) -> None: + logcat.debug("core", "Initializing SafetyRouter") self.routing_table = routing_table async def initialize(self) -> None: + logcat.debug("core", "SafetyRouter.initialize") pass async def shutdown(self) -> None: + logcat.debug("core", "SafetyRouter.shutdown") pass async def register_shield( @@ -264,6 +293,7 @@ class SafetyRouter(Safety): provider_id: Optional[str] = None, params: Optional[Dict[str, Any]] = None, ) -> Shield: + logcat.debug("core", f"SafetyRouter.register_shield: {shield_id}") return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) async def run_shield( @@ -272,6 +302,7 @@ class SafetyRouter(Safety): messages: List[Message], params: Dict[str, Any] = None, ) -> RunShieldResponse: + logcat.debug("core", f"SafetyRouter.run_shield: {shield_id}") return await self.routing_table.get_provider_impl(shield_id).run_shield( shield_id=shield_id, messages=messages, @@ -284,12 +315,15 @@ class DatasetIORouter(DatasetIO): self, routing_table: RoutingTable, ) -> None: + logcat.debug("core", "Initializing DatasetIORouter") self.routing_table = routing_table async def initialize(self) -> None: + logcat.debug("core", "DatasetIORouter.initialize") pass async def shutdown(self) -> None: + logcat.debug("core", "DatasetIORouter.shutdown") pass async def get_rows_paginated( @@ -299,6 +333,7 @@ class DatasetIORouter(DatasetIO): page_token: Optional[str] = None, filter_condition: Optional[str] = None, ) -> PaginatedRowsResult: + logcat.debug("core", f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}") return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated( dataset_id=dataset_id, rows_in_page=rows_in_page, @@ -307,6 +342,7 @@ class DatasetIORouter(DatasetIO): ) async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: + logcat.debug("core", f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows") return await self.routing_table.get_provider_impl(dataset_id).append_rows( dataset_id=dataset_id, rows=rows, @@ -318,12 +354,15 @@ class ScoringRouter(Scoring): self, routing_table: RoutingTable, ) -> None: + logcat.debug("core", "Initializing ScoringRouter") self.routing_table = routing_table async def initialize(self) -> None: + logcat.debug("core", "ScoringRouter.initialize") pass async def shutdown(self) -> None: + logcat.debug("core", "ScoringRouter.shutdown") pass async def score_batch( @@ -332,6 +371,7 @@ class ScoringRouter(Scoring): scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, save_results_dataset: bool = False, ) -> ScoreBatchResponse: + logcat.debug("core", f"ScoringRouter.score_batch: {dataset_id}") res = {} for fn_identifier in scoring_functions.keys(): score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch( @@ -352,6 +392,7 @@ class ScoringRouter(Scoring): input_rows: List[Dict[str, Any]], scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, ) -> ScoreResponse: + logcat.debug("core", f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions") res = {} # look up and map each scoring function to its provider impl for fn_identifier in scoring_functions.keys(): @@ -369,12 +410,15 @@ class EvalRouter(Eval): self, routing_table: RoutingTable, ) -> None: + logcat.debug("core", "Initializing EvalRouter") self.routing_table = routing_table async def initialize(self) -> None: + logcat.debug("core", "EvalRouter.initialize") pass async def shutdown(self) -> None: + logcat.debug("core", "EvalRouter.shutdown") pass async def run_eval( @@ -382,6 +426,7 @@ class EvalRouter(Eval): benchmark_id: str, task_config: BenchmarkConfig, ) -> Job: + logcat.debug("core", f"EvalRouter.run_eval: {benchmark_id}") return await self.routing_table.get_provider_impl(benchmark_id).run_eval( benchmark_id=benchmark_id, task_config=task_config, @@ -394,6 +439,7 @@ class EvalRouter(Eval): scoring_functions: List[str], task_config: BenchmarkConfig, ) -> EvaluateResponse: + logcat.debug("core", f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows") return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows( benchmark_id=benchmark_id, input_rows=input_rows, @@ -406,6 +452,7 @@ class EvalRouter(Eval): benchmark_id: str, job_id: str, ) -> Optional[JobStatus]: + logcat.debug("core", f"EvalRouter.job_status: {benchmark_id}, {job_id}") return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id) async def job_cancel( @@ -413,6 +460,7 @@ class EvalRouter(Eval): benchmark_id: str, job_id: str, ) -> None: + logcat.debug("core", f"EvalRouter.job_cancel: {benchmark_id}, {job_id}") await self.routing_table.get_provider_impl(benchmark_id).job_cancel( benchmark_id, job_id, @@ -423,6 +471,7 @@ class EvalRouter(Eval): benchmark_id: str, job_id: str, ) -> EvaluateResponse: + logcat.debug("core", f"EvalRouter.job_result: {benchmark_id}, {job_id}") return await self.routing_table.get_provider_impl(benchmark_id).job_result( benchmark_id, job_id, @@ -435,6 +484,7 @@ class ToolRuntimeRouter(ToolRuntime): self, routing_table: RoutingTable, ) -> None: + logcat.debug("core", "Initializing ToolRuntimeRouter.RagToolImpl") self.routing_table = routing_table async def query( @@ -443,6 +493,7 @@ class ToolRuntimeRouter(ToolRuntime): vector_db_ids: List[str], query_config: Optional[RAGQueryConfig] = None, ) -> RAGQueryResult: + logcat.debug("core", f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}") return await self.routing_table.get_provider_impl("knowledge_search").query( content, vector_db_ids, query_config ) @@ -453,6 +504,10 @@ class ToolRuntimeRouter(ToolRuntime): vector_db_id: str, chunk_size_in_tokens: int = 512, ) -> None: + logcat.debug( + "core", + f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}", + ) return await self.routing_table.get_provider_impl("insert_into_memory").insert( documents, vector_db_id, chunk_size_in_tokens ) @@ -461,6 +516,7 @@ class ToolRuntimeRouter(ToolRuntime): self, routing_table: RoutingTable, ) -> None: + logcat.debug("core", "Initializing ToolRuntimeRouter") self.routing_table = routing_table # HACK ALERT this should be in sync with "get_all_api_endpoints()" @@ -469,12 +525,15 @@ class ToolRuntimeRouter(ToolRuntime): setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method)) async def initialize(self) -> None: + logcat.debug("core", "ToolRuntimeRouter.initialize") pass async def shutdown(self) -> None: + logcat.debug("core", "ToolRuntimeRouter.shutdown") pass async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any: + logcat.debug("core", f"ToolRuntimeRouter.invoke_tool: {tool_name}") return await self.routing_table.get_provider_impl(tool_name).invoke_tool( tool_name=tool_name, kwargs=kwargs, @@ -483,4 +542,5 @@ class ToolRuntimeRouter(ToolRuntime): async def list_runtime_tools( self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None ) -> List[ToolDef]: + logcat.debug("core", f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index d12340e08..4b70e0087 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -26,9 +26,9 @@ from fastapi import Path as FastapiPath from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel, ValidationError -from termcolor import cprint from typing_extensions import Annotated +from llama_stack import logcat from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.request_headers import set_request_provider_data @@ -55,7 +55,7 @@ from .endpoints import get_all_api_endpoints REPO_ROOT = Path(__file__).parent.parent.parent.parent logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s") -logger = logging.getLogger(__name__) +logcat.init() def warn_with_traceback(message, category, filename, lineno, file=None, line=None): @@ -142,23 +142,23 @@ def handle_signal(app, signum, _) -> None: not block the current execution. """ signame = signal.Signals(signum).name - logger.info(f"Received signal {signame} ({signum}). Exiting gracefully...") + logcat.info("server", f"Received signal {signame} ({signum}). Exiting gracefully...") async def shutdown(): try: # Gracefully shut down implementations for impl in app.__llama_stack_impls__.values(): impl_name = impl.__class__.__name__ - logger.info("Shutting down %s", impl_name) + logcat.info("server", f"Shutting down {impl_name}") try: if hasattr(impl, "shutdown"): await asyncio.wait_for(impl.shutdown(), timeout=5) else: - logger.warning("No shutdown method for %s", impl_name) + logcat.warning("server", f"No shutdown method for {impl_name}") except asyncio.TimeoutError: - logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True) + logcat.exception("server", f"Shutdown timeout for {impl_name}") except Exception as e: - logger.exception("Failed to shutdown %s: %s", impl_name, {e}) + logcat.exception("server", f"Failed to shutdown {impl_name}: {e}") # Gather all running tasks loop = asyncio.get_running_loop() @@ -172,7 +172,7 @@ def handle_signal(app, signum, _) -> None: try: await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10) except asyncio.TimeoutError: - logger.exception("Timeout while waiting for tasks to finish") + logcat.exception("server", "Timeout while waiting for tasks to finish") except asyncio.CancelledError: pass finally: @@ -184,9 +184,9 @@ def handle_signal(app, signum, _) -> None: @asynccontextmanager async def lifespan(app: FastAPI): - logger.info("Starting up") + logcat.info("server", "Starting up") yield - logger.info("Shutting down") + logcat.info("server", "Shutting down") for impl in app.__llama_stack_impls__.values(): await impl.shutdown() @@ -209,10 +209,10 @@ async def sse_generator(event_gen): yield create_sse_event(item) await asyncio.sleep(0.01) except asyncio.CancelledError: - print("Generator cancelled") + logcat.info("server", "Generator cancelled") await event_gen.aclose() except Exception as e: - traceback.print_exception(e) + logcat.exception("server", "Error in sse_generator") yield create_sse_event( { "error": { @@ -234,7 +234,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str): value = func(**kwargs) return await maybe_await(value) except Exception as e: - traceback.print_exception(e) + logcat.exception("server", f"Error in {func.__name__}") raise translate_exception(e) from e sig = inspect.signature(func) @@ -313,6 +313,8 @@ class ClientVersionMiddleware: def main(): + logcat.init() + """Start the LlamaStack server.""" parser = argparse.ArgumentParser(description="Start the LlamaStack server.") parser.add_argument( @@ -352,10 +354,10 @@ def main(): for env_pair in args.env: try: key, value = validate_env_pair(env_pair) - logger.info(f"Setting CLI environment variable {key} => {value}") + logcat.info("server", f"Setting CLI environment variable {key} => {value}") os.environ[key] = value except ValueError as e: - logger.error(f"Error: {str(e)}") + logcat.error("server", f"Error: {str(e)}") sys.exit(1) if args.yaml_config: @@ -363,12 +365,12 @@ def main(): config_file = Path(args.yaml_config) if not config_file.exists(): raise ValueError(f"Config file {config_file} does not exist") - logger.info(f"Using config file: {config_file}") + logcat.info("server", f"Using config file: {config_file}") elif args.template: config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml" if not config_file.exists(): raise ValueError(f"Template {args.template} does not exist") - logger.info(f"Using template {args.template} config file: {config_file}") + logcat.info("server", f"Using template {args.template} config file: {config_file}") else: raise ValueError("Either --yaml-config or --template must be provided") @@ -376,9 +378,10 @@ def main(): config = replace_env_vars(yaml.safe_load(fp)) config = StackRunConfig(**config) - logger.info("Run configuration:") + logcat.info("server", "Run configuration:") safe_config = redact_sensitive_fields(config.model_dump()) - logger.info(yaml.dump(safe_config, indent=2)) + for log_line in yaml.dump(safe_config, indent=2).split("\n"): + logcat.info("server", log_line) app = FastAPI(lifespan=lifespan) app.add_middleware(TracingMiddleware) @@ -388,7 +391,7 @@ def main(): try: impls = asyncio.run(construct_stack(config)) except InvalidProviderError as e: - logger.error(f"Error: {str(e)}") + logcat.error("server", f"Error: {str(e)}") sys.exit(1) if Api.telemetry in impls: @@ -433,11 +436,8 @@ def main(): ) ) - logger.info(f"Serving API {api_str}") - for endpoint in endpoints: - cprint(f" {endpoint.method.upper()} {endpoint.route}", "white") + logcat.debug("server", f"Serving API {api_str}") - print("") app.exception_handler(RequestValidationError)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler) signal.signal(signal.SIGINT, functools.partial(handle_signal, app)) @@ -463,10 +463,10 @@ def main(): "ssl_keyfile": keyfile, "ssl_certfile": certfile, } - logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}") + logcat.info("server", f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}") listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0" - logger.info(f"Listening on {listen_host}:{port}") + logcat.info("server", f"Listening on {listen_host}:{port}") uvicorn_config = { "app": app, diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index 8f895e170..49942716a 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import importlib.resources -import logging import os import re from typing import Any, Dict, Optional @@ -13,6 +12,7 @@ from typing import Any, Dict, Optional import yaml from termcolor import colored +from llama_stack import logcat from llama_stack.apis.agents import Agents from llama_stack.apis.batch_inference import BatchInference from llama_stack.apis.benchmarks import Benchmarks @@ -39,8 +39,6 @@ from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls from llama_stack.distribution.store.registry import create_dist_registry from llama_stack.providers.datatypes import Api -log = logging.getLogger(__name__) - class LlamaStack( VectorDBs, @@ -101,12 +99,11 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]): objects_to_process = response.data if hasattr(response, "data") else response for obj in objects_to_process: - log.info( + logcat.debug( + "core", f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}", ) - log.info("") - class EnvVarError(Exception): def __init__(self, var_name: str, path: str = ""): diff --git a/llama_stack/distribution/start_stack.sh b/llama_stack/distribution/start_stack.sh index 713997331..a769bd66e 100755 --- a/llama_stack/distribution/start_stack.sh +++ b/llama_stack/distribution/start_stack.sh @@ -98,9 +98,8 @@ case "$env_type" in *) esac -set -x - if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then + set -x $PYTHON_BINARY -m llama_stack.distribution.server.server \ --yaml-config "$yaml_config" \ --port "$port" \ @@ -142,6 +141,8 @@ elif [[ "$env_type" == "container" ]]; then version_tag=$(curl -s $URL | jq -r '.info.version') fi + set -x + $CONTAINER_BINARY run $CONTAINER_OPTS -it \ -p $port:$port \ $env_vars \ diff --git a/llama_stack/logcat.py b/llama_stack/logcat.py new file mode 100644 index 000000000..0e11cb782 --- /dev/null +++ b/llama_stack/logcat.py @@ -0,0 +1,204 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Category-based logging utility for llama-stack. + +This module provides a wrapper over the standard Python logging module that supports +categorized logging with environment variable control. + +Usage: + from llama_stack import logcat + logcat.info("server", "Starting up...") + logcat.debug("inference", "Processing request...") + +Environment variable: + LLAMA_STACK_LOGGING: Semicolon-separated list of category=level pairs + Example: "server=debug;inference=warning" +""" + +import datetime +import logging +import os +from typing import Dict + +# ANSI color codes for terminal output +COLORS = { + "RESET": "\033[0m", + "DEBUG": "\033[36m", # Cyan + "INFO": "\033[32m", # Green + "WARNING": "\033[33m", # Yellow + "ERROR": "\033[31m", # Red + "CRITICAL": "\033[35m", # Magenta + "DIM": "\033[2m", # Dimmed text + "YELLOW_DIM": "\033[2;33m", # Dimmed yellow +} + +# Static list of valid categories representing various parts of the Llama Stack +# server codebase +CATEGORIES = [ + "core", + "server", + "router", + "inference", + "agents", + "safety", + "eval", + "tools", + "client", +] + +_logger = logging.getLogger("llama_stack") +_logger.propagate = False + +_default_level = logging.INFO + +# Category-level mapping (can be modified by environment variables) +_category_levels: Dict[str, int] = {} + + +class TerminalStreamHandler(logging.StreamHandler): + def __init__(self, stream=None): + super().__init__(stream) + self.is_tty = hasattr(self.stream, "isatty") and self.stream.isatty() + + def format(self, record): + record.is_tty = self.is_tty + return super().format(record) + + +class ColoredFormatter(logging.Formatter): + """Custom formatter with colors and fixed-width level names""" + + def format(self, record): + levelname = record.levelname + # Use only time with milliseconds, not date + timestamp = datetime.datetime.now().strftime("%H:%M:%S.%f")[:-3] # HH:MM:SS.mmm format + + file_info = f"{record.filename}:{record.lineno}" + + # Get category from extra if available + category = getattr(record, "category", None) + msg = record.getMessage() + + if getattr(record, "is_tty", False): + color = COLORS.get(levelname, COLORS["RESET"]) + if category: + category_formatted = f"{COLORS['YELLOW_DIM']}{category}{COLORS['RESET']} " + formatted_msg = ( + f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']} " + f"{file_info:<20} {category_formatted}{msg}" + ) + else: + formatted_msg = ( + f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']}] " + f"{file_info:<20} {msg}" + ) + else: + if category: + formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} [{category}] {msg}" + else: + formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} {msg}" + + return formatted_msg + + +def init(default_level: int = logging.INFO) -> None: + global _default_level, _category_levels, _logger + + _default_level = default_level + + _logger.setLevel(logging.DEBUG) + _logger.handlers = [] # Clear existing handlers + + # Add our custom handler with the colored formatter + handler = TerminalStreamHandler() + formatter = ColoredFormatter() + handler.setFormatter(formatter) + _logger.addHandler(handler) + + for category in CATEGORIES: + _category_levels[category] = default_level + + env_config = os.environ.get("LLAMA_STACK_LOGGING", "") + if env_config: + for pair in env_config.split(";"): + if not pair.strip(): + continue + + try: + category, level = pair.split("=", 1) + category = category.strip().lower() + level = level.strip().lower() + + level_value = { + "debug": logging.DEBUG, + "info": logging.INFO, + "warning": logging.WARNING, + "warn": logging.WARNING, + "error": logging.ERROR, + "critical": logging.CRITICAL, + }.get(level) + + if level_value is None: + _logger.warning(f"Unknown log level '{level}' for category '{category}'") + continue + + if category == "all": + for cat in CATEGORIES: + _category_levels[cat] = level_value + else: + if category in CATEGORIES: + _category_levels[category] = level_value + else: + _logger.warning(f"Unknown logging category: {category}") + + except ValueError: + _logger.warning(f"Invalid logging configuration: {pair}") + + +def _should_log(level: int, category: str) -> bool: + category = category.lower() + if category not in _category_levels: + return False + category_level = _category_levels[category] + return level >= category_level + + +def _log(level: int, level_name: str, category: str, msg: str, *args, **kwargs) -> None: + if _should_log(level, category): + kwargs.setdefault("extra", {})["category"] = category.lower() + getattr(_logger, level_name)(msg, *args, stacklevel=3, **kwargs) + + +def debug(category: str, msg: str, *args, **kwargs) -> None: + _log(logging.DEBUG, "debug", category, msg, *args, **kwargs) + + +def info(category: str, msg: str, *args, **kwargs) -> None: + _log(logging.INFO, "info", category, msg, *args, **kwargs) + + +def warning(category: str, msg: str, *args, **kwargs) -> None: + _log(logging.WARNING, "warning", category, msg, *args, **kwargs) + + +def warn(category: str, msg: str, *args, **kwargs) -> None: + warning(category, msg, *args, **kwargs) + + +def error(category: str, msg: str, *args, **kwargs) -> None: + _log(logging.ERROR, "error", category, msg, *args, **kwargs) + + +def critical(category: str, msg: str, *args, **kwargs) -> None: + _log(logging.CRITICAL, "critical", category, msg, *args, **kwargs) + + +def exception(category: str, msg: str, *args, **kwargs) -> None: + if _should_log(logging.ERROR, category): + kwargs.setdefault("extra", {})["category"] = category.lower() + _logger.exception(msg, *args, stacklevel=2, **kwargs) diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 90fe70cbf..e264fa434 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -8,6 +8,7 @@ from typing import AsyncGenerator, List, Optional, Union from fireworks.client import Fireworks +from llama_stack import logcat from llama_stack.apis.common.content_types import ( InterleavedContent, InterleavedContentItem, @@ -226,12 +227,14 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv if input_dict["prompt"].startswith("<|begin_of_text|>"): input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :] - return { + params = { "model": request.model, **input_dict, "stream": request.stream, **self._build_options(request.sampling_params, request.response_format, request.logprobs), } + logcat.debug("inference", f"params to fireworks: {params}") + return params async def embeddings( self, diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 6fcfd2e99..5a520f3b9 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -10,6 +10,7 @@ from typing import AsyncGenerator, List, Optional, Union import httpx from ollama import AsyncClient +from llama_stack import logcat from llama_stack.apis.common.content_types import ( ImageContentItem, InterleavedContent, @@ -203,12 +204,14 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): else: raise ValueError(f"Unknown response format type: {fmt.type}") - return { + params = { "model": request.model, **input_dict, "options": sampling_options, "stream": request.stream, } + logcat.debug("inference", f"params to ollama: {params}") + return params async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: params = await self._get_params(request) diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 040f04e77..6fe1bd03d 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -8,6 +8,7 @@ from typing import AsyncGenerator, List, Optional, Union from together import Together +from llama_stack import logcat from llama_stack.apis.common.content_types import ( InterleavedContent, InterleavedContentItem, @@ -213,12 +214,14 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi assert not media_present, "Together does not support media for Completion requests" input_dict["prompt"] = await completion_request_to_prompt(request) - return { + params = { "model": request.model, **input_dict, "stream": request.stream, **self._build_options(request.sampling_params, request.logprobs, request.response_format), } + logcat.debug("inference", f"params to together: {params}") + return params async def embeddings( self, diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index ddf7f193f..92199baa9 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -8,6 +8,7 @@ from typing import AsyncGenerator, AsyncIterator, List, Optional, Union import litellm +from llama_stack import logcat from llama_stack.apis.common.content_types import ( InterleavedContent, InterleavedContentItem, @@ -106,6 +107,8 @@ class LiteLLMOpenAIMixin( ) params = await self._get_params(request) + logcat.debug("inference", f"params to litellm (openai compat): {params}") + # unfortunately, we need to use synchronous litellm.completion here because litellm # caches various httpx.client objects in a non-eventloop aware manner response = litellm.completion(**params) diff --git a/pyproject.toml b/pyproject.toml index 21f443cb2..730af5888 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -153,6 +153,7 @@ exclude = [ "llama_stack/distribution", "llama_stack/apis", "llama_stack/cli", + "llama_stack/logcat.py", "llama_stack/models", "llama_stack/strong_typing", "llama_stack/templates", diff --git a/tests/test_logcat.py b/tests/test_logcat.py new file mode 100644 index 000000000..4a116a08f --- /dev/null +++ b/tests/test_logcat.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import io +import logging +import os +import unittest + +from llama_stack import logcat + + +class TestLogcat(unittest.TestCase): + def setUp(self): + self.original_env = os.environ.get("LLAMA_STACK_LOGGING") + + self.log_output = io.StringIO() + self._init_logcat() + + def tearDown(self): + if self.original_env is not None: + os.environ["LLAMA_STACK_LOGGING"] = self.original_env + else: + os.environ.pop("LLAMA_STACK_LOGGING", None) + + def _init_logcat(self): + logcat.init(default_level=logging.DEBUG) + self.handler = logging.StreamHandler(self.log_output) + self.handler.setFormatter(logging.Formatter("[%(category)s] %(message)s")) + logcat._logger.handlers.clear() + logcat._logger.addHandler(self.handler) + + def test_basic_logging(self): + logcat.info("server", "Info message") + logcat.warning("server", "Warning message") + logcat.error("server", "Error message") + + output = self.log_output.getvalue() + self.assertIn("[server] Info message", output) + self.assertIn("[server] Warning message", output) + self.assertIn("[server] Error message", output) + + def test_different_categories(self): + # Log messages with different categories + logcat.info("server", "Server message") + logcat.info("inference", "Inference message") + logcat.info("router", "Router message") + + output = self.log_output.getvalue() + self.assertIn("[server] Server message", output) + self.assertIn("[inference] Inference message", output) + self.assertIn("[router] Router message", output) + + def test_env_var_control(self): + os.environ["LLAMA_STACK_LOGGING"] = "server=debug;inference=warning" + self._init_logcat() + + # These should be visible based on the environment settings + logcat.debug("server", "Server debug message") + logcat.info("server", "Server info message") + logcat.warning("inference", "Inference warning message") + logcat.error("inference", "Inference error message") + + # These should be filtered out based on the environment settings + logcat.debug("inference", "Inference debug message") + logcat.info("inference", "Inference info message") + + output = self.log_output.getvalue() + self.assertIn("[server] Server debug message", output) + self.assertIn("[server] Server info message", output) + self.assertIn("[inference] Inference warning message", output) + self.assertIn("[inference] Inference error message", output) + + self.assertNotIn("[inference] Inference debug message", output) + self.assertNotIn("[inference] Inference info message", output) + + def test_invalid_category(self): + logcat.info("nonexistent", "This message should not be logged") + + # Check that the message was not logged + output = self.log_output.getvalue() + self.assertNotIn("[nonexistent] This message should not be logged", output) + + +if __name__ == "__main__": + unittest.main()