From 7cf1e24c4e248c8634f32f847a80101d030cb881 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Fri, 7 Mar 2025 20:34:30 +0100 Subject: [PATCH] feat(logging): implement category-based logging (#1362) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? This commit introduces a new logging system that allows loggers to be assigned a category while retaining the logger name based on the file name. The log format includes both the logger name and the category, producing output like: ``` INFO 2025-03-03 21:44:11,323 llama_stack.distribution.stack:103 [core]: Tool_groups: builtin::websearch served by tavily-search ``` Key features include: - Category-based logging: Loggers can be assigned a category (e.g., "core", "server") when programming. The logger can be loaded like this: `logger = get_logger(name=__name__, category="server")` - Environment variable control: Log levels can be configured per-category using the `LLAMA_STACK_LOGGING` environment variable. For example: `LLAMA_STACK_LOGGING="server=DEBUG;core=debug"` enables DEBUG level for the "server" and "core" categories. - `LLAMA_STACK_LOGGING="all=debug"` sets DEBUG level globally for all categories and third-party libraries. This provides fine-grained control over logging levels while maintaining a clean and informative log format. The formatter uses the rich library which provides nice colors better stack traces like so: ``` ERROR 2025-03-03 21:49:37,124 asyncio:1758 [uncategorized]: unhandled exception during asyncio.run() shutdown task: .shutdown() done, defined at /Users/leseb/Documents/AI/llama-stack/llama_stack/distribution/server/server.py:146> exception=UnboundLocalError("local variable 'loop' referenced before assignment")> ╭────────────────────────────────────── Traceback (most recent call last) ───────────────────────────────────────╮ │ /Users/leseb/Documents/AI/llama-stack/llama_stack/distribution/server/server.py:178 in shutdown │ │ │ │ 175 │ │ except asyncio.CancelledError: │ │ 176 │ │ │ pass │ │ 177 │ │ finally: │ │ ❱ 178 │ │ │ loop.stop() │ │ 179 │ │ │ 180 │ loop = asyncio.get_running_loop() │ │ 181 │ loop.create_task(shutdown()) │ ╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ UnboundLocalError: local variable 'loop' referenced before assignment ``` Co-authored-by: Ashwin Bharambe <@ashwinb> Signed-off-by: Sébastien Han [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan ``` python -m llama_stack.distribution.server.server --yaml-config ./llama_stack/templates/ollama/run.yaml INFO 2025-03-03 21:55:35,918 __main__:365 [server]: Using config file: llama_stack/templates/ollama/run.yaml INFO 2025-03-03 21:55:35,925 __main__:378 [server]: Run configuration: INFO 2025-03-03 21:55:35,928 __main__:380 [server]: apis: - agents ``` [//]: # (## Documentation) --------- Signed-off-by: Sébastien Han Co-authored-by: Ashwin Bharambe --- llama_stack/cli/stack/run.py | 4 +- llama_stack/distribution/resolver.py | 18 +- llama_stack/distribution/routers/routers.py | 106 +++++---- llama_stack/distribution/server/server.py | 53 +++-- llama_stack/distribution/stack.py | 10 +- llama_stack/distribution/start_stack.sh | 5 +- llama_stack/log.py | 169 +++++++++++++++ llama_stack/logcat.py | 204 ------------------ .../agents/meta_reference/agent_instance.py | 24 +-- .../remote/inference/fireworks/fireworks.py | 7 +- .../remote/inference/ollama/ollama.py | 13 +- .../remote/inference/together/together.py | 7 +- .../utils/inference/litellm_openai_mixin.py | 7 +- .../utils/inference/prompt_adapter.py | 7 +- pyproject.toml | 5 +- tests/unit/server/test_logcat.py | 88 -------- 16 files changed, 296 insertions(+), 431 deletions(-) create mode 100644 llama_stack/log.py delete mode 100644 llama_stack/logcat.py delete mode 100644 tests/unit/server/test_logcat.py diff --git a/llama_stack/cli/stack/run.py b/llama_stack/cli/stack/run.py index ba2273003..e5686fb10 100644 --- a/llama_stack/cli/stack/run.py +++ b/llama_stack/cli/stack/run.py @@ -5,15 +5,15 @@ # the root directory of this source tree. import argparse -import logging import os from pathlib import Path from llama_stack.cli.subcommand import Subcommand +from llama_stack.log import get_logger REPO_ROOT = Path(__file__).parent.parent.parent.parent -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="server") class StackRun(Subcommand): diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index c24df384d..d7ca4414d 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -7,7 +7,6 @@ import importlib import inspect from typing import Any, Dict, List, Set, Tuple -from llama_stack import logcat from llama_stack.apis.agents import Agents from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.datasetio import DatasetIO @@ -35,6 +34,7 @@ from llama_stack.distribution.datatypes import ( from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.utils.dynamic import instantiate_class_type +from llama_stack.log import get_logger from llama_stack.providers.datatypes import ( Api, BenchmarksProtocolPrivate, @@ -50,6 +50,8 @@ from llama_stack.providers.datatypes import ( VectorDBsProtocolPrivate, ) +logger = get_logger(name=__name__, category="core") + class InvalidProviderError(Exception): pass @@ -184,7 +186,7 @@ def validate_and_prepare_providers( specs = {} for provider in providers: if not provider.provider_id or provider.provider_id == "__disabled__": - logcat.warning("core", f"Provider `{provider.provider_type}` for API `{api}` is disabled") + logger.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled") continue validate_provider(provider, api, provider_registry) @@ -206,11 +208,10 @@ def validate_provider(provider: Provider, api: Api, provider_registry: ProviderR p = provider_registry[api][provider.provider_type] if p.deprecation_error: - logcat.error("core", p.deprecation_error) + logger.error(p.deprecation_error) raise InvalidProviderError(p.deprecation_error) elif p.deprecation_warning: - logcat.warning( - "core", + logger.warning( f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}", ) @@ -244,9 +245,10 @@ def sort_providers_by_deps( ) ) - logcat.debug("core", f"Resolved {len(sorted_providers)} providers") + logger.debug(f"Resolved {len(sorted_providers)} providers") for api_str, provider in sorted_providers: - logcat.debug("core", f" {api_str} => {provider.provider_id}") + logger.debug(f" {api_str} => {provider.provider_id}") + logger.debug("") return sorted_providers @@ -387,7 +389,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None: obj_params = set(obj_sig.parameters) obj_params.discard("self") if not (proto_params <= obj_params): - logcat.error("core", f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}") + logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}") missing_methods.append((name, "signature_mismatch")) else: # Check if the method is actually implemented in the class diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index f2c70e66f..28df67922 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -6,7 +6,6 @@ from typing import Any, AsyncGenerator, Dict, List, Optional -from llama_stack import logcat from llama_stack.apis.common.content_types import ( URL, InterleavedContent, @@ -52,8 +51,11 @@ from llama_stack.apis.tools import ( ToolRuntime, ) from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO +from llama_stack.log import get_logger from llama_stack.providers.datatypes import RoutingTable +logger = get_logger(name=__name__, category="core") + class VectorIORouter(VectorIO): """Routes to an provider based on the vector db identifier""" @@ -62,15 +64,15 @@ class VectorIORouter(VectorIO): self, routing_table: RoutingTable, ) -> None: - logcat.debug("core", "Initializing VectorIORouter") + logger.debug("Initializing VectorIORouter") self.routing_table = routing_table async def initialize(self) -> None: - logcat.debug("core", "VectorIORouter.initialize") + logger.debug("VectorIORouter.initialize") pass async def shutdown(self) -> None: - logcat.debug("core", "VectorIORouter.shutdown") + logger.debug("VectorIORouter.shutdown") pass async def register_vector_db( @@ -81,10 +83,7 @@ class VectorIORouter(VectorIO): provider_id: Optional[str] = None, provider_vector_db_id: Optional[str] = None, ) -> None: - logcat.debug( - "core", - f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}", - ) + logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}") await self.routing_table.register_vector_db( vector_db_id, embedding_model, @@ -99,8 +98,7 @@ class VectorIORouter(VectorIO): chunks: List[Chunk], ttl_seconds: Optional[int] = None, ) -> None: - logcat.debug( - "core", + logger.debug( f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", ) return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds) @@ -111,7 +109,7 @@ class VectorIORouter(VectorIO): query: InterleavedContent, params: Optional[Dict[str, Any]] = None, ) -> QueryChunksResponse: - logcat.debug("core", f"VectorIORouter.query_chunks: {vector_db_id}") + logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}") return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params) @@ -122,15 +120,15 @@ class InferenceRouter(Inference): self, routing_table: RoutingTable, ) -> None: - logcat.debug("core", "Initializing InferenceRouter") + logger.debug("Initializing InferenceRouter") self.routing_table = routing_table async def initialize(self) -> None: - logcat.debug("core", "InferenceRouter.initialize") + logger.debug("InferenceRouter.initialize") pass async def shutdown(self) -> None: - logcat.debug("core", "InferenceRouter.shutdown") + logger.debug("InferenceRouter.shutdown") pass async def register_model( @@ -141,8 +139,7 @@ class InferenceRouter(Inference): metadata: Optional[Dict[str, Any]] = None, model_type: Optional[ModelType] = None, ) -> None: - logcat.debug( - "core", + logger.debug( f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}", ) await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) @@ -160,8 +157,7 @@ class InferenceRouter(Inference): logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: - logcat.debug( - "core", + logger.debug( f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", ) if sampling_params is None: @@ -226,8 +222,7 @@ class InferenceRouter(Inference): ) -> AsyncGenerator: if sampling_params is None: sampling_params = SamplingParams() - logcat.debug( - "core", + logger.debug( f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}", ) model = await self.routing_table.get_model(model_id) @@ -257,7 +252,7 @@ class InferenceRouter(Inference): output_dimension: Optional[int] = None, task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: - logcat.debug("core", f"InferenceRouter.embeddings: {model_id}") + logger.debug(f"InferenceRouter.embeddings: {model_id}") model = await self.routing_table.get_model(model_id) if model is None: raise ValueError(f"Model '{model_id}' not found") @@ -277,15 +272,15 @@ class SafetyRouter(Safety): self, routing_table: RoutingTable, ) -> None: - logcat.debug("core", "Initializing SafetyRouter") + logger.debug("Initializing SafetyRouter") self.routing_table = routing_table async def initialize(self) -> None: - logcat.debug("core", "SafetyRouter.initialize") + logger.debug("SafetyRouter.initialize") pass async def shutdown(self) -> None: - logcat.debug("core", "SafetyRouter.shutdown") + logger.debug("SafetyRouter.shutdown") pass async def register_shield( @@ -295,7 +290,7 @@ class SafetyRouter(Safety): provider_id: Optional[str] = None, params: Optional[Dict[str, Any]] = None, ) -> Shield: - logcat.debug("core", f"SafetyRouter.register_shield: {shield_id}") + logger.debug(f"SafetyRouter.register_shield: {shield_id}") return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) async def run_shield( @@ -304,7 +299,7 @@ class SafetyRouter(Safety): messages: List[Message], params: Dict[str, Any] = None, ) -> RunShieldResponse: - logcat.debug("core", f"SafetyRouter.run_shield: {shield_id}") + logger.debug(f"SafetyRouter.run_shield: {shield_id}") return await self.routing_table.get_provider_impl(shield_id).run_shield( shield_id=shield_id, messages=messages, @@ -317,15 +312,15 @@ class DatasetIORouter(DatasetIO): self, routing_table: RoutingTable, ) -> None: - logcat.debug("core", "Initializing DatasetIORouter") + logger.debug("Initializing DatasetIORouter") self.routing_table = routing_table async def initialize(self) -> None: - logcat.debug("core", "DatasetIORouter.initialize") + logger.debug("DatasetIORouter.initialize") pass async def shutdown(self) -> None: - logcat.debug("core", "DatasetIORouter.shutdown") + logger.debug("DatasetIORouter.shutdown") pass async def get_rows_paginated( @@ -335,8 +330,7 @@ class DatasetIORouter(DatasetIO): page_token: Optional[str] = None, filter_condition: Optional[str] = None, ) -> PaginatedRowsResult: - logcat.debug( - "core", + logger.debug( f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}", ) return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated( @@ -347,7 +341,7 @@ class DatasetIORouter(DatasetIO): ) async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: - logcat.debug("core", f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows") + logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows") return await self.routing_table.get_provider_impl(dataset_id).append_rows( dataset_id=dataset_id, rows=rows, @@ -359,15 +353,15 @@ class ScoringRouter(Scoring): self, routing_table: RoutingTable, ) -> None: - logcat.debug("core", "Initializing ScoringRouter") + logger.debug("Initializing ScoringRouter") self.routing_table = routing_table async def initialize(self) -> None: - logcat.debug("core", "ScoringRouter.initialize") + logger.debug("ScoringRouter.initialize") pass async def shutdown(self) -> None: - logcat.debug("core", "ScoringRouter.shutdown") + logger.debug("ScoringRouter.shutdown") pass async def score_batch( @@ -376,7 +370,7 @@ class ScoringRouter(Scoring): scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, save_results_dataset: bool = False, ) -> ScoreBatchResponse: - logcat.debug("core", f"ScoringRouter.score_batch: {dataset_id}") + logger.debug(f"ScoringRouter.score_batch: {dataset_id}") res = {} for fn_identifier in scoring_functions.keys(): score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch( @@ -397,10 +391,7 @@ class ScoringRouter(Scoring): input_rows: List[Dict[str, Any]], scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, ) -> ScoreResponse: - logcat.debug( - "core", - f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions", - ) + logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions") res = {} # look up and map each scoring function to its provider impl for fn_identifier in scoring_functions.keys(): @@ -418,15 +409,15 @@ class EvalRouter(Eval): self, routing_table: RoutingTable, ) -> None: - logcat.debug("core", "Initializing EvalRouter") + logger.debug("Initializing EvalRouter") self.routing_table = routing_table async def initialize(self) -> None: - logcat.debug("core", "EvalRouter.initialize") + logger.debug("EvalRouter.initialize") pass async def shutdown(self) -> None: - logcat.debug("core", "EvalRouter.shutdown") + logger.debug("EvalRouter.shutdown") pass async def run_eval( @@ -434,7 +425,7 @@ class EvalRouter(Eval): benchmark_id: str, benchmark_config: BenchmarkConfig, ) -> Job: - logcat.debug("core", f"EvalRouter.run_eval: {benchmark_id}") + logger.debug(f"EvalRouter.run_eval: {benchmark_id}") return await self.routing_table.get_provider_impl(benchmark_id).run_eval( benchmark_id=benchmark_id, benchmark_config=benchmark_config, @@ -447,7 +438,7 @@ class EvalRouter(Eval): scoring_functions: List[str], benchmark_config: BenchmarkConfig, ) -> EvaluateResponse: - logcat.debug("core", f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows") + logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows") return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows( benchmark_id=benchmark_id, input_rows=input_rows, @@ -460,7 +451,7 @@ class EvalRouter(Eval): benchmark_id: str, job_id: str, ) -> Optional[JobStatus]: - logcat.debug("core", f"EvalRouter.job_status: {benchmark_id}, {job_id}") + logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}") return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id) async def job_cancel( @@ -468,7 +459,7 @@ class EvalRouter(Eval): benchmark_id: str, job_id: str, ) -> None: - logcat.debug("core", f"EvalRouter.job_cancel: {benchmark_id}, {job_id}") + logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}") await self.routing_table.get_provider_impl(benchmark_id).job_cancel( benchmark_id, job_id, @@ -479,7 +470,7 @@ class EvalRouter(Eval): benchmark_id: str, job_id: str, ) -> EvaluateResponse: - logcat.debug("core", f"EvalRouter.job_result: {benchmark_id}, {job_id}") + logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}") return await self.routing_table.get_provider_impl(benchmark_id).job_result( benchmark_id, job_id, @@ -492,7 +483,7 @@ class ToolRuntimeRouter(ToolRuntime): self, routing_table: RoutingTable, ) -> None: - logcat.debug("core", "Initializing ToolRuntimeRouter.RagToolImpl") + logger.debug("Initializing ToolRuntimeRouter.RagToolImpl") self.routing_table = routing_table async def query( @@ -501,7 +492,7 @@ class ToolRuntimeRouter(ToolRuntime): vector_db_ids: List[str], query_config: Optional[RAGQueryConfig] = None, ) -> RAGQueryResult: - logcat.debug("core", f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}") + logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}") return await self.routing_table.get_provider_impl("knowledge_search").query( content, vector_db_ids, query_config ) @@ -512,9 +503,8 @@ class ToolRuntimeRouter(ToolRuntime): vector_db_id: str, chunk_size_in_tokens: int = 512, ) -> None: - logcat.debug( - "core", - f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}", + logger.debug( + f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}" ) return await self.routing_table.get_provider_impl("insert_into_memory").insert( documents, vector_db_id, chunk_size_in_tokens @@ -524,7 +514,7 @@ class ToolRuntimeRouter(ToolRuntime): self, routing_table: RoutingTable, ) -> None: - logcat.debug("core", "Initializing ToolRuntimeRouter") + logger.debug("Initializing ToolRuntimeRouter") self.routing_table = routing_table # HACK ALERT this should be in sync with "get_all_api_endpoints()" @@ -533,15 +523,15 @@ class ToolRuntimeRouter(ToolRuntime): setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method)) async def initialize(self) -> None: - logcat.debug("core", "ToolRuntimeRouter.initialize") + logger.debug("ToolRuntimeRouter.initialize") pass async def shutdown(self) -> None: - logcat.debug("core", "ToolRuntimeRouter.shutdown") + logger.debug("ToolRuntimeRouter.shutdown") pass async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any: - logcat.debug("core", f"ToolRuntimeRouter.invoke_tool: {tool_name}") + logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}") return await self.routing_table.get_provider_impl(tool_name).invoke_tool( tool_name=tool_name, kwargs=kwargs, @@ -550,5 +540,5 @@ class ToolRuntimeRouter(ToolRuntime): async def list_runtime_tools( self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None ) -> List[ToolDef]: - logcat.debug("core", f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") + logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 2fc36e58f..c4ef79a69 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -9,7 +9,6 @@ import asyncio import functools import inspect import json -import logging import os import signal import sys @@ -28,7 +27,6 @@ from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel, ValidationError from typing_extensions import Annotated -from llama_stack import logcat from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.request_headers import set_request_provider_data @@ -39,6 +37,7 @@ from llama_stack.distribution.stack import ( replace_env_vars, validate_env_pair, ) +from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig from llama_stack.providers.inline.telemetry.meta_reference.telemetry import ( @@ -54,8 +53,7 @@ from .endpoints import get_all_api_endpoints REPO_ROOT = Path(__file__).parent.parent.parent.parent -logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s") -logcat.init() +logger = get_logger(name=__name__, category="server") def warn_with_traceback(message, category, filename, lineno, file=None, line=None): @@ -142,23 +140,23 @@ def handle_signal(app, signum, _) -> None: not block the current execution. """ signame = signal.Signals(signum).name - logcat.info("server", f"Received signal {signame} ({signum}). Exiting gracefully...") + logger.info(f"Received signal {signame} ({signum}). Exiting gracefully...") async def shutdown(): try: # Gracefully shut down implementations for impl in app.__llama_stack_impls__.values(): impl_name = impl.__class__.__name__ - logcat.info("server", f"Shutting down {impl_name}") + logger.info("Shutting down %s", impl_name) try: if hasattr(impl, "shutdown"): await asyncio.wait_for(impl.shutdown(), timeout=5) else: - logcat.warning("server", f"No shutdown method for {impl_name}") + logger.warning("No shutdown method for %s", impl_name) except asyncio.TimeoutError: - logcat.exception("server", f"Shutdown timeout for {impl_name}") + logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True) except Exception as e: - logcat.exception("server", f"Failed to shutdown {impl_name}: {e}") + logger.exception("Failed to shutdown %s: %s", impl_name, {e}) # Gather all running tasks loop = asyncio.get_running_loop() @@ -172,7 +170,7 @@ def handle_signal(app, signum, _) -> None: try: await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10) except asyncio.TimeoutError: - logcat.exception("server", "Timeout while waiting for tasks to finish") + logger.exception("Timeout while waiting for tasks to finish") except asyncio.CancelledError: pass finally: @@ -184,9 +182,9 @@ def handle_signal(app, signum, _) -> None: @asynccontextmanager async def lifespan(app: FastAPI): - logcat.info("server", "Starting up") + logger.info("Starting up") yield - logcat.info("server", "Shutting down") + logger.info("Shutting down") for impl in app.__llama_stack_impls__.values(): await impl.shutdown() @@ -209,11 +207,11 @@ async def sse_generator(event_gen): yield create_sse_event(item) await asyncio.sleep(0.01) except asyncio.CancelledError: - logcat.info("server", "Generator cancelled") + logger.info("Generator cancelled") await event_gen.aclose() except Exception as e: - logcat.exception("server", f"Error in sse_generator: {e}") - logcat.exception("server", f"Traceback: {''.join(traceback.format_exception(type(e), e, e.__traceback__))}") + logger.exception(f"Error in sse_generator: {e}") + logger.exception(f"Traceback: {''.join(traceback.format_exception(type(e), e, e.__traceback__))}") yield create_sse_event( { "error": { @@ -235,7 +233,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str): value = func(**kwargs) return await maybe_await(value) except Exception as e: - logcat.exception("server", f"Error in {func.__name__}") + traceback.print_exception(e) raise translate_exception(e) from e sig = inspect.signature(func) @@ -314,8 +312,6 @@ class ClientVersionMiddleware: def main(): - logcat.init() - """Start the LlamaStack server.""" parser = argparse.ArgumentParser(description="Start the LlamaStack server.") parser.add_argument( @@ -355,10 +351,10 @@ def main(): for env_pair in args.env: try: key, value = validate_env_pair(env_pair) - logcat.info("server", f"Setting CLI environment variable {key} => {value}") + logger.info(f"Setting CLI environment variable {key} => {value}") os.environ[key] = value except ValueError as e: - logcat.error("server", f"Error: {str(e)}") + logger.error(f"Error: {str(e)}") sys.exit(1) if args.yaml_config: @@ -366,12 +362,12 @@ def main(): config_file = Path(args.yaml_config) if not config_file.exists(): raise ValueError(f"Config file {config_file} does not exist") - logcat.info("server", f"Using config file: {config_file}") + logger.info(f"Using config file: {config_file}") elif args.template: config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml" if not config_file.exists(): raise ValueError(f"Template {args.template} does not exist") - logcat.info("server", f"Using template {args.template} config file: {config_file}") + logger.info(f"Using template {args.template} config file: {config_file}") else: raise ValueError("Either --yaml-config or --template must be provided") @@ -379,10 +375,9 @@ def main(): config = replace_env_vars(yaml.safe_load(fp)) config = StackRunConfig(**config) - logcat.info("server", "Run configuration:") + logger.info("Run configuration:") safe_config = redact_sensitive_fields(config.model_dump()) - for log_line in yaml.dump(safe_config, indent=2).split("\n"): - logcat.info("server", log_line) + logger.info(yaml.dump(safe_config, indent=2)) app = FastAPI(lifespan=lifespan) app.add_middleware(TracingMiddleware) @@ -392,7 +387,7 @@ def main(): try: impls = asyncio.run(construct_stack(config)) except InvalidProviderError as e: - logcat.error("server", f"Error: {str(e)}") + logger.error(f"Error: {str(e)}") sys.exit(1) if Api.telemetry in impls: @@ -437,7 +432,7 @@ def main(): ) ) - logcat.debug("server", f"serving APIs: {apis_to_serve}") + logger.debug(f"serving APIs: {apis_to_serve}") app.exception_handler(RequestValidationError)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler) @@ -464,10 +459,10 @@ def main(): "ssl_keyfile": keyfile, "ssl_certfile": certfile, } - logcat.info("server", f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}") + logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}") listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0" - logcat.info("server", f"Listening on {listen_host}:{port}") + logger.info(f"Listening on {listen_host}:{port}") uvicorn_config = { "app": app, diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index de74aa858..2b974739a 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -11,9 +11,7 @@ import tempfile from typing import Any, Dict, Optional import yaml -from termcolor import colored -from llama_stack import logcat from llama_stack.apis.agents import Agents from llama_stack.apis.batch_inference import BatchInference from llama_stack.apis.benchmarks import Benchmarks @@ -39,8 +37,11 @@ from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls from llama_stack.distribution.store.registry import create_dist_registry from llama_stack.distribution.utils.dynamic import instantiate_class_type +from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api +logger = get_logger(name=__name__, category="core") + class LlamaStack( VectorDBs, @@ -101,9 +102,8 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]): objects_to_process = response.data if hasattr(response, "data") else response for obj in objects_to_process: - logcat.debug( - "core", - f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}", + logger.debug( + f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}", ) diff --git a/llama_stack/distribution/start_stack.sh b/llama_stack/distribution/start_stack.sh index a769bd66e..cfc078c27 100755 --- a/llama_stack/distribution/start_stack.sh +++ b/llama_stack/distribution/start_stack.sh @@ -100,12 +100,15 @@ esac if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then set -x + $PYTHON_BINARY -m llama_stack.distribution.server.server \ --yaml-config "$yaml_config" \ --port "$port" \ $env_vars \ $other_args elif [[ "$env_type" == "container" ]]; then + set -x + # Check if container command is available if ! is_command_available $CONTAINER_BINARY; then printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2 @@ -141,8 +144,6 @@ elif [[ "$env_type" == "container" ]]; then version_tag=$(curl -s $URL | jq -r '.info.version') fi - set -x - $CONTAINER_BINARY run $CONTAINER_OPTS -it \ -p $port:$port \ $env_vars \ diff --git a/llama_stack/log.py b/llama_stack/log.py new file mode 100644 index 000000000..11aa1bf7e --- /dev/null +++ b/llama_stack/log.py @@ -0,0 +1,169 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import logging +import os +from logging.config import dictConfig +from typing import Dict + +from rich.console import Console +from rich.logging import RichHandler + +# Default log level +DEFAULT_LOG_LEVEL = logging.INFO + +# Predefined categories +CATEGORIES = [ + "core", + "server", + "router", + "inference", + "agents", + "safety", + "eval", + "tools", + "client", +] + +# Initialize category levels with default level +_category_levels: Dict[str, int] = {category: DEFAULT_LOG_LEVEL for category in CATEGORIES} + + +def parse_environment_config(env_config: str) -> Dict[str, int]: + """ + Parse the LLAMA_STACK_LOGGING environment variable and return a dictionary of category log levels. + + Parameters: + env_config (str): The value of the LLAMA_STACK_LOGGING environment variable. + + Returns: + Dict[str, int]: A dictionary mapping categories to their log levels. + """ + category_levels = {} + for pair in env_config.split(";"): + if not pair.strip(): + continue + + try: + category, level = pair.split("=", 1) + category = category.strip().lower() + level = level.strip().upper() # Convert to uppercase for logging._nameToLevel + + level_value = logging._nameToLevel.get(level) + if level_value is None: + logging.warning( + f"Unknown log level '{level}' for category '{category}'. Falling back to default 'INFO'." + ) + continue + + if category == "all": + # Apply the log level to all categories and the root logger + for cat in CATEGORIES: + category_levels[cat] = level_value + # Set the root logger's level to the specified level + category_levels["root"] = level_value + elif category in CATEGORIES: + category_levels[category] = level_value + logging.info(f"Setting '{category}' category to level '{level}'.") + else: + logging.warning(f"Unknown logging category: {category}. No changes made.") + + except ValueError: + logging.warning(f"Invalid logging configuration: '{pair}'. Expected format: 'category=level'.") + + return category_levels + + +class CustomRichHandler(RichHandler): + def __init__(self, *args, **kwargs): + kwargs["console"] = Console(width=120) + super().__init__(*args, **kwargs) + + +def setup_logging(category_levels: Dict[str, int]) -> None: + """ + Configure logging based on the provided category log levels. + + Parameters: + category_levels (Dict[str, int]): A dictionary mapping categories to their log levels. + """ + log_format = "[dim]%(asctime)s %(name)s:%(lineno)d[/] [yellow dim]%(category)s[/]: %(message)s" + + class CategoryFilter(logging.Filter): + """Ensure category is always present in log records.""" + + def filter(self, record): + if not hasattr(record, "category"): + record.category = "uncategorized" # Default to 'uncategorized' if no category found + return True + + # Determine the root logger's level (default to WARNING if not specified) + root_level = category_levels.get("root", logging.WARNING) + + logging_config = { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "rich": { + "()": logging.Formatter, + "format": log_format, + } + }, + "handlers": { + "console": { + "()": CustomRichHandler, # Use our custom handler class + "formatter": "rich", + "rich_tracebacks": True, + "show_time": False, + "show_path": False, + "markup": True, + "filters": ["category_filter"], + } + }, + "filters": { + "category_filter": { + "()": CategoryFilter, + } + }, + "loggers": { + category: { + "handlers": ["console"], + "level": category_levels.get(category, DEFAULT_LOG_LEVEL), + "propagate": False, # Disable propagation to root logger + } + for category in CATEGORIES + }, + "root": { + "handlers": ["console"], + "level": root_level, # Set root logger's level dynamically + }, + } + dictConfig(logging_config) + + +def get_logger(name: str, category: str = "uncategorized") -> logging.LoggerAdapter: + """ + Returns a logger with the specified name and category. + If no category is provided, defaults to 'uncategorized'. + + Parameters: + name (str): The name of the logger (e.g., module or filename). + category (str): The category of the logger (default 'uncategorized'). + + Returns: + logging.LoggerAdapter: Configured logger with category support. + """ + logger = logging.getLogger(name) + logger.setLevel(_category_levels.get(category, DEFAULT_LOG_LEVEL)) + return logging.LoggerAdapter(logger, {"category": category}) + + +env_config = os.environ.get("LLAMA_STACK_LOGGING", "") +if env_config: + print(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}") + _category_levels.update(parse_environment_config(env_config)) + +setup_logging(_category_levels) diff --git a/llama_stack/logcat.py b/llama_stack/logcat.py deleted file mode 100644 index 0e11cb782..000000000 --- a/llama_stack/logcat.py +++ /dev/null @@ -1,204 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -""" -Category-based logging utility for llama-stack. - -This module provides a wrapper over the standard Python logging module that supports -categorized logging with environment variable control. - -Usage: - from llama_stack import logcat - logcat.info("server", "Starting up...") - logcat.debug("inference", "Processing request...") - -Environment variable: - LLAMA_STACK_LOGGING: Semicolon-separated list of category=level pairs - Example: "server=debug;inference=warning" -""" - -import datetime -import logging -import os -from typing import Dict - -# ANSI color codes for terminal output -COLORS = { - "RESET": "\033[0m", - "DEBUG": "\033[36m", # Cyan - "INFO": "\033[32m", # Green - "WARNING": "\033[33m", # Yellow - "ERROR": "\033[31m", # Red - "CRITICAL": "\033[35m", # Magenta - "DIM": "\033[2m", # Dimmed text - "YELLOW_DIM": "\033[2;33m", # Dimmed yellow -} - -# Static list of valid categories representing various parts of the Llama Stack -# server codebase -CATEGORIES = [ - "core", - "server", - "router", - "inference", - "agents", - "safety", - "eval", - "tools", - "client", -] - -_logger = logging.getLogger("llama_stack") -_logger.propagate = False - -_default_level = logging.INFO - -# Category-level mapping (can be modified by environment variables) -_category_levels: Dict[str, int] = {} - - -class TerminalStreamHandler(logging.StreamHandler): - def __init__(self, stream=None): - super().__init__(stream) - self.is_tty = hasattr(self.stream, "isatty") and self.stream.isatty() - - def format(self, record): - record.is_tty = self.is_tty - return super().format(record) - - -class ColoredFormatter(logging.Formatter): - """Custom formatter with colors and fixed-width level names""" - - def format(self, record): - levelname = record.levelname - # Use only time with milliseconds, not date - timestamp = datetime.datetime.now().strftime("%H:%M:%S.%f")[:-3] # HH:MM:SS.mmm format - - file_info = f"{record.filename}:{record.lineno}" - - # Get category from extra if available - category = getattr(record, "category", None) - msg = record.getMessage() - - if getattr(record, "is_tty", False): - color = COLORS.get(levelname, COLORS["RESET"]) - if category: - category_formatted = f"{COLORS['YELLOW_DIM']}{category}{COLORS['RESET']} " - formatted_msg = ( - f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']} " - f"{file_info:<20} {category_formatted}{msg}" - ) - else: - formatted_msg = ( - f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']}] " - f"{file_info:<20} {msg}" - ) - else: - if category: - formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} [{category}] {msg}" - else: - formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} {msg}" - - return formatted_msg - - -def init(default_level: int = logging.INFO) -> None: - global _default_level, _category_levels, _logger - - _default_level = default_level - - _logger.setLevel(logging.DEBUG) - _logger.handlers = [] # Clear existing handlers - - # Add our custom handler with the colored formatter - handler = TerminalStreamHandler() - formatter = ColoredFormatter() - handler.setFormatter(formatter) - _logger.addHandler(handler) - - for category in CATEGORIES: - _category_levels[category] = default_level - - env_config = os.environ.get("LLAMA_STACK_LOGGING", "") - if env_config: - for pair in env_config.split(";"): - if not pair.strip(): - continue - - try: - category, level = pair.split("=", 1) - category = category.strip().lower() - level = level.strip().lower() - - level_value = { - "debug": logging.DEBUG, - "info": logging.INFO, - "warning": logging.WARNING, - "warn": logging.WARNING, - "error": logging.ERROR, - "critical": logging.CRITICAL, - }.get(level) - - if level_value is None: - _logger.warning(f"Unknown log level '{level}' for category '{category}'") - continue - - if category == "all": - for cat in CATEGORIES: - _category_levels[cat] = level_value - else: - if category in CATEGORIES: - _category_levels[category] = level_value - else: - _logger.warning(f"Unknown logging category: {category}") - - except ValueError: - _logger.warning(f"Invalid logging configuration: {pair}") - - -def _should_log(level: int, category: str) -> bool: - category = category.lower() - if category not in _category_levels: - return False - category_level = _category_levels[category] - return level >= category_level - - -def _log(level: int, level_name: str, category: str, msg: str, *args, **kwargs) -> None: - if _should_log(level, category): - kwargs.setdefault("extra", {})["category"] = category.lower() - getattr(_logger, level_name)(msg, *args, stacklevel=3, **kwargs) - - -def debug(category: str, msg: str, *args, **kwargs) -> None: - _log(logging.DEBUG, "debug", category, msg, *args, **kwargs) - - -def info(category: str, msg: str, *args, **kwargs) -> None: - _log(logging.INFO, "info", category, msg, *args, **kwargs) - - -def warning(category: str, msg: str, *args, **kwargs) -> None: - _log(logging.WARNING, "warning", category, msg, *args, **kwargs) - - -def warn(category: str, msg: str, *args, **kwargs) -> None: - warning(category, msg, *args, **kwargs) - - -def error(category: str, msg: str, *args, **kwargs) -> None: - _log(logging.ERROR, "error", category, msg, *args, **kwargs) - - -def critical(category: str, msg: str, *args, **kwargs) -> None: - _log(logging.CRITICAL, "critical", category, msg, *args, **kwargs) - - -def exception(category: str, msg: str, *args, **kwargs) -> None: - if _should_log(logging.ERROR, category): - kwargs.setdefault("extra", {})["category"] = category.lower() - _logger.exception(msg, *args, stacklevel=2, **kwargs) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 720e73503..3619b3f67 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -17,7 +17,6 @@ from urllib.parse import urlparse import httpx -from llama_stack import logcat from llama_stack.apis.agents import ( AgentConfig, AgentToolGroup, @@ -67,6 +66,7 @@ from llama_stack.apis.tools import ( ToolRuntime, ) from llama_stack.apis.vector_io import VectorIO +from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ( BuiltinTool, ToolCall, @@ -88,6 +88,8 @@ MEMORY_QUERY_TOOL = "knowledge_search" WEB_SEARCH_TOOL = "web_search" RAG_TOOL_GROUP = "builtin::rag" +logger = get_logger(name=__name__, category="agents") + class ChatAgent(ShieldRunnerMixin): def __init__( @@ -609,7 +611,7 @@ class ChatAgent(ShieldRunnerMixin): ) if n_iter >= self.agent_config.max_infer_iters: - logcat.info("agents", f"done with MAX iterations ({n_iter}), exiting.") + logger.info(f"done with MAX iterations ({n_iter}), exiting.") # NOTE: mark end_of_turn to indicate to client that we are done with the turn # Do not continue the tool call loop after this point message.stop_reason = StopReason.end_of_turn @@ -617,7 +619,7 @@ class ChatAgent(ShieldRunnerMixin): break if stop_reason == StopReason.out_of_tokens: - logcat.info("agents", "out of token budget, exiting.") + logger.info("out of token budget, exiting.") yield message break @@ -631,16 +633,10 @@ class ChatAgent(ShieldRunnerMixin): message.content = [message.content] + output_attachments yield message else: - logcat.debug( - "agents", - f"completion message with EOM (iter: {n_iter}): {str(message)}", - ) + logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}") input_messages = input_messages + [message] else: - logcat.debug( - "agents", - f"completion message (iter: {n_iter}) from the model: {str(message)}", - ) + logger.debug(f"completion message (iter: {n_iter}) from the model: {str(message)}") # 1. Start the tool execution step and progress step_id = str(uuid.uuid4()) yield AgentTurnResponseStreamChunk( @@ -983,7 +979,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa path = urlparse(uri).path basename = os.path.basename(path) filepath = f"{tempdir}/{make_random_string() + basename}" - logcat.info("agents", f"Downloading {url} -> {filepath}") + logger.info(f"Downloading {url} -> {filepath}") async with httpx.AsyncClient() as client: r = await client.get(uri) @@ -1023,7 +1019,7 @@ async def execute_tool_call_maybe( else: name = name.value - logcat.info("agents", f"executing tool call: {name} with args: {tool_call.arguments}") + logger.info(f"executing tool call: {name} with args: {tool_call.arguments}") result = await tool_runtime_api.invoke_tool( tool_name=name, kwargs={ @@ -1033,7 +1029,7 @@ async def execute_tool_call_maybe( **toolgroup_args.get(group_name, {}), }, ) - logcat.debug("agents", f"tool call {name} completed with result: {result}") + logger.info(f"tool call {name} completed with result: {result}") return result diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index a4cecf9f1..ec68fb556 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -8,7 +8,6 @@ from typing import AsyncGenerator, List, Optional, Union from fireworks.client import Fireworks -from llama_stack import logcat from llama_stack.apis.common.content_types import ( InterleavedContent, InterleavedContentItem, @@ -33,6 +32,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.log import get_logger from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ) @@ -55,6 +55,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import FireworksImplConfig from .models import MODEL_ENTRIES +logger = get_logger(name=__name__, category="inference") + class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): def __init__(self, config: FireworksImplConfig) -> None: @@ -237,7 +239,8 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv "stream": request.stream, **self._build_options(request.sampling_params, request.response_format, request.logprobs), } - logcat.debug("inference", f"params to fireworks: {params}") + logger.debug(f"params to fireworks: {params}") + return params async def embeddings( diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 4d7fef8ed..36941480c 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -4,13 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging + from typing import AsyncGenerator, List, Optional, Union import httpx from ollama import AsyncClient -from llama_stack import logcat from llama_stack.apis.common.content_types import ( ImageContentItem, InterleavedContent, @@ -35,6 +34,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import Model, ModelType +from llama_stack.log import get_logger from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, @@ -59,7 +59,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .models import model_entries -log = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="inference") class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): @@ -72,7 +72,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): return AsyncClient(host=self.url) async def initialize(self) -> None: - log.info(f"checking connectivity to Ollama at `{self.url}`...") + logger.info(f"checking connectivity to Ollama at `{self.url}`...") try: await self.client.ps() except httpx.ConnectError as e: @@ -214,7 +214,8 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): "options": sampling_options, "stream": request.stream, } - logcat.debug("inference", f"params to ollama: {params}") + logger.debug(f"params to ollama: {params}") + return params async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: @@ -290,7 +291,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): async def register_model(self, model: Model) -> Model: model = await self.register_helper.register_model(model) if model.model_type == ModelType.embedding: - log.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...") + logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...") await self.client.pull(model.provider_resource_id) response = await self.client.list() else: diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 0c468cdbf..f701c0da7 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -8,7 +8,6 @@ from typing import AsyncGenerator, List, Optional, Union from together import Together -from llama_stack import logcat from llama_stack.apis.common.content_types import ( InterleavedContent, InterleavedContentItem, @@ -32,6 +31,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.log import get_logger from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ) @@ -54,6 +54,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import TogetherImplConfig from .models import MODEL_ENTRIES +logger = get_logger(name=__name__, category="inference") + class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): def __init__(self, config: TogetherImplConfig) -> None: @@ -224,8 +226,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi "stream": request.stream, **self._build_options(request.sampling_params, request.logprobs, request.response_format), } - logcat.debug("inference", f"params to together: {params}") - return params + logger.debug(f"params to together: {params}") async def embeddings( self, diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 9467996a6..d88dc5a9e 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -8,7 +8,6 @@ from typing import AsyncGenerator, AsyncIterator, List, Optional, Union import litellm -from llama_stack import logcat from llama_stack.apis.common.content_types import ( InterleavedContent, InterleavedContentItem, @@ -33,6 +32,7 @@ from llama_stack.apis.inference import ( ) from llama_stack.apis.models.models import Model from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.log import get_logger from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ) @@ -47,6 +47,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) +logger = get_logger(name=__name__, category="inference") + class LiteLLMOpenAIMixin( ModelRegistryHelper, @@ -109,8 +111,7 @@ class LiteLLMOpenAIMixin( ) params = await self._get_params(request) - logcat.debug("inference", f"params to litellm (openai compat): {params}") - + logger.debug(f"params to litellm (openai compat): {params}") # unfortunately, we need to use synchronous litellm.completion here because litellm # caches various httpx.client objects in a non-eventloop aware manner response = litellm.completion(**params) diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 37b1a8160..1edf445c0 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -8,14 +8,12 @@ import asyncio import base64 import io import json -import logging import re from typing import List, Optional, Tuple, Union import httpx from PIL import Image as PIL_Image -from llama_stack import logcat from llama_stack.apis.common.content_types import ( ImageContentItem, InterleavedContent, @@ -34,6 +32,7 @@ from llama_stack.apis.inference import ( ToolDefinition, UserMessage, ) +from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ( ModelFamily, RawContent, @@ -58,7 +57,7 @@ from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.models.llama.sku_list import resolve_model from llama_stack.providers.utils.inference import supported_inference_models -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="inference") class ChatCompletionRequestWithRawContent(ChatCompletionRequest): @@ -464,7 +463,7 @@ def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefin def get_default_tool_prompt_format(model: str) -> ToolPromptFormat: llama_model = resolve_model(model) if llama_model is None: - logcat.warning("inference", f"Could not resolve model {model}, defaulting to json tool prompt format") + log.warning(f"Could not resolve model {model}, defaulting to json tool prompt format") return ToolPromptFormat.json if llama_model.model_family == ModelFamily.llama3_1 or ( diff --git a/pyproject.toml b/pyproject.toml index d8f3718d8..0fa055a02 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -151,7 +151,6 @@ exclude = [ "llama_stack/distribution", "llama_stack/apis", "llama_stack/cli", - "llama_stack/logcat.py", "llama_stack/models", "llama_stack/strong_typing", "llama_stack/templates", @@ -163,5 +162,5 @@ module = ["yaml", "fire"] ignore_missing_imports = true [[tool.mypy.overrides]] -module = "llama_stack.distribution.resolver" -follow_imports = "normal" # This will force type checking on this module +module = ["llama_stack.distribution.resolver", "llama_stack.log"] +follow_imports = "normal" # This will force type checking on this module diff --git a/tests/unit/server/test_logcat.py b/tests/unit/server/test_logcat.py deleted file mode 100644 index 4a116a08f..000000000 --- a/tests/unit/server/test_logcat.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import io -import logging -import os -import unittest - -from llama_stack import logcat - - -class TestLogcat(unittest.TestCase): - def setUp(self): - self.original_env = os.environ.get("LLAMA_STACK_LOGGING") - - self.log_output = io.StringIO() - self._init_logcat() - - def tearDown(self): - if self.original_env is not None: - os.environ["LLAMA_STACK_LOGGING"] = self.original_env - else: - os.environ.pop("LLAMA_STACK_LOGGING", None) - - def _init_logcat(self): - logcat.init(default_level=logging.DEBUG) - self.handler = logging.StreamHandler(self.log_output) - self.handler.setFormatter(logging.Formatter("[%(category)s] %(message)s")) - logcat._logger.handlers.clear() - logcat._logger.addHandler(self.handler) - - def test_basic_logging(self): - logcat.info("server", "Info message") - logcat.warning("server", "Warning message") - logcat.error("server", "Error message") - - output = self.log_output.getvalue() - self.assertIn("[server] Info message", output) - self.assertIn("[server] Warning message", output) - self.assertIn("[server] Error message", output) - - def test_different_categories(self): - # Log messages with different categories - logcat.info("server", "Server message") - logcat.info("inference", "Inference message") - logcat.info("router", "Router message") - - output = self.log_output.getvalue() - self.assertIn("[server] Server message", output) - self.assertIn("[inference] Inference message", output) - self.assertIn("[router] Router message", output) - - def test_env_var_control(self): - os.environ["LLAMA_STACK_LOGGING"] = "server=debug;inference=warning" - self._init_logcat() - - # These should be visible based on the environment settings - logcat.debug("server", "Server debug message") - logcat.info("server", "Server info message") - logcat.warning("inference", "Inference warning message") - logcat.error("inference", "Inference error message") - - # These should be filtered out based on the environment settings - logcat.debug("inference", "Inference debug message") - logcat.info("inference", "Inference info message") - - output = self.log_output.getvalue() - self.assertIn("[server] Server debug message", output) - self.assertIn("[server] Server info message", output) - self.assertIn("[inference] Inference warning message", output) - self.assertIn("[inference] Inference error message", output) - - self.assertNotIn("[inference] Inference debug message", output) - self.assertNotIn("[inference] Inference info message", output) - - def test_invalid_category(self): - logcat.info("nonexistent", "This message should not be logged") - - # Check that the message was not logged - output = self.log_output.getvalue() - self.assertNotIn("[nonexistent] This message should not be logged", output) - - -if __name__ == "__main__": - unittest.main()