mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-13 05:17:26 +00:00
refactor
This commit is contained in:
parent
99bc54b033
commit
2c06704d63
1 changed files with 12 additions and 16 deletions
|
@ -157,6 +157,15 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
async def create_session(self, name: str) -> str:
|
async def create_session(self, name: str) -> str:
|
||||||
return await self.storage.create_session(name)
|
return await self.storage.create_session(name)
|
||||||
|
|
||||||
|
async def get_messages_from_turns(self, turns: List[Turn]) -> List[Message]:
|
||||||
|
messages = []
|
||||||
|
if self.agent_config.instructions != "":
|
||||||
|
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||||
|
|
||||||
|
for i, turn in enumerate(turns):
|
||||||
|
messages.extend(self.turn_to_messages(turn))
|
||||||
|
return messages
|
||||||
|
|
||||||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
||||||
with tracing.span("create_and_execute_turn") as span:
|
with tracing.span("create_and_execute_turn") as span:
|
||||||
span.set_attribute("session_id", request.session_id)
|
span.set_attribute("session_id", request.session_id)
|
||||||
|
@ -169,14 +178,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
raise ValueError(f"Session {request.session_id} not found")
|
raise ValueError(f"Session {request.session_id} not found")
|
||||||
|
|
||||||
turns = await self.storage.get_session_turns(request.session_id)
|
turns = await self.storage.get_session_turns(request.session_id)
|
||||||
|
messages = await self.get_messages_from_turns(turns)
|
||||||
messages = []
|
|
||||||
if self.agent_config.instructions != "":
|
|
||||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
|
||||||
|
|
||||||
for i, turn in enumerate(turns):
|
|
||||||
messages.extend(self.turn_to_messages(turn))
|
|
||||||
|
|
||||||
messages.extend(request.messages)
|
messages.extend(request.messages)
|
||||||
|
|
||||||
turn_id = str(uuid.uuid4())
|
turn_id = str(uuid.uuid4())
|
||||||
|
@ -260,15 +262,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
raise ValueError(f"Session {request.session_id} not found")
|
raise ValueError(f"Session {request.session_id} not found")
|
||||||
|
|
||||||
turns = await self.storage.get_session_turns(request.session_id)
|
turns = await self.storage.get_session_turns(request.session_id)
|
||||||
|
messages = await self.get_messages_from_turns(turns)
|
||||||
messages = []
|
|
||||||
if self.agent_config.instructions != "":
|
|
||||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
|
||||||
|
|
||||||
for i, turn in enumerate(turns):
|
|
||||||
messages.extend(self.turn_to_messages(turn))
|
|
||||||
|
|
||||||
messages.extend(request.tool_responses)
|
messages.extend(request.tool_responses)
|
||||||
|
|
||||||
last_turn_messages = [
|
last_turn_messages = [
|
||||||
x for x in messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
|
x for x in messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
|
||||||
]
|
]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue