feat(logging): implement category-based logging

This commit introduces a new logging system that allows loggers to be assigned
a category while retaining the logger name based on the file name. The log
format includes both the logger name and the category, producing output
like:

```
INFO     2025-03-03 21:44:11,323 llama_stack.distribution.stack:103 [core]: Tool_groups: builtin::websearch served by
         tavily-search
```

Key features include:

- Category-based logging: Loggers can be assigned a category (e.g.,
  "core", "server") when programming. The logger can be loaded like
  this: `logger = get_logger(name=__name__, category="server")`
- Environment variable control: Log levels can be configured per-category using the
  `LLAMA_STACK_LOGGING` environment variable. For example:
  `LLAMA_STACK_LOGGING="server=DEBUG;core=debug"` enables DEBUG level for the "server"
    and "core" categories.
- `LLAMA_STACK_LOGGING="all=debug"` sets DEBUG level globally for all categories and
    third-party libraries.

This provides fine-grained control over logging levels while maintaining a clean and
informative log format.

The formatter uses the rich library which provides nice colors better
stack traces like so:

```
ERROR    2025-03-03 21:49:37,124 asyncio:1758 [uncategorized]: unhandled exception during asyncio.run() shutdown
         task: <Task finished name='Task-16' coro=<handle_signal.<locals>.shutdown() done, defined at
         /Users/leseb/Documents/AI/llama-stack/llama_stack/distribution/server/server.py:146>
         exception=UnboundLocalError("local variable 'loop' referenced before assignment")>
         ╭────────────────────────────────────── Traceback (most recent call last) ───────────────────────────────────────╮
         │ /Users/leseb/Documents/AI/llama-stack/llama_stack/distribution/server/server.py:178 in shutdown                │
         │                                                                                                                │
         │   175 │   │   except asyncio.CancelledError:                                                                   │
         │   176 │   │   │   pass                                                                                         │
         │   177 │   │   finally:                                                                                         │
         │ ❱ 178 │   │   │   loop.stop()                                                                                  │
         │   179 │                                                                                                        │
         │   180 │   loop = asyncio.get_running_loop()                                                                    │
         │   181 │   loop.create_task(shutdown())                                                                         │
         ╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
         UnboundLocalError: local variable 'loop' referenced before assignment
```

Co-authored-by: Ashwin Bharambe <@ashwinb>
Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-03-03 13:59:48 +01:00 committed by Ashwin Bharambe
parent efe1772727
commit 11fffe7b95
13 changed files with 258 additions and 57 deletions

View file

@ -5,15 +5,15 @@
# the root directory of this source tree.
import argparse
import logging
import os
from pathlib import Path
from llama_stack.cli.subcommand import Subcommand
from llama_stack.log import get_logger
REPO_ROOT = Path(__file__).parent.parent.parent.parent
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="server")
class StackRun(Subcommand):

View file

