add back span

This commit is contained in:
Dinesh Yeduguru 2024-12-11 08:37:56 -08:00
parent a4a29ea3a3
commit 62bb230ab0
3 changed files with 13 additions and 15 deletions

View file

@ -142,7 +142,7 @@ class ChatAgent(ShieldRunnerMixin):
async def create_and_execute_turn(
self, request: AgentTurnCreateRequest
) -> AsyncGenerator:
with tracing.SpanContextManager("create_and_execute_turn") as span:
with tracing.span("create_and_execute_turn") as span:
span.set_attribute("session_id", request.session_id)
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("request", request.model_dump_json())
@ -279,7 +279,7 @@ class ChatAgent(ShieldRunnerMixin):
shields: List[str],
touchpoint: str,
) -> AsyncGenerator:
with tracing.SpanContextManager("run_shields") as span:
with tracing.span("run_shields") as span:
span.set_attribute("input", [m.model_dump_json() for m in messages])
if len(shields) == 0:
span.set_attribute("output", "no shields")
@ -359,7 +359,7 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: find older context from the session and either replace it
# or append with a sliding window. this is really a very simplistic implementation
with tracing.SpanContextManager("retrieve_rag_context") as span:
with tracing.span("retrieve_rag_context") as span:
rag_context, bank_ids = await self._retrieve_context(
session_id, input_messages, attachments
)
@ -419,7 +419,7 @@ class ChatAgent(ShieldRunnerMixin):
content = ""
stop_reason = None
with tracing.SpanContextManager("inference") as span:
with tracing.span("inference") as span:
async for chunk in await self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
@ -558,7 +558,7 @@ class ChatAgent(ShieldRunnerMixin):
)
)
with tracing.SpanContextManager(
with tracing.span(
"tool_execution",
{
"tool_name": tool_call.tool_name,
@ -707,7 +707,7 @@ class ChatAgent(ShieldRunnerMixin):
)
for a in attachments
]
with tracing.SpanContextManager("insert_documents"):
with tracing.span("insert_documents"):
await self.memory_api.insert_documents(bank_id, documents)
else:
session_info = await self.storage.get_session_info(session_id)

View file

@ -81,9 +81,7 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
self, *args, **kwargs
)
with tracing.SpanContextManager(
f"{class_name}.{method_name}", span_attributes
) as span:
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
try:
count = 0
async for item in method(self, *args, **kwargs):
@ -98,9 +96,7 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
self, *args, **kwargs
)
with tracing.SpanContextManager(
f"{class_name}.{method_name}", span_attributes
) as span:
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
try:
result = await method(self, *args, **kwargs)
span.set_attribute("output", serialize_value(result))
@ -115,9 +111,7 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
self, *args, **kwargs
)
with tracing.SpanContextManager(
f"{class_name}.{method_name}", span_attributes
) as span:
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
try:
result = method(self, *args, **kwargs)
span.set_attribute("output", serialize_value(result))

View file

@ -259,6 +259,10 @@ class SpanContextManager:
return wrapper
def span(name: str, attributes: Dict[str, Any] = None):
return SpanContextManager(name, attributes)
def get_current_span() -> Optional[Span]:
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT