mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
explicit span management using with
This commit is contained in:
parent
6411007024
commit
b3021ea2da
3 changed files with 114 additions and 96 deletions
|
@ -29,6 +29,11 @@ class Span(BaseModel):
|
|||
end_time: Optional[datetime] = None
|
||||
attributes: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
def set_attribute(self, key: str, value: Any):
|
||||
if self.attributes is None:
|
||||
self.attributes = {}
|
||||
self.attributes[key] = value
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Trace(BaseModel):
|
||||
|
|
|
@ -144,87 +144,90 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async def create_session(self, name: str) -> str:
|
||||
return await self.storage.create_session(name)
|
||||
|
||||
@tracing.span("create_and_execute_turn")
|
||||
async def create_and_execute_turn(
|
||||
self, request: AgentTurnCreateRequest
|
||||
) -> AsyncGenerator:
|
||||
assert request.stream is True, "Non-streaming not supported"
|
||||
with tracing.span("create_and_execute_turn") as span:
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
assert request.stream is True, "Non-streaming not supported"
|
||||
|
||||
session_info = await self.storage.get_session_info(request.session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {request.session_id} not found")
|
||||
session_info = await self.storage.get_session_info(request.session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {request.session_id} not found")
|
||||
|
||||
turns = await self.storage.get_session_turns(request.session_id)
|
||||
turns = await self.storage.get_session_turns(request.session_id)
|
||||
|
||||
messages = []
|
||||
if self.agent_config.instructions != "":
|
||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||
messages = []
|
||||
if self.agent_config.instructions != "":
|
||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||
|
||||
for i, turn in enumerate(turns):
|
||||
messages.extend(self.turn_to_messages(turn))
|
||||
for i, turn in enumerate(turns):
|
||||
messages.extend(self.turn_to_messages(turn))
|
||||
|
||||
messages.extend(request.messages)
|
||||
messages.extend(request.messages)
|
||||
|
||||
turn_id = str(uuid.uuid4())
|
||||
start_time = datetime.now()
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnStartPayload(
|
||||
turn_id=turn_id,
|
||||
turn_id = str(uuid.uuid4())
|
||||
start_time = datetime.now()
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnStartPayload(
|
||||
turn_id=turn_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
steps = []
|
||||
output_message = None
|
||||
async for chunk in self.run(
|
||||
session_id=request.session_id,
|
||||
turn_id=turn_id,
|
||||
input_messages=messages,
|
||||
attachments=request.attachments or [],
|
||||
sampling_params=self.agent_config.sampling_params,
|
||||
stream=request.stream,
|
||||
):
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
log.info(
|
||||
f"{chunk.role.capitalize()}: {chunk.content}",
|
||||
)
|
||||
output_message = chunk
|
||||
continue
|
||||
|
||||
assert isinstance(
|
||||
chunk, AgentTurnResponseStreamChunk
|
||||
), f"Unexpected type {type(chunk)}"
|
||||
event = chunk.event
|
||||
if (
|
||||
event.payload.event_type
|
||||
== AgentTurnResponseEventType.step_complete.value
|
||||
steps = []
|
||||
output_message = None
|
||||
async for chunk in self.run(
|
||||
session_id=request.session_id,
|
||||
turn_id=turn_id,
|
||||
input_messages=messages,
|
||||
attachments=request.attachments or [],
|
||||
sampling_params=self.agent_config.sampling_params,
|
||||
stream=request.stream,
|
||||
):
|
||||
steps.append(event.payload.step_details)
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
log.info(
|
||||
f"{chunk.role.capitalize()}: {chunk.content}",
|
||||
)
|
||||
output_message = chunk
|
||||
continue
|
||||
|
||||
yield chunk
|
||||
assert isinstance(
|
||||
chunk, AgentTurnResponseStreamChunk
|
||||
), f"Unexpected type {type(chunk)}"
|
||||
event = chunk.event
|
||||
if (
|
||||
event.payload.event_type
|
||||
== AgentTurnResponseEventType.step_complete.value
|
||||
):
|
||||
steps.append(event.payload.step_details)
|
||||
|
||||
assert output_message is not None
|
||||
yield chunk
|
||||
|
||||
turn = Turn(
|
||||
turn_id=turn_id,
|
||||
session_id=request.session_id,
|
||||
input_messages=request.messages,
|
||||
output_message=output_message,
|
||||
started_at=start_time,
|
||||
completed_at=datetime.now(),
|
||||
steps=steps,
|
||||
)
|
||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||
assert output_message is not None
|
||||
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnCompletePayload(
|
||||
turn=turn,
|
||||
turn = Turn(
|
||||
turn_id=turn_id,
|
||||
session_id=request.session_id,
|
||||
input_messages=request.messages,
|
||||
output_message=output_message,
|
||||
started_at=start_time,
|
||||
completed_at=datetime.now(),
|
||||
steps=steps,
|
||||
)
|
||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnCompletePayload(
|
||||
turn=turn,
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
yield chunk
|
||||
yield chunk
|
||||
|
||||
async def run(
|
||||
self,
|
||||
|
@ -273,7 +276,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
yield final_response
|
||||
|
||||
@tracing.span("run_shields")
|
||||
async def run_multiple_shields_wrapper(
|
||||
self,
|
||||
turn_id: str,
|
||||
|
@ -281,23 +283,45 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
shields: List[str],
|
||||
touchpoint: str,
|
||||
) -> AsyncGenerator:
|
||||
if len(shields) == 0:
|
||||
return
|
||||
with tracing.span("run_shields") as span:
|
||||
span.set_attribute("turn_id", turn_id)
|
||||
span.set_attribute("messages", [m.model_dump_json() for m in messages])
|
||||
if len(shields) == 0:
|
||||
return
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
try:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_id=step_id,
|
||||
metadata=dict(touchpoint=touchpoint),
|
||||
step_id = str(uuid.uuid4())
|
||||
try:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_id=step_id,
|
||||
metadata=dict(touchpoint=touchpoint),
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
await self.run_multiple_shields(messages, shields)
|
||||
await self.run_multiple_shields(messages, shields)
|
||||
|
||||
except SafetyException as e:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_details=ShieldCallStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
violation=e.violation,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
yield CompletionMessage(
|
||||
content=str(e),
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
yield False
|
||||
|
||||
except SafetyException as e:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
|
@ -305,31 +329,12 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
step_details=ShieldCallStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
violation=e.violation,
|
||||
violation=None,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
yield CompletionMessage(
|
||||
content=str(e),
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
yield False
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_details=ShieldCallStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
violation=None,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def _run(
|
||||
self,
|
||||
session_id: str,
|
||||
|
|
|
@ -253,3 +253,11 @@ class SpanContextManager:
|
|||
|
||||
def span(name: str, attributes: Dict[str, Any] = None):
|
||||
return SpanContextManager(name, attributes)
|
||||
|
||||
|
||||
def get_current_span() -> Optional[Span]:
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT
|
||||
if context:
|
||||
return context.get_current_span()
|
||||
return None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue