add tracing back to the lib cli (#595)

Adds back all the tracing logic removed from library client. also adds
back the logging to agent_instance.
This commit is contained in:
Dinesh Yeduguru 2024-12-11 08:44:20 -08:00 committed by GitHub
parent 1c03ba239e
commit e128f2547a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 76 additions and 117 deletions

View file

@ -24,6 +24,7 @@ from termcolor import cprint
from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.resolver import ProviderRegistry
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.distribution.stack import (
@ -32,6 +33,12 @@ from llama_stack.distribution.stack import (
replace_env_vars,
)
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
setup_logger,
start_trace,
)
T = TypeVar("T")
@ -240,6 +247,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
)
return False
if Api.telemetry in self.impls:
setup_logger(self.impls[Api.telemetry])
console = Console()
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:")
console.print(yaml.dump(self.config.model_dump(), indent=2))
@ -276,21 +286,29 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
async def _call_non_streaming(
self, path: str, body: dict = None, cast_to: Any = None
):
func = self.endpoint_impls.get(path)
if not func:
raise ValueError(f"No endpoint found for {path}")
await start_trace(path, {"__location__": "library_client"})
try:
func = self.endpoint_impls.get(path)
if not func:
raise ValueError(f"No endpoint found for {path}")
body = self._convert_body(path, body)
return convert_pydantic_to_json_value(await func(**body), cast_to)
body = self._convert_body(path, body)
return convert_pydantic_to_json_value(await func(**body), cast_to)
finally:
await end_trace()
async def _call_streaming(self, path: str, body: dict = None, cast_to: Any = None):
func = self.endpoint_impls.get(path)
if not func:
raise ValueError(f"No endpoint found for {path}")
await start_trace(path, {"__location__": "library_client"})
try:
func = self.endpoint_impls.get(path)
if not func:
raise ValueError(f"No endpoint found for {path}")
body = self._convert_body(path, body)
async for chunk in await func(**body):
yield convert_pydantic_to_json_value(chunk, cast_to)
body = self._convert_body(path, body)
async for chunk in await func(**body):
yield convert_pydantic_to_json_value(chunk, cast_to)
finally:
await end_trace()
def _convert_body(self, path: str, body: Optional[dict] = None) -> dict:
if not body:

View file

@ -185,9 +185,9 @@ class ChatAgent(ShieldRunnerMixin):
stream=request.stream,
):
if isinstance(chunk, CompletionMessage):
# log.info(
# f"{chunk.role.capitalize()}: {chunk.content}",
# )
log.info(
f"{chunk.role.capitalize()}: {chunk.content}",
)
output_message = chunk
continue
@ -280,7 +280,6 @@ class ChatAgent(ShieldRunnerMixin):
touchpoint: str,
) -> AsyncGenerator:
with tracing.span("run_shields") as span:
span.set_attribute("turn_id", turn_id)
span.set_attribute("input", [m.model_dump_json() for m in messages])
if len(shields) == 0:
span.set_attribute("output", "no shields")
@ -405,11 +404,6 @@ class ChatAgent(ShieldRunnerMixin):
n_iter = 0
while True:
msg = input_messages[-1]
# if len(str(msg)) > 1000:
# msg_str = f"{str(msg)[:500]}...<more>...{str(msg)[-500:]}"
# else:
# msg_str = str(msg)
# log.info(f"{msg_str}")
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
@ -514,12 +508,12 @@ class ChatAgent(ShieldRunnerMixin):
)
if n_iter >= self.agent_config.max_infer_iters:
# log.info("Done with MAX iterations, exiting.")
log.info("Done with MAX iterations, exiting.")
yield message
break
if stop_reason == StopReason.out_of_tokens:
# log.info("Out of token budget, exiting.")
log.info("Out of token budget, exiting.")
yield message
break
@ -533,10 +527,10 @@ class ChatAgent(ShieldRunnerMixin):
message.content = [message.content] + attachments
yield message
else:
# log.info(f"Partial message: {str(message)}")
log.info(f"Partial message: {str(message)}")
input_messages = input_messages + [message]
else:
# log.info(f"{str(message)}")
log.info(f"{str(message)}")
try:
tool_call = message.tool_calls[0]
@ -800,7 +794,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
path = urlparse(uri).path
basename = os.path.basename(path)
filepath = f"{tempdir}/{make_random_string() + basename}"
# log.info(f"Downloading {url} -> {filepath}")
log.info(f"Downloading {url} -> {filepath}")
async with httpx.AsyncClient() as client:
r = await client.get(uri)

View file

