enhance the tracing span utility to make a context manager

This commit is contained in:
Ashwin Bharambe 2024-09-22 20:55:15 -07:00
parent 484dc2e4f5
commit 5d75c2437b
2 changed files with 96 additions and 77 deletions

View file

@ -26,6 +26,7 @@ from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.telemetry import tracing
from .persistence import AgentPersistence
from .rag.context_retriever import generate_rag_query
@ -138,6 +139,7 @@ class ChatAgent(ShieldRunnerMixin):
async def create_session(self, name: str) -> str:
return await self.storage.create_session(name)
@tracing.span("create_and_execute_turn")
async def create_and_execute_turn(
self, request: AgentTurnCreateRequest
) -> AsyncGenerator:
@ -266,6 +268,7 @@ class ChatAgent(ShieldRunnerMixin):
yield final_response
@tracing.span("run_shields")
async def run_multiple_shields_wrapper(
self,
turn_id: str,
@ -348,6 +351,7 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: find older context from the session and either replace it
# or append with a sliding window. this is really a very simplistic implementation
with tracing.span("retrieve_rag_context"):
rag_context, bank_ids = await self._retrieve_context(
session_id, input_messages, attachments
)
@ -403,6 +407,8 @@ class ChatAgent(ShieldRunnerMixin):
tool_calls = []
content = ""
stop_reason = None
with tracing.span("inference"):
async for chunk in self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
@ -528,6 +534,7 @@ class ChatAgent(ShieldRunnerMixin):
)
)
with tracing.span("tool_execution"):
result_messages = await execute_tool_call_maybe(
self.tools_dict,
[message],
@ -669,6 +676,7 @@ class ChatAgent(ShieldRunnerMixin):
)
for a in attachments
]
with tracing.span("insert_documents"):
await self.memory_api.insert_documents(bank_id, documents)
else:
session_info = await self.storage.get_session_info(session_id)

View file

@ -12,7 +12,7 @@ import threading
import uuid
from datetime import datetime
from functools import wraps
from typing import Any, Dict, List
from typing import Any, Callable, Dict, List
from llama_stack.apis.telemetry import * # noqa: F403
@ -196,33 +196,42 @@ class TelemetryHandler(logging.Handler):
pass
def span(name: str, attributes: Dict[str, Any] = None):
def decorator(func):
@wraps(func)
def sync_wrapper(*args, **kwargs):
try:
global CURRENT_TRACE_CONTEXT
class SpanContextManager:
def __init__(self, name: str, attributes: Dict[str, Any] = None):
self.name = name
self.attributes = attributes
def __enter__(self):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
context.push_span(self.name, self.attributes)
return self
def __exit__(self, exc_type, exc_value, traceback):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
context.push_span(name, attributes)
result = func(*args, **kwargs)
finally:
context.pop_span()
return result
async def __aenter__(self):
return self.__enter__()
async def __aexit__(self, exc_type, exc_value, traceback):
self.__exit__(exc_type, exc_value, traceback)
def __call__(self, func: Callable):
@wraps(func)
def sync_wrapper(*args, **kwargs):
print("sync wrapper")
with self:
return func(*args, **kwargs)
@wraps(func)
async def async_wrapper(*args, **kwargs):
try:
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
context.push_span(name, attributes)
result = await func(*args, **kwargs)
finally:
context.pop_span()
return result
print("async wrapper")
async with self:
return await func(*args, **kwargs)
@wraps(func)
def wrapper(*args, **kwargs):
@ -233,4 +242,6 @@ def span(name: str, attributes: Dict[str, Any] = None):
return wrapper
return decorator
def span(name: str, attributes: Dict[str, Any] = None):
return SpanContextManager(name, attributes)