explicit span management using with

This commit is contained in:
Dinesh Yeduguru 2024-11-25 13:07:43 -08:00
parent 6411007024
commit b3021ea2da
3 changed files with 114 additions and 96 deletions

View file

@ -29,6 +29,11 @@ class Span(BaseModel):
end_time: Optional[datetime] = None
attributes: Optional[Dict[str, Any]] = Field(default_factory=dict)
def set_attribute(self, key: str, value: Any):
if self.attributes is None:
self.attributes = {}
self.attributes[key] = value
@json_schema_type
class Trace(BaseModel):

View file

@ -144,10 +144,13 @@ class ChatAgent(ShieldRunnerMixin):
async def create_session(self, name: str) -> str:
return await self.storage.create_session(name)
@tracing.span("create_and_execute_turn")
async def create_and_execute_turn(
self, request: AgentTurnCreateRequest
) -> AsyncGenerator:
with tracing.span("create_and_execute_turn") as span:
span.set_attribute("session_id", request.session_id)
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("request", request.model_dump_json())
assert request.stream is True, "Non-streaming not supported"
session_info = await self.storage.get_session_info(request.session_id)
@ -273,7 +276,6 @@ class ChatAgent(ShieldRunnerMixin):
yield final_response
@tracing.span("run_shields")
async def run_multiple_shields_wrapper(
self,
turn_id: str,
@ -281,6 +283,9 @@ class ChatAgent(ShieldRunnerMixin):
shields: List[str],
touchpoint: str,
) -> AsyncGenerator:
with tracing.span("run_shields") as span:
span.set_attribute("turn_id", turn_id)
span.set_attribute("messages", [m.model_dump_json() for m in messages])
if len(shields) == 0:
return

View file

@ -253,3 +253,11 @@ class SpanContextManager:
def span(name: str, attributes: Dict[str, Any] = None):
return SpanContextManager(name, attributes)
def get_current_span() -> Optional[Span]:
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
return context.get_current_span()
return None