mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
chore(mypy): part-04 resolve mypy errors in meta_reference agents (#3969)
## Summary Fixes all mypy type errors in `providers/inline/agents/meta_reference/` and removes exclusions from pyproject.toml. ## Changes - Fix type annotations for Safety API message parameters (OpenAIMessageParam) - Add Action enum usage in access control checks - Correct method signatures to match API supertype (parameter ordering) - Handle optional return types with proper None checks - Remove 3 meta_reference exclusions from mypy config **Files fixed:** 25 errors across 3 files (safety.py, persistence.py, agents.py)
This commit is contained in:
parent
e6b27db30a
commit
c9d4b6c54f
7 changed files with 69 additions and 31 deletions
|
|
@ -284,12 +284,6 @@ exclude = [
|
||||||
"^src/llama_stack/models/llama/llama3/interface\\.py$",
|
"^src/llama_stack/models/llama/llama3/interface\\.py$",
|
||||||
"^src/llama_stack/models/llama/llama3/tokenizer\\.py$",
|
"^src/llama_stack/models/llama/llama3/tokenizer\\.py$",
|
||||||
"^src/llama_stack/models/llama/llama3/tool_utils\\.py$",
|
"^src/llama_stack/models/llama/llama3/tool_utils\\.py$",
|
||||||
"^src/llama_stack/providers/inline/agents/meta_reference/agents\\.py$",
|
|
||||||
"^src/llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$",
|
|
||||||
"^src/llama_stack/providers/inline/agents/meta_reference/config\\.py$",
|
|
||||||
"^src/llama_stack/providers/inline/agents/meta_reference/persistence\\.py$",
|
|
||||||
"^src/llama_stack/providers/inline/agents/meta_reference/safety\\.py$",
|
|
||||||
"^src/llama_stack/providers/inline/agents/meta_reference/__init__\\.py$",
|
|
||||||
"^src/llama_stack/providers/inline/datasetio/localfs/",
|
"^src/llama_stack/providers/inline/datasetio/localfs/",
|
||||||
"^src/llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
|
"^src/llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
|
||||||
"^src/llama_stack/providers/inline/inference/meta_reference/inference\\.py$",
|
"^src/llama_stack/providers/inline/inference/meta_reference/inference\\.py$",
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ import uuid
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
|
@ -363,7 +363,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
if self.input_shields:
|
if self.input_shields:
|
||||||
async for res in self.run_multiple_shields_wrapper(
|
async for res in self.run_multiple_shields_wrapper(
|
||||||
turn_id, input_messages, self.input_shields, "user-input"
|
turn_id, cast(list[OpenAIMessageParam], input_messages), self.input_shields, "user-input"
|
||||||
):
|
):
|
||||||
if isinstance(res, bool):
|
if isinstance(res, bool):
|
||||||
return
|
return
|
||||||
|
|
@ -392,7 +392,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
if self.output_shields:
|
if self.output_shields:
|
||||||
async for res in self.run_multiple_shields_wrapper(
|
async for res in self.run_multiple_shields_wrapper(
|
||||||
turn_id, messages, self.output_shields, "assistant-output"
|
turn_id, cast(list[OpenAIMessageParam], messages), self.output_shields, "assistant-output"
|
||||||
):
|
):
|
||||||
if isinstance(res, bool):
|
if isinstance(res, bool):
|
||||||
return
|
return
|
||||||
|
|
@ -404,7 +404,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
async def run_multiple_shields_wrapper(
|
async def run_multiple_shields_wrapper(
|
||||||
self,
|
self,
|
||||||
turn_id: str,
|
turn_id: str,
|
||||||
messages: list[Message],
|
messages: list[OpenAIMessageParam],
|
||||||
shields: list[str],
|
shields: list[str],
|
||||||
touchpoint: str,
|
touchpoint: str,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ from llama_stack.apis.agents import (
|
||||||
Document,
|
Document,
|
||||||
ListOpenAIResponseInputItem,
|
ListOpenAIResponseInputItem,
|
||||||
ListOpenAIResponseObject,
|
ListOpenAIResponseObject,
|
||||||
|
OpenAIDeleteResponseObject,
|
||||||
OpenAIResponseInput,
|
OpenAIResponseInput,
|
||||||
OpenAIResponseInputTool,
|
OpenAIResponseInputTool,
|
||||||
OpenAIResponseObject,
|
OpenAIResponseObject,
|
||||||
|
|
@ -141,7 +142,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
persistence_store=(
|
persistence_store=(
|
||||||
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
|
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
|
||||||
),
|
),
|
||||||
created_at=agent_info.created_at,
|
created_at=agent_info.created_at.isoformat(),
|
||||||
policy=self.policy,
|
policy=self.policy,
|
||||||
telemetry_enabled=self.telemetry_enabled,
|
telemetry_enabled=self.telemetry_enabled,
|
||||||
)
|
)
|
||||||
|
|
@ -163,9 +164,9 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
messages: list[UserMessage | ToolResponseMessage],
|
messages: list[UserMessage | ToolResponseMessage],
|
||||||
toolgroups: list[AgentToolGroup] | None = None,
|
|
||||||
documents: list[Document] | None = None,
|
|
||||||
stream: bool | None = False,
|
stream: bool | None = False,
|
||||||
|
documents: list[Document] | None = None,
|
||||||
|
toolgroups: list[AgentToolGroup] | None = None,
|
||||||
tool_config: ToolConfig | None = None,
|
tool_config: ToolConfig | None = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
request = AgentTurnCreateRequest(
|
request = AgentTurnCreateRequest(
|
||||||
|
|
@ -221,6 +222,8 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
||||||
agent = await self._get_agent_impl(agent_id)
|
agent = await self._get_agent_impl(agent_id)
|
||||||
turn = await agent.storage.get_session_turn(session_id, turn_id)
|
turn = await agent.storage.get_session_turn(session_id, turn_id)
|
||||||
|
if turn is None:
|
||||||
|
raise ValueError(f"Turn {turn_id} not found in session {session_id}")
|
||||||
return turn
|
return turn
|
||||||
|
|
||||||
async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
|
async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
|
||||||
|
|
@ -232,13 +235,15 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
|
|
||||||
async def get_agents_session(
|
async def get_agents_session(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
|
||||||
session_id: str,
|
session_id: str,
|
||||||
|
agent_id: str,
|
||||||
turn_ids: list[str] | None = None,
|
turn_ids: list[str] | None = None,
|
||||||
) -> Session:
|
) -> Session:
|
||||||
agent = await self._get_agent_impl(agent_id)
|
agent = await self._get_agent_impl(agent_id)
|
||||||
|
|
||||||
session_info = await agent.storage.get_session_info(session_id)
|
session_info = await agent.storage.get_session_info(session_id)
|
||||||
|
if session_info is None:
|
||||||
|
raise ValueError(f"Session {session_id} not found")
|
||||||
turns = await agent.storage.get_session_turns(session_id)
|
turns = await agent.storage.get_session_turns(session_id)
|
||||||
if turn_ids:
|
if turn_ids:
|
||||||
turns = [turn for turn in turns if turn.turn_id in turn_ids]
|
turns = [turn for turn in turns if turn.turn_id in turn_ids]
|
||||||
|
|
@ -249,7 +254,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
started_at=session_info.started_at,
|
started_at=session_info.started_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
|
async def delete_agents_session(self, session_id: str, agent_id: str) -> None:
|
||||||
agent = await self._get_agent_impl(agent_id)
|
agent = await self._get_agent_impl(agent_id)
|
||||||
|
|
||||||
# Delete turns first, then the session
|
# Delete turns first, then the session
|
||||||
|
|
@ -302,7 +307,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
agent = Agent(
|
agent = Agent(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
agent_config=chat_agent.agent_config,
|
agent_config=chat_agent.agent_config,
|
||||||
created_at=chat_agent.created_at,
|
created_at=datetime.fromisoformat(chat_agent.created_at),
|
||||||
)
|
)
|
||||||
return agent
|
return agent
|
||||||
|
|
||||||
|
|
@ -323,6 +328,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
self,
|
self,
|
||||||
response_id: str,
|
response_id: str,
|
||||||
) -> OpenAIResponseObject:
|
) -> OpenAIResponseObject:
|
||||||
|
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
|
||||||
return await self.openai_responses_impl.get_openai_response(response_id)
|
return await self.openai_responses_impl.get_openai_response(response_id)
|
||||||
|
|
||||||
async def create_openai_response(
|
async def create_openai_response(
|
||||||
|
|
@ -342,7 +348,8 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
max_infer_iters: int | None = 10,
|
max_infer_iters: int | None = 10,
|
||||||
guardrails: list[ResponseGuardrail] | None = None,
|
guardrails: list[ResponseGuardrail] | None = None,
|
||||||
) -> OpenAIResponseObject:
|
) -> OpenAIResponseObject:
|
||||||
return await self.openai_responses_impl.create_openai_response(
|
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
|
||||||
|
result = await self.openai_responses_impl.create_openai_response(
|
||||||
input,
|
input,
|
||||||
model,
|
model,
|
||||||
prompt,
|
prompt,
|
||||||
|
|
@ -358,6 +365,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
max_infer_iters,
|
max_infer_iters,
|
||||||
guardrails,
|
guardrails,
|
||||||
)
|
)
|
||||||
|
return result # type: ignore[no-any-return]
|
||||||
|
|
||||||
async def list_openai_responses(
|
async def list_openai_responses(
|
||||||
self,
|
self,
|
||||||
|
|
@ -366,6 +374,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
order: Order | None = Order.desc,
|
order: Order | None = Order.desc,
|
||||||
) -> ListOpenAIResponseObject:
|
) -> ListOpenAIResponseObject:
|
||||||
|
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
|
||||||
return await self.openai_responses_impl.list_openai_responses(after, limit, model, order)
|
return await self.openai_responses_impl.list_openai_responses(after, limit, model, order)
|
||||||
|
|
||||||
async def list_openai_response_input_items(
|
async def list_openai_response_input_items(
|
||||||
|
|
@ -377,9 +386,11 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
limit: int | None = 20,
|
limit: int | None = 20,
|
||||||
order: Order | None = Order.desc,
|
order: Order | None = Order.desc,
|
||||||
) -> ListOpenAIResponseInputItem:
|
) -> ListOpenAIResponseInputItem:
|
||||||
|
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
|
||||||
return await self.openai_responses_impl.list_openai_response_input_items(
|
return await self.openai_responses_impl.list_openai_response_input_items(
|
||||||
response_id, after, before, include, limit, order
|
response_id, after, before, include, limit, order
|
||||||
)
|
)
|
||||||
|
|
||||||
async def delete_openai_response(self, response_id: str) -> None:
|
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||||
|
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
|
||||||
return await self.openai_responses_impl.delete_openai_response(response_id)
|
return await self.openai_responses_impl.delete_openai_response(response_id)
|
||||||
|
|
|
||||||
|
|
@ -6,12 +6,14 @@
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
|
from dataclasses import dataclass
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
|
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
|
||||||
from llama_stack.apis.common.errors import SessionNotFoundError
|
from llama_stack.apis.common.errors import SessionNotFoundError
|
||||||
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
|
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||||
from llama_stack.core.access_control.datatypes import AccessRule
|
from llama_stack.core.access_control.conditions import User as ProtocolUser
|
||||||
|
from llama_stack.core.access_control.datatypes import AccessRule, Action
|
||||||
from llama_stack.core.datatypes import User
|
from llama_stack.core.datatypes import User
|
||||||
from llama_stack.core.request_headers import get_authenticated_user
|
from llama_stack.core.request_headers import get_authenticated_user
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
@ -33,6 +35,15 @@ class AgentInfo(AgentConfig):
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SessionResource:
|
||||||
|
"""Concrete implementation of ProtectedResource for session access control."""
|
||||||
|
|
||||||
|
type: str
|
||||||
|
identifier: str
|
||||||
|
owner: ProtocolUser # Use the protocol type for structural compatibility
|
||||||
|
|
||||||
|
|
||||||
class AgentPersistence:
|
class AgentPersistence:
|
||||||
def __init__(self, agent_id: str, kvstore: KVStore, policy: list[AccessRule]):
|
def __init__(self, agent_id: str, kvstore: KVStore, policy: list[AccessRule]):
|
||||||
self.agent_id = agent_id
|
self.agent_id = agent_id
|
||||||
|
|
@ -53,8 +64,15 @@ class AgentPersistence:
|
||||||
turns=[],
|
turns=[],
|
||||||
identifier=name, # should this be qualified in any way?
|
identifier=name, # should this be qualified in any way?
|
||||||
)
|
)
|
||||||
if not is_action_allowed(self.policy, "create", session_info, user):
|
# Only perform access control if we have an authenticated user
|
||||||
raise AccessDeniedError("create", session_info, user)
|
if user is not None and session_info.identifier is not None:
|
||||||
|
resource = SessionResource(
|
||||||
|
type=session_info.type,
|
||||||
|
identifier=session_info.identifier,
|
||||||
|
owner=user,
|
||||||
|
)
|
||||||
|
if not is_action_allowed(self.policy, Action.CREATE, resource, user):
|
||||||
|
raise AccessDeniedError(Action.CREATE, resource, user)
|
||||||
|
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
key=f"session:{self.agent_id}:{session_id}",
|
key=f"session:{self.agent_id}:{session_id}",
|
||||||
|
|
@ -62,7 +80,7 @@ class AgentPersistence:
|
||||||
)
|
)
|
||||||
return session_id
|
return session_id
|
||||||
|
|
||||||
async def get_session_info(self, session_id: str) -> AgentSessionInfo:
|
async def get_session_info(self, session_id: str) -> AgentSessionInfo | None:
|
||||||
value = await self.kvstore.get(
|
value = await self.kvstore.get(
|
||||||
key=f"session:{self.agent_id}:{session_id}",
|
key=f"session:{self.agent_id}:{session_id}",
|
||||||
)
|
)
|
||||||
|
|
@ -83,7 +101,22 @@ class AgentPersistence:
|
||||||
if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"):
|
if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return is_action_allowed(self.policy, "read", session_info, get_authenticated_user())
|
# Get current user - if None, skip access control (e.g., in tests)
|
||||||
|
user = get_authenticated_user()
|
||||||
|
if user is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Access control requires identifier and owner to be set
|
||||||
|
if session_info.identifier is None or session_info.owner is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# At this point, both identifier and owner are guaranteed to be non-None
|
||||||
|
resource = SessionResource(
|
||||||
|
type=session_info.type,
|
||||||
|
identifier=session_info.identifier,
|
||||||
|
owner=session_info.owner,
|
||||||
|
)
|
||||||
|
return is_action_allowed(self.policy, Action.READ, resource, user)
|
||||||
|
|
||||||
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
|
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
|
||||||
"""Get session info if the user has access to it. For internal use by sub-session methods."""
|
"""Get session info if the user has access to it. For internal use by sub-session methods."""
|
||||||
|
|
|
||||||
|
|
@ -254,7 +254,7 @@ class OpenAIResponsesImpl:
|
||||||
tools: list[OpenAIResponseInputTool] | None = None,
|
tools: list[OpenAIResponseInputTool] | None = None,
|
||||||
include: list[str] | None = None,
|
include: list[str] | None = None,
|
||||||
max_infer_iters: int | None = 10,
|
max_infer_iters: int | None = 10,
|
||||||
guardrails: list[ResponseGuardrailSpec] | None = None,
|
guardrails: list[str | ResponseGuardrailSpec] | None = None,
|
||||||
):
|
):
|
||||||
stream = bool(stream)
|
stream = bool(stream)
|
||||||
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
|
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
from llama_stack.apis.inference import OpenAIMessageParam
|
||||||
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||||
from llama_stack.core.telemetry import tracing
|
from llama_stack.core.telemetry import tracing
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
@ -31,7 +31,7 @@ class ShieldRunnerMixin:
|
||||||
self.input_shields = input_shields
|
self.input_shields = input_shields
|
||||||
self.output_shields = output_shields
|
self.output_shields = output_shields
|
||||||
|
|
||||||
async def run_multiple_shields(self, messages: list[Message], identifiers: list[str]) -> None:
|
async def run_multiple_shields(self, messages: list[OpenAIMessageParam], identifiers: list[str]) -> None:
|
||||||
async def run_shield_with_span(identifier: str):
|
async def run_shield_with_span(identifier: str):
|
||||||
async with tracing.span(f"run_shield_{identifier}"):
|
async with tracing.span(f"run_shield_{identifier}"):
|
||||||
return await self.safety_api.run_shield(
|
return await self.safety_api.run_shield(
|
||||||
|
|
|
||||||
|
|
@ -192,18 +192,18 @@ async def test_create_agent_session_persistence(agents_impl, sample_agent_config
|
||||||
assert session_response.session_id is not None
|
assert session_response.session_id is not None
|
||||||
|
|
||||||
# Verify the session was stored
|
# Verify the session was stored
|
||||||
session = await agents_impl.get_agents_session(agent_id, session_response.session_id)
|
session = await agents_impl.get_agents_session(session_response.session_id, agent_id)
|
||||||
assert session.session_name == "test_session"
|
assert session.session_name == "test_session"
|
||||||
assert session.session_id == session_response.session_id
|
assert session.session_id == session_response.session_id
|
||||||
assert session.started_at is not None
|
assert session.started_at is not None
|
||||||
assert session.turns == []
|
assert session.turns == []
|
||||||
|
|
||||||
# Delete the session
|
# Delete the session
|
||||||
await agents_impl.delete_agents_session(agent_id, session_response.session_id)
|
await agents_impl.delete_agents_session(session_response.session_id, agent_id)
|
||||||
|
|
||||||
# Verify the session was deleted
|
# Verify the session was deleted
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await agents_impl.get_agents_session(agent_id, session_response.session_id)
|
await agents_impl.get_agents_session(session_response.session_id, agent_id)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("enable_session_persistence", [True, False])
|
@pytest.mark.parametrize("enable_session_persistence", [True, False])
|
||||||
|
|
@ -226,11 +226,11 @@ async def test_list_agent_sessions_persistence(agents_impl, sample_agent_config,
|
||||||
assert session2.session_id in session_ids
|
assert session2.session_id in session_ids
|
||||||
|
|
||||||
# Delete one session
|
# Delete one session
|
||||||
await agents_impl.delete_agents_session(agent_id, session1.session_id)
|
await agents_impl.delete_agents_session(session1.session_id, agent_id)
|
||||||
|
|
||||||
# Verify the session was deleted
|
# Verify the session was deleted
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await agents_impl.get_agents_session(agent_id, session1.session_id)
|
await agents_impl.get_agents_session(session1.session_id, agent_id)
|
||||||
|
|
||||||
# List sessions again
|
# List sessions again
|
||||||
sessions = await agents_impl.list_agent_sessions(agent_id)
|
sessions = await agents_impl.list_agent_sessions(agent_id)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue