From f617a28164b59b64c94c6efdb38952278c132d9b Mon Sep 17 00:00:00 2001 From: Mustafa Elbehery Date: Tue, 8 Jul 2025 20:55:17 +0200 Subject: [PATCH] chore(api): add mypy coverage to meta_reference Signed-off-by: Mustafa Elbehery --- .../agents/meta_reference/openai_responses.py | 49 ++++++++++++------- .../agents/meta_reference/persistence.py | 33 +++++++++---- pyproject.toml | 2 +- 3 files changed, 55 insertions(+), 29 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py index 7eb2b3897..34669cb53 100644 --- a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py @@ -91,7 +91,7 @@ async def _convert_response_content_to_chat_content( if isinstance(content, str): return content - converted_parts = [] + converted_parts: list[OpenAIChatCompletionContentPartParam] = [] for content_part in content: if isinstance(content_part, OpenAIResponseInputMessageContentText): converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text)) @@ -136,7 +136,7 @@ async def _convert_response_input_to_chat_messages( ), ) messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call])) - else: + elif isinstance(input_item, OpenAIResponseMessage): content = await _convert_response_content_to_chat_content(input_item.content) message_type = await _get_message_type_by_role(input_item.role) if message_type is None: @@ -144,6 +144,11 @@ async def _convert_response_input_to_chat_messages( f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context" ) messages.append(message_type(content=content)) + else: + # Handle other tool call types that don't have content/role attributes + raise ValueError( + f"Llama Stack OpenAI Responses does not yet support input item type '{type(input_item)}' in this context" + ) else: messages.append(OpenAIUserMessageParam(content=input)) return messages @@ -175,13 +180,17 @@ async def _convert_response_text_to_chat_response_format(text: OpenAIResponseTex """ Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format. """ - if not text.format or text.format["type"] == "text": + if not text.format or text.format.get("type") == "text": return OpenAIResponseFormatText(type="text") - if text.format["type"] == "json_object": + if text.format.get("type") == "json_object": return OpenAIResponseFormatJSONObject() - if text.format["type"] == "json_schema": + if text.format.get("type") == "json_schema": + name = text.format.get("name") + schema = text.format.get("schema") + if name is None or schema is None: + raise ValueError(f"json_schema format requires both name and schema fields") return OpenAIResponseFormatJSONSchema( - json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"]) + json_schema=OpenAIJSONSchema(name=name, schema=schema) ) raise ValueError(f"Unsupported text format: {text.format}") @@ -472,11 +481,12 @@ class OpenAIResponsesImpl: response_tool_call = chat_response_tool_calls.get(tool_call.index, None) if response_tool_call: # Don't attempt to concatenate arguments if we don't have any new argumentsAdd commentMore actions - if tool_call.function.arguments: + if tool_call.function and tool_call.function.arguments: # Guard against an initial None argument before we concatenate - response_tool_call.function.arguments = ( - response_tool_call.function.arguments or "" - ) + tool_call.function.arguments + if response_tool_call.function: + response_tool_call.function.arguments = ( + response_tool_call.function.arguments or "" + ) + tool_call.function.arguments else: tool_call_dict: dict[str, Any] = tool_call.model_dump() tool_call_dict.pop("type", None) @@ -530,15 +540,16 @@ class OpenAIResponsesImpl: next_turn_messages.append(tool_response_message) for tool_call in function_tool_calls: - output_messages.append( - OpenAIResponseOutputMessageFunctionToolCall( - arguments=tool_call.function.arguments or "", - call_id=tool_call.id, - name=tool_call.function.name or "", - id=f"fc_{uuid.uuid4()}", - status="completed", + if tool_call.function: + output_messages.append( + OpenAIResponseOutputMessageFunctionToolCall( + arguments=tool_call.function.arguments or "", + call_id=tool_call.id, + name=tool_call.function.name or "", + id=f"fc_{uuid.uuid4()}", + status="completed", + ) ) - ) if not function_tool_calls and not non_function_tool_calls: break @@ -602,7 +613,7 @@ class OpenAIResponsesImpl: required=param.required, default=param.default, ) - for param in tool.parameters + for param in (tool.parameters or []) }, ) return convert_tooldef_to_openai_tool(tool_def) diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index cda535937..beb0aa748 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -11,7 +11,8 @@ from datetime import UTC, datetime from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed -from llama_stack.distribution.access_control.datatypes import AccessRule +from llama_stack.distribution.access_control.conditions import User as ProtectedResourceUser +from llama_stack.distribution.access_control.datatypes import AccessRule, Action from llama_stack.distribution.datatypes import User from llama_stack.distribution.request_headers import get_authenticated_user from llama_stack.providers.utils.kvstore import KVStore @@ -23,8 +24,8 @@ class AgentSessionInfo(Session): # TODO: is this used anywhere? vector_db_id: str | None = None started_at: datetime - owner: User | None = None - identifier: str | None = None + owner: ProtectedResourceUser + identifier: str type: str = "session" @@ -43,17 +44,21 @@ class AgentPersistence: # Get current user's auth attributes for new sessions user = get_authenticated_user() + + # If no user is authenticated, create a default user for backward compatibility + if user is None: + user = User(principal="anonymous", attributes=None) session_info = AgentSessionInfo( session_id=session_id, session_name=name, started_at=datetime.now(UTC), - owner=user, + owner=user, # User from datatypes is compatible with ProtectedResourceUser protocol turns=[], identifier=name, # should this be qualified in any way? ) - if not is_action_allowed(self.policy, "create", session_info, user): - raise AccessDeniedError("create", session_info, user) + if not is_action_allowed(self.policy, Action.CREATE, session_info, user): + raise AccessDeniedError(Action.CREATE, session_info, user) await self.kvstore.set( key=f"session:{self.agent_id}:{session_id}", @@ -68,7 +73,17 @@ class AgentPersistence: if not value: return None - session_info = AgentSessionInfo(**json.loads(value)) + session_data = json.loads(value) + + # Handle backward compatibility for sessions without owner field + if "owner" not in session_data or session_data["owner"] is None: + session_data["owner"] = User(principal="anonymous", attributes=None) + + # Handle backward compatibility for sessions without identifier field + if "identifier" not in session_data or session_data["identifier"] is None: + session_data["identifier"] = session_data.get("session_name", "unknown") + + session_info = AgentSessionInfo(**session_data) # Check access to session if not self._check_session_access(session_info): @@ -79,10 +94,10 @@ class AgentPersistence: def _check_session_access(self, session_info: AgentSessionInfo) -> bool: """Check if current user has access to the session.""" # Handle backward compatibility for old sessions without access control - if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"): + if not hasattr(session_info, "access_attributes"): return True - return is_action_allowed(self.policy, "read", session_info, get_authenticated_user()) + return is_action_allowed(self.policy, Action.READ, session_info, get_authenticated_user()) async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None: """Get session info if the user has access to it. For internal use by sub-session methods.""" diff --git a/pyproject.toml b/pyproject.toml index 04a6a685e..a8667b5e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -242,7 +242,7 @@ exclude = [ "^llama_stack/models/llama/llama3/interface\\.py$", "^llama_stack/models/llama/llama3/tokenizer\\.py$", "^llama_stack/models/llama/llama3/tool_utils\\.py$", - "^llama_stack/providers/inline/agents/meta_reference/", + "^llama_stack/models/llama/llama3_3/prompts\\.py$", "^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$", "^llama_stack/providers/inline/agents/meta_reference/agents\\.py$", "^llama_stack/providers/inline/datasetio/localfs/",