chore(api): add mypy coverage to meta_reference

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
Mustafa Elbehery 2025-07-08 20:55:17 +02:00
parent 1d8c00635c
commit f617a28164
3 changed files with 55 additions and 29 deletions

View file

@ -91,7 +91,7 @@ async def _convert_response_content_to_chat_content(
if isinstance(content, str): if isinstance(content, str):
return content return content
converted_parts = [] converted_parts: list[OpenAIChatCompletionContentPartParam] = []
for content_part in content: for content_part in content:
if isinstance(content_part, OpenAIResponseInputMessageContentText): if isinstance(content_part, OpenAIResponseInputMessageContentText):
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text)) converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
@ -136,7 +136,7 @@ async def _convert_response_input_to_chat_messages(
), ),
) )
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call])) messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
else: elif isinstance(input_item, OpenAIResponseMessage):
content = await _convert_response_content_to_chat_content(input_item.content) content = await _convert_response_content_to_chat_content(input_item.content)
message_type = await _get_message_type_by_role(input_item.role) message_type = await _get_message_type_by_role(input_item.role)
if message_type is None: if message_type is None:
@ -144,6 +144,11 @@ async def _convert_response_input_to_chat_messages(
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context" f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
) )
messages.append(message_type(content=content)) messages.append(message_type(content=content))
else:
# Handle other tool call types that don't have content/role attributes
raise ValueError(
f"Llama Stack OpenAI Responses does not yet support input item type '{type(input_item)}' in this context"
)
else: else:
messages.append(OpenAIUserMessageParam(content=input)) messages.append(OpenAIUserMessageParam(content=input))
return messages return messages
@ -175,13 +180,17 @@ async def _convert_response_text_to_chat_response_format(text: OpenAIResponseTex
""" """
Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format. Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format.
""" """
if not text.format or text.format["type"] == "text": if not text.format or text.format.get("type") == "text":
return OpenAIResponseFormatText(type="text") return OpenAIResponseFormatText(type="text")
if text.format["type"] == "json_object": if text.format.get("type") == "json_object":
return OpenAIResponseFormatJSONObject() return OpenAIResponseFormatJSONObject()
if text.format["type"] == "json_schema": if text.format.get("type") == "json_schema":
name = text.format.get("name")
schema = text.format.get("schema")
if name is None or schema is None:
raise ValueError(f"json_schema format requires both name and schema fields")
return OpenAIResponseFormatJSONSchema( return OpenAIResponseFormatJSONSchema(
json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"]) json_schema=OpenAIJSONSchema(name=name, schema=schema)
) )
raise ValueError(f"Unsupported text format: {text.format}") raise ValueError(f"Unsupported text format: {text.format}")
@ -472,11 +481,12 @@ class OpenAIResponsesImpl:
response_tool_call = chat_response_tool_calls.get(tool_call.index, None) response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
if response_tool_call: if response_tool_call:
# Don't attempt to concatenate arguments if we don't have any new argumentsAdd commentMore actions # Don't attempt to concatenate arguments if we don't have any new argumentsAdd commentMore actions
if tool_call.function.arguments: if tool_call.function and tool_call.function.arguments:
# Guard against an initial None argument before we concatenate # Guard against an initial None argument before we concatenate
response_tool_call.function.arguments = ( if response_tool_call.function:
response_tool_call.function.arguments or "" response_tool_call.function.arguments = (
) + tool_call.function.arguments response_tool_call.function.arguments or ""
) + tool_call.function.arguments
else: else:
tool_call_dict: dict[str, Any] = tool_call.model_dump() tool_call_dict: dict[str, Any] = tool_call.model_dump()
tool_call_dict.pop("type", None) tool_call_dict.pop("type", None)
@ -530,15 +540,16 @@ class OpenAIResponsesImpl:
next_turn_messages.append(tool_response_message) next_turn_messages.append(tool_response_message)
for tool_call in function_tool_calls: for tool_call in function_tool_calls:
output_messages.append( if tool_call.function:
OpenAIResponseOutputMessageFunctionToolCall( output_messages.append(
arguments=tool_call.function.arguments or "", OpenAIResponseOutputMessageFunctionToolCall(
call_id=tool_call.id, arguments=tool_call.function.arguments or "",
name=tool_call.function.name or "", call_id=tool_call.id,
id=f"fc_{uuid.uuid4()}", name=tool_call.function.name or "",
status="completed", id=f"fc_{uuid.uuid4()}",
status="completed",
)
) )
)
if not function_tool_calls and not non_function_tool_calls: if not function_tool_calls and not non_function_tool_calls:
break break
@ -602,7 +613,7 @@ class OpenAIResponsesImpl:
required=param.required, required=param.required,
default=param.default, default=param.default,
) )
for param in tool.parameters for param in (tool.parameters or [])
}, },
) )
return convert_tooldef_to_openai_tool(tool_def) return convert_tooldef_to_openai_tool(tool_def)

View file

@ -11,7 +11,8 @@ from datetime import UTC, datetime
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.distribution.access_control.datatypes import AccessRule from llama_stack.distribution.access_control.conditions import User as ProtectedResourceUser
from llama_stack.distribution.access_control.datatypes import AccessRule, Action
from llama_stack.distribution.datatypes import User from llama_stack.distribution.datatypes import User
from llama_stack.distribution.request_headers import get_authenticated_user from llama_stack.distribution.request_headers import get_authenticated_user
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.kvstore import KVStore
@ -23,8 +24,8 @@ class AgentSessionInfo(Session):
# TODO: is this used anywhere? # TODO: is this used anywhere?
vector_db_id: str | None = None vector_db_id: str | None = None
started_at: datetime started_at: datetime
owner: User | None = None owner: ProtectedResourceUser
identifier: str | None = None identifier: str
type: str = "session" type: str = "session"
@ -43,17 +44,21 @@ class AgentPersistence:
# Get current user's auth attributes for new sessions # Get current user's auth attributes for new sessions
user = get_authenticated_user() user = get_authenticated_user()
# If no user is authenticated, create a default user for backward compatibility
if user is None:
user = User(principal="anonymous", attributes=None)
session_info = AgentSessionInfo( session_info = AgentSessionInfo(
session_id=session_id, session_id=session_id,
session_name=name, session_name=name,
started_at=datetime.now(UTC), started_at=datetime.now(UTC),
owner=user, owner=user, # User from datatypes is compatible with ProtectedResourceUser protocol
turns=[], turns=[],
identifier=name, # should this be qualified in any way? identifier=name, # should this be qualified in any way?
) )
if not is_action_allowed(self.policy, "create", session_info, user): if not is_action_allowed(self.policy, Action.CREATE, session_info, user):
raise AccessDeniedError("create", session_info, user) raise AccessDeniedError(Action.CREATE, session_info, user)
await self.kvstore.set( await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}", key=f"session:{self.agent_id}:{session_id}",
@ -68,7 +73,17 @@ class AgentPersistence:
if not value: if not value:
return None return None
session_info = AgentSessionInfo(**json.loads(value)) session_data = json.loads(value)
# Handle backward compatibility for sessions without owner field
if "owner" not in session_data or session_data["owner"] is None:
session_data["owner"] = User(principal="anonymous", attributes=None)
# Handle backward compatibility for sessions without identifier field
if "identifier" not in session_data or session_data["identifier"] is None:
session_data["identifier"] = session_data.get("session_name", "unknown")
session_info = AgentSessionInfo(**session_data)
# Check access to session # Check access to session
if not self._check_session_access(session_info): if not self._check_session_access(session_info):
@ -79,10 +94,10 @@ class AgentPersistence:
def _check_session_access(self, session_info: AgentSessionInfo) -> bool: def _check_session_access(self, session_info: AgentSessionInfo) -> bool:
"""Check if current user has access to the session.""" """Check if current user has access to the session."""
# Handle backward compatibility for old sessions without access control # Handle backward compatibility for old sessions without access control
if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"): if not hasattr(session_info, "access_attributes"):
return True return True
return is_action_allowed(self.policy, "read", session_info, get_authenticated_user()) return is_action_allowed(self.policy, Action.READ, session_info, get_authenticated_user())
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None: async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
"""Get session info if the user has access to it. For internal use by sub-session methods.""" """Get session info if the user has access to it. For internal use by sub-session methods."""

View file

@ -242,7 +242,7 @@ exclude = [
"^llama_stack/models/llama/llama3/interface\\.py$", "^llama_stack/models/llama/llama3/interface\\.py$",
"^llama_stack/models/llama/llama3/tokenizer\\.py$", "^llama_stack/models/llama/llama3/tokenizer\\.py$",
"^llama_stack/models/llama/llama3/tool_utils\\.py$", "^llama_stack/models/llama/llama3/tool_utils\\.py$",
"^llama_stack/providers/inline/agents/meta_reference/", "^llama_stack/models/llama/llama3_3/prompts\\.py$",
"^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$", "^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$",
"^llama_stack/providers/inline/agents/meta_reference/agents\\.py$", "^llama_stack/providers/inline/agents/meta_reference/agents\\.py$",
"^llama_stack/providers/inline/datasetio/localfs/", "^llama_stack/providers/inline/datasetio/localfs/",