mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
chore(api): add mypy coverage to meta_reference
Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
parent
1d8c00635c
commit
f617a28164
3 changed files with 55 additions and 29 deletions
|
@ -91,7 +91,7 @@ async def _convert_response_content_to_chat_content(
|
|||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
converted_parts = []
|
||||
converted_parts: list[OpenAIChatCompletionContentPartParam] = []
|
||||
for content_part in content:
|
||||
if isinstance(content_part, OpenAIResponseInputMessageContentText):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
|
||||
|
@ -136,7 +136,7 @@ async def _convert_response_input_to_chat_messages(
|
|||
),
|
||||
)
|
||||
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
|
||||
else:
|
||||
elif isinstance(input_item, OpenAIResponseMessage):
|
||||
content = await _convert_response_content_to_chat_content(input_item.content)
|
||||
message_type = await _get_message_type_by_role(input_item.role)
|
||||
if message_type is None:
|
||||
|
@ -144,6 +144,11 @@ async def _convert_response_input_to_chat_messages(
|
|||
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
|
||||
)
|
||||
messages.append(message_type(content=content))
|
||||
else:
|
||||
# Handle other tool call types that don't have content/role attributes
|
||||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support input item type '{type(input_item)}' in this context"
|
||||
)
|
||||
else:
|
||||
messages.append(OpenAIUserMessageParam(content=input))
|
||||
return messages
|
||||
|
@ -175,13 +180,17 @@ async def _convert_response_text_to_chat_response_format(text: OpenAIResponseTex
|
|||
"""
|
||||
Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format.
|
||||
"""
|
||||
if not text.format or text.format["type"] == "text":
|
||||
if not text.format or text.format.get("type") == "text":
|
||||
return OpenAIResponseFormatText(type="text")
|
||||
if text.format["type"] == "json_object":
|
||||
if text.format.get("type") == "json_object":
|
||||
return OpenAIResponseFormatJSONObject()
|
||||
if text.format["type"] == "json_schema":
|
||||
if text.format.get("type") == "json_schema":
|
||||
name = text.format.get("name")
|
||||
schema = text.format.get("schema")
|
||||
if name is None or schema is None:
|
||||
raise ValueError(f"json_schema format requires both name and schema fields")
|
||||
return OpenAIResponseFormatJSONSchema(
|
||||
json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"])
|
||||
json_schema=OpenAIJSONSchema(name=name, schema=schema)
|
||||
)
|
||||
raise ValueError(f"Unsupported text format: {text.format}")
|
||||
|
||||
|
@ -472,8 +481,9 @@ class OpenAIResponsesImpl:
|
|||
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
|
||||
if response_tool_call:
|
||||
# Don't attempt to concatenate arguments if we don't have any new argumentsAdd commentMore actions
|
||||
if tool_call.function.arguments:
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
# Guard against an initial None argument before we concatenate
|
||||
if response_tool_call.function:
|
||||
response_tool_call.function.arguments = (
|
||||
response_tool_call.function.arguments or ""
|
||||
) + tool_call.function.arguments
|
||||
|
@ -530,6 +540,7 @@ class OpenAIResponsesImpl:
|
|||
next_turn_messages.append(tool_response_message)
|
||||
|
||||
for tool_call in function_tool_calls:
|
||||
if tool_call.function:
|
||||
output_messages.append(
|
||||
OpenAIResponseOutputMessageFunctionToolCall(
|
||||
arguments=tool_call.function.arguments or "",
|
||||
|
@ -602,7 +613,7 @@ class OpenAIResponsesImpl:
|
|||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool.parameters
|
||||
for param in (tool.parameters or [])
|
||||
},
|
||||
)
|
||||
return convert_tooldef_to_openai_tool(tool_def)
|
||||
|
|
|
@ -11,7 +11,8 @@ from datetime import UTC, datetime
|
|||
|
||||
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
|
||||
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||
from llama_stack.distribution.access_control.datatypes import AccessRule
|
||||
from llama_stack.distribution.access_control.conditions import User as ProtectedResourceUser
|
||||
from llama_stack.distribution.access_control.datatypes import AccessRule, Action
|
||||
from llama_stack.distribution.datatypes import User
|
||||
from llama_stack.distribution.request_headers import get_authenticated_user
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
@ -23,8 +24,8 @@ class AgentSessionInfo(Session):
|
|||
# TODO: is this used anywhere?
|
||||
vector_db_id: str | None = None
|
||||
started_at: datetime
|
||||
owner: User | None = None
|
||||
identifier: str | None = None
|
||||
owner: ProtectedResourceUser
|
||||
identifier: str
|
||||
type: str = "session"
|
||||
|
||||
|
||||
|
@ -44,16 +45,20 @@ class AgentPersistence:
|
|||
# Get current user's auth attributes for new sessions
|
||||
user = get_authenticated_user()
|
||||
|
||||
# If no user is authenticated, create a default user for backward compatibility
|
||||
if user is None:
|
||||
user = User(principal="anonymous", attributes=None)
|
||||
|
||||
session_info = AgentSessionInfo(
|
||||
session_id=session_id,
|
||||
session_name=name,
|
||||
started_at=datetime.now(UTC),
|
||||
owner=user,
|
||||
owner=user, # User from datatypes is compatible with ProtectedResourceUser protocol
|
||||
turns=[],
|
||||
identifier=name, # should this be qualified in any way?
|
||||
)
|
||||
if not is_action_allowed(self.policy, "create", session_info, user):
|
||||
raise AccessDeniedError("create", session_info, user)
|
||||
if not is_action_allowed(self.policy, Action.CREATE, session_info, user):
|
||||
raise AccessDeniedError(Action.CREATE, session_info, user)
|
||||
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
|
@ -68,7 +73,17 @@ class AgentPersistence:
|
|||
if not value:
|
||||
return None
|
||||
|
||||
session_info = AgentSessionInfo(**json.loads(value))
|
||||
session_data = json.loads(value)
|
||||
|
||||
# Handle backward compatibility for sessions without owner field
|
||||
if "owner" not in session_data or session_data["owner"] is None:
|
||||
session_data["owner"] = User(principal="anonymous", attributes=None)
|
||||
|
||||
# Handle backward compatibility for sessions without identifier field
|
||||
if "identifier" not in session_data or session_data["identifier"] is None:
|
||||
session_data["identifier"] = session_data.get("session_name", "unknown")
|
||||
|
||||
session_info = AgentSessionInfo(**session_data)
|
||||
|
||||
# Check access to session
|
||||
if not self._check_session_access(session_info):
|
||||
|
@ -79,10 +94,10 @@ class AgentPersistence:
|
|||
def _check_session_access(self, session_info: AgentSessionInfo) -> bool:
|
||||
"""Check if current user has access to the session."""
|
||||
# Handle backward compatibility for old sessions without access control
|
||||
if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"):
|
||||
if not hasattr(session_info, "access_attributes"):
|
||||
return True
|
||||
|
||||
return is_action_allowed(self.policy, "read", session_info, get_authenticated_user())
|
||||
return is_action_allowed(self.policy, Action.READ, session_info, get_authenticated_user())
|
||||
|
||||
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
|
||||
"""Get session info if the user has access to it. For internal use by sub-session methods."""
|
||||
|
|
|
@ -242,7 +242,7 @@ exclude = [
|
|||
"^llama_stack/models/llama/llama3/interface\\.py$",
|
||||
"^llama_stack/models/llama/llama3/tokenizer\\.py$",
|
||||
"^llama_stack/models/llama/llama3/tool_utils\\.py$",
|
||||
"^llama_stack/providers/inline/agents/meta_reference/",
|
||||
"^llama_stack/models/llama/llama3_3/prompts\\.py$",
|
||||
"^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$",
|
||||
"^llama_stack/providers/inline/agents/meta_reference/agents\\.py$",
|
||||
"^llama_stack/providers/inline/datasetio/localfs/",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue