add tracing back to the lib cli

This commit is contained in:
Dinesh Yeduguru 2024-12-10 11:31:41 -08:00
parent e2054d53e4
commit 84904914c2
5 changed files with 106 additions and 61 deletions

View file

@ -24,6 +24,7 @@ from termcolor import cprint
from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.resolver import ProviderRegistry from llama_stack.distribution.resolver import ProviderRegistry
from llama_stack.distribution.server.endpoints import get_all_api_endpoints from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.distribution.stack import ( from llama_stack.distribution.stack import (
@ -32,6 +33,12 @@ from llama_stack.distribution.stack import (
replace_env_vars, replace_env_vars,
) )
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
setup_logger,
start_trace,
)
T = TypeVar("T") T = TypeVar("T")
@ -240,6 +247,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
) )
return False return False
if Api.telemetry in self.impls:
setup_logger(self.impls[Api.telemetry])
console = Console() console = Console()
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:") console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:")
console.print(yaml.dump(self.config.model_dump(), indent=2)) console.print(yaml.dump(self.config.model_dump(), indent=2))
@ -276,21 +286,29 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
async def _call_non_streaming( async def _call_non_streaming(
self, path: str, body: dict = None, cast_to: Any = None self, path: str, body: dict = None, cast_to: Any = None
): ):
func = self.endpoint_impls.get(path) await start_trace(path, {"__location__": "library_client"})
if not func: try:
raise ValueError(f"No endpoint found for {path}") func = self.endpoint_impls.get(path)
if not func:
raise ValueError(f"No endpoint found for {path}")
body = self._convert_body(path, body) body = self._convert_body(path, body)
return convert_pydantic_to_json_value(await func(**body), cast_to) return convert_pydantic_to_json_value(await func(**body), cast_to)
finally:
await end_trace()
async def _call_streaming(self, path: str, body: dict = None, cast_to: Any = None): async def _call_streaming(self, path: str, body: dict = None, cast_to: Any = None):
func = self.endpoint_impls.get(path) await start_trace(path, {"__location__": "library_client"})
if not func: try:
raise ValueError(f"No endpoint found for {path}") func = self.endpoint_impls.get(path)
if not func:
raise ValueError(f"No endpoint found for {path}")
body = self._convert_body(path, body) body = self._convert_body(path, body)
async for chunk in await func(**body): async for chunk in await func(**body):
yield convert_pydantic_to_json_value(chunk, cast_to) yield convert_pydantic_to_json_value(chunk, cast_to)
finally:
await end_trace()
def _convert_body(self, path: str, body: Optional[dict] = None) -> dict: def _convert_body(self, path: str, body: Optional[dict] = None) -> dict:
if not body: if not body:

View file

