diff --git a/llama_stack/cli/stack/run.py b/llama_stack/cli/stack/run.py index d4e679e4b..c91cac048 100644 --- a/llama_stack/cli/stack/run.py +++ b/llama_stack/cli/stack/run.py @@ -5,15 +5,15 @@ # the root directory of this source tree. import argparse -import logging import os from pathlib import Path from llama_stack.cli.subcommand import Subcommand +from llama_stack.log import get_logger REPO_ROOT = Path(__file__).parent.parent.parent.parent -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="server") class StackRun(Subcommand): diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index e00518cf5..d7ca4414d 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -6,8 +6,6 @@ import importlib import inspect from typing import Any, Dict, List, Set, Tuple -import logging -from typing import Any, Dict, List, Set from llama_stack.apis.agents import Agents from llama_stack.apis.benchmarks import Benchmarks @@ -36,6 +34,7 @@ from llama_stack.distribution.datatypes import ( from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.utils.dynamic import instantiate_class_type +from llama_stack.log import get_logger from llama_stack.providers.datatypes import ( Api, BenchmarksProtocolPrivate, @@ -51,7 +50,7 @@ from llama_stack.providers.datatypes import ( VectorDBsProtocolPrivate, ) -log = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="core") class InvalidProviderError(Exception): @@ -187,7 +186,7 @@ def validate_and_prepare_providers( specs = {} for provider in providers: if not provider.provider_id or provider.provider_id == "__disabled__": - log.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled") + logger.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled") continue validate_provider(provider, api, provider_registry) @@ -209,10 +208,10 @@ def validate_provider(provider: Provider, api: Api, provider_registry: ProviderR p = provider_registry[api][provider.provider_type] if p.deprecation_error: - log.error(p.deprecation_error) + logger.error(p.deprecation_error) raise InvalidProviderError(p.deprecation_error) elif p.deprecation_warning: - log.warning( + logger.warning( f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}", ) @@ -246,10 +245,11 @@ def sort_providers_by_deps( ) ) - log.info(f"Resolved {len(sorted_providers)} providers") + logger.debug(f"Resolved {len(sorted_providers)} providers") for api_str, provider in sorted_providers: - log.debug(f" {api_str} => {provider.provider_id}") - log.debug("") + logger.debug(f" {api_str} => {provider.provider_id}") + logger.debug("") + return sorted_providers async def instantiate_providers( @@ -389,7 +389,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None: obj_params = set(obj_sig.parameters) obj_params.discard("self") if not (proto_params <= obj_params): - log.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}") + logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}") missing_methods.append((name, "signature_mismatch")) else: # Check if the method is actually implemented in the class diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index bcb3997cf..28df67922 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -54,6 +54,8 @@ from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.log import get_logger from llama_stack.providers.datatypes import RoutingTable +logger = get_logger(name=__name__, category="core") + class VectorIORouter(VectorIO): """Routes to an provider based on the vector db identifier""" @@ -62,12 +64,15 @@ class VectorIORouter(VectorIO): self, routing_table: RoutingTable, ) -> None: + logger.debug("Initializing VectorIORouter") self.routing_table = routing_table async def initialize(self) -> None: + logger.debug("VectorIORouter.initialize") pass async def shutdown(self) -> None: + logger.debug("VectorIORouter.shutdown") pass async def register_vector_db( @@ -78,9 +83,7 @@ class VectorIORouter(VectorIO): provider_id: Optional[str] = None, provider_vector_db_id: Optional[str] = None, ) -> None: - logger.debug( - f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}", - ) + logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}") await self.routing_table.register_vector_db( vector_db_id, embedding_model, @@ -106,6 +109,7 @@ class VectorIORouter(VectorIO): query: InterleavedContent, params: Optional[Dict[str, Any]] = None, ) -> QueryChunksResponse: + logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}") return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params) @@ -116,12 +120,15 @@ class InferenceRouter(Inference): self, routing_table: RoutingTable, ) -> None: + logger.debug("Initializing InferenceRouter") self.routing_table = routing_table async def initialize(self) -> None: + logger.debug("InferenceRouter.initialize") pass async def shutdown(self) -> None: + logger.debug("InferenceRouter.shutdown") pass async def register_model( @@ -132,6 +139,9 @@ class InferenceRouter(Inference): metadata: Optional[Dict[str, Any]] = None, model_type: Optional[ModelType] = None, ) -> None: + logger.debug( + f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}", + ) await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) async def chat_completion( @@ -242,6 +252,7 @@ class InferenceRouter(Inference): output_dimension: Optional[int] = None, task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: + logger.debug(f"InferenceRouter.embeddings: {model_id}") model = await self.routing_table.get_model(model_id) if model is None: raise ValueError(f"Model '{model_id}' not found") @@ -261,12 +272,15 @@ class SafetyRouter(Safety): self, routing_table: RoutingTable, ) -> None: + logger.debug("Initializing SafetyRouter") self.routing_table = routing_table async def initialize(self) -> None: + logger.debug("SafetyRouter.initialize") pass async def shutdown(self) -> None: + logger.debug("SafetyRouter.shutdown") pass async def register_shield( @@ -276,6 +290,7 @@ class SafetyRouter(Safety): provider_id: Optional[str] = None, params: Optional[Dict[str, Any]] = None, ) -> Shield: + logger.debug(f"SafetyRouter.register_shield: {shield_id}") return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) async def run_shield( @@ -284,6 +299,7 @@ class SafetyRouter(Safety): messages: List[Message], params: Dict[str, Any] = None, ) -> RunShieldResponse: + logger.debug(f"SafetyRouter.run_shield: {shield_id}") return await self.routing_table.get_provider_impl(shield_id).run_shield( shield_id=shield_id, messages=messages, @@ -296,12 +312,15 @@ class DatasetIORouter(DatasetIO): self, routing_table: RoutingTable, ) -> None: + logger.debug("Initializing DatasetIORouter") self.routing_table = routing_table async def initialize(self) -> None: + logger.debug("DatasetIORouter.initialize") pass async def shutdown(self) -> None: + logger.debug("DatasetIORouter.shutdown") pass async def get_rows_paginated( @@ -322,6 +341,7 @@ class DatasetIORouter(DatasetIO): ) async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: + logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows") return await self.routing_table.get_provider_impl(dataset_id).append_rows( dataset_id=dataset_id, rows=rows, @@ -333,12 +353,15 @@ class ScoringRouter(Scoring): self, routing_table: RoutingTable, ) -> None: + logger.debug("Initializing ScoringRouter") self.routing_table = routing_table async def initialize(self) -> None: + logger.debug("ScoringRouter.initialize") pass async def shutdown(self) -> None: + logger.debug("ScoringRouter.shutdown") pass async def score_batch( @@ -347,6 +370,7 @@ class ScoringRouter(Scoring): scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, save_results_dataset: bool = False, ) -> ScoreBatchResponse: + logger.debug(f"ScoringRouter.score_batch: {dataset_id}") res = {} for fn_identifier in scoring_functions.keys(): score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch( @@ -367,9 +391,7 @@ class ScoringRouter(Scoring): input_rows: List[Dict[str, Any]], scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, ) -> ScoreResponse: - logger.debug( - f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions", - ) + logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions") res = {} # look up and map each scoring function to its provider impl for fn_identifier in scoring_functions.keys(): @@ -387,12 +409,15 @@ class EvalRouter(Eval): self, routing_table: RoutingTable, ) -> None: + logger.debug("Initializing EvalRouter") self.routing_table = routing_table async def initialize(self) -> None: + logger.debug("EvalRouter.initialize") pass async def shutdown(self) -> None: + logger.debug("EvalRouter.shutdown") pass async def run_eval( @@ -400,6 +425,7 @@ class EvalRouter(Eval): benchmark_id: str, benchmark_config: BenchmarkConfig, ) -> Job: + logger.debug(f"EvalRouter.run_eval: {benchmark_id}") return await self.routing_table.get_provider_impl(benchmark_id).run_eval( benchmark_id=benchmark_id, benchmark_config=benchmark_config, @@ -412,6 +438,7 @@ class EvalRouter(Eval): scoring_functions: List[str], benchmark_config: BenchmarkConfig, ) -> EvaluateResponse: + logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows") return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows( benchmark_id=benchmark_id, input_rows=input_rows, @@ -424,6 +451,7 @@ class EvalRouter(Eval): benchmark_id: str, job_id: str, ) -> Optional[JobStatus]: + logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}") return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id) async def job_cancel( @@ -431,6 +459,7 @@ class EvalRouter(Eval): benchmark_id: str, job_id: str, ) -> None: + logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}") await self.routing_table.get_provider_impl(benchmark_id).job_cancel( benchmark_id, job_id, @@ -441,6 +470,7 @@ class EvalRouter(Eval): benchmark_id: str, job_id: str, ) -> EvaluateResponse: + logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}") return await self.routing_table.get_provider_impl(benchmark_id).job_result( benchmark_id, job_id, @@ -453,6 +483,7 @@ class ToolRuntimeRouter(ToolRuntime): self, routing_table: RoutingTable, ) -> None: + logger.debug("Initializing ToolRuntimeRouter.RagToolImpl") self.routing_table = routing_table async def query( @@ -461,6 +492,7 @@ class ToolRuntimeRouter(ToolRuntime): vector_db_ids: List[str], query_config: Optional[RAGQueryConfig] = None, ) -> RAGQueryResult: + logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}") return await self.routing_table.get_provider_impl("knowledge_search").query( content, vector_db_ids, query_config ) @@ -471,6 +503,9 @@ class ToolRuntimeRouter(ToolRuntime): vector_db_id: str, chunk_size_in_tokens: int = 512, ) -> None: + logger.debug( + f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}" + ) return await self.routing_table.get_provider_impl("insert_into_memory").insert( documents, vector_db_id, chunk_size_in_tokens ) @@ -479,6 +514,7 @@ class ToolRuntimeRouter(ToolRuntime): self, routing_table: RoutingTable, ) -> None: + logger.debug("Initializing ToolRuntimeRouter") self.routing_table = routing_table # HACK ALERT this should be in sync with "get_all_api_endpoints()" @@ -487,12 +523,15 @@ class ToolRuntimeRouter(ToolRuntime): setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method)) async def initialize(self) -> None: + logger.debug("ToolRuntimeRouter.initialize") pass async def shutdown(self) -> None: + logger.debug("ToolRuntimeRouter.shutdown") pass async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any: + logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}") return await self.routing_table.get_provider_impl(tool_name).invoke_tool( tool_name=tool_name, kwargs=kwargs, @@ -501,4 +540,5 @@ class ToolRuntimeRouter(ToolRuntime): async def list_runtime_tools( self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None ) -> List[ToolDef]: + logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 7cfc38f30..68523218a 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -9,7 +9,6 @@ import asyncio import functools import inspect import json -import logging import os import signal import sys @@ -26,7 +25,6 @@ from fastapi import Path as FastapiPath from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel, ValidationError -from termcolor import cprint from typing_extensions import Annotated from llama_stack.distribution.datatypes import StackRunConfig @@ -39,6 +37,7 @@ from llama_stack.distribution.stack import ( replace_env_vars, validate_env_pair, ) +from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig from llama_stack.providers.inline.telemetry.meta_reference.telemetry import ( @@ -54,8 +53,7 @@ from .endpoints import get_all_api_endpoints REPO_ROOT = Path(__file__).parent.parent.parent.parent -logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s") -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="server") def warn_with_traceback(message, category, filename, lineno, file=None, line=None): @@ -209,7 +207,7 @@ async def sse_generator(event_gen): yield create_sse_event(item) await asyncio.sleep(0.01) except asyncio.CancelledError: - print("Generator cancelled") + logger.info("Generator cancelled") await event_gen.aclose() except Exception as e: logger.exception(f"Error in sse_generator: {e}") diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index 30ca97a49..2b974739a 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -5,14 +5,12 @@ # the root directory of this source tree. import importlib.resources -import logging import os import re import tempfile from typing import Any, Dict, Optional import yaml -from termcolor import colored from llama_stack.apis.agents import Agents from llama_stack.apis.batch_inference import BatchInference @@ -39,9 +37,10 @@ from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls from llama_stack.distribution.store.registry import create_dist_registry from llama_stack.distribution.utils.dynamic import instantiate_class_type +from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api -log = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="core") class LlamaStack( @@ -103,12 +102,10 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]): objects_to_process = response.data if hasattr(response, "data") else response for obj in objects_to_process: - log.info( - f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}", + logger.debug( + f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}", ) - log.info("") - class EnvVarError(Exception): def __init__(self, var_name: str, path: str = ""): diff --git a/llama_stack/log.py b/llama_stack/log.py new file mode 100644 index 000000000..166e21604 --- /dev/null +++ b/llama_stack/log.py @@ -0,0 +1,153 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import logging +import os +from logging.config import dictConfig +from typing import Dict + +# Default log level +DEFAULT_LOG_LEVEL = logging.INFO + +# Predefined categories +CATEGORIES = ["core", "server", "router", "inference", "agents", "safety", "eval", "tools", "client"] + +# Initialize category levels with default level +_category_levels: Dict[str, int] = {category: DEFAULT_LOG_LEVEL for category in CATEGORIES} + + +def parse_environment_config(env_config: str) -> Dict[str, int]: + """ + Parse the LLAMA_STACK_LOGGING environment variable and return a dictionary of category log levels. + + Parameters: + env_config (str): The value of the LLAMA_STACK_LOGGING environment variable. + + Returns: + Dict[str, int]: A dictionary mapping categories to their log levels. + """ + category_levels = {} + for pair in env_config.split(";"): + if not pair.strip(): + continue + + try: + category, level = pair.split("=", 1) + category = category.strip().lower() + level = level.strip().upper() # Convert to uppercase for logging._nameToLevel + + level_value = logging._nameToLevel.get(level) + if level_value is None: + logging.warning( + f"Unknown log level '{level}' for category '{category}'. Falling back to default 'INFO'." + ) + continue + + if category == "all": + # Apply the log level to all categories and the root logger + for cat in CATEGORIES: + category_levels[cat] = level_value + # Set the root logger's level to the specified level + category_levels["root"] = level_value + elif category in CATEGORIES: + category_levels[category] = level_value + logging.info(f"Setting '{category}' category to level '{level}'.") + else: + logging.warning(f"Unknown logging category: {category}. No changes made.") + + except ValueError: + logging.warning(f"Invalid logging configuration: '{pair}'. Expected format: 'category=level'.") + + return category_levels + + +def setup_logging(category_levels: Dict[str, int]) -> None: + """ + Configure logging based on the provided category log levels. + + Parameters: + category_levels (Dict[str, int]): A dictionary mapping categories to their log levels. + """ + log_format = "%(asctime)s %(name)s:%(lineno)d [%(category)s]: %(message)s" + + class CategoryFilter(logging.Filter): + """Ensure category is always present in log records.""" + + def filter(self, record): + if not hasattr(record, "category"): + record.category = "uncategorized" # Default to 'uncategorized' if no category found + return True + + # Determine the root logger's level (default to WARNING if not specified) + root_level = category_levels.get("root", logging.WARNING) + + logging_config = { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "rich": { + "()": logging.Formatter, # Standard formatter (RichHandler handles colors) + "format": log_format, + } + }, + "handlers": { + "console": { + "class": "rich.logging.RichHandler", + "formatter": "rich", + "rich_tracebacks": True, + "show_time": False, # We handle timestamps ourselves in the log_format + "show_path": False, + "filters": ["category_filter"], # Ensures category is included + } + }, + "filters": { + "category_filter": { + "()": CategoryFilter, + } + }, + "loggers": { + category: { + "handlers": ["console"], + "level": category_levels.get(category, DEFAULT_LOG_LEVEL), + "propagate": False, # Disable propagation to root logger + } + for category in CATEGORIES + }, + "root": { + "handlers": ["console"], + "level": root_level, # Set root logger's level dynamically + }, + } + dictConfig(logging_config) + + +def get_logger(name: str, category: str = "uncategorized") -> logging.LoggerAdapter: + """ + Returns a logger with the specified name and category. + If no category is provided, defaults to 'uncategorized'. + + Parameters: + name (str): The name of the logger (e.g., module or filename). + category (str): The category of the logger (default 'uncategorized'). + + Returns: + logging.LoggerAdapter: Configured logger with category support. + """ + # Use the name as the logger's name + logger = logging.getLogger(name) + # Apply the category's log level to the logger + logger.setLevel(_category_levels.get(category, DEFAULT_LOG_LEVEL)) + # Attach the category as extra context + return logging.LoggerAdapter(logger, {"category": category}) + + +# Parse environment variable and configure logging +env_config = os.environ.get("LLAMA_STACK_LOGGING", "") +if env_config: + print(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}") + _category_levels.update(parse_environment_config(env_config)) + +setup_logging(_category_levels) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 720e73503..3619b3f67 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -17,7 +17,6 @@ from urllib.parse import urlparse import httpx -from llama_stack import logcat from llama_stack.apis.agents import ( AgentConfig, AgentToolGroup, @@ -67,6 +66,7 @@ from llama_stack.apis.tools import ( ToolRuntime, ) from llama_stack.apis.vector_io import VectorIO +from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ( BuiltinTool, ToolCall, @@ -88,6 +88,8 @@ MEMORY_QUERY_TOOL = "knowledge_search" WEB_SEARCH_TOOL = "web_search" RAG_TOOL_GROUP = "builtin::rag" +logger = get_logger(name=__name__, category="agents") + class ChatAgent(ShieldRunnerMixin): def __init__( @@ -609,7 +611,7 @@ class ChatAgent(ShieldRunnerMixin): ) if n_iter >= self.agent_config.max_infer_iters: - logcat.info("agents", f"done with MAX iterations ({n_iter}), exiting.") + logger.info(f"done with MAX iterations ({n_iter}), exiting.") # NOTE: mark end_of_turn to indicate to client that we are done with the turn # Do not continue the tool call loop after this point message.stop_reason = StopReason.end_of_turn @@ -617,7 +619,7 @@ class ChatAgent(ShieldRunnerMixin): break if stop_reason == StopReason.out_of_tokens: - logcat.info("agents", "out of token budget, exiting.") + logger.info("out of token budget, exiting.") yield message break @@ -631,16 +633,10 @@ class ChatAgent(ShieldRunnerMixin): message.content = [message.content] + output_attachments yield message else: - logcat.debug( - "agents", - f"completion message with EOM (iter: {n_iter}): {str(message)}", - ) + logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}") input_messages = input_messages + [message] else: - logcat.debug( - "agents", - f"completion message (iter: {n_iter}) from the model: {str(message)}", - ) + logger.debug(f"completion message (iter: {n_iter}) from the model: {str(message)}") # 1. Start the tool execution step and progress step_id = str(uuid.uuid4()) yield AgentTurnResponseStreamChunk( @@ -983,7 +979,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa path = urlparse(uri).path basename = os.path.basename(path) filepath = f"{tempdir}/{make_random_string() + basename}" - logcat.info("agents", f"Downloading {url} -> {filepath}") + logger.info(f"Downloading {url} -> {filepath}") async with httpx.AsyncClient() as client: r = await client.get(uri) @@ -1023,7 +1019,7 @@ async def execute_tool_call_maybe( else: name = name.value - logcat.info("agents", f"executing tool call: {name} with args: {tool_call.arguments}") + logger.info(f"executing tool call: {name} with args: {tool_call.arguments}") result = await tool_runtime_api.invoke_tool( tool_name=name, kwargs={ @@ -1033,7 +1029,7 @@ async def execute_tool_call_maybe( **toolgroup_args.get(group_name, {}), }, ) - logcat.debug("agents", f"tool call {name} completed with result: {result}") + logger.info(f"tool call {name} completed with result: {result}") return result diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 8c25f1a40..ec68fb556 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -32,6 +32,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.log import get_logger from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ) @@ -54,6 +55,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import FireworksImplConfig from .models import MODEL_ENTRIES +logger = get_logger(name=__name__, category="inference") + class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): def __init__(self, config: FireworksImplConfig) -> None: @@ -230,12 +233,15 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv if input_dict["prompt"].startswith("<|begin_of_text|>"): input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :] - return { + params = { "model": request.model, **input_dict, "stream": request.stream, **self._build_options(request.sampling_params, request.response_format, request.logprobs), } + logger.debug(f"params to fireworks: {params}") + + return params async def embeddings( self, diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 6530d8971..36941480c 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging + from typing import AsyncGenerator, List, Optional, Union import httpx @@ -34,6 +34,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import Model, ModelType +from llama_stack.log import get_logger from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, @@ -58,7 +59,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .models import model_entries -log = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="inference") class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): @@ -71,7 +72,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): return AsyncClient(host=self.url) async def initialize(self) -> None: - log.info(f"checking connectivity to Ollama at `{self.url}`...") + logger.info(f"checking connectivity to Ollama at `{self.url}`...") try: await self.client.ps() except httpx.ConnectError as e: @@ -207,12 +208,15 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): else: raise ValueError(f"Unknown response format type: {fmt.type}") - return { + params = { "model": request.model, **input_dict, "options": sampling_options, "stream": request.stream, } + logger.debug(f"params to ollama: {params}") + + return params async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: params = await self._get_params(request) @@ -287,7 +291,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): async def register_model(self, model: Model) -> Model: model = await self.register_helper.register_model(model) if model.model_type == ModelType.embedding: - log.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...") + logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...") await self.client.pull(model.provider_resource_id) response = await self.client.list() else: diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 6f2c5607f..f701c0da7 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -31,6 +31,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.log import get_logger from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ) @@ -53,6 +54,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import TogetherImplConfig from .models import MODEL_ENTRIES +logger = get_logger(name=__name__, category="inference") + class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): def __init__(self, config: TogetherImplConfig) -> None: @@ -217,12 +220,13 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi assert not media_present, "Together does not support media for Completion requests" input_dict["prompt"] = await completion_request_to_prompt(request) - return { + params = { "model": request.model, **input_dict, "stream": request.stream, **self._build_options(request.sampling_params, request.logprobs, request.response_format), } + logger.debug(f"params to together: {params}") async def embeddings( self, diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 6fe3bb042..d88dc5a9e 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -32,6 +32,7 @@ from llama_stack.apis.inference import ( ) from llama_stack.apis.models.models import Model from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.log import get_logger from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ) @@ -46,6 +47,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) +logger = get_logger(name=__name__, category="inference") + class LiteLLMOpenAIMixin( ModelRegistryHelper, @@ -108,6 +111,7 @@ class LiteLLMOpenAIMixin( ) params = await self._get_params(request) + logger.debug(f"params to litellm (openai compat): {params}") # unfortunately, we need to use synchronous litellm.completion here because litellm # caches various httpx.client objects in a non-eventloop aware manner response = litellm.completion(**params) diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 37b1a8160..1edf445c0 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -8,14 +8,12 @@ import asyncio import base64 import io import json -import logging import re from typing import List, Optional, Tuple, Union import httpx from PIL import Image as PIL_Image -from llama_stack import logcat from llama_stack.apis.common.content_types import ( ImageContentItem, InterleavedContent, @@ -34,6 +32,7 @@ from llama_stack.apis.inference import ( ToolDefinition, UserMessage, ) +from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ( ModelFamily, RawContent, @@ -58,7 +57,7 @@ from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.models.llama.sku_list import resolve_model from llama_stack.providers.utils.inference import supported_inference_models -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="inference") class ChatCompletionRequestWithRawContent(ChatCompletionRequest): @@ -464,7 +463,7 @@ def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefin def get_default_tool_prompt_format(model: str) -> ToolPromptFormat: llama_model = resolve_model(model) if llama_model is None: - logcat.warning("inference", f"Could not resolve model {model}, defaulting to json tool prompt format") + log.warning(f"Could not resolve model {model}, defaulting to json tool prompt format") return ToolPromptFormat.json if llama_model.model_family == ModelFamily.llama3_1 or ( diff --git a/pyproject.toml b/pyproject.toml index 3ed7a1b56..0fa055a02 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -162,5 +162,5 @@ module = ["yaml", "fire"] ignore_missing_imports = true [[tool.mypy.overrides]] -module = "llama_stack.distribution.resolver" -follow_imports = "normal" # This will force type checking on this module +module = ["llama_stack.distribution.resolver", "llama_stack.log"] +follow_imports = "normal" # This will force type checking on this module