chore: deprecate ToolResponseMessage in agent.resume API (#1566)

# Summary:
closes #1431 

# Test Plan:
LLAMA_STACK_CONFIG=fireworks pytest -s -v
tests/integration/agents/test_agents.py --safety-shield
meta-llama/Llama-Guard-3-8B --text-model
meta-llama/Llama-3.1-8B-Instruct
This commit is contained in:
ehhuang 2025-03-12 12:10:21 -07:00 committed by GitHub
parent 58d08d100e
commit b7a9c45477
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 17 additions and 41 deletions

View file

@ -9490,21 +9490,11 @@
"type": "object", "type": "object",
"properties": { "properties": {
"tool_responses": { "tool_responses": {
"oneOf": [
{
"type": "array", "type": "array",
"items": { "items": {
"$ref": "#/components/schemas/ToolResponse" "$ref": "#/components/schemas/ToolResponse"
}
}, },
{ "description": "The tool call responses to resume the turn with."
"type": "array",
"items": {
"$ref": "#/components/schemas/ToolResponseMessage"
}
}
],
"description": "The tool call responses to resume the turn with. NOTE: ToolResponseMessage will be deprecated. Use ToolResponse."
}, },
"stream": { "stream": {
"type": "boolean", "type": "boolean",

View file

@ -6405,16 +6405,11 @@ components:
type: object type: object
properties: properties:
tool_responses: tool_responses:
oneOf: type: array
- type: array
items: items:
$ref: '#/components/schemas/ToolResponse' $ref: '#/components/schemas/ToolResponse'
- type: array
items:
$ref: '#/components/schemas/ToolResponseMessage'
description: >- description: >-
The tool call responses to resume the turn with. NOTE: ToolResponseMessage The tool call responses to resume the turn with.
will be deprecated. Use ToolResponse.
stream: stream:
type: boolean type: boolean
description: Whether to stream the response. description: Whether to stream the response.

View file

@ -370,7 +370,7 @@ class AgentTurnResumeRequest(BaseModel):
agent_id: str agent_id: str
session_id: str session_id: str
turn_id: str turn_id: str
tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]] tool_responses: List[ToolResponse]
stream: Optional[bool] = False stream: Optional[bool] = False
@ -449,7 +449,7 @@ class Agents(Protocol):
agent_id: str, agent_id: str,
session_id: str, session_id: str,
turn_id: str, turn_id: str,
tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]], tool_responses: List[ToolResponse],
stream: Optional[bool] = False, stream: Optional[bool] = False,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
"""Resume an agent turn with executed tool call responses. """Resume an agent turn with executed tool call responses.
@ -460,7 +460,6 @@ class Agents(Protocol):
:param session_id: The ID of the session to resume. :param session_id: The ID of the session to resume.
:param turn_id: The ID of the turn to resume. :param turn_id: The ID of the turn to resume.
:param tool_responses: The tool call responses to resume the turn with. :param tool_responses: The tool call responses to resume the turn with.
NOTE: ToolResponseMessage will be deprecated. Use ToolResponse.
:param stream: Whether to stream the response. :param stream: Whether to stream the response.
:returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects. :returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.
""" """

View file

@ -218,18 +218,10 @@ class ChatAgent(ShieldRunnerMixin):
steps = [] steps = []
messages = await self.get_messages_from_turns(turns) messages = await self.get_messages_from_turns(turns)
if is_resume: if is_resume:
if isinstance(request.tool_responses[0], ToolResponseMessage):
tool_response_messages = request.tool_responses
tool_responses = [
ToolResponse(call_id=x.call_id, tool_name=x.tool_name, content=x.content)
for x in request.tool_responses
]
else:
tool_response_messages = [ tool_response_messages = [
ToolResponseMessage(call_id=x.call_id, tool_name=x.tool_name, content=x.content) ToolResponseMessage(call_id=x.call_id, tool_name=x.tool_name, content=x.content)
for x in request.tool_responses for x in request.tool_responses
] ]
tool_responses = request.tool_responses
messages.extend(tool_response_messages) messages.extend(tool_response_messages)
last_turn = turns[-1] last_turn = turns[-1]
last_turn_messages = self.turn_to_messages(last_turn) last_turn_messages = self.turn_to_messages(last_turn)
@ -252,7 +244,7 @@ class ChatAgent(ShieldRunnerMixin):
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())), step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
turn_id=request.turn_id, turn_id=request.turn_id,
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []), tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
tool_responses=tool_responses, tool_responses=request.tool_responses,
completed_at=now, completed_at=now,
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now), started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
) )

View file

@ -172,7 +172,7 @@ class MetaReferenceAgentsImpl(Agents):
agent_id: str, agent_id: str,
session_id: str, session_id: str,
turn_id: str, turn_id: str,
tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]], tool_responses: List[ToolResponse],
stream: Optional[bool] = False, stream: Optional[bool] = False,
) -> AsyncGenerator: ) -> AsyncGenerator:
request = AgentTurnResumeRequest( request = AgentTurnResumeRequest(