@ -142,7 +142,7 @@ class ChatAgent(ShieldRunnerMixin):
async def create_and_execute_turn( async def create_and_execute_turn(
self, request: AgentTurnCreateRequest self, request: AgentTurnCreateRequest
) -> AsyncGenerator: ) -> AsyncGenerator:
with tracing.span("create_and_execute_turn") as span: with tracing.SpanContextManager("create_and_execute_turn") as span:
span.set_attribute("session_id", request.session_id) span.set_attribute("session_id", request.session_id)
span.set_attribute("agent_id", self.agent_id) span.set_attribute("agent_id", self.agent_id)
span.set_attribute("request", request.model_dump_json()) span.set_attribute("request", request.model_dump_json())
@ -185,9 +185,9 @@ class ChatAgent(ShieldRunnerMixin):
stream=request.stream, stream=request.stream,
): ):
if isinstance(chunk, CompletionMessage): if isinstance(chunk, CompletionMessage):
# log.info( log.info(
# f"{chunk.role.capitalize()}: {chunk.content}", f"{chunk.role.capitalize()}: {chunk.content}",
# ) )
output_message = chunk output_message = chunk
continue continue
@ -279,8 +279,7 @@ class ChatAgent(ShieldRunnerMixin):
shields: List[str], shields: List[str],
touchpoint: str, touchpoint: str,
) -> AsyncGenerator: ) -> AsyncGenerator:
with tracing.span("run_shields") as span: with tracing.SpanContextManager("run_shields") as span:
span.set_attribute("turn_id", turn_id)
span.set_attribute("input", [m.model_dump_json() for m in messages]) span.set_attribute("input", [m.model_dump_json() for m in messages])
if len(shields) == 0: if len(shields) == 0:
span.set_attribute("output", "no shields") span.set_attribute("output", "no shields")
@ -360,7 +359,7 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: find older context from the session and either replace it # TODO: find older context from the session and either replace it
# or append with a sliding window. this is really a very simplistic implementation # or append with a sliding window. this is really a very simplistic implementation
with tracing.span("retrieve_rag_context") as span: with tracing.SpanContextManager("retrieve_rag_context") as span:
rag_context, bank_ids = await self._retrieve_context( rag_context, bank_ids = await self._retrieve_context(
session_id, input_messages, attachments session_id, input_messages, attachments
) )
@ -405,11 +404,11 @@ class ChatAgent(ShieldRunnerMixin):
n_iter = 0 n_iter = 0
while True: while True:
msg = input_messages[-1] msg = input_messages[-1]
# if len(str(msg)) > 1000: if len(str(msg)) > 1000:
# msg_str = f"{str(msg)[:500]}...<more>...{str(msg)[-500:]}" msg_str = f"{str(msg)[:500]}...<more>...{str(msg)[-500:]}"
# else: else:
# msg_str = str(msg) msg_str = str(msg)
# log.info(f"{msg_str}") log.info(f"{msg_str}")
step_id = str(uuid.uuid4()) step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
@ -425,7 +424,7 @@ class ChatAgent(ShieldRunnerMixin):
content = "" content = ""
stop_reason = None stop_reason = None
with tracing.span("inference") as span: with tracing.SpanContextManager("inference") as span:
async for chunk in await self.inference_api.chat_completion( async for chunk in await self.inference_api.chat_completion(
self.agent_config.model, self.agent_config.model,
input_messages, input_messages,
@ -514,12 +513,12 @@ class ChatAgent(ShieldRunnerMixin):
) )
if n_iter >= self.agent_config.max_infer_iters: if n_iter >= self.agent_config.max_infer_iters:
# log.info("Done with MAX iterations, exiting.") log.info("Done with MAX iterations, exiting.")
yield message yield message
break break
if stop_reason == StopReason.out_of_tokens: if stop_reason == StopReason.out_of_tokens:
# log.info("Out of token budget, exiting.") log.info("Out of token budget, exiting.")
yield message yield message
break break
@ -533,10 +532,10 @@ class ChatAgent(ShieldRunnerMixin):
message.content = [message.content] + attachments message.content = [message.content] + attachments
yield message yield message
else: else:
# log.info(f"Partial message: {str(message)}") log.info(f"Partial message: {str(message)}")
input_messages = input_messages + [message] input_messages = input_messages + [message]
else: else:
# log.info(f"{str(message)}") log.info(f"{str(message)}")
try: try:
tool_call = message.tool_calls[0] tool_call = message.tool_calls[0]
@ -564,7 +563,7 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
with tracing.span( with tracing.SpanContextManager(
"tool_execution", "tool_execution",
{ {
"tool_name": tool_call.tool_name, "tool_name": tool_call.tool_name,
@ -713,7 +712,7 @@ class ChatAgent(ShieldRunnerMixin):
) )
for a in attachments for a in attachments
] ]
with tracing.span("insert_documents"): with tracing.SpanContextManager("insert_documents"):
await self.memory_api.insert_documents(bank_id, documents) await self.memory_api.insert_documents(bank_id, documents)
else: else:
session_info = await self.storage.get_session_info(session_id) session_info = await self.storage.get_session_info(session_id)
@ -800,7 +799,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
path = urlparse(uri).path path = urlparse(uri).path
basename = os.path.basename(path) basename = os.path.basename(path)
filepath = f"{tempdir}/{make_random_string() + basename}" filepath = f"{tempdir}/{make_random_string() + basename}"
# log.info(f"Downloading {url} -> {filepath}") log.info(f"Downloading {url} -> {filepath}")
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
r = await client.get(uri) r = await client.get(uri)

View file

@ -20,6 +20,7 @@ class SQLiteSpanProcessor(SpanProcessor):
"""Initialize the SQLite span processor with a connection string.""" """Initialize the SQLite span processor with a connection string."""
self.conn_string = conn_string self.conn_string = conn_string
self.ttl_days = ttl_days self.ttl_days = ttl_days
self._shutdown_event = threading.Event()
self.cleanup_task = None self.cleanup_task = None
self._thread_local = threading.local() self._thread_local = threading.local()
self._connections: Dict[int, sqlite3.Connection] = {} self._connections: Dict[int, sqlite3.Connection] = {}
@ -144,9 +145,10 @@ class SQLiteSpanProcessor(SpanProcessor):
"""Run cleanup periodically.""" """Run cleanup periodically."""
import time import time
while True: while not self._shutdown_event.is_set():
time.sleep(3600) # Sleep for 1 hour time.sleep(3600) # Sleep for 1 hour
self._cleanup_old_data() if not self._shutdown_event.is_set():
self._cleanup_old_data()
def on_start(self, span: Span, parent_context=None): def on_start(self, span: Span, parent_context=None):
"""Called when a span starts.""" """Called when a span starts."""
@ -231,11 +233,23 @@ class SQLiteSpanProcessor(SpanProcessor):
def shutdown(self): def shutdown(self):
"""Cleanup any resources.""" """Cleanup any resources."""
self._shutdown_event.set()
# Wait for cleanup thread to finish if it exists
if self.cleanup_task and self.cleanup_task.is_alive():
self.cleanup_task.join(timeout=5.0)
current_thread_id = threading.get_ident()
with self._lock: with self._lock:
for conn in self._connections.values(): # Close all connections from the current thread
if conn: for thread_id, conn in list(self._connections.items()):
conn.close() if thread_id == current_thread_id:
self._connections.clear() try:
if conn:
conn.close()
del self._connections[thread_id]
except sqlite3.Error:
pass # Ignore errors during shutdown
def force_flush(self, timeout_millis=30000): def force_flush(self, timeout_millis=30000):
"""Force export of spans.""" """Force export of spans."""

