add tracing back to the lib cli (#595)

Adds back all the tracing logic removed from library client. also adds
back the logging to agent_instance.
This commit is contained in:
Dinesh Yeduguru 2024-12-11 08:44:20 -08:00 committed by GitHub
parent 1c03ba239e
commit e128f2547a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 76 additions and 117 deletions

View file

@ -24,6 +24,7 @@ from termcolor import cprint
from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.resolver import ProviderRegistry from llama_stack.distribution.resolver import ProviderRegistry
from llama_stack.distribution.server.endpoints import get_all_api_endpoints from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.distribution.stack import ( from llama_stack.distribution.stack import (
@ -32,6 +33,12 @@ from llama_stack.distribution.stack import (
replace_env_vars, replace_env_vars,
) )
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
setup_logger,
start_trace,
)
T = TypeVar("T") T = TypeVar("T")
@ -240,6 +247,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
) )
return False return False
if Api.telemetry in self.impls:
setup_logger(self.impls[Api.telemetry])
console = Console() console = Console()
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:") console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:")
console.print(yaml.dump(self.config.model_dump(), indent=2)) console.print(yaml.dump(self.config.model_dump(), indent=2))
@ -276,14 +286,20 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
async def _call_non_streaming( async def _call_non_streaming(
self, path: str, body: dict = None, cast_to: Any = None self, path: str, body: dict = None, cast_to: Any = None
): ):
await start_trace(path, {"__location__": "library_client"})
try:
func = self.endpoint_impls.get(path) func = self.endpoint_impls.get(path)
if not func: if not func:
raise ValueError(f"No endpoint found for {path}") raise ValueError(f"No endpoint found for {path}")
body = self._convert_body(path, body) body = self._convert_body(path, body)
return convert_pydantic_to_json_value(await func(**body), cast_to) return convert_pydantic_to_json_value(await func(**body), cast_to)
finally:
await end_trace()
async def _call_streaming(self, path: str, body: dict = None, cast_to: Any = None): async def _call_streaming(self, path: str, body: dict = None, cast_to: Any = None):
await start_trace(path, {"__location__": "library_client"})
try:
func = self.endpoint_impls.get(path) func = self.endpoint_impls.get(path)
if not func: if not func:
raise ValueError(f"No endpoint found for {path}") raise ValueError(f"No endpoint found for {path}")
@ -291,6 +307,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
body = self._convert_body(path, body) body = self._convert_body(path, body)
async for chunk in await func(**body): async for chunk in await func(**body):
yield convert_pydantic_to_json_value(chunk, cast_to) yield convert_pydantic_to_json_value(chunk, cast_to)
finally:
await end_trace()
def _convert_body(self, path: str, body: Optional[dict] = None) -> dict: def _convert_body(self, path: str, body: Optional[dict] = None) -> dict:
if not body: if not body:

View file