@ -6,8 +6,6 @@
import importlib
import inspect
from typing import Any, Dict, List, Set, Tuple
import logging
from typing import Any, Dict, List, Set
from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmarks
@ -36,6 +34,7 @@ from llama_stack.distribution.datatypes import (
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
Api,
BenchmarksProtocolPrivate,
@ -51,7 +50,7 @@ from llama_stack.providers.datatypes import (
VectorDBsProtocolPrivate,
)
log = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="core")
class InvalidProviderError(Exception):
@ -187,7 +186,7 @@ def validate_and_prepare_providers(
specs = {}
for provider in providers:
if not provider.provider_id or provider.provider_id == "__disabled__":
log.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
logger.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
continue
validate_provider(provider, api, provider_registry)
@ -209,10 +208,10 @@ def validate_provider(provider: Provider, api: Api, provider_registry: ProviderR
p = provider_registry[api][provider.provider_type]
if p.deprecation_error:
log.error(p.deprecation_error)
logger.error(p.deprecation_error)
raise InvalidProviderError(p.deprecation_error)
elif p.deprecation_warning:
log.warning(
logger.warning(
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
)
@ -246,10 +245,11 @@ def sort_providers_by_deps(
)
)
log.info(f"Resolved {len(sorted_providers)} providers")
logger.debug(f"Resolved {len(sorted_providers)} providers")
for api_str, provider in sorted_providers:
log.debug(f" {api_str} => {provider.provider_id}")
log.debug("")
logger.debug(f" {api_str} => {provider.provider_id}")
logger.debug("")
return sorted_providers
async def instantiate_providers(
@ -389,7 +389,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
obj_params = set(obj_sig.parameters)
obj_params.discard("self")
if not (proto_params <= obj_params):
log.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
missing_methods.append((name, "signature_mismatch"))
else:
# Check if the method is actually implemented in the class

View file

@ -54,6 +54,8 @@ from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
class VectorIORouter(VectorIO):
"""Routes to an provider based on the vector db identifier"""
@ -62,12 +64,15 @@ class VectorIORouter(VectorIO):
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing VectorIORouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("VectorIORouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("VectorIORouter.shutdown")
pass
async def register_vector_db(
@ -78,9 +83,7 @@ class VectorIORouter(VectorIO):
provider_id: Optional[str] = None,
provider_vector_db_id: Optional[str] = None,
) -> None:
logger.debug(
f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}",
)
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
await self.routing_table.register_vector_db(
vector_db_id,
embedding_model,
@ -106,6 +109,7 @@ class VectorIORouter(VectorIO):
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryChunksResponse:
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
@ -116,12 +120,15 @@ class InferenceRouter(Inference):
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing InferenceRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("InferenceRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("InferenceRouter.shutdown")
pass
async def register_model(
@ -132,6 +139,9 @@ class InferenceRouter(Inference):
metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> None:
logger.debug(
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
)
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
async def chat_completion(
@ -242,6 +252,7 @@ class InferenceRouter(Inference):
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
logger.debug(f"InferenceRouter.embeddings: {model_id}")
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
@ -261,12 +272,15 @@ class SafetyRouter(Safety):
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing SafetyRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("SafetyRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("SafetyRouter.shutdown")
pass
async def register_shield(
@ -276,6 +290,7 @@ class SafetyRouter(Safety):
provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
) -> Shield:
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
async def run_shield(
@ -284,6 +299,7 @@ class SafetyRouter(Safety):
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
return await self.routing_table.get_provider_impl(shield_id).run_shield(
shield_id=shield_id,
messages=messages,
@ -296,12 +312,15 @@ class DatasetIORouter(DatasetIO):
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing DatasetIORouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("DatasetIORouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("DatasetIORouter.shutdown")
pass
async def get_rows_paginated(
@ -322,6 +341,7 @@ class DatasetIORouter(DatasetIO):
)
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
dataset_id=dataset_id,
rows=rows,
@ -333,12 +353,15 @@ class ScoringRouter(Scoring):
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing ScoringRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("ScoringRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("ScoringRouter.shutdown")
pass
async def score_batch(
@ -347,6 +370,7 @@ class ScoringRouter(Scoring):
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
res = {}
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
@ -367,9 +391,7 @@ class ScoringRouter(Scoring):
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse:
logger.debug(
f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions",
)
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
res = {}
# look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys():
@ -387,12 +409,15 @@ class EvalRouter(Eval):
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing EvalRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("EvalRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("EvalRouter.shutdown")
pass
async def run_eval(
@ -400,6 +425,7 @@ class EvalRouter(Eval):
benchmark_id: str,
benchmark_config: BenchmarkConfig,
) -> Job:
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
benchmark_id=benchmark_id,
benchmark_config=benchmark_config,
@ -412,6 +438,7 @@ class EvalRouter(Eval):
scoring_functions: List[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
benchmark_id=benchmark_id,
input_rows=input_rows,
@ -424,6 +451,7 @@ class EvalRouter(Eval):
benchmark_id: str,
job_id: str,
) -> Optional[JobStatus]:
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
async def job_cancel(
@ -431,6 +459,7 @@ class EvalRouter(Eval):
benchmark_id: str,
job_id: str,
) -> None:
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
benchmark_id,
job_id,
@ -441,6 +470,7 @@ class EvalRouter(Eval):
benchmark_id: str,
job_id: str,
) -> EvaluateResponse:
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
benchmark_id,
job_id,
@ -453,6 +483,7 @@ class ToolRuntimeRouter(ToolRuntime):
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
self.routing_table = routing_table
async def query(
@ -461,6 +492,7 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_ids: List[str],
query_config: Optional[RAGQueryConfig] = None,
) -> RAGQueryResult:
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
return await self.routing_table.get_provider_impl("knowledge_search").query(
content, vector_db_ids, query_config
)
@ -471,6 +503,9 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
logger.debug(
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
)
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
documents, vector_db_id, chunk_size_in_tokens
)
@ -479,6 +514,7 @@ class ToolRuntimeRouter(ToolRuntime):
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing ToolRuntimeRouter")
self.routing_table = routing_table
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
@ -487,12 +523,15 @@ class ToolRuntimeRouter(ToolRuntime):
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
async def initialize(self) -> None:
logger.debug("ToolRuntimeRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("ToolRuntimeRouter.shutdown")
pass
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
tool_name=tool_name,
kwargs=kwargs,
@ -501,4 +540,5 @@ class ToolRuntimeRouter(ToolRuntime):
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)

View file

@ -9,7 +9,6 @@ import asyncio
import functools
import inspect
import json
import logging
import os
import signal
import sys
@ -26,7 +25,6 @@ from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.datatypes import StackRunConfig
@ -39,6 +37,7 @@ from llama_stack.distribution.stack import (
replace_env_vars,
validate_env_pair,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
@ -54,8 +53,7 @@ from .endpoints import get_all_api_endpoints
REPO_ROOT = Path(__file__).parent.parent.parent.parent
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s")
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="server")
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
@ -209,7 +207,7 @@ async def sse_generator(event_gen):
yield create_sse_event(item)
await asyncio.sleep(0.01)
except asyncio.CancelledError:
print("Generator cancelled")
logger.info("Generator cancelled")
await event_gen.aclose()
except Exception as e:
logger.exception(f"Error in sse_generator: {e}")

View file

@ -5,14 +5,12 @@
# the root directory of this source tree.
import importlib.resources
import logging
import os
import re
import tempfile
from typing import Any, Dict, Optional
import yaml
from termcolor import colored
from llama_stack.apis.agents import Agents
from llama_stack.apis.batch_inference import BatchInference
@ -39,9 +37,10 @@ from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="core")
class LlamaStack(
@ -103,12 +102,10 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
objects_to_process = response.data if hasattr(response, "data") else response
for obj in objects_to_process:
log.info(
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
logger.debug(
f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}",
)
log.info("")
class EnvVarError(Exception):
def __init__(self, var_name: str, path: str = ""):

153
llama_stack/log.py Normal file
View file

@ -0,0 +1,153 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import os
from logging.config import dictConfig
from typing import Dict
# Default log level
DEFAULT_LOG_LEVEL = logging.INFO
# Predefined categories
CATEGORIES = ["core", "server", "router", "inference", "agents", "safety", "eval", "tools", "client"]
# Initialize category levels with default level
_category_levels: Dict[str, int] = {category: DEFAULT_LOG_LEVEL for category in CATEGORIES}
def parse_environment_config(env_config: str) -> Dict[str, int]:
"""
Parse the LLAMA_STACK_LOGGING environment variable and return a dictionary of category log levels.
Parameters:
env_config (str): The value of the LLAMA_STACK_LOGGING environment variable.
Returns:
Dict[str, int]: A dictionary mapping categories to their log levels.
"""
category_levels = {}
for pair in env_config.split(";"):
if not pair.strip():
continue
try:
category, level = pair.split("=", 1)
category = category.strip().lower()
level = level.strip().upper() # Convert to uppercase for logging._nameToLevel
level_value = logging._nameToLevel.get(level)
if level_value is None:
logging.warning(
f"Unknown log level '{level}' for category '{category}'. Falling back to default 'INFO'."
)
continue
if category == "all":
# Apply the log level to all categories and the root logger
for cat in CATEGORIES:
category_levels[cat] = level_value
# Set the root logger's level to the specified level
category_levels["root"] = level_value
elif category in CATEGORIES:
category_levels[category] = level_value
logging.info(f"Setting '{category}' category to level '{level}'.")
else:
logging.warning(f"Unknown logging category: {category}. No changes made.")
except ValueError:
logging.warning(f"Invalid logging configuration: '{pair}'. Expected format: 'category=level'.")
return category_levels
def setup_logging(category_levels: Dict[str, int]) -> None:
"""
Configure logging based on the provided category log levels.
Parameters:
category_levels (Dict[str, int]): A dictionary mapping categories to their log levels.
"""
log_format = "%(asctime)s %(name)s:%(lineno)d [%(category)s]: %(message)s"
class CategoryFilter(logging.Filter):
"""Ensure category is always present in log records."""
def filter(self, record):
if not hasattr(record, "category"):
record.category = "uncategorized" # Default to 'uncategorized' if no category found
return True
# Determine the root logger's level (default to WARNING if not specified)
root_level = category_levels.get("root", logging.WARNING)
logging_config = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"rich": {
"()": logging.Formatter, # Standard formatter (RichHandler handles colors)
"format": log_format,
}
},
"handlers": {
"console": {
"class": "rich.logging.RichHandler",
"formatter": "rich",
"rich_tracebacks": True,
"show_time": False, # We handle timestamps ourselves in the log_format
"show_path": False,
"filters": ["category_filter"], # Ensures category is included
}
},
"filters": {
"category_filter": {
"()": CategoryFilter,
}
},
"loggers": {
category: {
"handlers": ["console"],
"level": category_levels.get(category, DEFAULT_LOG_LEVEL),
"propagate": False, # Disable propagation to root logger
}
for category in CATEGORIES
},
"root": {
"handlers": ["console"],
"level": root_level, # Set root logger's level dynamically
},
}
dictConfig(logging_config)
def get_logger(name: str, category: str = "uncategorized") -> logging.LoggerAdapter:
"""
Returns a logger with the specified name and category.
If no category is provided, defaults to 'uncategorized'.
Parameters:
name (str): The name of the logger (e.g., module or filename).
category (str): The category of the logger (default 'uncategorized').
Returns:
logging.LoggerAdapter: Configured logger with category support.
"""
# Use the name as the logger's name
logger = logging.getLogger(name)
# Apply the category's log level to the logger
logger.setLevel(_category_levels.get(category, DEFAULT_LOG_LEVEL))
# Attach the category as extra context
return logging.LoggerAdapter(logger, {"category": category})
# Parse environment variable and configure logging
env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
if env_config:
print(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}")
_category_levels.update(parse_environment_config(env_config))
setup_logging(_category_levels)

View file

@ -17,7 +17,6 @@ from urllib.parse import urlparse
import httpx
from llama_stack import logcat
from llama_stack.apis.agents import (
AgentConfig,
AgentToolGroup,
@ -67,6 +66,7 @@ from llama_stack.apis.tools import (
ToolRuntime,
)
from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import (
BuiltinTool,
ToolCall,
@ -88,6 +88,8 @@ MEMORY_QUERY_TOOL = "knowledge_search"
WEB_SEARCH_TOOL = "web_search"
RAG_TOOL_GROUP = "builtin::rag"
logger = get_logger(name=__name__, category="agents")
class ChatAgent(ShieldRunnerMixin):
def __init__(
@ -609,7 +611,7 @@ class ChatAgent(ShieldRunnerMixin):
)
if n_iter >= self.agent_config.max_infer_iters:
logcat.info("agents", f"done with MAX iterations ({n_iter}), exiting.")
logger.info(f"done with MAX iterations ({n_iter}), exiting.")
# NOTE: mark end_of_turn to indicate to client that we are done with the turn
# Do not continue the tool call loop after this point
message.stop_reason = StopReason.end_of_turn
@ -617,7 +619,7 @@ class ChatAgent(ShieldRunnerMixin):
break
if stop_reason == StopReason.out_of_tokens:
logcat.info("agents", "out of token budget, exiting.")
logger.info("out of token budget, exiting.")
yield message
break
@ -631,16 +633,10 @@ class ChatAgent(ShieldRunnerMixin):
message.content = [message.content] + output_attachments
yield message
else:
logcat.debug(
"agents",
f"completion message with EOM (iter: {n_iter}): {str(message)}",
)
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
input_messages = input_messages + [message]
else:
logcat.debug(
"agents",
f"completion message (iter: {n_iter}) from the model: {str(message)}",
)
logger.debug(f"completion message (iter: {n_iter}) from the model: {str(message)}")
# 1. Start the tool execution step and progress
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
@ -983,7 +979,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
path = urlparse(uri).path
basename = os.path.basename(path)
filepath = f"{tempdir}/{make_random_string() + basename}"
logcat.info("agents", f"Downloading {url} -> {filepath}")
logger.info(f"Downloading {url} -> {filepath}")
async with httpx.AsyncClient() as client:
r = await client.get(uri)
@ -1023,7 +1019,7 @@ async def execute_tool_call_maybe(
else:
name = name.value
logcat.info("agents", f"executing tool call: {name} with args: {tool_call.arguments}")
logger.info(f"executing tool call: {name} with args: {tool_call.arguments}")
result = await tool_runtime_api.invoke_tool(
tool_name=name,
kwargs={
@ -1033,7 +1029,7 @@ async def execute_tool_call_maybe(
**toolgroup_args.get(group_name, {}),
},
)
logcat.debug("agents", f"tool call {name} completed with result: {result}")
logger.info(f"tool call {name} completed with result: {result}")
return result

View file

@ -32,6 +32,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
@ -54,6 +55,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import FireworksImplConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: FireworksImplConfig) -> None:
@ -230,12 +233,15 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
if input_dict["prompt"].startswith("<|begin_of_text|>"):
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
return {
params = {
"model": request.model,
**input_dict,
"stream": request.stream,
**self._build_options(request.sampling_params, request.response_format, request.logprobs),
}
logger.debug(f"params to fireworks: {params}")
return params
async def embeddings(
self,

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import AsyncGenerator, List, Optional, Union
import httpx
@ -34,6 +34,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
@ -58,7 +59,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .models import model_entries
log = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="inference")
class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
@ -71,7 +72,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return AsyncClient(host=self.url)
async def initialize(self) -> None:
log.info(f"checking connectivity to Ollama at `{self.url}`...")
logger.info(f"checking connectivity to Ollama at `{self.url}`...")
try:
await self.client.ps()
except httpx.ConnectError as e:
@ -207,12 +208,15 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
else:
raise ValueError(f"Unknown response format type: {fmt.type}")
return {
params = {
"model": request.model,
**input_dict,
"options": sampling_options,
"stream": request.stream,
}
logger.debug(f"params to ollama: {params}")
return params
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
@ -287,7 +291,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def register_model(self, model: Model) -> Model:
model = await self.register_helper.register_model(model)
if model.model_type == ModelType.embedding:
log.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
await self.client.pull(model.provider_resource_id)
response = await self.client.list()
else:

View file

@ -31,6 +31,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
@ -53,6 +54,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import TogetherImplConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: TogetherImplConfig) -> None:
@ -217,12 +220,13 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
assert not media_present, "Together does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request)
return {
params = {
"model": request.model,
**input_dict,
"stream": request.stream,
**self._build_options(request.sampling_params, request.logprobs, request.response_format),
}
logger.debug(f"params to together: {params}")
async def embeddings(
self,

View file

@ -32,6 +32,7 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.models.models import Model
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
@ -46,6 +47,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
logger = get_logger(name=__name__, category="inference")
class LiteLLMOpenAIMixin(
ModelRegistryHelper,
@ -108,6 +111,7 @@ class LiteLLMOpenAIMixin(
)
params = await self._get_params(request)
logger.debug(f"params to litellm (openai compat): {params}")
# unfortunately, we need to use synchronous litellm.completion here because litellm
# caches various httpx.client objects in a non-eventloop aware manner
response = litellm.completion(**params)

View file

@ -8,14 +8,12 @@ import asyncio
import base64
import io
import json
import logging
import re
from typing import List, Optional, Tuple, Union
import httpx
from PIL import Image as PIL_Image
from llama_stack import logcat
from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
@ -34,6 +32,7 @@ from llama_stack.apis.inference import (
ToolDefinition,
UserMessage,
)
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import (
ModelFamily,
RawContent,
@ -58,7 +57,7 @@ from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.utils.inference import supported_inference_models
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="inference")
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):
@ -464,7 +463,7 @@ def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefin
def get_default_tool_prompt_format(model: str) -> ToolPromptFormat:
llama_model = resolve_model(model)
if llama_model is None:
logcat.warning("inference", f"Could not resolve model {model}, defaulting to json tool prompt format")
log.warning(f"Could not resolve model {model}, defaulting to json tool prompt format")
return ToolPromptFormat.json
if llama_model.model_family == ModelFamily.llama3_1 or (

View file

@ -162,5 +162,5 @@ module = ["yaml", "fire"]
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "llama_stack.distribution.resolver"
follow_imports = "normal" # This will force type checking on this module
module = ["llama_stack.distribution.resolver", "llama_stack.log"]
follow_imports = "normal" # This will force type checking on this module