mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat(logging): implement category-based logging (#1362)
# What does this PR do? This commit introduces a new logging system that allows loggers to be assigned a category while retaining the logger name based on the file name. The log format includes both the logger name and the category, producing output like: ``` INFO 2025-03-03 21:44:11,323 llama_stack.distribution.stack:103 [core]: Tool_groups: builtin::websearch served by tavily-search ``` Key features include: - Category-based logging: Loggers can be assigned a category (e.g., "core", "server") when programming. The logger can be loaded like this: `logger = get_logger(name=__name__, category="server")` - Environment variable control: Log levels can be configured per-category using the `LLAMA_STACK_LOGGING` environment variable. For example: `LLAMA_STACK_LOGGING="server=DEBUG;core=debug"` enables DEBUG level for the "server" and "core" categories. - `LLAMA_STACK_LOGGING="all=debug"` sets DEBUG level globally for all categories and third-party libraries. This provides fine-grained control over logging levels while maintaining a clean and informative log format. The formatter uses the rich library which provides nice colors better stack traces like so: ``` ERROR 2025-03-03 21:49:37,124 asyncio:1758 [uncategorized]: unhandled exception during asyncio.run() shutdown task: <Task finished name='Task-16' coro=<handle_signal.<locals>.shutdown() done, defined at /Users/leseb/Documents/AI/llama-stack/llama_stack/distribution/server/server.py:146> exception=UnboundLocalError("local variable 'loop' referenced before assignment")> ╭────────────────────────────────────── Traceback (most recent call last) ───────────────────────────────────────╮ │ /Users/leseb/Documents/AI/llama-stack/llama_stack/distribution/server/server.py:178 in shutdown │ │ │ │ 175 │ │ except asyncio.CancelledError: │ │ 176 │ │ │ pass │ │ 177 │ │ finally: │ │ ❱ 178 │ │ │ loop.stop() │ │ 179 │ │ │ 180 │ loop = asyncio.get_running_loop() │ │ 181 │ loop.create_task(shutdown()) │ ╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ UnboundLocalError: local variable 'loop' referenced before assignment ``` Co-authored-by: Ashwin Bharambe <@ashwinb> Signed-off-by: Sébastien Han <seb@redhat.com> [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan ``` python -m llama_stack.distribution.server.server --yaml-config ./llama_stack/templates/ollama/run.yaml INFO 2025-03-03 21:55:35,918 __main__:365 [server]: Using config file: llama_stack/templates/ollama/run.yaml INFO 2025-03-03 21:55:35,925 __main__:378 [server]: Run configuration: INFO 2025-03-03 21:55:35,928 __main__:380 [server]: apis: - agents ``` [//]: # (## Documentation) --------- Signed-off-by: Sébastien Han <seb@redhat.com> Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
parent
bad12ee21f
commit
7cf1e24c4e
16 changed files with 296 additions and 431 deletions
|
@ -5,15 +5,15 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="server")
|
||||
|
||||
|
||||
class StackRun(Subcommand):
|
||||
|
|
|
@ -7,7 +7,6 @@ import importlib
|
|||
import inspect
|
||||
from typing import Any, Dict, List, Set, Tuple
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
|
@ -35,6 +34,7 @@ from llama_stack.distribution.datatypes import (
|
|||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
Api,
|
||||
BenchmarksProtocolPrivate,
|
||||
|
@ -50,6 +50,8 @@ from llama_stack.providers.datatypes import (
|
|||
VectorDBsProtocolPrivate,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class InvalidProviderError(Exception):
|
||||
pass
|
||||
|
@ -184,7 +186,7 @@ def validate_and_prepare_providers(
|
|||
specs = {}
|
||||
for provider in providers:
|
||||
if not provider.provider_id or provider.provider_id == "__disabled__":
|
||||
logcat.warning("core", f"Provider `{provider.provider_type}` for API `{api}` is disabled")
|
||||
logger.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
|
||||
continue
|
||||
|
||||
validate_provider(provider, api, provider_registry)
|
||||
|
@ -206,11 +208,10 @@ def validate_provider(provider: Provider, api: Api, provider_registry: ProviderR
|
|||
|
||||
p = provider_registry[api][provider.provider_type]
|
||||
if p.deprecation_error:
|
||||
logcat.error("core", p.deprecation_error)
|
||||
logger.error(p.deprecation_error)
|
||||
raise InvalidProviderError(p.deprecation_error)
|
||||
elif p.deprecation_warning:
|
||||
logcat.warning(
|
||||
"core",
|
||||
logger.warning(
|
||||
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
||||
)
|
||||
|
||||
|
@ -244,9 +245,10 @@ def sort_providers_by_deps(
|
|||
)
|
||||
)
|
||||
|
||||
logcat.debug("core", f"Resolved {len(sorted_providers)} providers")
|
||||
logger.debug(f"Resolved {len(sorted_providers)} providers")
|
||||
for api_str, provider in sorted_providers:
|
||||
logcat.debug("core", f" {api_str} => {provider.provider_id}")
|
||||
logger.debug(f" {api_str} => {provider.provider_id}")
|
||||
logger.debug("")
|
||||
return sorted_providers
|
||||
|
||||
|
||||
|
@ -387,7 +389,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
|||
obj_params = set(obj_sig.parameters)
|
||||
obj_params.discard("self")
|
||||
if not (proto_params <= obj_params):
|
||||
logcat.error("core", f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
||||
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
||||
missing_methods.append((name, "signature_mismatch"))
|
||||
else:
|
||||
# Check if the method is actually implemented in the class
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
InterleavedContent,
|
||||
|
@ -52,8 +51,11 @@ from llama_stack.apis.tools import (
|
|||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class VectorIORouter(VectorIO):
|
||||
"""Routes to an provider based on the vector db identifier"""
|
||||
|
@ -62,15 +64,15 @@ class VectorIORouter(VectorIO):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing VectorIORouter")
|
||||
logger.debug("Initializing VectorIORouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "VectorIORouter.initialize")
|
||||
logger.debug("VectorIORouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "VectorIORouter.shutdown")
|
||||
logger.debug("VectorIORouter.shutdown")
|
||||
pass
|
||||
|
||||
async def register_vector_db(
|
||||
|
@ -81,10 +83,7 @@ class VectorIORouter(VectorIO):
|
|||
provider_id: Optional[str] = None,
|
||||
provider_vector_db_id: Optional[str] = None,
|
||||
) -> None:
|
||||
logcat.debug(
|
||||
"core",
|
||||
f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}",
|
||||
)
|
||||
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
||||
await self.routing_table.register_vector_db(
|
||||
vector_db_id,
|
||||
embedding_model,
|
||||
|
@ -99,8 +98,7 @@ class VectorIORouter(VectorIO):
|
|||
chunks: List[Chunk],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
logcat.debug(
|
||||
"core",
|
||||
logger.debug(
|
||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
||||
|
@ -111,7 +109,7 @@ class VectorIORouter(VectorIO):
|
|||
query: InterleavedContent,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryChunksResponse:
|
||||
logcat.debug("core", f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
||||
|
||||
|
||||
|
@ -122,15 +120,15 @@ class InferenceRouter(Inference):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing InferenceRouter")
|
||||
logger.debug("Initializing InferenceRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "InferenceRouter.initialize")
|
||||
logger.debug("InferenceRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "InferenceRouter.shutdown")
|
||||
logger.debug("InferenceRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def register_model(
|
||||
|
@ -141,8 +139,7 @@ class InferenceRouter(Inference):
|
|||
metadata: Optional[Dict[str, Any]] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
) -> None:
|
||||
logcat.debug(
|
||||
"core",
|
||||
logger.debug(
|
||||
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
||||
)
|
||||
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
||||
|
@ -160,8 +157,7 @@ class InferenceRouter(Inference):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
logcat.debug(
|
||||
"core",
|
||||
logger.debug(
|
||||
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
||||
)
|
||||
if sampling_params is None:
|
||||
|
@ -226,8 +222,7 @@ class InferenceRouter(Inference):
|
|||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
logcat.debug(
|
||||
"core",
|
||||
logger.debug(
|
||||
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
|
||||
)
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
|
@ -257,7 +252,7 @@ class InferenceRouter(Inference):
|
|||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
logcat.debug("core", f"InferenceRouter.embeddings: {model_id}")
|
||||
logger.debug(f"InferenceRouter.embeddings: {model_id}")
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
|
@ -277,15 +272,15 @@ class SafetyRouter(Safety):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing SafetyRouter")
|
||||
logger.debug("Initializing SafetyRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "SafetyRouter.initialize")
|
||||
logger.debug("SafetyRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "SafetyRouter.shutdown")
|
||||
logger.debug("SafetyRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def register_shield(
|
||||
|
@ -295,7 +290,7 @@ class SafetyRouter(Safety):
|
|||
provider_id: Optional[str] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> Shield:
|
||||
logcat.debug("core", f"SafetyRouter.register_shield: {shield_id}")
|
||||
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
||||
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
||||
|
||||
async def run_shield(
|
||||
|
@ -304,7 +299,7 @@ class SafetyRouter(Safety):
|
|||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
logcat.debug("core", f"SafetyRouter.run_shield: {shield_id}")
|
||||
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||
return await self.routing_table.get_provider_impl(shield_id).run_shield(
|
||||
shield_id=shield_id,
|
||||
messages=messages,
|
||||
|
@ -317,15 +312,15 @@ class DatasetIORouter(DatasetIO):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing DatasetIORouter")
|
||||
logger.debug("Initializing DatasetIORouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "DatasetIORouter.initialize")
|
||||
logger.debug("DatasetIORouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "DatasetIORouter.shutdown")
|
||||
logger.debug("DatasetIORouter.shutdown")
|
||||
pass
|
||||
|
||||
async def get_rows_paginated(
|
||||
|
@ -335,8 +330,7 @@ class DatasetIORouter(DatasetIO):
|
|||
page_token: Optional[str] = None,
|
||||
filter_condition: Optional[str] = None,
|
||||
) -> PaginatedRowsResult:
|
||||
logcat.debug(
|
||||
"core",
|
||||
logger.debug(
|
||||
f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}",
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
|
||||
|
@ -347,7 +341,7 @@ class DatasetIORouter(DatasetIO):
|
|||
)
|
||||
|
||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||
logcat.debug("core", f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
||||
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
||||
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
|
||||
dataset_id=dataset_id,
|
||||
rows=rows,
|
||||
|
@ -359,15 +353,15 @@ class ScoringRouter(Scoring):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing ScoringRouter")
|
||||
logger.debug("Initializing ScoringRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "ScoringRouter.initialize")
|
||||
logger.debug("ScoringRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "ScoringRouter.shutdown")
|
||||
logger.debug("ScoringRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def score_batch(
|
||||
|
@ -376,7 +370,7 @@ class ScoringRouter(Scoring):
|
|||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
logcat.debug("core", f"ScoringRouter.score_batch: {dataset_id}")
|
||||
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
||||
res = {}
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
|
||||
|
@ -397,10 +391,7 @@ class ScoringRouter(Scoring):
|
|||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
) -> ScoreResponse:
|
||||
logcat.debug(
|
||||
"core",
|
||||
f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions",
|
||||
)
|
||||
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
|
||||
res = {}
|
||||
# look up and map each scoring function to its provider impl
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
|
@ -418,15 +409,15 @@ class EvalRouter(Eval):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing EvalRouter")
|
||||
logger.debug("Initializing EvalRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "EvalRouter.initialize")
|
||||
logger.debug("EvalRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "EvalRouter.shutdown")
|
||||
logger.debug("EvalRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def run_eval(
|
||||
|
@ -434,7 +425,7 @@ class EvalRouter(Eval):
|
|||
benchmark_id: str,
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> Job:
|
||||
logcat.debug("core", f"EvalRouter.run_eval: {benchmark_id}")
|
||||
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
|
||||
benchmark_id=benchmark_id,
|
||||
benchmark_config=benchmark_config,
|
||||
|
@ -447,7 +438,7 @@ class EvalRouter(Eval):
|
|||
scoring_functions: List[str],
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
logcat.debug("core", f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
||||
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
||||
benchmark_id=benchmark_id,
|
||||
input_rows=input_rows,
|
||||
|
@ -460,7 +451,7 @@ class EvalRouter(Eval):
|
|||
benchmark_id: str,
|
||||
job_id: str,
|
||||
) -> Optional[JobStatus]:
|
||||
logcat.debug("core", f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
||||
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
|
||||
|
||||
async def job_cancel(
|
||||
|
@ -468,7 +459,7 @@ class EvalRouter(Eval):
|
|||
benchmark_id: str,
|
||||
job_id: str,
|
||||
) -> None:
|
||||
logcat.debug("core", f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
|
||||
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
|
||||
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
|
||||
benchmark_id,
|
||||
job_id,
|
||||
|
@ -479,7 +470,7 @@ class EvalRouter(Eval):
|
|||
benchmark_id: str,
|
||||
job_id: str,
|
||||
) -> EvaluateResponse:
|
||||
logcat.debug("core", f"EvalRouter.job_result: {benchmark_id}, {job_id}")
|
||||
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
|
||||
benchmark_id,
|
||||
job_id,
|
||||
|
@ -492,7 +483,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing ToolRuntimeRouter.RagToolImpl")
|
||||
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def query(
|
||||
|
@ -501,7 +492,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
vector_db_ids: List[str],
|
||||
query_config: Optional[RAGQueryConfig] = None,
|
||||
) -> RAGQueryResult:
|
||||
logcat.debug("core", f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
||||
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
||||
return await self.routing_table.get_provider_impl("knowledge_search").query(
|
||||
content, vector_db_ids, query_config
|
||||
)
|
||||
|
@ -512,9 +503,8 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
vector_db_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
logcat.debug(
|
||||
"core",
|
||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}",
|
||||
logger.debug(
|
||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
||||
)
|
||||
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
||||
documents, vector_db_id, chunk_size_in_tokens
|
||||
|
@ -524,7 +514,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logcat.debug("core", "Initializing ToolRuntimeRouter")
|
||||
logger.debug("Initializing ToolRuntimeRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
|
||||
|
@ -533,15 +523,15 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logcat.debug("core", "ToolRuntimeRouter.initialize")
|
||||
logger.debug("ToolRuntimeRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logcat.debug("core", "ToolRuntimeRouter.shutdown")
|
||||
logger.debug("ToolRuntimeRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
|
||||
logcat.debug("core", f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
||||
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
||||
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
||||
tool_name=tool_name,
|
||||
kwargs=kwargs,
|
||||
|
@ -550,5 +540,5 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
logcat.debug("core", f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
||||
|
|
|
@ -9,7 +9,6 @@ import asyncio
|
|||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
|
@ -28,7 +27,6 @@ from fastapi.responses import JSONResponse, StreamingResponse
|
|||
from pydantic import BaseModel, ValidationError
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
|
@ -39,6 +37,7 @@ from llama_stack.distribution.stack import (
|
|||
replace_env_vars,
|
||||
validate_env_pair,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
|
||||
|
@ -54,8 +53,7 @@ from .endpoints import get_all_api_endpoints
|
|||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s")
|
||||
logcat.init()
|
||||
logger = get_logger(name=__name__, category="server")
|
||||
|
||||
|
||||
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
||||
|
@ -142,23 +140,23 @@ def handle_signal(app, signum, _) -> None:
|
|||
not block the current execution.
|
||||
"""
|
||||
signame = signal.Signals(signum).name
|
||||
logcat.info("server", f"Received signal {signame} ({signum}). Exiting gracefully...")
|
||||
logger.info(f"Received signal {signame} ({signum}). Exiting gracefully...")
|
||||
|
||||
async def shutdown():
|
||||
try:
|
||||
# Gracefully shut down implementations
|
||||
for impl in app.__llama_stack_impls__.values():
|
||||
impl_name = impl.__class__.__name__
|
||||
logcat.info("server", f"Shutting down {impl_name}")
|
||||
logger.info("Shutting down %s", impl_name)
|
||||
try:
|
||||
if hasattr(impl, "shutdown"):
|
||||
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
||||
else:
|
||||
logcat.warning("server", f"No shutdown method for {impl_name}")
|
||||
logger.warning("No shutdown method for %s", impl_name)
|
||||
except asyncio.TimeoutError:
|
||||
logcat.exception("server", f"Shutdown timeout for {impl_name}")
|
||||
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
|
||||
except Exception as e:
|
||||
logcat.exception("server", f"Failed to shutdown {impl_name}: {e}")
|
||||
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
|
||||
|
||||
# Gather all running tasks
|
||||
loop = asyncio.get_running_loop()
|
||||
|
@ -172,7 +170,7 @@ def handle_signal(app, signum, _) -> None:
|
|||
try:
|
||||
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
|
||||
except asyncio.TimeoutError:
|
||||
logcat.exception("server", "Timeout while waiting for tasks to finish")
|
||||
logger.exception("Timeout while waiting for tasks to finish")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
|
@ -184,9 +182,9 @@ def handle_signal(app, signum, _) -> None:
|
|||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logcat.info("server", "Starting up")
|
||||
logger.info("Starting up")
|
||||
yield
|
||||
logcat.info("server", "Shutting down")
|
||||
logger.info("Shutting down")
|
||||
for impl in app.__llama_stack_impls__.values():
|
||||
await impl.shutdown()
|
||||
|
||||
|
@ -209,11 +207,11 @@ async def sse_generator(event_gen):
|
|||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
logcat.info("server", "Generator cancelled")
|
||||
logger.info("Generator cancelled")
|
||||
await event_gen.aclose()
|
||||
except Exception as e:
|
||||
logcat.exception("server", f"Error in sse_generator: {e}")
|
||||
logcat.exception("server", f"Traceback: {''.join(traceback.format_exception(type(e), e, e.__traceback__))}")
|
||||
logger.exception(f"Error in sse_generator: {e}")
|
||||
logger.exception(f"Traceback: {''.join(traceback.format_exception(type(e), e, e.__traceback__))}")
|
||||
yield create_sse_event(
|
||||
{
|
||||
"error": {
|
||||
|
@ -235,7 +233,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
|
|||
value = func(**kwargs)
|
||||
return await maybe_await(value)
|
||||
except Exception as e:
|
||||
logcat.exception("server", f"Error in {func.__name__}")
|
||||
traceback.print_exception(e)
|
||||
raise translate_exception(e) from e
|
||||
|
||||
sig = inspect.signature(func)
|
||||
|
@ -314,8 +312,6 @@ class ClientVersionMiddleware:
|
|||
|
||||
|
||||
def main():
|
||||
logcat.init()
|
||||
|
||||
"""Start the LlamaStack server."""
|
||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||
parser.add_argument(
|
||||
|
@ -355,10 +351,10 @@ def main():
|
|||
for env_pair in args.env:
|
||||
try:
|
||||
key, value = validate_env_pair(env_pair)
|
||||
logcat.info("server", f"Setting CLI environment variable {key} => {value}")
|
||||
logger.info(f"Setting CLI environment variable {key} => {value}")
|
||||
os.environ[key] = value
|
||||
except ValueError as e:
|
||||
logcat.error("server", f"Error: {str(e)}")
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
if args.yaml_config:
|
||||
|
@ -366,12 +362,12 @@ def main():
|
|||
config_file = Path(args.yaml_config)
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"Config file {config_file} does not exist")
|
||||
logcat.info("server", f"Using config file: {config_file}")
|
||||
logger.info(f"Using config file: {config_file}")
|
||||
elif args.template:
|
||||
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"Template {args.template} does not exist")
|
||||
logcat.info("server", f"Using template {args.template} config file: {config_file}")
|
||||
logger.info(f"Using template {args.template} config file: {config_file}")
|
||||
else:
|
||||
raise ValueError("Either --yaml-config or --template must be provided")
|
||||
|
||||
|
@ -379,10 +375,9 @@ def main():
|
|||
config = replace_env_vars(yaml.safe_load(fp))
|
||||
config = StackRunConfig(**config)
|
||||
|
||||
logcat.info("server", "Run configuration:")
|
||||
logger.info("Run configuration:")
|
||||
safe_config = redact_sensitive_fields(config.model_dump())
|
||||
for log_line in yaml.dump(safe_config, indent=2).split("\n"):
|
||||
logcat.info("server", log_line)
|
||||
logger.info(yaml.dump(safe_config, indent=2))
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.add_middleware(TracingMiddleware)
|
||||
|
@ -392,7 +387,7 @@ def main():
|
|||
try:
|
||||
impls = asyncio.run(construct_stack(config))
|
||||
except InvalidProviderError as e:
|
||||
logcat.error("server", f"Error: {str(e)}")
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
if Api.telemetry in impls:
|
||||
|
@ -437,7 +432,7 @@ def main():
|
|||
)
|
||||
)
|
||||
|
||||
logcat.debug("server", f"serving APIs: {apis_to_serve}")
|
||||
logger.debug(f"serving APIs: {apis_to_serve}")
|
||||
|
||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||
app.exception_handler(Exception)(global_exception_handler)
|
||||
|
@ -464,10 +459,10 @@ def main():
|
|||
"ssl_keyfile": keyfile,
|
||||
"ssl_certfile": certfile,
|
||||
}
|
||||
logcat.info("server", f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
||||
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
||||
|
||||
listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0"
|
||||
logcat.info("server", f"Listening on {listen_host}:{port}")
|
||||
logger.info(f"Listening on {listen_host}:{port}")
|
||||
|
||||
uvicorn_config = {
|
||||
"app": app,
|
||||
|
|
|
@ -11,9 +11,7 @@ import tempfile
|
|||
from typing import Any, Dict, Optional
|
||||
|
||||
import yaml
|
||||
from termcolor import colored
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.batch_inference import BatchInference
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
|
@ -39,8 +37,11 @@ from llama_stack.distribution.distribution import get_provider_registry
|
|||
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
||||
from llama_stack.distribution.store.registry import create_dist_registry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class LlamaStack(
|
||||
VectorDBs,
|
||||
|
@ -101,9 +102,8 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
|
|||
objects_to_process = response.data if hasattr(response, "data") else response
|
||||
|
||||
for obj in objects_to_process:
|
||||
logcat.debug(
|
||||
"core",
|
||||
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
|
||||
logger.debug(
|
||||
f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -100,12 +100,15 @@ esac
|
|||
|
||||
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
|
||||
set -x
|
||||
|
||||
$PYTHON_BINARY -m llama_stack.distribution.server.server \
|
||||
--yaml-config "$yaml_config" \
|
||||
--port "$port" \
|
||||
$env_vars \
|
||||
$other_args
|
||||
elif [[ "$env_type" == "container" ]]; then
|
||||
set -x
|
||||
|
||||
# Check if container command is available
|
||||
if ! is_command_available $CONTAINER_BINARY; then
|
||||
printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2
|
||||
|
@ -141,8 +144,6 @@ elif [[ "$env_type" == "container" ]]; then
|
|||
version_tag=$(curl -s $URL | jq -r '.info.version')
|
||||
fi
|
||||
|
||||
set -x
|
||||
|
||||
$CONTAINER_BINARY run $CONTAINER_OPTS -it \
|
||||
-p $port:$port \
|
||||
$env_vars \
|
||||
|
|
169
llama_stack/log.py
Normal file
169
llama_stack/log.py
Normal file
|
@ -0,0 +1,169 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from logging.config import dictConfig
|
||||
from typing import Dict
|
||||
|
||||
from rich.console import Console
|
||||
from rich.logging import RichHandler
|
||||
|
||||
# Default log level
|
||||
DEFAULT_LOG_LEVEL = logging.INFO
|
||||
|
||||
# Predefined categories
|
||||
CATEGORIES = [
|
||||
"core",
|
||||
"server",
|
||||
"router",
|
||||
"inference",
|
||||
"agents",
|
||||
"safety",
|
||||
"eval",
|
||||
"tools",
|
||||
"client",
|
||||
]
|
||||
|
||||
# Initialize category levels with default level
|
||||
_category_levels: Dict[str, int] = {category: DEFAULT_LOG_LEVEL for category in CATEGORIES}
|
||||
|
||||
|
||||
def parse_environment_config(env_config: str) -> Dict[str, int]:
|
||||
"""
|
||||
Parse the LLAMA_STACK_LOGGING environment variable and return a dictionary of category log levels.
|
||||
|
||||
Parameters:
|
||||
env_config (str): The value of the LLAMA_STACK_LOGGING environment variable.
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: A dictionary mapping categories to their log levels.
|
||||
"""
|
||||
category_levels = {}
|
||||
for pair in env_config.split(";"):
|
||||
if not pair.strip():
|
||||
continue
|
||||
|
||||
try:
|
||||
category, level = pair.split("=", 1)
|
||||
category = category.strip().lower()
|
||||
level = level.strip().upper() # Convert to uppercase for logging._nameToLevel
|
||||
|
||||
level_value = logging._nameToLevel.get(level)
|
||||
if level_value is None:
|
||||
logging.warning(
|
||||
f"Unknown log level '{level}' for category '{category}'. Falling back to default 'INFO'."
|
||||
)
|
||||
continue
|
||||
|
||||
if category == "all":
|
||||
# Apply the log level to all categories and the root logger
|
||||
for cat in CATEGORIES:
|
||||
category_levels[cat] = level_value
|
||||
# Set the root logger's level to the specified level
|
||||
category_levels["root"] = level_value
|
||||
elif category in CATEGORIES:
|
||||
category_levels[category] = level_value
|
||||
logging.info(f"Setting '{category}' category to level '{level}'.")
|
||||
else:
|
||||
logging.warning(f"Unknown logging category: {category}. No changes made.")
|
||||
|
||||
except ValueError:
|
||||
logging.warning(f"Invalid logging configuration: '{pair}'. Expected format: 'category=level'.")
|
||||
|
||||
return category_levels
|
||||
|
||||
|
||||
class CustomRichHandler(RichHandler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs["console"] = Console(width=120)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
def setup_logging(category_levels: Dict[str, int]) -> None:
|
||||
"""
|
||||
Configure logging based on the provided category log levels.
|
||||
|
||||
Parameters:
|
||||
category_levels (Dict[str, int]): A dictionary mapping categories to their log levels.
|
||||
"""
|
||||
log_format = "[dim]%(asctime)s %(name)s:%(lineno)d[/] [yellow dim]%(category)s[/]: %(message)s"
|
||||
|
||||
class CategoryFilter(logging.Filter):
|
||||
"""Ensure category is always present in log records."""
|
||||
|
||||
def filter(self, record):
|
||||
if not hasattr(record, "category"):
|
||||
record.category = "uncategorized" # Default to 'uncategorized' if no category found
|
||||
return True
|
||||
|
||||
# Determine the root logger's level (default to WARNING if not specified)
|
||||
root_level = category_levels.get("root", logging.WARNING)
|
||||
|
||||
logging_config = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"rich": {
|
||||
"()": logging.Formatter,
|
||||
"format": log_format,
|
||||
}
|
||||
},
|
||||
"handlers": {
|
||||
"console": {
|
||||
"()": CustomRichHandler, # Use our custom handler class
|
||||
"formatter": "rich",
|
||||
"rich_tracebacks": True,
|
||||
"show_time": False,
|
||||
"show_path": False,
|
||||
"markup": True,
|
||||
"filters": ["category_filter"],
|
||||
}
|
||||
},
|
||||
"filters": {
|
||||
"category_filter": {
|
||||
"()": CategoryFilter,
|
||||
}
|
||||
},
|
||||
"loggers": {
|
||||
category: {
|
||||
"handlers": ["console"],
|
||||
"level": category_levels.get(category, DEFAULT_LOG_LEVEL),
|
||||
"propagate": False, # Disable propagation to root logger
|
||||
}
|
||||
for category in CATEGORIES
|
||||
},
|
||||
"root": {
|
||||
"handlers": ["console"],
|
||||
"level": root_level, # Set root logger's level dynamically
|
||||
},
|
||||
}
|
||||
dictConfig(logging_config)
|
||||
|
||||
|
||||
def get_logger(name: str, category: str = "uncategorized") -> logging.LoggerAdapter:
|
||||
"""
|
||||
Returns a logger with the specified name and category.
|
||||
If no category is provided, defaults to 'uncategorized'.
|
||||
|
||||
Parameters:
|
||||
name (str): The name of the logger (e.g., module or filename).
|
||||
category (str): The category of the logger (default 'uncategorized').
|
||||
|
||||
Returns:
|
||||
logging.LoggerAdapter: Configured logger with category support.
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(_category_levels.get(category, DEFAULT_LOG_LEVEL))
|
||||
return logging.LoggerAdapter(logger, {"category": category})
|
||||
|
||||
|
||||
env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
|
||||
if env_config:
|
||||
print(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}")
|
||||
_category_levels.update(parse_environment_config(env_config))
|
||||
|
||||
setup_logging(_category_levels)
|
|
@ -1,204 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Category-based logging utility for llama-stack.
|
||||
|
||||
This module provides a wrapper over the standard Python logging module that supports
|
||||
categorized logging with environment variable control.
|
||||
|
||||
Usage:
|
||||
from llama_stack import logcat
|
||||
logcat.info("server", "Starting up...")
|
||||
logcat.debug("inference", "Processing request...")
|
||||
|
||||
Environment variable:
|
||||
LLAMA_STACK_LOGGING: Semicolon-separated list of category=level pairs
|
||||
Example: "server=debug;inference=warning"
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
# ANSI color codes for terminal output
|
||||
COLORS = {
|
||||
"RESET": "\033[0m",
|
||||
"DEBUG": "\033[36m", # Cyan
|
||||
"INFO": "\033[32m", # Green
|
||||
"WARNING": "\033[33m", # Yellow
|
||||
"ERROR": "\033[31m", # Red
|
||||
"CRITICAL": "\033[35m", # Magenta
|
||||
"DIM": "\033[2m", # Dimmed text
|
||||
"YELLOW_DIM": "\033[2;33m", # Dimmed yellow
|
||||
}
|
||||
|
||||
# Static list of valid categories representing various parts of the Llama Stack
|
||||
# server codebase
|
||||
CATEGORIES = [
|
||||
"core",
|
||||
"server",
|
||||
"router",
|
||||
"inference",
|
||||
"agents",
|
||||
"safety",
|
||||
"eval",
|
||||
"tools",
|
||||
"client",
|
||||
]
|
||||
|
||||
_logger = logging.getLogger("llama_stack")
|
||||
_logger.propagate = False
|
||||
|
||||
_default_level = logging.INFO
|
||||
|
||||
# Category-level mapping (can be modified by environment variables)
|
||||
_category_levels: Dict[str, int] = {}
|
||||
|
||||
|
||||
class TerminalStreamHandler(logging.StreamHandler):
|
||||
def __init__(self, stream=None):
|
||||
super().__init__(stream)
|
||||
self.is_tty = hasattr(self.stream, "isatty") and self.stream.isatty()
|
||||
|
||||
def format(self, record):
|
||||
record.is_tty = self.is_tty
|
||||
return super().format(record)
|
||||
|
||||
|
||||
class ColoredFormatter(logging.Formatter):
|
||||
"""Custom formatter with colors and fixed-width level names"""
|
||||
|
||||
def format(self, record):
|
||||
levelname = record.levelname
|
||||
# Use only time with milliseconds, not date
|
||||
timestamp = datetime.datetime.now().strftime("%H:%M:%S.%f")[:-3] # HH:MM:SS.mmm format
|
||||
|
||||
file_info = f"{record.filename}:{record.lineno}"
|
||||
|
||||
# Get category from extra if available
|
||||
category = getattr(record, "category", None)
|
||||
msg = record.getMessage()
|
||||
|
||||
if getattr(record, "is_tty", False):
|
||||
color = COLORS.get(levelname, COLORS["RESET"])
|
||||
if category:
|
||||
category_formatted = f"{COLORS['YELLOW_DIM']}{category}{COLORS['RESET']} "
|
||||
formatted_msg = (
|
||||
f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']} "
|
||||
f"{file_info:<20} {category_formatted}{msg}"
|
||||
)
|
||||
else:
|
||||
formatted_msg = (
|
||||
f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']}] "
|
||||
f"{file_info:<20} {msg}"
|
||||
)
|
||||
else:
|
||||
if category:
|
||||
formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} [{category}] {msg}"
|
||||
else:
|
||||
formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} {msg}"
|
||||
|
||||
return formatted_msg
|
||||
|
||||
|
||||
def init(default_level: int = logging.INFO) -> None:
|
||||
global _default_level, _category_levels, _logger
|
||||
|
||||
_default_level = default_level
|
||||
|
||||
_logger.setLevel(logging.DEBUG)
|
||||
_logger.handlers = [] # Clear existing handlers
|
||||
|
||||
# Add our custom handler with the colored formatter
|
||||
handler = TerminalStreamHandler()
|
||||
formatter = ColoredFormatter()
|
||||
handler.setFormatter(formatter)
|
||||
_logger.addHandler(handler)
|
||||
|
||||
for category in CATEGORIES:
|
||||
_category_levels[category] = default_level
|
||||
|
||||
env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
|
||||
if env_config:
|
||||
for pair in env_config.split(";"):
|
||||
if not pair.strip():
|
||||
continue
|
||||
|
||||
try:
|
||||
category, level = pair.split("=", 1)
|
||||
category = category.strip().lower()
|
||||
level = level.strip().lower()
|
||||
|
||||
level_value = {
|
||||
"debug": logging.DEBUG,
|
||||
"info": logging.INFO,
|
||||
"warning": logging.WARNING,
|
||||
"warn": logging.WARNING,
|
||||
"error": logging.ERROR,
|
||||
"critical": logging.CRITICAL,
|
||||
}.get(level)
|
||||
|
||||
if level_value is None:
|
||||
_logger.warning(f"Unknown log level '{level}' for category '{category}'")
|
||||
continue
|
||||
|
||||
if category == "all":
|
||||
for cat in CATEGORIES:
|
||||
_category_levels[cat] = level_value
|
||||
else:
|
||||
if category in CATEGORIES:
|
||||
_category_levels[category] = level_value
|
||||
else:
|
||||
_logger.warning(f"Unknown logging category: {category}")
|
||||
|
||||
except ValueError:
|
||||
_logger.warning(f"Invalid logging configuration: {pair}")
|
||||
|
||||
|
||||
def _should_log(level: int, category: str) -> bool:
|
||||
category = category.lower()
|
||||
if category not in _category_levels:
|
||||
return False
|
||||
category_level = _category_levels[category]
|
||||
return level >= category_level
|
||||
|
||||
|
||||
def _log(level: int, level_name: str, category: str, msg: str, *args, **kwargs) -> None:
|
||||
if _should_log(level, category):
|
||||
kwargs.setdefault("extra", {})["category"] = category.lower()
|
||||
getattr(_logger, level_name)(msg, *args, stacklevel=3, **kwargs)
|
||||
|
||||
|
||||
def debug(category: str, msg: str, *args, **kwargs) -> None:
|
||||
_log(logging.DEBUG, "debug", category, msg, *args, **kwargs)
|
||||
|
||||
|
||||
def info(category: str, msg: str, *args, **kwargs) -> None:
|
||||
_log(logging.INFO, "info", category, msg, *args, **kwargs)
|
||||
|
||||
|
||||
def warning(category: str, msg: str, *args, **kwargs) -> None:
|
||||
_log(logging.WARNING, "warning", category, msg, *args, **kwargs)
|
||||
|
||||
|
||||
def warn(category: str, msg: str, *args, **kwargs) -> None:
|
||||
warning(category, msg, *args, **kwargs)
|
||||
|
||||
|
||||
def error(category: str, msg: str, *args, **kwargs) -> None:
|
||||
_log(logging.ERROR, "error", category, msg, *args, **kwargs)
|
||||
|
||||
|
||||
def critical(category: str, msg: str, *args, **kwargs) -> None:
|
||||
_log(logging.CRITICAL, "critical", category, msg, *args, **kwargs)
|
||||
|
||||
|
||||
def exception(category: str, msg: str, *args, **kwargs) -> None:
|
||||
if _should_log(logging.ERROR, category):
|
||||
kwargs.setdefault("extra", {})["category"] = category.lower()
|
||||
_logger.exception(msg, *args, stacklevel=2, **kwargs)
|
|
@ -17,7 +17,6 @@ from urllib.parse import urlparse
|
|||
|
||||
import httpx
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.apis.agents import (
|
||||
AgentConfig,
|
||||
AgentToolGroup,
|
||||
|
@ -67,6 +66,7 @@ from llama_stack.apis.tools import (
|
|||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
ToolCall,
|
||||
|
@ -88,6 +88,8 @@ MEMORY_QUERY_TOOL = "knowledge_search"
|
|||
WEB_SEARCH_TOOL = "web_search"
|
||||
RAG_TOOL_GROUP = "builtin::rag"
|
||||
|
||||
logger = get_logger(name=__name__, category="agents")
|
||||
|
||||
|
||||
class ChatAgent(ShieldRunnerMixin):
|
||||
def __init__(
|
||||
|
@ -609,7 +611,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
|
||||
if n_iter >= self.agent_config.max_infer_iters:
|
||||
logcat.info("agents", f"done with MAX iterations ({n_iter}), exiting.")
|
||||
logger.info(f"done with MAX iterations ({n_iter}), exiting.")
|
||||
# NOTE: mark end_of_turn to indicate to client that we are done with the turn
|
||||
# Do not continue the tool call loop after this point
|
||||
message.stop_reason = StopReason.end_of_turn
|
||||
|
@ -617,7 +619,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
break
|
||||
|
||||
if stop_reason == StopReason.out_of_tokens:
|
||||
logcat.info("agents", "out of token budget, exiting.")
|
||||
logger.info("out of token budget, exiting.")
|
||||
yield message
|
||||
break
|
||||
|
||||
|
@ -631,16 +633,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
message.content = [message.content] + output_attachments
|
||||
yield message
|
||||
else:
|
||||
logcat.debug(
|
||||
"agents",
|
||||
f"completion message with EOM (iter: {n_iter}): {str(message)}",
|
||||
)
|
||||
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
|
||||
input_messages = input_messages + [message]
|
||||
else:
|
||||
logcat.debug(
|
||||
"agents",
|
||||
f"completion message (iter: {n_iter}) from the model: {str(message)}",
|
||||
)
|
||||
logger.debug(f"completion message (iter: {n_iter}) from the model: {str(message)}")
|
||||
# 1. Start the tool execution step and progress
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
|
@ -983,7 +979,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
|||
path = urlparse(uri).path
|
||||
basename = os.path.basename(path)
|
||||
filepath = f"{tempdir}/{make_random_string() + basename}"
|
||||
logcat.info("agents", f"Downloading {url} -> {filepath}")
|
||||
logger.info(f"Downloading {url} -> {filepath}")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(uri)
|
||||
|
@ -1023,7 +1019,7 @@ async def execute_tool_call_maybe(
|
|||
else:
|
||||
name = name.value
|
||||
|
||||
logcat.info("agents", f"executing tool call: {name} with args: {tool_call.arguments}")
|
||||
logger.info(f"executing tool call: {name} with args: {tool_call.arguments}")
|
||||
result = await tool_runtime_api.invoke_tool(
|
||||
tool_name=name,
|
||||
kwargs={
|
||||
|
@ -1033,7 +1029,7 @@ async def execute_tool_call_maybe(
|
|||
**toolgroup_args.get(group_name, {}),
|
||||
},
|
||||
)
|
||||
logcat.debug("agents", f"tool call {name} completed with result: {result}")
|
||||
logger.info(f"tool call {name} completed with result: {result}")
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
@ -8,7 +8,6 @@ from typing import AsyncGenerator, List, Optional, Union
|
|||
|
||||
from fireworks.client import Fireworks
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
|
@ -33,6 +32,7 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
|
@ -55,6 +55,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
from .config import FireworksImplConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
def __init__(self, config: FireworksImplConfig) -> None:
|
||||
|
@ -237,7 +239,8 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
"stream": request.stream,
|
||||
**self._build_options(request.sampling_params, request.response_format, request.logprobs),
|
||||
}
|
||||
logcat.debug("inference", f"params to fireworks: {params}")
|
||||
logger.debug(f"params to fireworks: {params}")
|
||||
|
||||
return params
|
||||
|
||||
async def embeddings(
|
||||
|
|
|
@ -4,13 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
from ollama import AsyncClient
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.apis.common.content_types import (
|
||||
ImageContentItem,
|
||||
InterleavedContent,
|
||||
|
@ -35,6 +34,7 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
|
@ -59,7 +59,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .models import model_entries
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||
|
@ -72,7 +72,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
return AsyncClient(host=self.url)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
log.info(f"checking connectivity to Ollama at `{self.url}`...")
|
||||
logger.info(f"checking connectivity to Ollama at `{self.url}`...")
|
||||
try:
|
||||
await self.client.ps()
|
||||
except httpx.ConnectError as e:
|
||||
|
@ -214,7 +214,8 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
"options": sampling_options,
|
||||
"stream": request.stream,
|
||||
}
|
||||
logcat.debug("inference", f"params to ollama: {params}")
|
||||
logger.debug(f"params to ollama: {params}")
|
||||
|
||||
return params
|
||||
|
||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
|
@ -290,7 +291,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
async def register_model(self, model: Model) -> Model:
|
||||
model = await self.register_helper.register_model(model)
|
||||
if model.model_type == ModelType.embedding:
|
||||
log.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
|
||||
logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
|
||||
await self.client.pull(model.provider_resource_id)
|
||||
response = await self.client.list()
|
||||
else:
|
||||
|
|
|
@ -8,7 +8,6 @@ from typing import AsyncGenerator, List, Optional, Union
|
|||
|
||||
from together import Together
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
|
@ -32,6 +31,7 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
|
@ -54,6 +54,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
from .config import TogetherImplConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
def __init__(self, config: TogetherImplConfig) -> None:
|
||||
|
@ -224,8 +226,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
"stream": request.stream,
|
||||
**self._build_options(request.sampling_params, request.logprobs, request.response_format),
|
||||
}
|
||||
logcat.debug("inference", f"params to together: {params}")
|
||||
return params
|
||||
logger.debug(f"params to together: {params}")
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
|
|
|
@ -8,7 +8,6 @@ from typing import AsyncGenerator, AsyncIterator, List, Optional, Union
|
|||
|
||||
import litellm
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
|
@ -33,6 +32,7 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.apis.models.models import Model
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
|
@ -47,6 +47,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class LiteLLMOpenAIMixin(
|
||||
ModelRegistryHelper,
|
||||
|
@ -109,8 +111,7 @@ class LiteLLMOpenAIMixin(
|
|||
)
|
||||
|
||||
params = await self._get_params(request)
|
||||
logcat.debug("inference", f"params to litellm (openai compat): {params}")
|
||||
|
||||
logger.debug(f"params to litellm (openai compat): {params}")
|
||||
# unfortunately, we need to use synchronous litellm.completion here because litellm
|
||||
# caches various httpx.client objects in a non-eventloop aware manner
|
||||
response = litellm.completion(**params)
|
||||
|
|
|
@ -8,14 +8,12 @@ import asyncio
|
|||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
from PIL import Image as PIL_Image
|
||||
|
||||
from llama_stack import logcat
|
||||
from llama_stack.apis.common.content_types import (
|
||||
ImageContentItem,
|
||||
InterleavedContent,
|
||||
|
@ -34,6 +32,7 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
ModelFamily,
|
||||
RawContent,
|
||||
|
@ -58,7 +57,7 @@ from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
|||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):
|
||||
|
@ -464,7 +463,7 @@ def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefin
|
|||
def get_default_tool_prompt_format(model: str) -> ToolPromptFormat:
|
||||
llama_model = resolve_model(model)
|
||||
if llama_model is None:
|
||||
logcat.warning("inference", f"Could not resolve model {model}, defaulting to json tool prompt format")
|
||||
log.warning(f"Could not resolve model {model}, defaulting to json tool prompt format")
|
||||
return ToolPromptFormat.json
|
||||
|
||||
if llama_model.model_family == ModelFamily.llama3_1 or (
|
||||
|
|
|
@ -151,7 +151,6 @@ exclude = [
|
|||
"llama_stack/distribution",
|
||||
"llama_stack/apis",
|
||||
"llama_stack/cli",
|
||||
"llama_stack/logcat.py",
|
||||
"llama_stack/models",
|
||||
"llama_stack/strong_typing",
|
||||
"llama_stack/templates",
|
||||
|
@ -163,5 +162,5 @@ module = ["yaml", "fire"]
|
|||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "llama_stack.distribution.resolver"
|
||||
follow_imports = "normal" # This will force type checking on this module
|
||||
module = ["llama_stack.distribution.resolver", "llama_stack.log"]
|
||||
follow_imports = "normal" # This will force type checking on this module
|
||||
|
|
|
@ -1,88 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from llama_stack import logcat
|
||||
|
||||
|
||||
class TestLogcat(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.original_env = os.environ.get("LLAMA_STACK_LOGGING")
|
||||
|
||||
self.log_output = io.StringIO()
|
||||
self._init_logcat()
|
||||
|
||||
def tearDown(self):
|
||||
if self.original_env is not None:
|
||||
os.environ["LLAMA_STACK_LOGGING"] = self.original_env
|
||||
else:
|
||||
os.environ.pop("LLAMA_STACK_LOGGING", None)
|
||||
|
||||
def _init_logcat(self):
|
||||
logcat.init(default_level=logging.DEBUG)
|
||||
self.handler = logging.StreamHandler(self.log_output)
|
||||
self.handler.setFormatter(logging.Formatter("[%(category)s] %(message)s"))
|
||||
logcat._logger.handlers.clear()
|
||||
logcat._logger.addHandler(self.handler)
|
||||
|
||||
def test_basic_logging(self):
|
||||
logcat.info("server", "Info message")
|
||||
logcat.warning("server", "Warning message")
|
||||
logcat.error("server", "Error message")
|
||||
|
||||
output = self.log_output.getvalue()
|
||||
self.assertIn("[server] Info message", output)
|
||||
self.assertIn("[server] Warning message", output)
|
||||
self.assertIn("[server] Error message", output)
|
||||
|
||||
def test_different_categories(self):
|
||||
# Log messages with different categories
|
||||
logcat.info("server", "Server message")
|
||||
logcat.info("inference", "Inference message")
|
||||
logcat.info("router", "Router message")
|
||||
|
||||
output = self.log_output.getvalue()
|
||||
self.assertIn("[server] Server message", output)
|
||||
self.assertIn("[inference] Inference message", output)
|
||||
self.assertIn("[router] Router message", output)
|
||||
|
||||
def test_env_var_control(self):
|
||||
os.environ["LLAMA_STACK_LOGGING"] = "server=debug;inference=warning"
|
||||
self._init_logcat()
|
||||
|
||||
# These should be visible based on the environment settings
|
||||
logcat.debug("server", "Server debug message")
|
||||
logcat.info("server", "Server info message")
|
||||
logcat.warning("inference", "Inference warning message")
|
||||
logcat.error("inference", "Inference error message")
|
||||
|
||||
# These should be filtered out based on the environment settings
|
||||
logcat.debug("inference", "Inference debug message")
|
||||
logcat.info("inference", "Inference info message")
|
||||
|
||||
output = self.log_output.getvalue()
|
||||
self.assertIn("[server] Server debug message", output)
|
||||
self.assertIn("[server] Server info message", output)
|
||||
self.assertIn("[inference] Inference warning message", output)
|
||||
self.assertIn("[inference] Inference error message", output)
|
||||
|
||||
self.assertNotIn("[inference] Inference debug message", output)
|
||||
self.assertNotIn("[inference] Inference info message", output)
|
||||
|
||||
def test_invalid_category(self):
|
||||
logcat.info("nonexistent", "This message should not be logged")
|
||||
|
||||
# Check that the message was not logged
|
||||
output = self.log_output.getvalue()
|
||||
self.assertNotIn("[nonexistent] This message should not be logged", output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Add table
Add a link
Reference in a new issue