wip refactor

This commit is contained in:
Xi Yan 2025-03-04 13:38:04 -08:00
parent e9a37bad63
commit e8f7fa1ce1

View file

@ -12,7 +12,7 @@ import secrets
import string import string
import uuid import uuid
from datetime import datetime from datetime import datetime
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
from urllib.parse import urlparse from urllib.parse import urlparse
import httpx import httpx
@ -31,7 +31,6 @@ from llama_stack.apis.agents import (
AgentTurnResponseStreamChunk, AgentTurnResponseStreamChunk,
AgentTurnResponseTurnAwaitingInputPayload, AgentTurnResponseTurnAwaitingInputPayload,
AgentTurnResponseTurnCompletePayload, AgentTurnResponseTurnCompletePayload,
AgentTurnResponseTurnStartPayload,
AgentTurnResumeRequest, AgentTurnResumeRequest,
Attachment, Attachment,
Document, Document,
@ -184,26 +183,34 @@ class ChatAgent(ShieldRunnerMixin):
span.set_attribute("session_id", request.session_id) span.set_attribute("session_id", request.session_id)
span.set_attribute("agent_id", self.agent_id) span.set_attribute("agent_id", self.agent_id)
span.set_attribute("request", request.model_dump_json()) span.set_attribute("request", request.model_dump_json())
turn_id = str(uuid.uuid4())
span.set_attribute("turn_id", turn_id)
assert request.stream is True, "Non-streaming not supported" assert request.stream is True, "Non-streaming not supported"
async for chunk in self._run_turn(request, turn_id):
yield chunk
async def _run_turn(
self,
request: Union[AgentTurnCreateRequest, AgentTurnResumeRequest],
turn_id: Optional[str] = None,
) -> AsyncGenerator:
is_resume = isinstance(request, AgentTurnResumeRequest)
session_info = await self.storage.get_session_info(request.session_id) session_info = await self.storage.get_session_info(request.session_id)
if session_info is None: if session_info is None:
raise ValueError(f"Session {request.session_id} not found") raise ValueError(f"Session {request.session_id} not found")
turns = await self.storage.get_session_turns(request.session_id) turns = await self.storage.get_session_turns(request.session_id)
messages = await self.get_messages_from_turns(turns) if is_resume and len(turns) == 0:
messages.extend(request.messages) raise ValueError("No turns found for session")
turn_id = str(uuid.uuid4()) messages = await self.get_messages_from_turns(turns)
span.set_attribute("turn_id", turn_id) if is_resume:
messages.extend(request.tool_responses)
turn_id = request.turn_id
start_time = turns[-1].started_at
else:
messages.extend(request.messages)
start_time = datetime.now().astimezone().isoformat() start_time = datetime.now().astimezone().isoformat()
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnStartPayload(
turn_id=turn_id,
)
)
)
steps = [] steps = []
output_message = None output_message = None
@ -213,14 +220,10 @@ class ChatAgent(ShieldRunnerMixin):
input_messages=messages, input_messages=messages,
sampling_params=self.agent_config.sampling_params, sampling_params=self.agent_config.sampling_params,
stream=request.stream, stream=request.stream,
documents=request.documents, documents=request.documents if not is_resume else None,
toolgroups_for_turn=request.toolgroups, toolgroups_for_turn=request.toolgroups if not is_resume else None,
): ):
if isinstance(chunk, CompletionMessage): if isinstance(chunk, CompletionMessage):
logcat.info(
"agents",
f"returning result from the agent turn: {chunk}",
)
output_message = chunk output_message = chunk
continue continue