mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-11 19:56:03 +00:00
fix(mypy): resolve agent_instance type issues (part 1 of 2)
Fixed 35 errors, 46 remaining: - Add isinstance() checks for union type discrimination - Fix list type annotations for Message types - Convert strings to datetime/StepType where needed - Use assert to narrow AgentTurnCreateRequest vs AgentTurnResumeRequest - Add explicit type annotations to avoid inference issues Still to fix: - Remaining str to datetime/StepType conversions - Optional list handling for shields - Type annotations for tool maps - List variance issues for input_messages - Fix turn_id variable redefinition
This commit is contained in:
parent
3a437d80af
commit
3cf36e665b
1 changed files with 32 additions and 16 deletions
|
|
@ -11,6 +11,7 @@ import uuid
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
|
@ -125,12 +126,12 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
def turn_to_messages(self, turn: Turn) -> list[Message]:
|
def turn_to_messages(self, turn: Turn) -> list[Message]:
|
||||||
messages = []
|
messages: list[Message] = []
|
||||||
|
|
||||||
# NOTE: if a toolcall response is in a step, we do not add it when processing the input messages
|
# NOTE: if a toolcall response is in a step, we do not add it when processing the input messages
|
||||||
tool_call_ids = set()
|
tool_call_ids = set()
|
||||||
for step in turn.steps:
|
for step in turn.steps:
|
||||||
if step.step_type == StepType.tool_execution.value:
|
if step.step_type == StepType.tool_execution.value and isinstance(step, ToolExecutionStep):
|
||||||
for response in step.tool_responses:
|
for response in step.tool_responses:
|
||||||
tool_call_ids.add(response.call_id)
|
tool_call_ids.add(response.call_id)
|
||||||
|
|
||||||
|
|
@ -149,9 +150,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
messages.append(msg)
|
messages.append(msg)
|
||||||
|
|
||||||
for step in turn.steps:
|
for step in turn.steps:
|
||||||
if step.step_type == StepType.inference.value:
|
if step.step_type == StepType.inference.value and isinstance(step, InferenceStep):
|
||||||
messages.append(step.model_response)
|
messages.append(step.model_response)
|
||||||
elif step.step_type == StepType.tool_execution.value:
|
elif step.step_type == StepType.tool_execution.value and isinstance(step, ToolExecutionStep):
|
||||||
for response in step.tool_responses:
|
for response in step.tool_responses:
|
||||||
messages.append(
|
messages.append(
|
||||||
ToolResponseMessage(
|
ToolResponseMessage(
|
||||||
|
|
@ -159,7 +160,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
content=response.content,
|
content=response.content,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif step.step_type == StepType.shield_call.value:
|
elif step.step_type == StepType.shield_call.value and isinstance(step, ShieldCallStep):
|
||||||
if step.violation:
|
if step.violation:
|
||||||
# CompletionMessage itself in the ShieldResponse
|
# CompletionMessage itself in the ShieldResponse
|
||||||
messages.append(
|
messages.append(
|
||||||
|
|
@ -174,7 +175,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
return await self.storage.create_session(name)
|
return await self.storage.create_session(name)
|
||||||
|
|
||||||
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
|
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
|
||||||
messages = []
|
messages: list[Message] = []
|
||||||
if self.agent_config.instructions != "":
|
if self.agent_config.instructions != "":
|
||||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||||
|
|
||||||
|
|
@ -231,7 +232,13 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
steps = []
|
steps = []
|
||||||
messages = await self.get_messages_from_turns(turns)
|
messages = await self.get_messages_from_turns(turns)
|
||||||
|
|
||||||
|
turn_id: str
|
||||||
|
start_time: datetime
|
||||||
|
input_messages: list[Message]
|
||||||
|
|
||||||
if is_resume:
|
if is_resume:
|
||||||
|
assert isinstance(request, AgentTurnResumeRequest), "Expected AgentTurnResumeRequest for resume"
|
||||||
tool_response_messages = [
|
tool_response_messages = [
|
||||||
ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses
|
ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses
|
||||||
]
|
]
|
||||||
|
|
@ -252,20 +259,21 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
|
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
|
||||||
request.session_id, request.turn_id
|
request.session_id, request.turn_id
|
||||||
)
|
)
|
||||||
now = datetime.now(UTC).isoformat()
|
now_iso = datetime.now(UTC).isoformat()
|
||||||
|
now_dt = datetime.now(UTC)
|
||||||
tool_execution_step = ToolExecutionStep(
|
tool_execution_step = ToolExecutionStep(
|
||||||
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
||||||
turn_id=request.turn_id,
|
turn_id=request.turn_id,
|
||||||
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
|
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
|
||||||
tool_responses=request.tool_responses,
|
tool_responses=request.tool_responses,
|
||||||
completed_at=now,
|
completed_at=now_dt,
|
||||||
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
|
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now_dt),
|
||||||
)
|
)
|
||||||
steps.append(tool_execution_step)
|
steps.append(tool_execution_step)
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
step_type=StepType.tool_execution.value,
|
step_type=StepType.tool_execution,
|
||||||
step_id=tool_execution_step.step_id,
|
step_id=tool_execution_step.step_id,
|
||||||
step_details=tool_execution_step,
|
step_details=tool_execution_step,
|
||||||
)
|
)
|
||||||
|
|
@ -276,18 +284,22 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
turn_id = request.turn_id
|
turn_id = request.turn_id
|
||||||
start_time = last_turn.started_at
|
start_time = last_turn.started_at
|
||||||
else:
|
else:
|
||||||
|
assert isinstance(request, AgentTurnCreateRequest), "Expected AgentTurnCreateRequest for create"
|
||||||
messages.extend(request.messages)
|
messages.extend(request.messages)
|
||||||
start_time = datetime.now(UTC).isoformat()
|
start_time = datetime.now(UTC)
|
||||||
input_messages = request.messages
|
input_messages = request.messages
|
||||||
|
|
||||||
output_message = None
|
output_message = None
|
||||||
|
req_documents = request.documents if isinstance(request, AgentTurnCreateRequest) and not is_resume else None
|
||||||
|
req_sampling = self.agent_config.sampling_params if self.agent_config.sampling_params is not None else SamplingParams()
|
||||||
|
|
||||||
async for chunk in self.run(
|
async for chunk in self.run(
|
||||||
session_id=request.session_id,
|
session_id=request.session_id,
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
input_messages=messages,
|
input_messages=messages,
|
||||||
sampling_params=self.agent_config.sampling_params,
|
sampling_params=req_sampling,
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
documents=request.documents if not is_resume else None,
|
documents=req_documents,
|
||||||
):
|
):
|
||||||
if isinstance(chunk, CompletionMessage):
|
if isinstance(chunk, CompletionMessage):
|
||||||
output_message = chunk
|
output_message = chunk
|
||||||
|
|
@ -295,8 +307,12 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
||||||
event = chunk.event
|
event = chunk.event
|
||||||
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
|
if (
|
||||||
steps.append(event.payload.step_details)
|
event.payload.event_type == AgentTurnResponseEventType.step_complete.value
|
||||||
|
and hasattr(event.payload, "step_details")
|
||||||
|
):
|
||||||
|
step_details = getattr(event.payload, "step_details")
|
||||||
|
steps.append(step_details)
|
||||||
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
@ -308,7 +324,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
input_messages=input_messages,
|
input_messages=input_messages,
|
||||||
output_message=output_message,
|
output_message=output_message,
|
||||||
started_at=start_time,
|
started_at=start_time,
|
||||||
completed_at=datetime.now(UTC).isoformat(),
|
completed_at=datetime.now(UTC),
|
||||||
steps=steps,
|
steps=steps,
|
||||||
)
|
)
|
||||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue