From 84904914c2fcf24c3013abeebc77ff264ae3fb4b Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 10 Dec 2024 11:31:41 -0800 Subject: [PATCH] add tracing back to the lib cli --- llama_stack/distribution/library_client.py | 40 +++++++++---- .../agents/meta_reference/agent_instance.py | 39 ++++++------- .../meta_reference/sqlite_span_processor.py | 26 +++++++-- .../utils/telemetry/trace_protocol.py | 58 ++++++++++++------- .../providers/utils/telemetry/tracing.py | 4 -- 5 files changed, 106 insertions(+), 61 deletions(-) diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 8766f7a72..ee483f2bc 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -24,6 +24,7 @@ from termcolor import cprint from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.configure import parse_and_maybe_upgrade_config +from llama_stack.distribution.datatypes import Api from llama_stack.distribution.resolver import ProviderRegistry from llama_stack.distribution.server.endpoints import get_all_api_endpoints from llama_stack.distribution.stack import ( @@ -32,6 +33,12 @@ from llama_stack.distribution.stack import ( replace_env_vars, ) +from llama_stack.providers.utils.telemetry.tracing import ( + end_trace, + setup_logger, + start_trace, +) + T = TypeVar("T") @@ -240,6 +247,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): ) return False + if Api.telemetry in self.impls: + setup_logger(self.impls[Api.telemetry]) + console = Console() console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:") console.print(yaml.dump(self.config.model_dump(), indent=2)) @@ -276,21 +286,29 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): async def _call_non_streaming( self, path: str, body: dict = None, cast_to: Any = None ): - func = self.endpoint_impls.get(path) - if not func: - raise ValueError(f"No endpoint found for {path}") + await start_trace(path, {"__location__": "library_client"}) + try: + func = self.endpoint_impls.get(path) + if not func: + raise ValueError(f"No endpoint found for {path}") - body = self._convert_body(path, body) - return convert_pydantic_to_json_value(await func(**body), cast_to) + body = self._convert_body(path, body) + return convert_pydantic_to_json_value(await func(**body), cast_to) + finally: + await end_trace() async def _call_streaming(self, path: str, body: dict = None, cast_to: Any = None): - func = self.endpoint_impls.get(path) - if not func: - raise ValueError(f"No endpoint found for {path}") + await start_trace(path, {"__location__": "library_client"}) + try: + func = self.endpoint_impls.get(path) + if not func: + raise ValueError(f"No endpoint found for {path}") - body = self._convert_body(path, body) - async for chunk in await func(**body): - yield convert_pydantic_to_json_value(chunk, cast_to) + body = self._convert_body(path, body) + async for chunk in await func(**body): + yield convert_pydantic_to_json_value(chunk, cast_to) + finally: + await end_trace() def _convert_body(self, path: str, body: Optional[dict] = None) -> dict: if not body: diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index f08bdb032..df2b55b92 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -142,7 +142,7 @@ class ChatAgent(ShieldRunnerMixin): async def create_and_execute_turn( self, request: AgentTurnCreateRequest ) -> AsyncGenerator: - with tracing.span("create_and_execute_turn") as span: + with tracing.SpanContextManager("create_and_execute_turn") as span: span.set_attribute("session_id", request.session_id) span.set_attribute("agent_id", self.agent_id) span.set_attribute("request", request.model_dump_json()) @@ -185,9 +185,9 @@ class ChatAgent(ShieldRunnerMixin): stream=request.stream, ): if isinstance(chunk, CompletionMessage): - # log.info( - # f"{chunk.role.capitalize()}: {chunk.content}", - # ) + log.info( + f"{chunk.role.capitalize()}: {chunk.content}", + ) output_message = chunk continue @@ -279,8 +279,7 @@ class ChatAgent(ShieldRunnerMixin): shields: List[str], touchpoint: str, ) -> AsyncGenerator: - with tracing.span("run_shields") as span: - span.set_attribute("turn_id", turn_id) + with tracing.SpanContextManager("run_shields") as span: span.set_attribute("input", [m.model_dump_json() for m in messages]) if len(shields) == 0: span.set_attribute("output", "no shields") @@ -360,7 +359,7 @@ class ChatAgent(ShieldRunnerMixin): # TODO: find older context from the session and either replace it # or append with a sliding window. this is really a very simplistic implementation - with tracing.span("retrieve_rag_context") as span: + with tracing.SpanContextManager("retrieve_rag_context") as span: rag_context, bank_ids = await self._retrieve_context( session_id, input_messages, attachments ) @@ -405,11 +404,11 @@ class ChatAgent(ShieldRunnerMixin): n_iter = 0 while True: msg = input_messages[-1] - # if len(str(msg)) > 1000: - # msg_str = f"{str(msg)[:500]}......{str(msg)[-500:]}" - # else: - # msg_str = str(msg) - # log.info(f"{msg_str}") + if len(str(msg)) > 1000: + msg_str = f"{str(msg)[:500]}......{str(msg)[-500:]}" + else: + msg_str = str(msg) + log.info(f"{msg_str}") step_id = str(uuid.uuid4()) yield AgentTurnResponseStreamChunk( @@ -425,7 +424,7 @@ class ChatAgent(ShieldRunnerMixin): content = "" stop_reason = None - with tracing.span("inference") as span: + with tracing.SpanContextManager("inference") as span: async for chunk in await self.inference_api.chat_completion( self.agent_config.model, input_messages, @@ -514,12 +513,12 @@ class ChatAgent(ShieldRunnerMixin): ) if n_iter >= self.agent_config.max_infer_iters: - # log.info("Done with MAX iterations, exiting.") + log.info("Done with MAX iterations, exiting.") yield message break if stop_reason == StopReason.out_of_tokens: - # log.info("Out of token budget, exiting.") + log.info("Out of token budget, exiting.") yield message break @@ -533,10 +532,10 @@ class ChatAgent(ShieldRunnerMixin): message.content = [message.content] + attachments yield message else: - # log.info(f"Partial message: {str(message)}") + log.info(f"Partial message: {str(message)}") input_messages = input_messages + [message] else: - # log.info(f"{str(message)}") + log.info(f"{str(message)}") try: tool_call = message.tool_calls[0] @@ -564,7 +563,7 @@ class ChatAgent(ShieldRunnerMixin): ) ) - with tracing.span( + with tracing.SpanContextManager( "tool_execution", { "tool_name": tool_call.tool_name, @@ -713,7 +712,7 @@ class ChatAgent(ShieldRunnerMixin): ) for a in attachments ] - with tracing.span("insert_documents"): + with tracing.SpanContextManager("insert_documents"): await self.memory_api.insert_documents(bank_id, documents) else: session_info = await self.storage.get_session_info(session_id) @@ -800,7 +799,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa path = urlparse(uri).path basename = os.path.basename(path) filepath = f"{tempdir}/{make_random_string() + basename}" - # log.info(f"Downloading {url} -> {filepath}") + log.info(f"Downloading {url} -> {filepath}") async with httpx.AsyncClient() as client: r = await client.get(uri) diff --git a/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py b/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py index 553dd5000..f8fdbc12f 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py @@ -20,6 +20,7 @@ class SQLiteSpanProcessor(SpanProcessor): """Initialize the SQLite span processor with a connection string.""" self.conn_string = conn_string self.ttl_days = ttl_days + self._shutdown_event = threading.Event() self.cleanup_task = None self._thread_local = threading.local() self._connections: Dict[int, sqlite3.Connection] = {} @@ -144,9 +145,10 @@ class SQLiteSpanProcessor(SpanProcessor): """Run cleanup periodically.""" import time - while True: + while not self._shutdown_event.is_set(): time.sleep(3600) # Sleep for 1 hour - self._cleanup_old_data() + if not self._shutdown_event.is_set(): + self._cleanup_old_data() def on_start(self, span: Span, parent_context=None): """Called when a span starts.""" @@ -231,11 +233,23 @@ class SQLiteSpanProcessor(SpanProcessor): def shutdown(self): """Cleanup any resources.""" + self._shutdown_event.set() + + # Wait for cleanup thread to finish if it exists + if self.cleanup_task and self.cleanup_task.is_alive(): + self.cleanup_task.join(timeout=5.0) + current_thread_id = threading.get_ident() + with self._lock: - for conn in self._connections.values(): - if conn: - conn.close() - self._connections.clear() + # Close all connections from the current thread + for thread_id, conn in list(self._connections.items()): + if thread_id == current_thread_id: + try: + if conn: + conn.close() + del self._connections[thread_id] + except sqlite3.Error: + pass # Ignore errors during shutdown def force_flush(self, timeout_millis=30000): """Force export of spans.""" diff --git a/llama_stack/providers/utils/telemetry/trace_protocol.py b/llama_stack/providers/utils/telemetry/trace_protocol.py index 3fcce08e9..57f58d50f 100644 --- a/llama_stack/providers/utils/telemetry/trace_protocol.py +++ b/llama_stack/providers/utils/telemetry/trace_protocol.py @@ -6,29 +6,31 @@ import asyncio import inspect -import json +from datetime import datetime from functools import wraps from typing import Any, AsyncGenerator, Callable, Type, TypeVar +from uuid import UUID from pydantic import BaseModel T = TypeVar("T") -def serialize_value(value: Any) -> str: - """Helper function to serialize values to string representation.""" - try: - if isinstance(value, BaseModel): - return value.model_dump_json() - elif isinstance(value, list) and value and isinstance(value[0], BaseModel): - return json.dumps([item.model_dump_json() for item in value]) - elif hasattr(value, "to_dict"): - return json.dumps(value.to_dict()) - elif isinstance(value, (dict, list, int, float, str, bool)): - return json.dumps(value) - else: - return str(value) - except Exception: +def serialize_value(value: Any) -> Any: + """Serialize a single value into JSON-compatible format.""" + if value is None: + return None + elif isinstance(value, (str, int, float, bool)): + return value + elif isinstance(value, BaseModel): + return value.model_dump() + elif isinstance(value, (list, tuple, set)): + return [serialize_value(item) for item in value] + elif isinstance(value, dict): + return {str(k): serialize_value(v) for k, v in value.items()} + elif isinstance(value, (datetime, UUID)): + return str(value) + else: return str(value) @@ -47,16 +49,26 @@ def trace_protocol(cls: Type[T]) -> Type[T]: def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple: class_name = self.__class__.__name__ method_name = method.__name__ - span_type = ( "async_generator" if is_async_gen else "async" if is_async else "sync" ) + sig = inspect.signature(method) + param_names = list(sig.parameters.keys())[1:] # Skip 'self' + combined_args = {} + for i, arg in enumerate(args): + param_name = ( + param_names[i] if i < len(param_names) else f"position_{i+1}" + ) + combined_args[param_name] = serialize_value(arg) + for k, v in kwargs.items(): + combined_args[str(k)] = serialize_value(v) + span_attributes = { "__autotraced__": True, "__class__": class_name, "__method__": method_name, "__type__": span_type, - "__args__": serialize_value(args), + "__args__": str(combined_args), } return class_name, method_name, span_attributes @@ -69,7 +81,9 @@ def trace_protocol(cls: Type[T]) -> Type[T]: self, *args, **kwargs ) - with tracing.span(f"{class_name}.{method_name}", span_attributes) as span: + with tracing.SpanContextManager( + f"{class_name}.{method_name}", span_attributes + ) as span: try: count = 0 async for item in method(self, *args, **kwargs): @@ -84,7 +98,9 @@ def trace_protocol(cls: Type[T]) -> Type[T]: self, *args, **kwargs ) - with tracing.span(f"{class_name}.{method_name}", span_attributes) as span: + with tracing.SpanContextManager( + f"{class_name}.{method_name}", span_attributes + ) as span: try: result = await method(self, *args, **kwargs) span.set_attribute("output", serialize_value(result)) @@ -99,7 +115,9 @@ def trace_protocol(cls: Type[T]) -> Type[T]: self, *args, **kwargs ) - with tracing.span(f"{class_name}.{method_name}", span_attributes) as span: + with tracing.SpanContextManager( + f"{class_name}.{method_name}", span_attributes + ) as span: try: result = method(self, *args, **kwargs) span.set_attribute("output", serialize_value(result)) diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py index 54558afdc..326a7c023 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -259,10 +259,6 @@ class SpanContextManager: return wrapper -def span(name: str, attributes: Dict[str, Any] = None): - return SpanContextManager(name, attributes) - - def get_current_span() -> Optional[Span]: global CURRENT_TRACE_CONTEXT context = CURRENT_TRACE_CONTEXT