@ -185,9 +185,9 @@ class ChatAgent(ShieldRunnerMixin):
stream=request.stream, stream=request.stream,
): ):
if isinstance(chunk, CompletionMessage): if isinstance(chunk, CompletionMessage):
# log.info( log.info(
# f"{chunk.role.capitalize()}: {chunk.content}", f"{chunk.role.capitalize()}: {chunk.content}",
# ) )
output_message = chunk output_message = chunk
continue continue
@ -280,7 +280,6 @@ class ChatAgent(ShieldRunnerMixin):
touchpoint: str, touchpoint: str,
) -> AsyncGenerator: ) -> AsyncGenerator:
with tracing.span("run_shields") as span: with tracing.span("run_shields") as span:
span.set_attribute("turn_id", turn_id)
span.set_attribute("input", [m.model_dump_json() for m in messages]) span.set_attribute("input", [m.model_dump_json() for m in messages])
if len(shields) == 0: if len(shields) == 0:
span.set_attribute("output", "no shields") span.set_attribute("output", "no shields")
@ -405,11 +404,6 @@ class ChatAgent(ShieldRunnerMixin):
n_iter = 0 n_iter = 0
while True: while True:
msg = input_messages[-1] msg = input_messages[-1]
# if len(str(msg)) > 1000:
# msg_str = f"{str(msg)[:500]}...<more>...{str(msg)[-500:]}"
# else:
# msg_str = str(msg)
# log.info(f"{msg_str}")
step_id = str(uuid.uuid4()) step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
@ -514,12 +508,12 @@ class ChatAgent(ShieldRunnerMixin):
) )
if n_iter >= self.agent_config.max_infer_iters: if n_iter >= self.agent_config.max_infer_iters:
# log.info("Done with MAX iterations, exiting.") log.info("Done with MAX iterations, exiting.")
yield message yield message
break break
if stop_reason == StopReason.out_of_tokens: if stop_reason == StopReason.out_of_tokens:
# log.info("Out of token budget, exiting.") log.info("Out of token budget, exiting.")
yield message yield message
break break
@ -533,10 +527,10 @@ class ChatAgent(ShieldRunnerMixin):
message.content = [message.content] + attachments message.content = [message.content] + attachments
yield message yield message
else: else:
# log.info(f"Partial message: {str(message)}") log.info(f"Partial message: {str(message)}")
input_messages = input_messages + [message] input_messages = input_messages + [message]
else: else:
# log.info(f"{str(message)}") log.info(f"{str(message)}")
try: try:
tool_call = message.tool_calls[0] tool_call = message.tool_calls[0]
@ -800,7 +794,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
path = urlparse(uri).path path = urlparse(uri).path
basename = os.path.basename(path) basename = os.path.basename(path)
filepath = f"{tempdir}/{make_random_string() + basename}" filepath = f"{tempdir}/{make_random_string() + basename}"
# log.info(f"Downloading {url} -> {filepath}") log.info(f"Downloading {url} -> {filepath}")
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
r = await client.get(uri) r = await client.get(uri)

View file

@ -7,33 +7,24 @@
import json import json
import os import os
import sqlite3 import sqlite3
import threading from datetime import datetime
from datetime import datetime, timedelta
from typing import Dict
from opentelemetry.sdk.trace import SpanProcessor from opentelemetry.sdk.trace import SpanProcessor
from opentelemetry.trace import Span from opentelemetry.trace import Span
class SQLiteSpanProcessor(SpanProcessor): class SQLiteSpanProcessor(SpanProcessor):
def __init__(self, conn_string, ttl_days=30): def __init__(self, conn_string):
"""Initialize the SQLite span processor with a connection string.""" """Initialize the SQLite span processor with a connection string."""
self.conn_string = conn_string self.conn_string = conn_string
self.ttl_days = ttl_days self.conn = None
self.cleanup_task = None
self._thread_local = threading.local()
self._connections: Dict[int, sqlite3.Connection] = {}
self._lock = threading.Lock()
self.setup_database() self.setup_database()
def _get_connection(self) -> sqlite3.Connection: def _get_connection(self) -> sqlite3.Connection:
"""Get a thread-specific database connection.""" """Get the database connection."""
thread_id = threading.get_ident() if self.conn is None:
with self._lock: self.conn = sqlite3.connect(self.conn_string, check_same_thread=False)
if thread_id not in self._connections: return self.conn
conn = sqlite3.connect(self.conn_string)
self._connections[thread_id] = conn
return self._connections[thread_id]
def setup_database(self): def setup_database(self):
"""Create the necessary tables if they don't exist.""" """Create the necessary tables if they don't exist."""
@ -94,60 +85,6 @@ class SQLiteSpanProcessor(SpanProcessor):
conn.commit() conn.commit()
cursor.close() cursor.close()
# Start periodic cleanup in a separate thread
self.cleanup_task = threading.Thread(target=self._periodic_cleanup, daemon=True)
self.cleanup_task.start()
def _cleanup_old_data(self):
"""Delete records older than TTL."""
try:
conn = self._get_connection()
cutoff_date = (datetime.now() - timedelta(days=self.ttl_days)).isoformat()
cursor = conn.cursor()
# Delete old span events
cursor.execute(
"""
DELETE FROM span_events
WHERE span_id IN (
SELECT span_id FROM spans
WHERE trace_id IN (
SELECT trace_id FROM traces
WHERE created_at < ?
)
)
""",
(cutoff_date,),
)
# Delete old spans
cursor.execute(
"""
DELETE FROM spans
WHERE trace_id IN (
SELECT trace_id FROM traces
WHERE created_at < ?
)
""",
(cutoff_date,),
)
# Delete old traces
cursor.execute("DELETE FROM traces WHERE created_at < ?", (cutoff_date,))
conn.commit()
cursor.close()
except Exception as e:
print(f"Error during cleanup: {e}")
def _periodic_cleanup(self):
"""Run cleanup periodically."""
import time
while True:
time.sleep(3600) # Sleep for 1 hour
self._cleanup_old_data()
def on_start(self, span: Span, parent_context=None): def on_start(self, span: Span, parent_context=None):
"""Called when a span starts.""" """Called when a span starts."""
pass pass
@ -231,11 +168,9 @@ class SQLiteSpanProcessor(SpanProcessor):
def shutdown(self): def shutdown(self):
"""Cleanup any resources.""" """Cleanup any resources."""
with self._lock: if self.conn:
for conn in self._connections.values(): self.conn.close()
if conn: self.conn = None
conn.close()
self._connections.clear()
def force_flush(self, timeout_millis=30000): def force_flush(self, timeout_millis=30000):
"""Force export of spans.""" """Force export of spans."""

View file

@ -6,29 +6,31 @@
import asyncio import asyncio
import inspect import inspect
import json from datetime import datetime
from functools import wraps from functools import wraps
from typing import Any, AsyncGenerator, Callable, Type, TypeVar from typing import Any, AsyncGenerator, Callable, Type, TypeVar
from uuid import UUID
from pydantic import BaseModel from pydantic import BaseModel
T = TypeVar("T") T = TypeVar("T")
def serialize_value(value: Any) -> str: def serialize_value(value: Any) -> Any:
"""Helper function to serialize values to string representation.""" """Serialize a single value into JSON-compatible format."""
try: if value is None:
if isinstance(value, BaseModel): return None
return value.model_dump_json() elif isinstance(value, (str, int, float, bool)):
elif isinstance(value, list) and value and isinstance(value[0], BaseModel): return value
return json.dumps([item.model_dump_json() for item in value]) elif isinstance(value, BaseModel):
elif hasattr(value, "to_dict"): return value.model_dump()
return json.dumps(value.to_dict()) elif isinstance(value, (list, tuple, set)):
elif isinstance(value, (dict, list, int, float, str, bool)): return [serialize_value(item) for item in value]
return json.dumps(value) elif isinstance(value, dict):
else: return {str(k): serialize_value(v) for k, v in value.items()}
elif isinstance(value, (datetime, UUID)):
return str(value) return str(value)
except Exception: else:
return str(value) return str(value)
@ -47,16 +49,26 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple: def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple:
class_name = self.__class__.__name__ class_name = self.__class__.__name__
method_name = method.__name__ method_name = method.__name__
span_type = ( span_type = (
"async_generator" if is_async_gen else "async" if is_async else "sync" "async_generator" if is_async_gen else "async" if is_async else "sync"
) )
sig = inspect.signature(method)
param_names = list(sig.parameters.keys())[1:] # Skip 'self'
combined_args = {}
for i, arg in enumerate(args):
param_name = (
param_names[i] if i < len(param_names) else f"position_{i+1}"
)
combined_args[param_name] = serialize_value(arg)
for k, v in kwargs.items():
combined_args[str(k)] = serialize_value(v)
span_attributes = { span_attributes = {
"__autotraced__": True, "__autotraced__": True,
"__class__": class_name, "__class__": class_name,
"__method__": method_name, "__method__": method_name,
"__type__": span_type, "__type__": span_type,
"__args__": serialize_value(args), "__args__": str(combined_args),
} }
return class_name, method_name, span_attributes return class_name, method_name, span_attributes