This commit is contained in:
Xi Yan 2025-02-20 12:57:18 -08:00
parent 6b6feebc72
commit 7dae81cb68
3 changed files with 30 additions and 7 deletions

View file

@ -344,15 +344,15 @@ class Agents(Protocol):
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ... ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
@webmethod( @webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/submit_tool_response_messages", route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/tool_responses",
method="POST", method="POST",
) )
async def submit_tool_response_messages( async def submit_tool_responses(
self, self,
agent_id: str, agent_id: str,
session_id: str, session_id: str,
turn_id: str, turn_id: str,
tool_response_messages: List[ToolResponseMessage], tool_responses: Dict[str, ToolResponseMessage],
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ... ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
@webmethod( @webmethod(

View file

@ -31,6 +31,7 @@ from llama_stack.apis.agents import (
AgentTurnResponseStepStartPayload, AgentTurnResponseStepStartPayload,
AgentTurnResponseStreamChunk, AgentTurnResponseStreamChunk,
AgentTurnResponseTurnCompletePayload, AgentTurnResponseTurnCompletePayload,
AgentTurnResponseTurnPendingPayload,
AgentTurnResponseTurnStartPayload, AgentTurnResponseTurnStartPayload,
Attachment, Attachment,
Document, Document,
@ -62,7 +63,11 @@ from llama_stack.apis.inference import (
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.models.llama.datatypes import BuiltinTool, ToolCall, ToolParamDefinition from llama_stack.models.llama.datatypes import (
BuiltinTool,
ToolCall,
ToolParamDefinition,
)
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
from llama_stack.providers.utils.telemetry import tracing from llama_stack.providers.utils.telemetry import tracing
@ -222,6 +227,15 @@ class ChatAgent(ShieldRunnerMixin):
) )
await self.storage.add_turn_to_session(request.session_id, turn) await self.storage.add_turn_to_session(request.session_id, turn)
if output_message.tool_calls:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnPendingPayload(
turn=turn,
)
)
)
else:
chunk = AgentTurnResponseStreamChunk( chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload( payload=AgentTurnResponseTurnCompletePayload(
@ -229,7 +243,8 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
) )
yield chunk
yield chunk
async def run( async def run(
self, self,

View file

@ -19,8 +19,12 @@ from llama_stack_client.types.shared.completion_message import CompletionMessage
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
from llama_stack_client.types.tool_def_param import Parameter from llama_stack_client.types.tool_def_param import Parameter
from llama_stack.apis.agents.agents import AgentConfig as Server__AgentConfig from llama_stack.apis.agents.agents import (
from llama_stack.apis.agents.agents import ToolChoice AgentConfig as Server__AgentConfig,
)
from llama_stack.apis.agents.agents import (
ToolChoice,
)
class TestClientTool(ClientTool): class TestClientTool(ClientTool):
@ -314,6 +318,10 @@ def test_custom_tool(llama_stack_client, agent_config):
], ],
session_id=session_id, session_id=session_id,
) )
from rich.pretty import pprint
for x in response:
pprint(x)
logs = [str(log) for log in EventLogger().log(response) if log is not None] logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs) logs_str = "".join(logs)