@ -7,33 +7,24 @@
import json
import os
import sqlite3
import threading
from datetime import datetime, timedelta
from typing import Dict
from datetime import datetime
from opentelemetry.sdk.trace import SpanProcessor
from opentelemetry.trace import Span
class SQLiteSpanProcessor(SpanProcessor):
def __init__(self, conn_string, ttl_days=30):
def __init__(self, conn_string):
"""Initialize the SQLite span processor with a connection string."""
self.conn_string = conn_string
self.ttl_days = ttl_days
self.cleanup_task = None
self._thread_local = threading.local()
self._connections: Dict[int, sqlite3.Connection] = {}
self._lock = threading.Lock()
self.conn = None
self.setup_database()
def _get_connection(self) -> sqlite3.Connection:
"""Get a thread-specific database connection."""
thread_id = threading.get_ident()
with self._lock:
if thread_id not in self._connections:
conn = sqlite3.connect(self.conn_string)
self._connections[thread_id] = conn
return self._connections[thread_id]
"""Get the database connection."""
if self.conn is None:
self.conn = sqlite3.connect(self.conn_string, check_same_thread=False)
return self.conn
def setup_database(self):
"""Create the necessary tables if they don't exist."""
@ -94,60 +85,6 @@ class SQLiteSpanProcessor(SpanProcessor):
conn.commit()
cursor.close()
# Start periodic cleanup in a separate thread
self.cleanup_task = threading.Thread(target=self._periodic_cleanup, daemon=True)
self.cleanup_task.start()
def _cleanup_old_data(self):
"""Delete records older than TTL."""
try:
conn = self._get_connection()
cutoff_date = (datetime.now() - timedelta(days=self.ttl_days)).isoformat()
cursor = conn.cursor()
# Delete old span events
cursor.execute(
"""
DELETE FROM span_events
WHERE span_id IN (
SELECT span_id FROM spans
WHERE trace_id IN (
SELECT trace_id FROM traces
WHERE created_at < ?
)
)
""",
(cutoff_date,),
)
# Delete old spans
cursor.execute(
"""
DELETE FROM spans
WHERE trace_id IN (
SELECT trace_id FROM traces
WHERE created_at < ?
)
""",
(cutoff_date,),
)
# Delete old traces
cursor.execute("DELETE FROM traces WHERE created_at < ?", (cutoff_date,))
conn.commit()
cursor.close()
except Exception as e:
print(f"Error during cleanup: {e}")
def _periodic_cleanup(self):
"""Run cleanup periodically."""
import time
while True:
time.sleep(3600) # Sleep for 1 hour
self._cleanup_old_data()
def on_start(self, span: Span, parent_context=None):
"""Called when a span starts."""
pass
@ -231,11 +168,9 @@ class SQLiteSpanProcessor(SpanProcessor):
def shutdown(self):
"""Cleanup any resources."""
with self._lock:
for conn in self._connections.values():
if conn:
conn.close()
self._connections.clear()
if self.conn:
self.conn.close()
self.conn = None
def force_flush(self, timeout_millis=30000):
"""Force export of spans."""

View file

@ -6,29 +6,31 @@
import asyncio
import inspect
import json
from datetime import datetime
from functools import wraps
from typing import Any, AsyncGenerator, Callable, Type, TypeVar
from uuid import UUID
from pydantic import BaseModel
T = TypeVar("T")
def serialize_value(value: Any) -> str:
"""Helper function to serialize values to string representation."""
try:
if isinstance(value, BaseModel):
return value.model_dump_json()
elif isinstance(value, list) and value and isinstance(value[0], BaseModel):
return json.dumps([item.model_dump_json() for item in value])
elif hasattr(value, "to_dict"):
return json.dumps(value.to_dict())
elif isinstance(value, (dict, list, int, float, str, bool)):
return json.dumps(value)
else:
return str(value)
except Exception:
def serialize_value(value: Any) -> Any:
"""Serialize a single value into JSON-compatible format."""
if value is None:
return None
elif isinstance(value, (str, int, float, bool)):
return value
elif isinstance(value, BaseModel):
return value.model_dump()
elif isinstance(value, (list, tuple, set)):
return [serialize_value(item) for item in value]
elif isinstance(value, dict):
return {str(k): serialize_value(v) for k, v in value.items()}
elif isinstance(value, (datetime, UUID)):
return str(value)
else:
return str(value)
@ -47,16 +49,26 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple:
class_name = self.__class__.__name__
method_name = method.__name__
span_type = (
"async_generator" if is_async_gen else "async" if is_async else "sync"
)
sig = inspect.signature(method)
param_names = list(sig.parameters.keys())[1:] # Skip 'self'
combined_args = {}
for i, arg in enumerate(args):
param_name = (
param_names[i] if i < len(param_names) else f"position_{i+1}"
)
combined_args[param_name] = serialize_value(arg)
for k, v in kwargs.items():
combined_args[str(k)] = serialize_value(v)
span_attributes = {
"__autotraced__": True,
"__class__": class_name,
"__method__": method_name,
"__type__": span_type,
"__args__": serialize_value(args),
"__args__": str(combined_args),
}
return class_name, method_name, span_attributes