View file

@ -6,29 +6,31 @@
import asyncio import asyncio
import inspect import inspect
import json from datetime import datetime
from functools import wraps from functools import wraps
from typing import Any, AsyncGenerator, Callable, Type, TypeVar from typing import Any, AsyncGenerator, Callable, Type, TypeVar
from uuid import UUID
from pydantic import BaseModel from pydantic import BaseModel
T = TypeVar("T") T = TypeVar("T")
def serialize_value(value: Any) -> str: def serialize_value(value: Any) -> Any:
"""Helper function to serialize values to string representation.""" """Serialize a single value into JSON-compatible format."""
try: if value is None:
if isinstance(value, BaseModel): return None
return value.model_dump_json() elif isinstance(value, (str, int, float, bool)):
elif isinstance(value, list) and value and isinstance(value[0], BaseModel): return value
return json.dumps([item.model_dump_json() for item in value]) elif isinstance(value, BaseModel):
elif hasattr(value, "to_dict"): return value.model_dump()
return json.dumps(value.to_dict()) elif isinstance(value, (list, tuple, set)):
elif isinstance(value, (dict, list, int, float, str, bool)): return [serialize_value(item) for item in value]
return json.dumps(value) elif isinstance(value, dict):
else: return {str(k): serialize_value(v) for k, v in value.items()}
return str(value) elif isinstance(value, (datetime, UUID)):
except Exception: return str(value)
else:
return str(value) return str(value)
@ -47,16 +49,26 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple: def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple:
class_name = self.__class__.__name__ class_name = self.__class__.__name__
method_name = method.__name__ method_name = method.__name__
span_type = ( span_type = (
"async_generator" if is_async_gen else "async" if is_async else "sync" "async_generator" if is_async_gen else "async" if is_async else "sync"
) )
sig = inspect.signature(method)
param_names = list(sig.parameters.keys())[1:] # Skip 'self'
combined_args = {}
for i, arg in enumerate(args):
param_name = (
param_names[i] if i < len(param_names) else f"position_{i+1}"
)
combined_args[param_name] = serialize_value(arg)
for k, v in kwargs.items():
combined_args[str(k)] = serialize_value(v)
span_attributes = { span_attributes = {
"__autotraced__": True, "__autotraced__": True,
"__class__": class_name, "__class__": class_name,
"__method__": method_name, "__method__": method_name,
"__type__": span_type, "__type__": span_type,
"__args__": serialize_value(args), "__args__": str(combined_args),
} }
return class_name, method_name, span_attributes return class_name, method_name, span_attributes
@ -69,7 +81,9 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
self, *args, **kwargs self, *args, **kwargs
) )
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span: with tracing.SpanContextManager(
f"{class_name}.{method_name}", span_attributes
) as span:
try: try:
count = 0 count = 0
async for item in method(self, *args, **kwargs): async for item in method(self, *args, **kwargs):
@ -84,7 +98,9 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
self, *args, **kwargs self, *args, **kwargs
) )
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span: with tracing.SpanContextManager(
f"{class_name}.{method_name}", span_attributes
) as span:
try: try:
result = await method(self, *args, **kwargs) result = await method(self, *args, **kwargs)
span.set_attribute("output", serialize_value(result)) span.set_attribute("output", serialize_value(result))
@ -99,7 +115,9 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
self, *args, **kwargs self, *args, **kwargs
) )
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span: with tracing.SpanContextManager(
f"{class_name}.{method_name}", span_attributes
) as span:
try: try:
result = method(self, *args, **kwargs) result = method(self, *args, **kwargs)
span.set_attribute("output", serialize_value(result)) span.set_attribute("output", serialize_value(result))

View file

@ -259,10 +259,6 @@ class SpanContextManager:
return wrapper return wrapper
def span(name: str, attributes: Dict[str, Any] = None):
return SpanContextManager(name, attributes)
def get_current_span() -> Optional[Span]: def get_current_span() -> Optional[Span]:
global CURRENT_TRACE_CONTEXT global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT context = CURRENT_TRACE_CONTEXT