mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-11 19:56:03 +00:00
fix: resolve mypy errors in meta_reference agents
- Fix type annotations for Safety API message parameters - Add Action enum usage in access control checks - Correct method signatures to match API supertype - Handle optional return types with proper None checks - Remove meta_reference exclusions from mypy config
This commit is contained in:
parent
b90c6a2c8b
commit
f0d365622a
5 changed files with 48 additions and 26 deletions
|
|
@ -284,12 +284,6 @@ exclude = [
|
|||
"^src/llama_stack/models/llama/llama3/interface\\.py$",
|
||||
"^src/llama_stack/models/llama/llama3/tokenizer\\.py$",
|
||||
"^src/llama_stack/models/llama/llama3/tool_utils\\.py$",
|
||||
"^src/llama_stack/providers/inline/agents/meta_reference/agents\\.py$",
|
||||
"^src/llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$",
|
||||
"^src/llama_stack/providers/inline/agents/meta_reference/config\\.py$",
|
||||
"^src/llama_stack/providers/inline/agents/meta_reference/persistence\\.py$",
|
||||
"^src/llama_stack/providers/inline/agents/meta_reference/safety\\.py$",
|
||||
"^src/llama_stack/providers/inline/agents/meta_reference/__init__\\.py$",
|
||||
"^src/llama_stack/providers/inline/datasetio/localfs/",
|
||||
"^src/llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
|
||||
"^src/llama_stack/providers/inline/inference/meta_reference/inference\\.py$",
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import uuid
|
|||
import warnings
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
|
||||
|
|
@ -363,7 +363,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
if self.input_shields:
|
||||
async for res in self.run_multiple_shields_wrapper(
|
||||
turn_id, input_messages, self.input_shields, "user-input"
|
||||
turn_id, cast(list[OpenAIMessageParam], input_messages), self.input_shields, "user-input"
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
|
|
@ -392,7 +392,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
if self.output_shields:
|
||||
async for res in self.run_multiple_shields_wrapper(
|
||||
turn_id, messages, self.output_shields, "assistant-output"
|
||||
turn_id, cast(list[OpenAIMessageParam], messages), self.output_shields, "assistant-output"
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
|
|
@ -404,7 +404,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async def run_multiple_shields_wrapper(
|
||||
self,
|
||||
turn_id: str,
|
||||
messages: list[Message],
|
||||
messages: list[OpenAIMessageParam],
|
||||
shields: list[str],
|
||||
touchpoint: str,
|
||||
) -> AsyncGenerator:
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from llama_stack.apis.agents import (
|
|||
Document,
|
||||
ListOpenAIResponseInputItem,
|
||||
ListOpenAIResponseObject,
|
||||
OpenAIDeleteResponseObject,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseObject,
|
||||
|
|
@ -141,7 +142,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
persistence_store=(
|
||||
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
|
||||
),
|
||||
created_at=agent_info.created_at,
|
||||
created_at=agent_info.created_at.isoformat(),
|
||||
policy=self.policy,
|
||||
telemetry_enabled=self.telemetry_enabled,
|
||||
)
|
||||
|
|
@ -163,9 +164,9 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent_id: str,
|
||||
session_id: str,
|
||||
messages: list[UserMessage | ToolResponseMessage],
|
||||
toolgroups: list[AgentToolGroup] | None = None,
|
||||
documents: list[Document] | None = None,
|
||||
stream: bool | None = False,
|
||||
documents: list[Document] | None = None,
|
||||
toolgroups: list[AgentToolGroup] | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnCreateRequest(
|
||||
|
|
@ -221,6 +222,8 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
turn = await agent.storage.get_session_turn(session_id, turn_id)
|
||||
if turn is None:
|
||||
raise ValueError(f"Turn {turn_id} not found in session {session_id}")
|
||||
return turn
|
||||
|
||||
async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
|
||||
|
|
@ -232,13 +235,15 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
|
||||
async def get_agents_session(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
agent_id: str,
|
||||
turn_ids: list[str] | None = None,
|
||||
) -> Session:
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
|
||||
session_info = await agent.storage.get_session_info(session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
turns = await agent.storage.get_session_turns(session_id)
|
||||
if turn_ids:
|
||||
turns = [turn for turn in turns if turn.turn_id in turn_ids]
|
||||
|
|
@ -249,7 +254,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
started_at=session_info.started_at,
|
||||
)
|
||||
|
||||
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
|
||||
async def delete_agents_session(self, session_id: str, agent_id: str) -> None:
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
|
||||
# Delete turns first, then the session
|
||||
|
|
@ -302,7 +307,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent = Agent(
|
||||
agent_id=agent_id,
|
||||
agent_config=chat_agent.agent_config,
|
||||
created_at=chat_agent.created_at,
|
||||
created_at=datetime.fromisoformat(chat_agent.created_at),
|
||||
)
|
||||
return agent
|
||||
|
||||
|
|
@ -323,6 +328,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self,
|
||||
response_id: str,
|
||||
) -> OpenAIResponseObject:
|
||||
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
|
||||
return await self.openai_responses_impl.get_openai_response(response_id)
|
||||
|
||||
async def create_openai_response(
|
||||
|
|
@ -342,7 +348,12 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
max_infer_iters: int | None = 10,
|
||||
guardrails: list[ResponseGuardrail] | None = None,
|
||||
) -> OpenAIResponseObject:
|
||||
return await self.openai_responses_impl.create_openai_response(
|
||||
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
|
||||
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
|
||||
from typing import cast as typing_cast
|
||||
# Cast guardrails to the more specific type expected by the implementation
|
||||
guardrails_spec = typing_cast(list[ResponseGuardrailSpec] | None, guardrails)
|
||||
result = await self.openai_responses_impl.create_openai_response(
|
||||
input,
|
||||
model,
|
||||
prompt,
|
||||
|
|
@ -356,8 +367,9 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
tools,
|
||||
include,
|
||||
max_infer_iters,
|
||||
guardrails,
|
||||
guardrails_spec,
|
||||
)
|
||||
return typing_cast(OpenAIResponseObject, result)
|
||||
|
||||
async def list_openai_responses(
|
||||
self,
|
||||
|
|
@ -366,6 +378,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
model: str | None = None,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseObject:
|
||||
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
|
||||
return await self.openai_responses_impl.list_openai_responses(after, limit, model, order)
|
||||
|
||||
async def list_openai_response_input_items(
|
||||
|
|
@ -377,9 +390,11 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
limit: int | None = 20,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseInputItem:
|
||||
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
|
||||
return await self.openai_responses_impl.list_openai_response_input_items(
|
||||
response_id, after, before, include, limit, order
|
||||
)
|
||||
|
||||
async def delete_openai_response(self, response_id: str) -> None:
|
||||
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
|
||||
return await self.openai_responses_impl.delete_openai_response(response_id)
|
||||
|
|
|
|||
|
|
@ -7,11 +7,12 @@
|
|||
import json
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from typing import cast
|
||||
|
||||
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
|
||||
from llama_stack.apis.common.errors import SessionNotFoundError
|
||||
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||
from llama_stack.core.access_control.datatypes import AccessRule
|
||||
from llama_stack.core.access_control.datatypes import AccessRule, Action
|
||||
from llama_stack.core.datatypes import User
|
||||
from llama_stack.core.request_headers import get_authenticated_user
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -53,8 +54,12 @@ class AgentPersistence:
|
|||
turns=[],
|
||||
identifier=name, # should this be qualified in any way?
|
||||
)
|
||||
if not is_action_allowed(self.policy, "create", session_info, user):
|
||||
raise AccessDeniedError("create", session_info, user)
|
||||
# Both identifier and owner are set above, safe to use for access control
|
||||
assert session_info.identifier is not None and session_info.owner is not None
|
||||
from llama_stack.core.access_control.conditions import ProtectedResource
|
||||
resource = cast(ProtectedResource, session_info)
|
||||
if not is_action_allowed(self.policy, Action.CREATE, resource, user):
|
||||
raise AccessDeniedError(Action.CREATE, resource, user)
|
||||
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
|
|
@ -62,7 +67,7 @@ class AgentPersistence:
|
|||
)
|
||||
return session_id
|
||||
|
||||
async def get_session_info(self, session_id: str) -> AgentSessionInfo:
|
||||
async def get_session_info(self, session_id: str) -> AgentSessionInfo | None:
|
||||
value = await self.kvstore.get(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
)
|
||||
|
|
@ -83,7 +88,15 @@ class AgentPersistence:
|
|||
if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"):
|
||||
return True
|
||||
|
||||
return is_action_allowed(self.policy, "read", session_info, get_authenticated_user())
|
||||
# Access control requires identifier and owner to be set
|
||||
if session_info.identifier is None or session_info.owner is None:
|
||||
return True
|
||||
|
||||
# At this point, both identifier and owner are guaranteed to be non-None
|
||||
assert session_info.identifier is not None and session_info.owner is not None
|
||||
from llama_stack.core.access_control.conditions import ProtectedResource
|
||||
resource = cast(ProtectedResource, session_info)
|
||||
return is_action_allowed(self.policy, Action.READ, resource, get_authenticated_user())
|
||||
|
||||
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
|
||||
"""Get session info if the user has access to it. For internal use by sub-session methods."""
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import asyncio
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.inference import OpenAIMessageParam
|
||||
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||
from llama_stack.core.telemetry import tracing
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -31,7 +31,7 @@ class ShieldRunnerMixin:
|
|||
self.input_shields = input_shields
|
||||
self.output_shields = output_shields
|
||||
|
||||
async def run_multiple_shields(self, messages: list[Message], identifiers: list[str]) -> None:
|
||||
async def run_multiple_shields(self, messages: list[OpenAIMessageParam], identifiers: list[str]) -> None:
|
||||
async def run_shield_with_span(identifier: str):
|
||||
async with tracing.span(f"run_shield_{identifier}"):
|
||||
return await self.safety_api.run_shield(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue