mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 01:03:59 +00:00
explicit span management using with
This commit is contained in:
parent
6411007024
commit
b3021ea2da
3 changed files with 114 additions and 96 deletions
|
@ -29,6 +29,11 @@ class Span(BaseModel):
|
||||||
end_time: Optional[datetime] = None
|
end_time: Optional[datetime] = None
|
||||||
attributes: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
attributes: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
def set_attribute(self, key: str, value: Any):
|
||||||
|
if self.attributes is None:
|
||||||
|
self.attributes = {}
|
||||||
|
self.attributes[key] = value
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Trace(BaseModel):
|
class Trace(BaseModel):
|
||||||
|
|
|
@ -144,87 +144,90 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
async def create_session(self, name: str) -> str:
|
async def create_session(self, name: str) -> str:
|
||||||
return await self.storage.create_session(name)
|
return await self.storage.create_session(name)
|
||||||
|
|
||||||
@tracing.span("create_and_execute_turn")
|
|
||||||
async def create_and_execute_turn(
|
async def create_and_execute_turn(
|
||||||
self, request: AgentTurnCreateRequest
|
self, request: AgentTurnCreateRequest
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
assert request.stream is True, "Non-streaming not supported"
|
with tracing.span("create_and_execute_turn") as span:
|
||||||
|
span.set_attribute("session_id", request.session_id)
|
||||||
|
span.set_attribute("agent_id", self.agent_id)
|
||||||
|
span.set_attribute("request", request.model_dump_json())
|
||||||
|
assert request.stream is True, "Non-streaming not supported"
|
||||||
|
|
||||||
session_info = await self.storage.get_session_info(request.session_id)
|
session_info = await self.storage.get_session_info(request.session_id)
|
||||||
if session_info is None:
|
if session_info is None:
|
||||||
raise ValueError(f"Session {request.session_id} not found")
|
raise ValueError(f"Session {request.session_id} not found")
|
||||||
|
|
||||||
turns = await self.storage.get_session_turns(request.session_id)
|
turns = await self.storage.get_session_turns(request.session_id)
|
||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
if self.agent_config.instructions != "":
|
if self.agent_config.instructions != "":
|
||||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||||
|
|
||||||
for i, turn in enumerate(turns):
|
for i, turn in enumerate(turns):
|
||||||
messages.extend(self.turn_to_messages(turn))
|
messages.extend(self.turn_to_messages(turn))
|
||||||
|
|
||||||
messages.extend(request.messages)
|
messages.extend(request.messages)
|
||||||
|
|
||||||
turn_id = str(uuid.uuid4())
|
turn_id = str(uuid.uuid4())
|
||||||
start_time = datetime.now()
|
start_time = datetime.now()
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseTurnStartPayload(
|
payload=AgentTurnResponseTurnStartPayload(
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
steps = []
|
steps = []
|
||||||
output_message = None
|
output_message = None
|
||||||
async for chunk in self.run(
|
async for chunk in self.run(
|
||||||
session_id=request.session_id,
|
session_id=request.session_id,
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
input_messages=messages,
|
input_messages=messages,
|
||||||
attachments=request.attachments or [],
|
attachments=request.attachments or [],
|
||||||
sampling_params=self.agent_config.sampling_params,
|
sampling_params=self.agent_config.sampling_params,
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
):
|
|
||||||
if isinstance(chunk, CompletionMessage):
|
|
||||||
log.info(
|
|
||||||
f"{chunk.role.capitalize()}: {chunk.content}",
|
|
||||||
)
|
|
||||||
output_message = chunk
|
|
||||||
continue
|
|
||||||
|
|
||||||
assert isinstance(
|
|
||||||
chunk, AgentTurnResponseStreamChunk
|
|
||||||
), f"Unexpected type {type(chunk)}"
|
|
||||||
event = chunk.event
|
|
||||||
if (
|
|
||||||
event.payload.event_type
|
|
||||||
== AgentTurnResponseEventType.step_complete.value
|
|
||||||
):
|
):
|
||||||
steps.append(event.payload.step_details)
|
if isinstance(chunk, CompletionMessage):
|
||||||
|
log.info(
|
||||||
|
f"{chunk.role.capitalize()}: {chunk.content}",
|
||||||
|
)
|
||||||
|
output_message = chunk
|
||||||
|
continue
|
||||||
|
|
||||||
yield chunk
|
assert isinstance(
|
||||||
|
chunk, AgentTurnResponseStreamChunk
|
||||||
|
), f"Unexpected type {type(chunk)}"
|
||||||
|
event = chunk.event
|
||||||
|
if (
|
||||||
|
event.payload.event_type
|
||||||
|
== AgentTurnResponseEventType.step_complete.value
|
||||||
|
):
|
||||||
|
steps.append(event.payload.step_details)
|
||||||
|
|
||||||
assert output_message is not None
|
yield chunk
|
||||||
|
|
||||||
turn = Turn(
|
assert output_message is not None
|
||||||
turn_id=turn_id,
|
|
||||||
session_id=request.session_id,
|
|
||||||
input_messages=request.messages,
|
|
||||||
output_message=output_message,
|
|
||||||
started_at=start_time,
|
|
||||||
completed_at=datetime.now(),
|
|
||||||
steps=steps,
|
|
||||||
)
|
|
||||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
|
||||||
|
|
||||||
chunk = AgentTurnResponseStreamChunk(
|
turn = Turn(
|
||||||
event=AgentTurnResponseEvent(
|
turn_id=turn_id,
|
||||||
payload=AgentTurnResponseTurnCompletePayload(
|
session_id=request.session_id,
|
||||||
turn=turn,
|
input_messages=request.messages,
|
||||||
|
output_message=output_message,
|
||||||
|
started_at=start_time,
|
||||||
|
completed_at=datetime.now(),
|
||||||
|
steps=steps,
|
||||||
|
)
|
||||||
|
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||||
|
|
||||||
|
chunk = AgentTurnResponseStreamChunk(
|
||||||
|
event=AgentTurnResponseEvent(
|
||||||
|
payload=AgentTurnResponseTurnCompletePayload(
|
||||||
|
turn=turn,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
yield chunk
|
||||||
yield chunk
|
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
|
@ -273,7 +276,6 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
yield final_response
|
yield final_response
|
||||||
|
|
||||||
@tracing.span("run_shields")
|
|
||||||
async def run_multiple_shields_wrapper(
|
async def run_multiple_shields_wrapper(
|
||||||
self,
|
self,
|
||||||
turn_id: str,
|
turn_id: str,
|
||||||
|
@ -281,23 +283,45 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
shields: List[str],
|
shields: List[str],
|
||||||
touchpoint: str,
|
touchpoint: str,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
if len(shields) == 0:
|
with tracing.span("run_shields") as span:
|
||||||
return
|
span.set_attribute("turn_id", turn_id)
|
||||||
|
span.set_attribute("messages", [m.model_dump_json() for m in messages])
|
||||||
|
if len(shields) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
try:
|
try:
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepStartPayload(
|
payload=AgentTurnResponseStepStartPayload(
|
||||||
step_type=StepType.shield_call.value,
|
step_type=StepType.shield_call.value,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
metadata=dict(touchpoint=touchpoint),
|
metadata=dict(touchpoint=touchpoint),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
await self.run_multiple_shields(messages, shields)
|
||||||
await self.run_multiple_shields(messages, shields)
|
|
||||||
|
except SafetyException as e:
|
||||||
|
yield AgentTurnResponseStreamChunk(
|
||||||
|
event=AgentTurnResponseEvent(
|
||||||
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
|
step_type=StepType.shield_call.value,
|
||||||
|
step_details=ShieldCallStep(
|
||||||
|
step_id=step_id,
|
||||||
|
turn_id=turn_id,
|
||||||
|
violation=e.violation,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
yield CompletionMessage(
|
||||||
|
content=str(e),
|
||||||
|
stop_reason=StopReason.end_of_turn,
|
||||||
|
)
|
||||||
|
yield False
|
||||||
|
|
||||||
except SafetyException as e:
|
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
|
@ -305,31 +329,12 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
step_details=ShieldCallStep(
|
step_details=ShieldCallStep(
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
violation=e.violation,
|
violation=None,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield CompletionMessage(
|
|
||||||
content=str(e),
|
|
||||||
stop_reason=StopReason.end_of_turn,
|
|
||||||
)
|
|
||||||
yield False
|
|
||||||
|
|
||||||
yield AgentTurnResponseStreamChunk(
|
|
||||||
event=AgentTurnResponseEvent(
|
|
||||||
payload=AgentTurnResponseStepCompletePayload(
|
|
||||||
step_type=StepType.shield_call.value,
|
|
||||||
step_details=ShieldCallStep(
|
|
||||||
step_id=step_id,
|
|
||||||
turn_id=turn_id,
|
|
||||||
violation=None,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _run(
|
async def _run(
|
||||||
self,
|
self,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
|
|
|
@ -253,3 +253,11 @@ class SpanContextManager:
|
||||||
|
|
||||||
def span(name: str, attributes: Dict[str, Any] = None):
|
def span(name: str, attributes: Dict[str, Any] = None):
|
||||||
return SpanContextManager(name, attributes)
|
return SpanContextManager(name, attributes)
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_span() -> Optional[Span]:
|
||||||
|
global CURRENT_TRACE_CONTEXT
|
||||||
|
context = CURRENT_TRACE_CONTEXT
|
||||||
|
if context:
|
||||||
|
return context.get_current_span()
|
||||||
|
return None
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue