From efe17727271992e180dc85e924b86f2033d7880e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Mon, 3 Mar 2025 11:58:40 +0100 Subject: [PATCH] Revert "feat: add a configurable category-based logger (#1352)" This reverts commit 754feba61ff5d294235ab569b7b79ff963ab922a. --- llama_stack/distribution/resolver.py | 20 +- llama_stack/distribution/routers/routers.py | 64 +----- llama_stack/distribution/server/server.py | 52 +++-- llama_stack/distribution/stack.py | 9 +- llama_stack/distribution/start_stack.sh | 5 +- llama_stack/logcat.py | 204 ------------------ .../remote/inference/fireworks/fireworks.py | 5 +- .../remote/inference/ollama/ollama.py | 5 +- .../remote/inference/together/together.py | 5 +- .../utils/inference/litellm_openai_mixin.py | 3 - pyproject.toml | 1 - tests/unit/server/test_logcat.py | 88 -------- 12 files changed, 54 insertions(+), 407 deletions(-) delete mode 100644 llama_stack/logcat.py delete mode 100644 tests/unit/server/test_logcat.py diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index c24df384d..e00518cf5 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -6,8 +6,9 @@ import importlib import inspect from typing import Any, Dict, List, Set, Tuple +import logging +from typing import Any, Dict, List, Set -from llama_stack import logcat from llama_stack.apis.agents import Agents from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.datasetio import DatasetIO @@ -50,6 +51,8 @@ from llama_stack.providers.datatypes import ( VectorDBsProtocolPrivate, ) +log = logging.getLogger(__name__) + class InvalidProviderError(Exception): pass @@ -184,7 +187,7 @@ def validate_and_prepare_providers( specs = {} for provider in providers: if not provider.provider_id or provider.provider_id == "__disabled__": - logcat.warning("core", f"Provider `{provider.provider_type}` for API `{api}` is disabled") + log.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled") continue validate_provider(provider, api, provider_registry) @@ -206,11 +209,10 @@ def validate_provider(provider: Provider, api: Api, provider_registry: ProviderR p = provider_registry[api][provider.provider_type] if p.deprecation_error: - logcat.error("core", p.deprecation_error) + log.error(p.deprecation_error) raise InvalidProviderError(p.deprecation_error) elif p.deprecation_warning: - logcat.warning( - "core", + log.warning( f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}", ) @@ -244,10 +246,10 @@ def sort_providers_by_deps( ) ) - logcat.debug("core", f"Resolved {len(sorted_providers)} providers") + log.info(f"Resolved {len(sorted_providers)} providers") for api_str, provider in sorted_providers: - logcat.debug("core", f" {api_str} => {provider.provider_id}") - return sorted_providers + log.debug(f" {api_str} => {provider.provider_id}") + log.debug("") async def instantiate_providers( @@ -387,7 +389,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None: obj_params = set(obj_sig.parameters) obj_params.discard("self") if not (proto_params <= obj_params): - logcat.error("core", f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}") + log.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}") missing_methods.append((name, "signature_mismatch")) else: # Check if the method is actually implemented in the class diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index f2c70e66f..bcb3997cf 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -6,7 +6,6 @@ from typing import Any, AsyncGenerator, Dict, List, Optional -from llama_stack import logcat from llama_stack.apis.common.content_types import ( URL, InterleavedContent, @@ -52,6 +51,7 @@ from llama_stack.apis.tools import ( ToolRuntime, ) from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO +from llama_stack.log import get_logger from llama_stack.providers.datatypes import RoutingTable @@ -62,15 +62,12 @@ class VectorIORouter(VectorIO): self, routing_table: RoutingTable, ) -> None: - logcat.debug("core", "Initializing VectorIORouter") self.routing_table = routing_table async def initialize(self) -> None: - logcat.debug("core", "VectorIORouter.initialize") pass async def shutdown(self) -> None: - logcat.debug("core", "VectorIORouter.shutdown") pass async def register_vector_db( @@ -81,8 +78,7 @@ class VectorIORouter(VectorIO): provider_id: Optional[str] = None, provider_vector_db_id: Optional[str] = None, ) -> None: - logcat.debug( - "core", + logger.debug( f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}", ) await self.routing_table.register_vector_db( @@ -99,8 +95,7 @@ class VectorIORouter(VectorIO): chunks: List[Chunk], ttl_seconds: Optional[int] = None, ) -> None: - logcat.debug( - "core", + logger.debug( f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", ) return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds) @@ -111,7 +106,6 @@ class VectorIORouter(VectorIO): query: InterleavedContent, params: Optional[Dict[str, Any]] = None, ) -> QueryChunksResponse: - logcat.debug("core", f"VectorIORouter.query_chunks: {vector_db_id}") return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params) @@ -122,15 +116,12 @@ class InferenceRouter(Inference): self, routing_table: RoutingTable, ) -> None: - logcat.debug("core", "Initializing InferenceRouter") self.routing_table = routing_table async def initialize(self) -> None: - logcat.debug("core", "InferenceRouter.initialize") pass async def shutdown(self) -> None: - logcat.debug("core", "InferenceRouter.shutdown") pass async def register_model( @@ -141,10 +132,6 @@ class InferenceRouter(Inference): metadata: Optional[Dict[str, Any]] = None, model_type: Optional[ModelType] = None, ) -> None: - logcat.debug( - "core", - f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}", - ) await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) async def chat_completion( @@ -160,8 +147,7 @@ class InferenceRouter(Inference): logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: - logcat.debug( - "core", + logger.debug( f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", ) if sampling_params is None: @@ -226,8 +212,7 @@ class InferenceRouter(Inference): ) -> AsyncGenerator: if sampling_params is None: sampling_params = SamplingParams() - logcat.debug( - "core", + logger.debug( f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}", ) model = await self.routing_table.get_model(model_id) @@ -257,7 +242,6 @@ class InferenceRouter(Inference): output_dimension: Optional[int] = None, task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: - logcat.debug("core", f"InferenceRouter.embeddings: {model_id}") model = await self.routing_table.get_model(model_id) if model is None: raise ValueError(f"Model '{model_id}' not found") @@ -277,15 +261,12 @@ class SafetyRouter(Safety): self, routing_table: RoutingTable, ) -> None: - logcat.debug("core", "Initializing SafetyRouter") self.routing_table = routing_table async def initialize(self) -> None: - logcat.debug("core", "SafetyRouter.initialize") pass async def shutdown(self) -> None: - logcat.debug("core", "SafetyRouter.shutdown") pass async def register_shield( @@ -295,7 +276,6 @@ class SafetyRouter(Safety): provider_id: Optional[str] = None, params: Optional[Dict[str, Any]] = None, ) -> Shield: - logcat.debug("core", f"SafetyRouter.register_shield: {shield_id}") return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) async def run_shield( @@ -304,7 +284,6 @@ class SafetyRouter(Safety): messages: List[Message], params: Dict[str, Any] = None, ) -> RunShieldResponse: - logcat.debug("core", f"SafetyRouter.run_shield: {shield_id}") return await self.routing_table.get_provider_impl(shield_id).run_shield( shield_id=shield_id, messages=messages, @@ -317,15 +296,12 @@ class DatasetIORouter(DatasetIO): self, routing_table: RoutingTable, ) -> None: - logcat.debug("core", "Initializing DatasetIORouter") self.routing_table = routing_table async def initialize(self) -> None: - logcat.debug("core", "DatasetIORouter.initialize") pass async def shutdown(self) -> None: - logcat.debug("core", "DatasetIORouter.shutdown") pass async def get_rows_paginated( @@ -335,8 +311,7 @@ class DatasetIORouter(DatasetIO): page_token: Optional[str] = None, filter_condition: Optional[str] = None, ) -> PaginatedRowsResult: - logcat.debug( - "core", + logger.debug( f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}", ) return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated( @@ -347,7 +322,6 @@ class DatasetIORouter(DatasetIO): ) async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: - logcat.debug("core", f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows") return await self.routing_table.get_provider_impl(dataset_id).append_rows( dataset_id=dataset_id, rows=rows, @@ -359,15 +333,12 @@ class ScoringRouter(Scoring): self, routing_table: RoutingTable, ) -> None: - logcat.debug("core", "Initializing ScoringRouter") self.routing_table = routing_table async def initialize(self) -> None: - logcat.debug("core", "ScoringRouter.initialize") pass async def shutdown(self) -> None: - logcat.debug("core", "ScoringRouter.shutdown") pass async def score_batch( @@ -376,7 +347,6 @@ class ScoringRouter(Scoring): scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, save_results_dataset: bool = False, ) -> ScoreBatchResponse: - logcat.debug("core", f"ScoringRouter.score_batch: {dataset_id}") res = {} for fn_identifier in scoring_functions.keys(): score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch( @@ -397,8 +367,7 @@ class ScoringRouter(Scoring): input_rows: List[Dict[str, Any]], scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, ) -> ScoreResponse: - logcat.debug( - "core", + logger.debug( f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions", ) res = {} @@ -418,15 +387,12 @@ class EvalRouter(Eval): self, routing_table: RoutingTable, ) -> None: - logcat.debug("core", "Initializing EvalRouter") self.routing_table = routing_table async def initialize(self) -> None: - logcat.debug("core", "EvalRouter.initialize") pass async def shutdown(self) -> None: - logcat.debug("core", "EvalRouter.shutdown") pass async def run_eval( @@ -434,7 +400,6 @@ class EvalRouter(Eval): benchmark_id: str, benchmark_config: BenchmarkConfig, ) -> Job: - logcat.debug("core", f"EvalRouter.run_eval: {benchmark_id}") return await self.routing_table.get_provider_impl(benchmark_id).run_eval( benchmark_id=benchmark_id, benchmark_config=benchmark_config, @@ -447,7 +412,6 @@ class EvalRouter(Eval): scoring_functions: List[str], benchmark_config: BenchmarkConfig, ) -> EvaluateResponse: - logcat.debug("core", f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows") return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows( benchmark_id=benchmark_id, input_rows=input_rows, @@ -460,7 +424,6 @@ class EvalRouter(Eval): benchmark_id: str, job_id: str, ) -> Optional[JobStatus]: - logcat.debug("core", f"EvalRouter.job_status: {benchmark_id}, {job_id}") return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id) async def job_cancel( @@ -468,7 +431,6 @@ class EvalRouter(Eval): benchmark_id: str, job_id: str, ) -> None: - logcat.debug("core", f"EvalRouter.job_cancel: {benchmark_id}, {job_id}") await self.routing_table.get_provider_impl(benchmark_id).job_cancel( benchmark_id, job_id, @@ -479,7 +441,6 @@ class EvalRouter(Eval): benchmark_id: str, job_id: str, ) -> EvaluateResponse: - logcat.debug("core", f"EvalRouter.job_result: {benchmark_id}, {job_id}") return await self.routing_table.get_provider_impl(benchmark_id).job_result( benchmark_id, job_id, @@ -492,7 +453,6 @@ class ToolRuntimeRouter(ToolRuntime): self, routing_table: RoutingTable, ) -> None: - logcat.debug("core", "Initializing ToolRuntimeRouter.RagToolImpl") self.routing_table = routing_table async def query( @@ -501,7 +461,6 @@ class ToolRuntimeRouter(ToolRuntime): vector_db_ids: List[str], query_config: Optional[RAGQueryConfig] = None, ) -> RAGQueryResult: - logcat.debug("core", f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}") return await self.routing_table.get_provider_impl("knowledge_search").query( content, vector_db_ids, query_config ) @@ -512,10 +471,6 @@ class ToolRuntimeRouter(ToolRuntime): vector_db_id: str, chunk_size_in_tokens: int = 512, ) -> None: - logcat.debug( - "core", - f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}", - ) return await self.routing_table.get_provider_impl("insert_into_memory").insert( documents, vector_db_id, chunk_size_in_tokens ) @@ -524,7 +479,6 @@ class ToolRuntimeRouter(ToolRuntime): self, routing_table: RoutingTable, ) -> None: - logcat.debug("core", "Initializing ToolRuntimeRouter") self.routing_table = routing_table # HACK ALERT this should be in sync with "get_all_api_endpoints()" @@ -533,15 +487,12 @@ class ToolRuntimeRouter(ToolRuntime): setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method)) async def initialize(self) -> None: - logcat.debug("core", "ToolRuntimeRouter.initialize") pass async def shutdown(self) -> None: - logcat.debug("core", "ToolRuntimeRouter.shutdown") pass async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any: - logcat.debug("core", f"ToolRuntimeRouter.invoke_tool: {tool_name}") return await self.routing_table.get_provider_impl(tool_name).invoke_tool( tool_name=tool_name, kwargs=kwargs, @@ -550,5 +501,4 @@ class ToolRuntimeRouter(ToolRuntime): async def list_runtime_tools( self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None ) -> List[ToolDef]: - logcat.debug("core", f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 2fc36e58f..7cfc38f30 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -26,9 +26,9 @@ from fastapi import Path as FastapiPath from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel, ValidationError +from termcolor import cprint from typing_extensions import Annotated -from llama_stack import logcat from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.request_headers import set_request_provider_data @@ -55,7 +55,7 @@ from .endpoints import get_all_api_endpoints REPO_ROOT = Path(__file__).parent.parent.parent.parent logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s") -logcat.init() +logger = logging.getLogger(__name__) def warn_with_traceback(message, category, filename, lineno, file=None, line=None): @@ -142,23 +142,23 @@ def handle_signal(app, signum, _) -> None: not block the current execution. """ signame = signal.Signals(signum).name - logcat.info("server", f"Received signal {signame} ({signum}). Exiting gracefully...") + logger.info(f"Received signal {signame} ({signum}). Exiting gracefully...") async def shutdown(): try: # Gracefully shut down implementations for impl in app.__llama_stack_impls__.values(): impl_name = impl.__class__.__name__ - logcat.info("server", f"Shutting down {impl_name}") + logger.info("Shutting down %s", impl_name) try: if hasattr(impl, "shutdown"): await asyncio.wait_for(impl.shutdown(), timeout=5) else: - logcat.warning("server", f"No shutdown method for {impl_name}") + logger.warning("No shutdown method for %s", impl_name) except asyncio.TimeoutError: - logcat.exception("server", f"Shutdown timeout for {impl_name}") + logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True) except Exception as e: - logcat.exception("server", f"Failed to shutdown {impl_name}: {e}") + logger.exception("Failed to shutdown %s: %s", impl_name, {e}) # Gather all running tasks loop = asyncio.get_running_loop() @@ -172,7 +172,7 @@ def handle_signal(app, signum, _) -> None: try: await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10) except asyncio.TimeoutError: - logcat.exception("server", "Timeout while waiting for tasks to finish") + logger.exception("Timeout while waiting for tasks to finish") except asyncio.CancelledError: pass finally: @@ -184,9 +184,9 @@ def handle_signal(app, signum, _) -> None: @asynccontextmanager async def lifespan(app: FastAPI): - logcat.info("server", "Starting up") + logger.info("Starting up") yield - logcat.info("server", "Shutting down") + logger.info("Shutting down") for impl in app.__llama_stack_impls__.values(): await impl.shutdown() @@ -209,11 +209,11 @@ async def sse_generator(event_gen): yield create_sse_event(item) await asyncio.sleep(0.01) except asyncio.CancelledError: - logcat.info("server", "Generator cancelled") + print("Generator cancelled") await event_gen.aclose() except Exception as e: - logcat.exception("server", f"Error in sse_generator: {e}") - logcat.exception("server", f"Traceback: {''.join(traceback.format_exception(type(e), e, e.__traceback__))}") + logger.exception(f"Error in sse_generator: {e}") + logger.exception(f"Traceback: {''.join(traceback.format_exception(type(e), e, e.__traceback__))}") yield create_sse_event( { "error": { @@ -235,7 +235,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str): value = func(**kwargs) return await maybe_await(value) except Exception as e: - logcat.exception("server", f"Error in {func.__name__}") + traceback.print_exception(e) raise translate_exception(e) from e sig = inspect.signature(func) @@ -314,8 +314,6 @@ class ClientVersionMiddleware: def main(): - logcat.init() - """Start the LlamaStack server.""" parser = argparse.ArgumentParser(description="Start the LlamaStack server.") parser.add_argument( @@ -355,10 +353,10 @@ def main(): for env_pair in args.env: try: key, value = validate_env_pair(env_pair) - logcat.info("server", f"Setting CLI environment variable {key} => {value}") + logger.info(f"Setting CLI environment variable {key} => {value}") os.environ[key] = value except ValueError as e: - logcat.error("server", f"Error: {str(e)}") + logger.error(f"Error: {str(e)}") sys.exit(1) if args.yaml_config: @@ -366,12 +364,12 @@ def main(): config_file = Path(args.yaml_config) if not config_file.exists(): raise ValueError(f"Config file {config_file} does not exist") - logcat.info("server", f"Using config file: {config_file}") + logger.info(f"Using config file: {config_file}") elif args.template: config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml" if not config_file.exists(): raise ValueError(f"Template {args.template} does not exist") - logcat.info("server", f"Using template {args.template} config file: {config_file}") + logger.info(f"Using template {args.template} config file: {config_file}") else: raise ValueError("Either --yaml-config or --template must be provided") @@ -379,10 +377,9 @@ def main(): config = replace_env_vars(yaml.safe_load(fp)) config = StackRunConfig(**config) - logcat.info("server", "Run configuration:") + logger.info("Run configuration:") safe_config = redact_sensitive_fields(config.model_dump()) - for log_line in yaml.dump(safe_config, indent=2).split("\n"): - logcat.info("server", log_line) + logger.info(yaml.dump(safe_config, indent=2)) app = FastAPI(lifespan=lifespan) app.add_middleware(TracingMiddleware) @@ -392,7 +389,7 @@ def main(): try: impls = asyncio.run(construct_stack(config)) except InvalidProviderError as e: - logcat.error("server", f"Error: {str(e)}") + logger.error(f"Error: {str(e)}") sys.exit(1) if Api.telemetry in impls: @@ -437,8 +434,9 @@ def main(): ) ) - logcat.debug("server", f"serving APIs: {apis_to_serve}") + logger.debug(f"serving APIs: {apis_to_serve}") + print("") app.exception_handler(RequestValidationError)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler) signal.signal(signal.SIGINT, functools.partial(handle_signal, app)) @@ -464,10 +462,10 @@ def main(): "ssl_keyfile": keyfile, "ssl_certfile": certfile, } - logcat.info("server", f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}") + logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}") listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0" - logcat.info("server", f"Listening on {listen_host}:{port}") + logger.info(f"Listening on {listen_host}:{port}") uvicorn_config = { "app": app, diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index de74aa858..30ca97a49 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import importlib.resources +import logging import os import re import tempfile @@ -13,7 +14,6 @@ from typing import Any, Dict, Optional import yaml from termcolor import colored -from llama_stack import logcat from llama_stack.apis.agents import Agents from llama_stack.apis.batch_inference import BatchInference from llama_stack.apis.benchmarks import Benchmarks @@ -41,6 +41,8 @@ from llama_stack.distribution.store.registry import create_dist_registry from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.providers.datatypes import Api +log = logging.getLogger(__name__) + class LlamaStack( VectorDBs, @@ -101,11 +103,12 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]): objects_to_process = response.data if hasattr(response, "data") else response for obj in objects_to_process: - logcat.debug( - "core", + log.info( f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}", ) + log.info("") + class EnvVarError(Exception): def __init__(self, var_name: str, path: str = ""): diff --git a/llama_stack/distribution/start_stack.sh b/llama_stack/distribution/start_stack.sh index a769bd66e..713997331 100755 --- a/llama_stack/distribution/start_stack.sh +++ b/llama_stack/distribution/start_stack.sh @@ -98,8 +98,9 @@ case "$env_type" in *) esac +set -x + if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then - set -x $PYTHON_BINARY -m llama_stack.distribution.server.server \ --yaml-config "$yaml_config" \ --port "$port" \ @@ -141,8 +142,6 @@ elif [[ "$env_type" == "container" ]]; then version_tag=$(curl -s $URL | jq -r '.info.version') fi - set -x - $CONTAINER_BINARY run $CONTAINER_OPTS -it \ -p $port:$port \ $env_vars \ diff --git a/llama_stack/logcat.py b/llama_stack/logcat.py deleted file mode 100644 index 0e11cb782..000000000 --- a/llama_stack/logcat.py +++ /dev/null @@ -1,204 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -""" -Category-based logging utility for llama-stack. - -This module provides a wrapper over the standard Python logging module that supports -categorized logging with environment variable control. - -Usage: - from llama_stack import logcat - logcat.info("server", "Starting up...") - logcat.debug("inference", "Processing request...") - -Environment variable: - LLAMA_STACK_LOGGING: Semicolon-separated list of category=level pairs - Example: "server=debug;inference=warning" -""" - -import datetime -import logging -import os -from typing import Dict - -# ANSI color codes for terminal output -COLORS = { - "RESET": "\033[0m", - "DEBUG": "\033[36m", # Cyan - "INFO": "\033[32m", # Green - "WARNING": "\033[33m", # Yellow - "ERROR": "\033[31m", # Red - "CRITICAL": "\033[35m", # Magenta - "DIM": "\033[2m", # Dimmed text - "YELLOW_DIM": "\033[2;33m", # Dimmed yellow -} - -# Static list of valid categories representing various parts of the Llama Stack -# server codebase -CATEGORIES = [ - "core", - "server", - "router", - "inference", - "agents", - "safety", - "eval", - "tools", - "client", -] - -_logger = logging.getLogger("llama_stack") -_logger.propagate = False - -_default_level = logging.INFO - -# Category-level mapping (can be modified by environment variables) -_category_levels: Dict[str, int] = {} - - -class TerminalStreamHandler(logging.StreamHandler): - def __init__(self, stream=None): - super().__init__(stream) - self.is_tty = hasattr(self.stream, "isatty") and self.stream.isatty() - - def format(self, record): - record.is_tty = self.is_tty - return super().format(record) - - -class ColoredFormatter(logging.Formatter): - """Custom formatter with colors and fixed-width level names""" - - def format(self, record): - levelname = record.levelname - # Use only time with milliseconds, not date - timestamp = datetime.datetime.now().strftime("%H:%M:%S.%f")[:-3] # HH:MM:SS.mmm format - - file_info = f"{record.filename}:{record.lineno}" - - # Get category from extra if available - category = getattr(record, "category", None) - msg = record.getMessage() - - if getattr(record, "is_tty", False): - color = COLORS.get(levelname, COLORS["RESET"]) - if category: - category_formatted = f"{COLORS['YELLOW_DIM']}{category}{COLORS['RESET']} " - formatted_msg = ( - f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']} " - f"{file_info:<20} {category_formatted}{msg}" - ) - else: - formatted_msg = ( - f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']}] " - f"{file_info:<20} {msg}" - ) - else: - if category: - formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} [{category}] {msg}" - else: - formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} {msg}" - - return formatted_msg - - -def init(default_level: int = logging.INFO) -> None: - global _default_level, _category_levels, _logger - - _default_level = default_level - - _logger.setLevel(logging.DEBUG) - _logger.handlers = [] # Clear existing handlers - - # Add our custom handler with the colored formatter - handler = TerminalStreamHandler() - formatter = ColoredFormatter() - handler.setFormatter(formatter) - _logger.addHandler(handler) - - for category in CATEGORIES: - _category_levels[category] = default_level - - env_config = os.environ.get("LLAMA_STACK_LOGGING", "") - if env_config: - for pair in env_config.split(";"): - if not pair.strip(): - continue - - try: - category, level = pair.split("=", 1) - category = category.strip().lower() - level = level.strip().lower() - - level_value = { - "debug": logging.DEBUG, - "info": logging.INFO, - "warning": logging.WARNING, - "warn": logging.WARNING, - "error": logging.ERROR, - "critical": logging.CRITICAL, - }.get(level) - - if level_value is None: - _logger.warning(f"Unknown log level '{level}' for category '{category}'") - continue - - if category == "all": - for cat in CATEGORIES: - _category_levels[cat] = level_value - else: - if category in CATEGORIES: - _category_levels[category] = level_value - else: - _logger.warning(f"Unknown logging category: {category}") - - except ValueError: - _logger.warning(f"Invalid logging configuration: {pair}") - - -def _should_log(level: int, category: str) -> bool: - category = category.lower() - if category not in _category_levels: - return False - category_level = _category_levels[category] - return level >= category_level - - -def _log(level: int, level_name: str, category: str, msg: str, *args, **kwargs) -> None: - if _should_log(level, category): - kwargs.setdefault("extra", {})["category"] = category.lower() - getattr(_logger, level_name)(msg, *args, stacklevel=3, **kwargs) - - -def debug(category: str, msg: str, *args, **kwargs) -> None: - _log(logging.DEBUG, "debug", category, msg, *args, **kwargs) - - -def info(category: str, msg: str, *args, **kwargs) -> None: - _log(logging.INFO, "info", category, msg, *args, **kwargs) - - -def warning(category: str, msg: str, *args, **kwargs) -> None: - _log(logging.WARNING, "warning", category, msg, *args, **kwargs) - - -def warn(category: str, msg: str, *args, **kwargs) -> None: - warning(category, msg, *args, **kwargs) - - -def error(category: str, msg: str, *args, **kwargs) -> None: - _log(logging.ERROR, "error", category, msg, *args, **kwargs) - - -def critical(category: str, msg: str, *args, **kwargs) -> None: - _log(logging.CRITICAL, "critical", category, msg, *args, **kwargs) - - -def exception(category: str, msg: str, *args, **kwargs) -> None: - if _should_log(logging.ERROR, category): - kwargs.setdefault("extra", {})["category"] = category.lower() - _logger.exception(msg, *args, stacklevel=2, **kwargs) diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index a4cecf9f1..8c25f1a40 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -8,7 +8,6 @@ from typing import AsyncGenerator, List, Optional, Union from fireworks.client import Fireworks -from llama_stack import logcat from llama_stack.apis.common.content_types import ( InterleavedContent, InterleavedContentItem, @@ -231,14 +230,12 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv if input_dict["prompt"].startswith("<|begin_of_text|>"): input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :] - params = { + return { "model": request.model, **input_dict, "stream": request.stream, **self._build_options(request.sampling_params, request.response_format, request.logprobs), } - logcat.debug("inference", f"params to fireworks: {params}") - return params async def embeddings( self, diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 4d7fef8ed..6530d8971 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -10,7 +10,6 @@ from typing import AsyncGenerator, List, Optional, Union import httpx from ollama import AsyncClient -from llama_stack import logcat from llama_stack.apis.common.content_types import ( ImageContentItem, InterleavedContent, @@ -208,14 +207,12 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): else: raise ValueError(f"Unknown response format type: {fmt.type}") - params = { + return { "model": request.model, **input_dict, "options": sampling_options, "stream": request.stream, } - logcat.debug("inference", f"params to ollama: {params}") - return params async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: params = await self._get_params(request) diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 0c468cdbf..6f2c5607f 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -8,7 +8,6 @@ from typing import AsyncGenerator, List, Optional, Union from together import Together -from llama_stack import logcat from llama_stack.apis.common.content_types import ( InterleavedContent, InterleavedContentItem, @@ -218,14 +217,12 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi assert not media_present, "Together does not support media for Completion requests" input_dict["prompt"] = await completion_request_to_prompt(request) - params = { + return { "model": request.model, **input_dict, "stream": request.stream, **self._build_options(request.sampling_params, request.logprobs, request.response_format), } - logcat.debug("inference", f"params to together: {params}") - return params async def embeddings( self, diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 9467996a6..6fe3bb042 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -8,7 +8,6 @@ from typing import AsyncGenerator, AsyncIterator, List, Optional, Union import litellm -from llama_stack import logcat from llama_stack.apis.common.content_types import ( InterleavedContent, InterleavedContentItem, @@ -109,8 +108,6 @@ class LiteLLMOpenAIMixin( ) params = await self._get_params(request) - logcat.debug("inference", f"params to litellm (openai compat): {params}") - # unfortunately, we need to use synchronous litellm.completion here because litellm # caches various httpx.client objects in a non-eventloop aware manner response = litellm.completion(**params) diff --git a/pyproject.toml b/pyproject.toml index d8f3718d8..3ed7a1b56 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -151,7 +151,6 @@ exclude = [ "llama_stack/distribution", "llama_stack/apis", "llama_stack/cli", - "llama_stack/logcat.py", "llama_stack/models", "llama_stack/strong_typing", "llama_stack/templates", diff --git a/tests/unit/server/test_logcat.py b/tests/unit/server/test_logcat.py deleted file mode 100644 index 4a116a08f..000000000 --- a/tests/unit/server/test_logcat.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import io -import logging -import os -import unittest - -from llama_stack import logcat - - -class TestLogcat(unittest.TestCase): - def setUp(self): - self.original_env = os.environ.get("LLAMA_STACK_LOGGING") - - self.log_output = io.StringIO() - self._init_logcat() - - def tearDown(self): - if self.original_env is not None: - os.environ["LLAMA_STACK_LOGGING"] = self.original_env - else: - os.environ.pop("LLAMA_STACK_LOGGING", None) - - def _init_logcat(self): - logcat.init(default_level=logging.DEBUG) - self.handler = logging.StreamHandler(self.log_output) - self.handler.setFormatter(logging.Formatter("[%(category)s] %(message)s")) - logcat._logger.handlers.clear() - logcat._logger.addHandler(self.handler) - - def test_basic_logging(self): - logcat.info("server", "Info message") - logcat.warning("server", "Warning message") - logcat.error("server", "Error message") - - output = self.log_output.getvalue() - self.assertIn("[server] Info message", output) - self.assertIn("[server] Warning message", output) - self.assertIn("[server] Error message", output) - - def test_different_categories(self): - # Log messages with different categories - logcat.info("server", "Server message") - logcat.info("inference", "Inference message") - logcat.info("router", "Router message") - - output = self.log_output.getvalue() - self.assertIn("[server] Server message", output) - self.assertIn("[inference] Inference message", output) - self.assertIn("[router] Router message", output) - - def test_env_var_control(self): - os.environ["LLAMA_STACK_LOGGING"] = "server=debug;inference=warning" - self._init_logcat() - - # These should be visible based on the environment settings - logcat.debug("server", "Server debug message") - logcat.info("server", "Server info message") - logcat.warning("inference", "Inference warning message") - logcat.error("inference", "Inference error message") - - # These should be filtered out based on the environment settings - logcat.debug("inference", "Inference debug message") - logcat.info("inference", "Inference info message") - - output = self.log_output.getvalue() - self.assertIn("[server] Server debug message", output) - self.assertIn("[server] Server info message", output) - self.assertIn("[inference] Inference warning message", output) - self.assertIn("[inference] Inference error message", output) - - self.assertNotIn("[inference] Inference debug message", output) - self.assertNotIn("[inference] Inference info message", output) - - def test_invalid_category(self): - logcat.info("nonexistent", "This message should not be logged") - - # Check that the message was not logged - output = self.log_output.getvalue() - self.assertNotIn("[nonexistent] This message should not be logged", output) - - -if __name__ == "__main__": - unittest.main()