enhance the tracing span utility to make a context manager

This commit is contained in:
Ashwin Bharambe 2024-09-22 20:55:15 -07:00
parent 484dc2e4f5
commit 5d75c2437b
2 changed files with 96 additions and 77 deletions

View file

@ -26,6 +26,7 @@ from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.telemetry import tracing
from .persistence import AgentPersistence from .persistence import AgentPersistence
from .rag.context_retriever import generate_rag_query from .rag.context_retriever import generate_rag_query
@ -138,6 +139,7 @@ class ChatAgent(ShieldRunnerMixin):
async def create_session(self, name: str) -> str: async def create_session(self, name: str) -> str:
return await self.storage.create_session(name) return await self.storage.create_session(name)
@tracing.span("create_and_execute_turn")
async def create_and_execute_turn( async def create_and_execute_turn(
self, request: AgentTurnCreateRequest self, request: AgentTurnCreateRequest
) -> AsyncGenerator: ) -> AsyncGenerator:
@ -266,6 +268,7 @@ class ChatAgent(ShieldRunnerMixin):
yield final_response yield final_response
@tracing.span("run_shields")
async def run_multiple_shields_wrapper( async def run_multiple_shields_wrapper(
self, self,
turn_id: str, turn_id: str,
@ -348,6 +351,7 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: find older context from the session and either replace it # TODO: find older context from the session and either replace it
# or append with a sliding window. this is really a very simplistic implementation # or append with a sliding window. this is really a very simplistic implementation
with tracing.span("retrieve_rag_context"):
rag_context, bank_ids = await self._retrieve_context( rag_context, bank_ids = await self._retrieve_context(
session_id, input_messages, attachments session_id, input_messages, attachments
) )
@ -403,6 +407,8 @@ class ChatAgent(ShieldRunnerMixin):
tool_calls = [] tool_calls = []
content = "" content = ""
stop_reason = None stop_reason = None
with tracing.span("inference"):
async for chunk in self.inference_api.chat_completion( async for chunk in self.inference_api.chat_completion(
self.agent_config.model, self.agent_config.model,
input_messages, input_messages,
@ -528,6 +534,7 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
with tracing.span("tool_execution"):
result_messages = await execute_tool_call_maybe( result_messages = await execute_tool_call_maybe(
self.tools_dict, self.tools_dict,
[message], [message],
@ -669,6 +676,7 @@ class ChatAgent(ShieldRunnerMixin):
) )
for a in attachments for a in attachments
] ]
with tracing.span("insert_documents"):
await self.memory_api.insert_documents(bank_id, documents) await self.memory_api.insert_documents(bank_id, documents)
else: else:
session_info = await self.storage.get_session_info(session_id) session_info = await self.storage.get_session_info(session_id)

View file

@ -12,7 +12,7 @@ import threading
import uuid import uuid
from datetime import datetime from datetime import datetime
from functools import wraps from functools import wraps
from typing import Any, Dict, List from typing import Any, Callable, Dict, List
from llama_stack.apis.telemetry import * # noqa: F403 from llama_stack.apis.telemetry import * # noqa: F403
@ -196,33 +196,42 @@ class TelemetryHandler(logging.Handler):
pass pass
def span(name: str, attributes: Dict[str, Any] = None): class SpanContextManager:
def decorator(func): def __init__(self, name: str, attributes: Dict[str, Any] = None):
@wraps(func) self.name = name
def sync_wrapper(*args, **kwargs): self.attributes = attributes
try:
global CURRENT_TRACE_CONTEXT
def __enter__(self):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
context.push_span(self.name, self.attributes)
return self
def __exit__(self, exc_type, exc_value, traceback):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT context = CURRENT_TRACE_CONTEXT
if context: if context:
context.push_span(name, attributes)
result = func(*args, **kwargs)
finally:
context.pop_span() context.pop_span()
return result
async def __aenter__(self):
return self.__enter__()
async def __aexit__(self, exc_type, exc_value, traceback):
self.__exit__(exc_type, exc_value, traceback)
def __call__(self, func: Callable):
@wraps(func)
def sync_wrapper(*args, **kwargs):
print("sync wrapper")
with self:
return func(*args, **kwargs)
@wraps(func) @wraps(func)
async def async_wrapper(*args, **kwargs): async def async_wrapper(*args, **kwargs):
try: print("async wrapper")
global CURRENT_TRACE_CONTEXT async with self:
return await func(*args, **kwargs)
context = CURRENT_TRACE_CONTEXT
if context:
context.push_span(name, attributes)
result = await func(*args, **kwargs)
finally:
context.pop_span()
return result
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
@ -233,4 +242,6 @@ def span(name: str, attributes: Dict[str, Any] = None):
return wrapper return wrapper
return decorator
def span(name: str, attributes: Dict[str, Any] = None):
return SpanContextManager(name, attributes)