From 0fe071764f8902aa41487e77df5faf292d47ba5f Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Fri, 21 Feb 2025 11:48:27 -0800 Subject: [PATCH] feat(1/n): api: unify agents for handling server & client tools (#1178) # Problem Our current Agent framework has discrepancies in definition on how we handle server side and client side tools. 1. Server Tools: a single Turn is returned including `ToolExecutionStep` in agenst 2. Client Tools: `create_agent_turn` is called in loop with client agent lib yielding the agent chunk https://github.com/meta-llama/llama-stack-client-python/blob/ad6ffc63df658674f275267b1befc2b7046dbf33/src/llama_stack_client/lib/agents/agent.py#L186-L211 This makes it inconsistent to work with server & client tools. It also complicates the logs to telemetry to get information about agents turn / history for observability. #### Principle The same `turn_id` should be used to represent the steps required to complete a user message including client tools. ## Solution 1. `AgentTurnResponseEventType.turn_awaiting_input` status to indicate that the current turn is not completed, and awaiting tool input 2. `continue_agent_turn` endpoint to update agent turn with client's tool response. # What does this PR do? - Skeleton API as example ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] - Just API update, no functionality change ``` llama stack run + client-sdk test ``` image [//]: # (## Documentation) --- docs/_static/llama-stack-spec.html | 113 ++++++++++- docs/_static/llama-stack-spec.yaml | 82 ++++++++ llama_stack/apis/agents/agents.py | 48 +++++ .../agents/meta_reference/agent_instance.py | 179 ++++++++++++++++-- .../inline/agents/meta_reference/agents.py | 31 +++ .../agents/meta_reference/persistence.py | 14 +- tests/client-sdk/agents/test_agents.py | 8 +- 7 files changed, 454 insertions(+), 21 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index fab7c802e..ce08e041f 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -2315,6 +2315,70 @@ } } }, + "/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume": { + "post": { + "responses": { + "200": { + "description": "A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Turn" + } + }, + "text/event-stream": { + "schema": { + "$ref": "#/components/schemas/AgentTurnResponseStreamChunk" + } + } + } + } + }, + "tags": [ + "Agents" + ], + "description": "Resume an agent turn with executed tool call responses.\nWhen a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.", + "parameters": [ + { + "name": "agent_id", + "in": "path", + "description": "The ID of the agent to resume.", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "session_id", + "in": "path", + "description": "The ID of the session to resume.", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "turn_id", + "in": "path", + "description": "The ID of the turn to resume.", + "required": true, + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ResumeAgentTurnRequest" + } + } + }, + "required": true + } + } + }, "/v1/eval/benchmarks/{benchmark_id}/jobs": { "post": { "responses": { @@ -4226,6 +4290,9 @@ }, "tool_config": { "$ref": "#/components/schemas/ToolConfig" + }, + "allow_turn_resume": { + "type": "boolean" } }, "additionalProperties": false, @@ -4612,6 +4679,9 @@ }, { "$ref": "#/components/schemas/AgentTurnResponseTurnCompletePayload" + }, + { + "$ref": "#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload" } ], "discriminator": { @@ -4621,7 +4691,8 @@ "step_progress": "#/components/schemas/AgentTurnResponseStepProgressPayload", "step_complete": "#/components/schemas/AgentTurnResponseStepCompletePayload", "turn_start": "#/components/schemas/AgentTurnResponseTurnStartPayload", - "turn_complete": "#/components/schemas/AgentTurnResponseTurnCompletePayload" + "turn_complete": "#/components/schemas/AgentTurnResponseTurnCompletePayload", + "turn_awaiting_input": "#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload" } } }, @@ -4784,6 +4855,25 @@ "title": "AgentTurnResponseStreamChunk", "description": "streamed agent turn completion response." }, + "AgentTurnResponseTurnAwaitingInputPayload": { + "type": "object", + "properties": { + "event_type": { + "type": "string", + "const": "turn_awaiting_input", + "default": "turn_awaiting_input" + }, + "turn": { + "$ref": "#/components/schemas/Turn" + } + }, + "additionalProperties": false, + "required": [ + "event_type", + "turn" + ], + "title": "AgentTurnResponseTurnAwaitingInputPayload" + }, "AgentTurnResponseTurnCompletePayload": { "type": "object", "properties": { @@ -8046,6 +8136,27 @@ ], "title": "RegisterVectorDbRequest" }, + "ResumeAgentTurnRequest": { + "type": "object", + "properties": { + "tool_responses": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ToolResponseMessage" + }, + "description": "The tool call responses to resume the turn with." + }, + "stream": { + "type": "boolean", + "description": "Whether to stream the response." + } + }, + "additionalProperties": false, + "required": [ + "tool_responses" + ], + "title": "ResumeAgentTurnRequest" + }, "RunEvalRequest": { "type": "object", "properties": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index fc57bf258..0e4955a5c 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -1401,6 +1401,53 @@ paths: schema: $ref: '#/components/schemas/QueryTracesRequest' required: true + /v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume: + post: + responses: + '200': + description: >- + A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk + objects. + content: + application/json: + schema: + $ref: '#/components/schemas/Turn' + text/event-stream: + schema: + $ref: '#/components/schemas/AgentTurnResponseStreamChunk' + tags: + - Agents + description: >- + Resume an agent turn with executed tool call responses. + + When a Turn has the status `awaiting_input` due to pending input from client + side tool calls, this endpoint can be used to submit the outputs from the + tool calls once they are ready. + parameters: + - name: agent_id + in: path + description: The ID of the agent to resume. + required: true + schema: + type: string + - name: session_id + in: path + description: The ID of the session to resume. + required: true + schema: + type: string + - name: turn_id + in: path + description: The ID of the turn to resume. + required: true + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/ResumeAgentTurnRequest' + required: true /v1/eval/benchmarks/{benchmark_id}/jobs: post: responses: @@ -2740,6 +2787,8 @@ components: $ref: '#/components/schemas/AgentTool' tool_config: $ref: '#/components/schemas/ToolConfig' + allow_turn_resume: + type: boolean additionalProperties: false required: - messages @@ -2992,6 +3041,7 @@ components: - $ref: '#/components/schemas/AgentTurnResponseStepCompletePayload' - $ref: '#/components/schemas/AgentTurnResponseTurnStartPayload' - $ref: '#/components/schemas/AgentTurnResponseTurnCompletePayload' + - $ref: '#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload' discriminator: propertyName: event_type mapping: @@ -3000,6 +3050,7 @@ components: step_complete: '#/components/schemas/AgentTurnResponseStepCompletePayload' turn_start: '#/components/schemas/AgentTurnResponseTurnStartPayload' turn_complete: '#/components/schemas/AgentTurnResponseTurnCompletePayload' + turn_awaiting_input: '#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload' AgentTurnResponseStepCompletePayload: type: object properties: @@ -3106,6 +3157,21 @@ components: - event title: AgentTurnResponseStreamChunk description: streamed agent turn completion response. + "AgentTurnResponseTurnAwaitingInputPayload": + type: object + properties: + event_type: + type: string + const: turn_awaiting_input + default: turn_awaiting_input + turn: + $ref: '#/components/schemas/Turn' + additionalProperties: false + required: + - event_type + - turn + title: >- + AgentTurnResponseTurnAwaitingInputPayload AgentTurnResponseTurnCompletePayload: type: object properties: @@ -5205,6 +5271,22 @@ components: - vector_db_id - embedding_model title: RegisterVectorDbRequest + ResumeAgentTurnRequest: + type: object + properties: + tool_responses: + type: array + items: + $ref: '#/components/schemas/ToolResponseMessage' + description: >- + The tool call responses to resume the turn with. + stream: + type: boolean + description: Whether to stream the response. + additionalProperties: false + required: + - tool_responses + title: ResumeAgentTurnRequest RunEvalRequest: type: object properties: diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 367648ded..c904fdbef 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -194,6 +194,7 @@ class AgentTurnResponseEventType(Enum): turn_start = "turn_start" turn_complete = "turn_complete" + turn_awaiting_input = "turn_awaiting_input" @json_schema_type @@ -235,6 +236,14 @@ class AgentTurnResponseTurnCompletePayload(BaseModel): turn: Turn +@json_schema_type +class AgentTurnResponseTurnAwaitingInputPayload(BaseModel): + event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input.value] = ( + AgentTurnResponseEventType.turn_awaiting_input.value + ) + turn: Turn + + AgentTurnResponseEventPayload = register_schema( Annotated[ Union[ @@ -243,6 +252,7 @@ AgentTurnResponseEventPayload = register_schema( AgentTurnResponseStepCompletePayload, AgentTurnResponseTurnStartPayload, AgentTurnResponseTurnCompletePayload, + AgentTurnResponseTurnAwaitingInputPayload, ], Field(discriminator="event_type"), ], @@ -286,6 +296,18 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn): stream: Optional[bool] = False tool_config: Optional[ToolConfig] = None + # TODO (xiyan): temporary flag, will remove for 0.1.5 + allow_turn_resume: Optional[bool] = False + + +@json_schema_type +class AgentTurnResumeRequest(BaseModel): + agent_id: str + session_id: str + turn_id: str + tool_responses: List[ToolResponseMessage] + stream: Optional[bool] = False + @json_schema_type class AgentTurnResponseStreamChunk(BaseModel): @@ -333,8 +355,34 @@ class Agents(Protocol): documents: Optional[List[Document]] = None, toolgroups: Optional[List[AgentToolGroup]] = None, tool_config: Optional[ToolConfig] = None, + allow_turn_resume: Optional[bool] = False, ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ... + @webmethod( + route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume", + method="POST", + ) + async def resume_agent_turn( + self, + agent_id: str, + session_id: str, + turn_id: str, + tool_responses: List[ToolResponseMessage], + stream: Optional[bool] = False, + ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: + """Resume an agent turn with executed tool call responses. + + When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready. + + :param agent_id: The ID of the agent to resume. + :param session_id: The ID of the session to resume. + :param turn_id: The ID of the turn to resume. + :param tool_responses: The tool call responses to resume the turn with. + :param stream: Whether to stream the response. + :returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects. + """ + ... + @webmethod( route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}", method="GET", diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 1c21df57f..edd253356 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -30,8 +30,10 @@ from llama_stack.apis.agents import ( AgentTurnResponseStepProgressPayload, AgentTurnResponseStepStartPayload, AgentTurnResponseStreamChunk, + AgentTurnResponseTurnAwaitingInputPayload, AgentTurnResponseTurnCompletePayload, AgentTurnResponseTurnStartPayload, + AgentTurnResumeRequest, Attachment, Document, InferenceStep, @@ -62,7 +64,11 @@ from llama_stack.apis.inference import ( from llama_stack.apis.safety import Safety from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO -from llama_stack.models.llama.datatypes import BuiltinTool, ToolCall, ToolParamDefinition +from llama_stack.models.llama.datatypes import ( + BuiltinTool, + ToolCall, + ToolParamDefinition, +) from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content from llama_stack.providers.utils.telemetry import tracing @@ -151,6 +157,15 @@ class ChatAgent(ShieldRunnerMixin): async def create_session(self, name: str) -> str: return await self.storage.create_session(name) + async def get_messages_from_turns(self, turns: List[Turn]) -> List[Message]: + messages = [] + if self.agent_config.instructions != "": + messages.append(SystemMessage(content=self.agent_config.instructions)) + + for turn in turns: + messages.extend(self.turn_to_messages(turn)) + return messages + async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator: with tracing.span("create_and_execute_turn") as span: span.set_attribute("session_id", request.session_id) @@ -163,14 +178,7 @@ class ChatAgent(ShieldRunnerMixin): raise ValueError(f"Session {request.session_id} not found") turns = await self.storage.get_session_turns(request.session_id) - - messages = [] - if self.agent_config.instructions != "": - messages.append(SystemMessage(content=self.agent_config.instructions)) - - for i, turn in enumerate(turns): - messages.extend(self.turn_to_messages(turn)) - + messages = await self.get_messages_from_turns(turns) messages.extend(request.messages) turn_id = str(uuid.uuid4()) @@ -222,13 +230,136 @@ class ChatAgent(ShieldRunnerMixin): ) await self.storage.add_turn_to_session(request.session_id, turn) - chunk = AgentTurnResponseStreamChunk( + if output_message.tool_calls and request.allow_turn_resume: + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnAwaitingInputPayload( + turn=turn, + ) + ) + ) + else: + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnCompletePayload( + turn=turn, + ) + ) + ) + + yield chunk + + async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator: + with tracing.span("resume_turn") as span: + span.set_attribute("agent_id", self.agent_id) + span.set_attribute("session_id", request.session_id) + span.set_attribute("turn_id", request.turn_id) + span.set_attribute("request", request.model_dump_json()) + assert request.stream is True, "Non-streaming not supported" + + session_info = await self.storage.get_session_info(request.session_id) + if session_info is None: + raise ValueError(f"Session {request.session_id} not found") + + turns = await self.storage.get_session_turns(request.session_id) + messages = await self.get_messages_from_turns(turns) + messages.extend(request.tool_responses) + + last_turn_messages = [ + x for x in messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage) + ] + + # get the steps from the turn id + steps = [] + if len(turns) > 0: + steps = turns[-1].steps + + # mark tool execution step as complete + # if there's no tool execution in progress step (due to storage, or tool call parsing on client), + # we'll create a new tool execution step with current time + in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step( + request.session_id, request.turn_id + ) + now = datetime.now() + tool_execution_step = ToolExecutionStep( + step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())), + turn_id=request.turn_id, + tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []), + tool_responses=[ + ToolResponse( + call_id=x.call_id, + tool_name=x.tool_name, + content=x.content, + ) + for x in request.tool_responses + ], + completed_at=now, + started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now), + ) + steps.append(tool_execution_step) + yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( - payload=AgentTurnResponseTurnCompletePayload( - turn=turn, + payload=AgentTurnResponseStepCompletePayload( + step_type=StepType.tool_execution.value, + step_id=tool_execution_step.step_id, + step_details=tool_execution_step, ) ) ) + + output_message = None + async for chunk in self.run( + session_id=request.session_id, + turn_id=request.turn_id, + input_messages=messages, + sampling_params=self.agent_config.sampling_params, + stream=request.stream, + ): + if isinstance(chunk, CompletionMessage): + output_message = chunk + continue + + assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}" + event = chunk.event + if event.payload.event_type == AgentTurnResponseEventType.step_complete.value: + steps.append(event.payload.step_details) + + yield chunk + + assert output_message is not None + + last_turn_start_time = datetime.now() + if len(turns) > 0: + last_turn_start_time = turns[-1].started_at + + turn = Turn( + turn_id=request.turn_id, + session_id=request.session_id, + input_messages=last_turn_messages, + output_message=output_message, + started_at=last_turn_start_time, + completed_at=datetime.now(), + steps=steps, + ) + await self.storage.add_turn_to_session(request.session_id, turn) + + if output_message.tool_calls: + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnAwaitingInputPayload( + turn=turn, + ) + ) + ) + else: + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnCompletePayload( + turn=turn, + ) + ) + ) + yield chunk async def run( @@ -611,11 +742,7 @@ class ChatAgent(ShieldRunnerMixin): input_messages = input_messages + [message] else: log.info(f"{str(message)}") - tool_call = message.tool_calls[0] - if tool_call.tool_name in client_tools: - yield message - return - + # 1. Start the tool execution step and progress step_id = str(uuid.uuid4()) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( @@ -625,6 +752,7 @@ class ChatAgent(ShieldRunnerMixin): ) ) ) + tool_call = message.tool_calls[0] yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( @@ -639,6 +767,23 @@ class ChatAgent(ShieldRunnerMixin): ) ) + # If tool is a client tool, yield CompletionMessage and return + if tool_call.tool_name in client_tools: + await self.storage.set_in_progress_tool_call_step( + session_id, + turn_id, + ToolExecutionStep( + step_id=step_id, + turn_id=turn_id, + tool_calls=[tool_call], + tool_responses=[], + started_at=datetime.now(), + ), + ) + yield message + return + + # If tool is a builtin server tool, execute it tool_name = tool_call.tool_name if isinstance(tool_name, BuiltinTool): tool_name = tool_name.value diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index e3c18d112..8a4d91238 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -21,6 +21,7 @@ from llama_stack.apis.agents import ( AgentStepResponse, AgentToolGroup, AgentTurnCreateRequest, + AgentTurnResumeRequest, Document, Session, Turn, @@ -146,6 +147,7 @@ class MetaReferenceAgentsImpl(Agents): documents: Optional[List[Document]] = None, stream: Optional[bool] = False, tool_config: Optional[ToolConfig] = None, + allow_turn_resume: Optional[bool] = False, ) -> AsyncGenerator: request = AgentTurnCreateRequest( agent_id=agent_id, @@ -155,6 +157,7 @@ class MetaReferenceAgentsImpl(Agents): toolgroups=toolgroups, documents=documents, tool_config=tool_config, + allow_turn_resume=allow_turn_resume, ) if stream: return self._create_agent_turn_streaming(request) @@ -169,6 +172,34 @@ class MetaReferenceAgentsImpl(Agents): async for event in agent.create_and_execute_turn(request): yield event + async def resume_agent_turn( + self, + agent_id: str, + session_id: str, + turn_id: str, + tool_responses: List[ToolResponseMessage], + stream: Optional[bool] = False, + ) -> AsyncGenerator: + request = AgentTurnResumeRequest( + agent_id=agent_id, + session_id=session_id, + turn_id=turn_id, + tool_responses=tool_responses, + stream=stream, + ) + if stream: + return self._continue_agent_turn_streaming(request) + else: + raise NotImplementedError("Non-streaming agent turns not yet implemented") + + async def _continue_agent_turn_streaming( + self, + request: AgentTurnResumeRequest, + ) -> AsyncGenerator: + agent = await self.get_agent(request.agent_id) + async for event in agent.resume_turn(request): + yield event + async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn: turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}") turn = json.loads(turn) diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index 4b8ad6d4a..3c3866873 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -12,7 +12,7 @@ from typing import List, Optional from pydantic import BaseModel -from llama_stack.apis.agents import Turn +from llama_stack.apis.agents import ToolExecutionStep, Turn from llama_stack.providers.utils.kvstore import KVStore log = logging.getLogger(__name__) @@ -84,3 +84,15 @@ class AgentPersistence: continue turns.sort(key=lambda x: (x.completed_at or datetime.min)) return turns + + async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep): + await self.kvstore.set( + key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}", + value=step.model_dump_json(), + ) + + async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]: + value = await self.kvstore.get( + key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}", + ) + return ToolExecutionStep(**json.loads(value)) if value else None diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index e5380d357..781095d2b 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -19,8 +19,12 @@ from llama_stack_client.types.shared.completion_message import CompletionMessage from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig from llama_stack_client.types.tool_def_param import Parameter -from llama_stack.apis.agents.agents import AgentConfig as Server__AgentConfig -from llama_stack.apis.agents.agents import ToolChoice +from llama_stack.apis.agents.agents import ( + AgentConfig as Server__AgentConfig, +) +from llama_stack.apis.agents.agents import ( + ToolChoice, +) class TestClientTool(ClientTool):