add back span

This commit is contained in:
Dinesh Yeduguru 2024-12-11 08:37:56 -08:00
parent a4a29ea3a3
commit 62bb230ab0
3 changed files with 13 additions and 15 deletions

View file

@ -142,7 +142,7 @@ class ChatAgent(ShieldRunnerMixin):
async def create_and_execute_turn( async def create_and_execute_turn(
self, request: AgentTurnCreateRequest self, request: AgentTurnCreateRequest
) -> AsyncGenerator: ) -> AsyncGenerator:
with tracing.SpanContextManager("create_and_execute_turn") as span: with tracing.span("create_and_execute_turn") as span:
span.set_attribute("session_id", request.session_id) span.set_attribute("session_id", request.session_id)
span.set_attribute("agent_id", self.agent_id) span.set_attribute("agent_id", self.agent_id)
span.set_attribute("request", request.model_dump_json()) span.set_attribute("request", request.model_dump_json())
@ -279,7 +279,7 @@ class ChatAgent(ShieldRunnerMixin):
shields: List[str], shields: List[str],
touchpoint: str, touchpoint: str,
) -> AsyncGenerator: ) -> AsyncGenerator:
with tracing.SpanContextManager("run_shields") as span: with tracing.span("run_shields") as span:
span.set_attribute("input", [m.model_dump_json() for m in messages]) span.set_attribute("input", [m.model_dump_json() for m in messages])
if len(shields) == 0: if len(shields) == 0:
span.set_attribute("output", "no shields") span.set_attribute("output", "no shields")
@ -359,7 +359,7 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: find older context from the session and either replace it # TODO: find older context from the session and either replace it
# or append with a sliding window. this is really a very simplistic implementation # or append with a sliding window. this is really a very simplistic implementation
with tracing.SpanContextManager("retrieve_rag_context") as span: with tracing.span("retrieve_rag_context") as span:
rag_context, bank_ids = await self._retrieve_context( rag_context, bank_ids = await self._retrieve_context(
session_id, input_messages, attachments session_id, input_messages, attachments
) )
@ -419,7 +419,7 @@ class ChatAgent(ShieldRunnerMixin):
content = "" content = ""
stop_reason = None stop_reason = None
with tracing.SpanContextManager("inference") as span: with tracing.span("inference") as span:
async for chunk in await self.inference_api.chat_completion( async for chunk in await self.inference_api.chat_completion(
self.agent_config.model, self.agent_config.model,
input_messages, input_messages,
@ -558,7 +558,7 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
with tracing.SpanContextManager( with tracing.span(
"tool_execution", "tool_execution",
{ {
"tool_name": tool_call.tool_name, "tool_name": tool_call.tool_name,
@ -707,7 +707,7 @@ class ChatAgent(ShieldRunnerMixin):
) )
for a in attachments for a in attachments
] ]
with tracing.SpanContextManager("insert_documents"): with tracing.span("insert_documents"):
await self.memory_api.insert_documents(bank_id, documents) await self.memory_api.insert_documents(bank_id, documents)
else: else:
session_info = await self.storage.get_session_info(session_id) session_info = await self.storage.get_session_info(session_id)

View file

@ -81,9 +81,7 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
self, *args, **kwargs self, *args, **kwargs
) )
with tracing.SpanContextManager( with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
f"{class_name}.{method_name}", span_attributes
) as span:
try: try:
count = 0 count = 0
async for item in method(self, *args, **kwargs): async for item in method(self, *args, **kwargs):
@ -98,9 +96,7 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
self, *args, **kwargs self, *args, **kwargs
) )
with tracing.SpanContextManager( with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
f"{class_name}.{method_name}", span_attributes
) as span:
try: try:
result = await method(self, *args, **kwargs) result = await method(self, *args, **kwargs)
span.set_attribute("output", serialize_value(result)) span.set_attribute("output", serialize_value(result))
@ -115,9 +111,7 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
self, *args, **kwargs self, *args, **kwargs
) )
with tracing.SpanContextManager( with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
f"{class_name}.{method_name}", span_attributes
) as span:
try: try:
result = method(self, *args, **kwargs) result = method(self, *args, **kwargs)
span.set_attribute("output", serialize_value(result)) span.set_attribute("output", serialize_value(result))

View file

@ -259,6 +259,10 @@ class SpanContextManager:
return wrapper return wrapper
def span(name: str, attributes: Dict[str, Any] = None):
return SpanContextManager(name, attributes)
def get_current_span() -> Optional[Span]: def get_current_span() -> Optional[Span]:
global CURRENT_TRACE_CONTEXT global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT context = CURRENT_TRACE_CONTEXT