From e128f2547a748fecba29ef33435ddef2e9328ef7 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 11 Dec 2024 08:44:20 -0800 Subject: [PATCH] add tracing back to the lib cli (#595) Adds back all the tracing logic removed from library client. also adds back the logging to agent_instance. --- llama_stack/distribution/library_client.py | 40 ++++++--- .../agents/meta_reference/agent_instance.py | 22 ++--- .../meta_reference/sqlite_span_processor.py | 85 +++---------------- .../utils/telemetry/trace_protocol.py | 46 ++++++---- 4 files changed, 76 insertions(+), 117 deletions(-) diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 8766f7a72..ee483f2bc 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -24,6 +24,7 @@ from termcolor import cprint from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.configure import parse_and_maybe_upgrade_config +from llama_stack.distribution.datatypes import Api from llama_stack.distribution.resolver import ProviderRegistry from llama_stack.distribution.server.endpoints import get_all_api_endpoints from llama_stack.distribution.stack import ( @@ -32,6 +33,12 @@ from llama_stack.distribution.stack import ( replace_env_vars, ) +from llama_stack.providers.utils.telemetry.tracing import ( + end_trace, + setup_logger, + start_trace, +) + T = TypeVar("T") @@ -240,6 +247,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): ) return False + if Api.telemetry in self.impls: + setup_logger(self.impls[Api.telemetry]) + console = Console() console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:") console.print(yaml.dump(self.config.model_dump(), indent=2)) @@ -276,21 +286,29 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): async def _call_non_streaming( self, path: str, body: dict = None, cast_to: Any = None ): - func = self.endpoint_impls.get(path) - if not func: - raise ValueError(f"No endpoint found for {path}") + await start_trace(path, {"__location__": "library_client"}) + try: + func = self.endpoint_impls.get(path) + if not func: + raise ValueError(f"No endpoint found for {path}") - body = self._convert_body(path, body) - return convert_pydantic_to_json_value(await func(**body), cast_to) + body = self._convert_body(path, body) + return convert_pydantic_to_json_value(await func(**body), cast_to) + finally: + await end_trace() async def _call_streaming(self, path: str, body: dict = None, cast_to: Any = None): - func = self.endpoint_impls.get(path) - if not func: - raise ValueError(f"No endpoint found for {path}") + await start_trace(path, {"__location__": "library_client"}) + try: + func = self.endpoint_impls.get(path) + if not func: + raise ValueError(f"No endpoint found for {path}") - body = self._convert_body(path, body) - async for chunk in await func(**body): - yield convert_pydantic_to_json_value(chunk, cast_to) + body = self._convert_body(path, body) + async for chunk in await func(**body): + yield convert_pydantic_to_json_value(chunk, cast_to) + finally: + await end_trace() def _convert_body(self, path: str, body: Optional[dict] = None) -> dict: if not body: diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index f08bdb032..b403b9203 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -185,9 +185,9 @@ class ChatAgent(ShieldRunnerMixin): stream=request.stream, ): if isinstance(chunk, CompletionMessage): - # log.info( - # f"{chunk.role.capitalize()}: {chunk.content}", - # ) + log.info( + f"{chunk.role.capitalize()}: {chunk.content}", + ) output_message = chunk continue @@ -280,7 +280,6 @@ class ChatAgent(ShieldRunnerMixin): touchpoint: str, ) -> AsyncGenerator: with tracing.span("run_shields") as span: - span.set_attribute("turn_id", turn_id) span.set_attribute("input", [m.model_dump_json() for m in messages]) if len(shields) == 0: span.set_attribute("output", "no shields") @@ -405,11 +404,6 @@ class ChatAgent(ShieldRunnerMixin): n_iter = 0 while True: msg = input_messages[-1] - # if len(str(msg)) > 1000: - # msg_str = f"{str(msg)[:500]}......{str(msg)[-500:]}" - # else: - # msg_str = str(msg) - # log.info(f"{msg_str}") step_id = str(uuid.uuid4()) yield AgentTurnResponseStreamChunk( @@ -514,12 +508,12 @@ class ChatAgent(ShieldRunnerMixin): ) if n_iter >= self.agent_config.max_infer_iters: - # log.info("Done with MAX iterations, exiting.") + log.info("Done with MAX iterations, exiting.") yield message break if stop_reason == StopReason.out_of_tokens: - # log.info("Out of token budget, exiting.") + log.info("Out of token budget, exiting.") yield message break @@ -533,10 +527,10 @@ class ChatAgent(ShieldRunnerMixin): message.content = [message.content] + attachments yield message else: - # log.info(f"Partial message: {str(message)}") + log.info(f"Partial message: {str(message)}") input_messages = input_messages + [message] else: - # log.info(f"{str(message)}") + log.info(f"{str(message)}") try: tool_call = message.tool_calls[0] @@ -800,7 +794,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa path = urlparse(uri).path basename = os.path.basename(path) filepath = f"{tempdir}/{make_random_string() + basename}" - # log.info(f"Downloading {url} -> {filepath}") + log.info(f"Downloading {url} -> {filepath}") async with httpx.AsyncClient() as client: r = await client.get(uri) diff --git a/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py b/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py index 553dd5000..3455c2236 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py @@ -7,33 +7,24 @@ import json import os import sqlite3 -import threading -from datetime import datetime, timedelta -from typing import Dict +from datetime import datetime from opentelemetry.sdk.trace import SpanProcessor from opentelemetry.trace import Span class SQLiteSpanProcessor(SpanProcessor): - def __init__(self, conn_string, ttl_days=30): + def __init__(self, conn_string): """Initialize the SQLite span processor with a connection string.""" self.conn_string = conn_string - self.ttl_days = ttl_days - self.cleanup_task = None - self._thread_local = threading.local() - self._connections: Dict[int, sqlite3.Connection] = {} - self._lock = threading.Lock() + self.conn = None self.setup_database() def _get_connection(self) -> sqlite3.Connection: - """Get a thread-specific database connection.""" - thread_id = threading.get_ident() - with self._lock: - if thread_id not in self._connections: - conn = sqlite3.connect(self.conn_string) - self._connections[thread_id] = conn - return self._connections[thread_id] + """Get the database connection.""" + if self.conn is None: + self.conn = sqlite3.connect(self.conn_string, check_same_thread=False) + return self.conn def setup_database(self): """Create the necessary tables if they don't exist.""" @@ -94,60 +85,6 @@ class SQLiteSpanProcessor(SpanProcessor): conn.commit() cursor.close() - # Start periodic cleanup in a separate thread - self.cleanup_task = threading.Thread(target=self._periodic_cleanup, daemon=True) - self.cleanup_task.start() - - def _cleanup_old_data(self): - """Delete records older than TTL.""" - try: - conn = self._get_connection() - cutoff_date = (datetime.now() - timedelta(days=self.ttl_days)).isoformat() - cursor = conn.cursor() - - # Delete old span events - cursor.execute( - """ - DELETE FROM span_events - WHERE span_id IN ( - SELECT span_id FROM spans - WHERE trace_id IN ( - SELECT trace_id FROM traces - WHERE created_at < ? - ) - ) - """, - (cutoff_date,), - ) - - # Delete old spans - cursor.execute( - """ - DELETE FROM spans - WHERE trace_id IN ( - SELECT trace_id FROM traces - WHERE created_at < ? - ) - """, - (cutoff_date,), - ) - - # Delete old traces - cursor.execute("DELETE FROM traces WHERE created_at < ?", (cutoff_date,)) - - conn.commit() - cursor.close() - except Exception as e: - print(f"Error during cleanup: {e}") - - def _periodic_cleanup(self): - """Run cleanup periodically.""" - import time - - while True: - time.sleep(3600) # Sleep for 1 hour - self._cleanup_old_data() - def on_start(self, span: Span, parent_context=None): """Called when a span starts.""" pass @@ -231,11 +168,9 @@ class SQLiteSpanProcessor(SpanProcessor): def shutdown(self): """Cleanup any resources.""" - with self._lock: - for conn in self._connections.values(): - if conn: - conn.close() - self._connections.clear() + if self.conn: + self.conn.close() + self.conn = None def force_flush(self, timeout_millis=30000): """Force export of spans.""" diff --git a/llama_stack/providers/utils/telemetry/trace_protocol.py b/llama_stack/providers/utils/telemetry/trace_protocol.py index 3fcce08e9..938d333fa 100644 --- a/llama_stack/providers/utils/telemetry/trace_protocol.py +++ b/llama_stack/providers/utils/telemetry/trace_protocol.py @@ -6,29 +6,31 @@ import asyncio import inspect -import json +from datetime import datetime from functools import wraps from typing import Any, AsyncGenerator, Callable, Type, TypeVar +from uuid import UUID from pydantic import BaseModel T = TypeVar("T") -def serialize_value(value: Any) -> str: - """Helper function to serialize values to string representation.""" - try: - if isinstance(value, BaseModel): - return value.model_dump_json() - elif isinstance(value, list) and value and isinstance(value[0], BaseModel): - return json.dumps([item.model_dump_json() for item in value]) - elif hasattr(value, "to_dict"): - return json.dumps(value.to_dict()) - elif isinstance(value, (dict, list, int, float, str, bool)): - return json.dumps(value) - else: - return str(value) - except Exception: +def serialize_value(value: Any) -> Any: + """Serialize a single value into JSON-compatible format.""" + if value is None: + return None + elif isinstance(value, (str, int, float, bool)): + return value + elif isinstance(value, BaseModel): + return value.model_dump() + elif isinstance(value, (list, tuple, set)): + return [serialize_value(item) for item in value] + elif isinstance(value, dict): + return {str(k): serialize_value(v) for k, v in value.items()} + elif isinstance(value, (datetime, UUID)): + return str(value) + else: return str(value) @@ -47,16 +49,26 @@ def trace_protocol(cls: Type[T]) -> Type[T]: def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple: class_name = self.__class__.__name__ method_name = method.__name__ - span_type = ( "async_generator" if is_async_gen else "async" if is_async else "sync" ) + sig = inspect.signature(method) + param_names = list(sig.parameters.keys())[1:] # Skip 'self' + combined_args = {} + for i, arg in enumerate(args): + param_name = ( + param_names[i] if i < len(param_names) else f"position_{i+1}" + ) + combined_args[param_name] = serialize_value(arg) + for k, v in kwargs.items(): + combined_args[str(k)] = serialize_value(v) + span_attributes = { "__autotraced__": True, "__class__": class_name, "__method__": method_name, "__type__": span_type, - "__args__": serialize_value(args), + "__args__": str(combined_args), } return class_name, method_name, span_attributes