add tracing back to the lib cli

This commit is contained in:
Dinesh Yeduguru 2024-12-10 11:31:41 -08:00
parent e2054d53e4
commit 84904914c2
5 changed files with 106 additions and 61 deletions

View file

@ -142,7 +142,7 @@ class ChatAgent(ShieldRunnerMixin):
async def create_and_execute_turn(
self, request: AgentTurnCreateRequest
) -> AsyncGenerator:
with tracing.span("create_and_execute_turn") as span:
with tracing.SpanContextManager("create_and_execute_turn") as span:
span.set_attribute("session_id", request.session_id)
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("request", request.model_dump_json())
@ -185,9 +185,9 @@ class ChatAgent(ShieldRunnerMixin):
stream=request.stream,
):
if isinstance(chunk, CompletionMessage):
# log.info(
# f"{chunk.role.capitalize()}: {chunk.content}",
# )
log.info(
f"{chunk.role.capitalize()}: {chunk.content}",
)
output_message = chunk
continue
@ -279,8 +279,7 @@ class ChatAgent(ShieldRunnerMixin):
shields: List[str],
touchpoint: str,
) -> AsyncGenerator:
with tracing.span("run_shields") as span:
span.set_attribute("turn_id", turn_id)
with tracing.SpanContextManager("run_shields") as span:
span.set_attribute("input", [m.model_dump_json() for m in messages])
if len(shields) == 0:
span.set_attribute("output", "no shields")
@ -360,7 +359,7 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: find older context from the session and either replace it
# or append with a sliding window. this is really a very simplistic implementation
with tracing.span("retrieve_rag_context") as span:
with tracing.SpanContextManager("retrieve_rag_context") as span:
rag_context, bank_ids = await self._retrieve_context(
session_id, input_messages, attachments
)
@ -405,11 +404,11 @@ class ChatAgent(ShieldRunnerMixin):
n_iter = 0
while True:
msg = input_messages[-1]
# if len(str(msg)) > 1000:
# msg_str = f"{str(msg)[:500]}...<more>...{str(msg)[-500:]}"
# else:
# msg_str = str(msg)
# log.info(f"{msg_str}")
if len(str(msg)) > 1000:
msg_str = f"{str(msg)[:500]}...<more>...{str(msg)[-500:]}"
else:
msg_str = str(msg)
log.info(f"{msg_str}")
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
@ -425,7 +424,7 @@ class ChatAgent(ShieldRunnerMixin):
content = ""
stop_reason = None
with tracing.span("inference") as span:
with tracing.SpanContextManager("inference") as span:
async for chunk in await self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
@ -514,12 +513,12 @@ class ChatAgent(ShieldRunnerMixin):
)
if n_iter >= self.agent_config.max_infer_iters:
# log.info("Done with MAX iterations, exiting.")
log.info("Done with MAX iterations, exiting.")
yield message
break
if stop_reason == StopReason.out_of_tokens:
# log.info("Out of token budget, exiting.")
log.info("Out of token budget, exiting.")
yield message
break
@ -533,10 +532,10 @@ class ChatAgent(ShieldRunnerMixin):
message.content = [message.content] + attachments
yield message
else:
# log.info(f"Partial message: {str(message)}")
log.info(f"Partial message: {str(message)}")
input_messages = input_messages + [message]
else:
# log.info(f"{str(message)}")
log.info(f"{str(message)}")
try:
tool_call = message.tool_calls[0]
@ -564,7 +563,7 @@ class ChatAgent(ShieldRunnerMixin):
)
)
with tracing.span(
with tracing.SpanContextManager(
"tool_execution",
{
"tool_name": tool_call.tool_name,
@ -713,7 +712,7 @@ class ChatAgent(ShieldRunnerMixin):
)
for a in attachments
]
with tracing.span("insert_documents"):
with tracing.SpanContextManager("insert_documents"):
await self.memory_api.insert_documents(bank_id, documents)
else:
session_info = await self.storage.get_session_info(session_id)
@ -800,7 +799,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
path = urlparse(uri).path
basename = os.path.basename(path)
filepath = f"{tempdir}/{make_random_string() + basename}"
# log.info(f"Downloading {url} -> {filepath}")
log.info(f"Downloading {url} -> {filepath}")
async with httpx.AsyncClient() as client:
r = await client.get(uri)

View file

@ -20,6 +20,7 @@ class SQLiteSpanProcessor(SpanProcessor):
"""Initialize the SQLite span processor with a connection string."""
self.conn_string = conn_string
self.ttl_days = ttl_days
self._shutdown_event = threading.Event()
self.cleanup_task = None
self._thread_local = threading.local()
self._connections: Dict[int, sqlite3.Connection] = {}
@ -144,9 +145,10 @@ class SQLiteSpanProcessor(SpanProcessor):
"""Run cleanup periodically."""
import time
while True:
while not self._shutdown_event.is_set():
time.sleep(3600) # Sleep for 1 hour
self._cleanup_old_data()
if not self._shutdown_event.is_set():
self._cleanup_old_data()
def on_start(self, span: Span, parent_context=None):
"""Called when a span starts."""
@ -231,11 +233,23 @@ class SQLiteSpanProcessor(SpanProcessor):
def shutdown(self):
"""Cleanup any resources."""
self._shutdown_event.set()
# Wait for cleanup thread to finish if it exists
if self.cleanup_task and self.cleanup_task.is_alive():
self.cleanup_task.join(timeout=5.0)
current_thread_id = threading.get_ident()
with self._lock:
for conn in self._connections.values():
if conn:
conn.close()
self._connections.clear()
# Close all connections from the current thread
for thread_id, conn in list(self._connections.items()):
if thread_id == current_thread_id:
try:
if conn:
conn.close()
del self._connections[thread_id]
except sqlite3.Error:
pass # Ignore errors during shutdown
def force_flush(self, timeout_millis=30000):
"""Force export of spans."""