mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-02 20:40:36 +00:00
chore: enable pyupgrade fixes (#1806)
# What does this PR do? The goal of this PR is code base modernization. Schema reflection code needed a minor adjustment to handle UnionTypes and collections.abc.AsyncIterator. (Both are preferred for latest Python releases.) Note to reviewers: almost all changes here are automatically generated by pyupgrade. Some additional unused imports were cleaned up. The only change worth of note can be found under `docs/openapi_generator` and `llama_stack/strong_typing/schema.py` where reflection code was updated to deal with "newer" types. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
ffe3d0b2cd
commit
9e6561a1ec
319 changed files with 2843 additions and 3033 deletions
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import MetaReferenceAgentsImplConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, Any]):
|
||||
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any]):
|
||||
from .agents import MetaReferenceAgentsImpl
|
||||
|
||||
impl = MetaReferenceAgentsImpl(
|
||||
|
|
|
@ -10,8 +10,8 @@ import re
|
|||
import secrets
|
||||
import string
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timezone
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -112,7 +112,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
output_shields=agent_config.output_shields,
|
||||
)
|
||||
|
||||
def turn_to_messages(self, turn: Turn) -> List[Message]:
|
||||
def turn_to_messages(self, turn: Turn) -> list[Message]:
|
||||
messages = []
|
||||
|
||||
# NOTE: if a toolcall response is in a step, we do not add it when processing the input messages
|
||||
|
@ -161,7 +161,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async def create_session(self, name: str) -> str:
|
||||
return await self.storage.create_session(name)
|
||||
|
||||
async def get_messages_from_turns(self, turns: List[Turn]) -> List[Message]:
|
||||
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
|
||||
messages = []
|
||||
if self.agent_config.instructions != "":
|
||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||
|
@ -201,8 +201,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
async def _run_turn(
|
||||
self,
|
||||
request: Union[AgentTurnCreateRequest, AgentTurnResumeRequest],
|
||||
turn_id: Optional[str] = None,
|
||||
request: AgentTurnCreateRequest | AgentTurnResumeRequest,
|
||||
turn_id: str | None = None,
|
||||
) -> AsyncGenerator:
|
||||
assert request.stream is True, "Non-streaming not supported"
|
||||
|
||||
|
@ -321,10 +321,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
self,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
input_messages: List[Message],
|
||||
input_messages: list[Message],
|
||||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
documents: list[Document] | None = None,
|
||||
) -> AsyncGenerator:
|
||||
# Doing async generators makes downstream code much simpler and everything amenable to
|
||||
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
||||
|
@ -374,8 +374,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async def run_multiple_shields_wrapper(
|
||||
self,
|
||||
turn_id: str,
|
||||
messages: List[Message],
|
||||
shields: List[str],
|
||||
messages: list[Message],
|
||||
shields: list[str],
|
||||
touchpoint: str,
|
||||
) -> AsyncGenerator:
|
||||
async with tracing.span("run_shields") as span:
|
||||
|
@ -443,10 +443,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
self,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
input_messages: List[Message],
|
||||
input_messages: list[Message],
|
||||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
documents: list[Document] | None = None,
|
||||
) -> AsyncGenerator:
|
||||
# if document is passed in a turn, we parse the raw text of the document
|
||||
# and sent it as a user message
|
||||
|
@ -760,7 +760,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
async def _initialize_tools(
|
||||
self,
|
||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||
toolgroups_for_turn: list[AgentToolGroup] | None = None,
|
||||
) -> None:
|
||||
toolgroup_to_args = {}
|
||||
for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []):
|
||||
|
@ -847,7 +847,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_name_to_args,
|
||||
)
|
||||
|
||||
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]:
|
||||
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, str | None]:
|
||||
"""Parse a toolgroup name into its components.
|
||||
|
||||
Args:
|
||||
|
@ -921,7 +921,7 @@ async def get_raw_document_text(document: Document) -> str:
|
|||
|
||||
def _interpret_content_as_attachment(
|
||||
content: str,
|
||||
) -> Optional[Attachment]:
|
||||
) -> Attachment | None:
|
||||
match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content)
|
||||
if match:
|
||||
snippet = match.group(1)
|
||||
|
|
|
@ -8,7 +8,7 @@ import json
|
|||
import logging
|
||||
import shutil
|
||||
import uuid
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
Agent,
|
||||
|
@ -142,16 +142,11 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
messages: List[
|
||||
Union[
|
||||
UserMessage,
|
||||
ToolResponseMessage,
|
||||
]
|
||||
],
|
||||
toolgroups: Optional[List[AgentToolGroup]] = None,
|
||||
documents: Optional[List[Document]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
messages: list[UserMessage | ToolResponseMessage],
|
||||
toolgroups: list[AgentToolGroup] | None = None,
|
||||
documents: list[Document] | None = None,
|
||||
stream: bool | None = False,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnCreateRequest(
|
||||
agent_id=agent_id,
|
||||
|
@ -180,8 +175,8 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
tool_responses: List[ToolResponse],
|
||||
stream: Optional[bool] = False,
|
||||
tool_responses: list[ToolResponse],
|
||||
stream: bool | None = False,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnResumeRequest(
|
||||
agent_id=agent_id,
|
||||
|
@ -219,7 +214,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_ids: Optional[List[str]] = None,
|
||||
turn_ids: list[str] | None = None,
|
||||
) -> Session:
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
session_info = await agent.storage.get_session_info(session_id)
|
||||
|
@ -265,13 +260,13 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
|
||||
async def create_openai_response(
|
||||
self,
|
||||
input: Union[str, List[OpenAIResponseInputMessage]],
|
||||
input: str | list[OpenAIResponseInputMessage],
|
||||
model: str,
|
||||
previous_response_id: Optional[str] = None,
|
||||
store: Optional[bool] = True,
|
||||
stream: Optional[bool] = False,
|
||||
temperature: Optional[float] = None,
|
||||
tools: Optional[List[OpenAIResponseInputTool]] = None,
|
||||
previous_response_id: str | None = None,
|
||||
store: bool | None = True,
|
||||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
) -> OpenAIResponseObject:
|
||||
return await self.openai_responses_impl.create_openai_response(
|
||||
input, model, previous_response_id, store, stream, temperature, tools
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -16,7 +16,7 @@ class MetaReferenceAgentsImplConfig(BaseModel):
|
|||
persistence_store: KVStoreConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
return {
|
||||
"persistence_store": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
|
|
|
@ -6,7 +6,8 @@
|
|||
|
||||
import json
|
||||
import uuid
|
||||
from typing import AsyncIterator, List, Optional, Union, cast
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import cast
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
|
@ -49,15 +50,15 @@ logger = get_logger(name=__name__, category="openai_responses")
|
|||
OPENAI_RESPONSES_PREFIX = "openai_responses:"
|
||||
|
||||
|
||||
async def _previous_response_to_messages(previous_response: OpenAIResponseObject) -> List[OpenAIMessageParam]:
|
||||
messages: List[OpenAIMessageParam] = []
|
||||
async def _previous_response_to_messages(previous_response: OpenAIResponseObject) -> list[OpenAIMessageParam]:
|
||||
messages: list[OpenAIMessageParam] = []
|
||||
for output_message in previous_response.output:
|
||||
if isinstance(output_message, OpenAIResponseOutputMessage):
|
||||
messages.append(OpenAIAssistantMessageParam(content=output_message.content[0].text))
|
||||
return messages
|
||||
|
||||
|
||||
async def _openai_choices_to_output_messages(choices: List[OpenAIChoice]) -> List[OpenAIResponseOutputMessage]:
|
||||
async def _openai_choices_to_output_messages(choices: list[OpenAIChoice]) -> list[OpenAIResponseOutputMessage]:
|
||||
output_messages = []
|
||||
for choice in choices:
|
||||
output_content = ""
|
||||
|
@ -101,22 +102,22 @@ class OpenAIResponsesImpl:
|
|||
|
||||
async def create_openai_response(
|
||||
self,
|
||||
input: Union[str, List[OpenAIResponseInputMessage]],
|
||||
input: str | list[OpenAIResponseInputMessage],
|
||||
model: str,
|
||||
previous_response_id: Optional[str] = None,
|
||||
store: Optional[bool] = True,
|
||||
stream: Optional[bool] = False,
|
||||
temperature: Optional[float] = None,
|
||||
tools: Optional[List[OpenAIResponseInputTool]] = None,
|
||||
previous_response_id: str | None = None,
|
||||
store: bool | None = True,
|
||||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
):
|
||||
stream = False if stream is None else stream
|
||||
|
||||
messages: List[OpenAIMessageParam] = []
|
||||
messages: list[OpenAIMessageParam] = []
|
||||
if previous_response_id:
|
||||
previous_response = await self.get_openai_response(previous_response_id)
|
||||
messages.extend(await _previous_response_to_messages(previous_response))
|
||||
# TODO: refactor this user_content parsing out into a separate method
|
||||
user_content: Union[str, List[OpenAIChatCompletionContentPartParam]] = ""
|
||||
user_content: str | list[OpenAIChatCompletionContentPartParam] = ""
|
||||
if isinstance(input, list):
|
||||
user_content = []
|
||||
for user_input in input:
|
||||
|
@ -179,7 +180,7 @@ class OpenAIResponsesImpl:
|
|||
# dump and reload to map to our pydantic types
|
||||
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
|
||||
|
||||
output_messages: List[OpenAIResponseOutput] = []
|
||||
output_messages: list[OpenAIResponseOutput] = []
|
||||
if chat_response.choices[0].message.tool_calls:
|
||||
output_messages.extend(
|
||||
await self._execute_tool_and_return_final_output(model, stream, chat_response, messages, temperature)
|
||||
|
@ -215,9 +216,9 @@ class OpenAIResponsesImpl:
|
|||
return response
|
||||
|
||||
async def _convert_response_tools_to_chat_tools(
|
||||
self, tools: List[OpenAIResponseInputTool]
|
||||
) -> List[ChatCompletionToolParam]:
|
||||
chat_tools: List[ChatCompletionToolParam] = []
|
||||
self, tools: list[OpenAIResponseInputTool]
|
||||
) -> list[ChatCompletionToolParam]:
|
||||
chat_tools: list[ChatCompletionToolParam] = []
|
||||
for input_tool in tools:
|
||||
# TODO: Handle other tool types
|
||||
if input_tool.type == "web_search":
|
||||
|
@ -247,10 +248,10 @@ class OpenAIResponsesImpl:
|
|||
model_id: str,
|
||||
stream: bool,
|
||||
chat_response: OpenAIChatCompletion,
|
||||
messages: List[OpenAIMessageParam],
|
||||
messages: list[OpenAIMessageParam],
|
||||
temperature: float,
|
||||
) -> List[OpenAIResponseOutput]:
|
||||
output_messages: List[OpenAIResponseOutput] = []
|
||||
) -> list[OpenAIResponseOutput]:
|
||||
output_messages: list[OpenAIResponseOutput] = []
|
||||
choice = chat_response.choices[0]
|
||||
|
||||
# If the choice is not an assistant message, we don't need to execute any tools
|
||||
|
@ -314,7 +315,7 @@ class OpenAIResponsesImpl:
|
|||
async def _execute_tool_call(
|
||||
self,
|
||||
function: OpenAIChatCompletionToolCallFunction,
|
||||
) -> Optional[ToolInvocationResult]:
|
||||
) -> ToolInvocationResult | None:
|
||||
if not function.name:
|
||||
return None
|
||||
function_args = json.loads(function.arguments) if function.arguments else {}
|
||||
|
|
|
@ -8,7 +8,6 @@ import json
|
|||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -25,9 +24,9 @@ class AgentSessionInfo(BaseModel):
|
|||
session_id: str
|
||||
session_name: str
|
||||
# TODO: is this used anywhere?
|
||||
vector_db_id: Optional[str] = None
|
||||
vector_db_id: str | None = None
|
||||
started_at: datetime
|
||||
access_attributes: Optional[AccessAttributes] = None
|
||||
access_attributes: AccessAttributes | None = None
|
||||
|
||||
|
||||
class AgentPersistence:
|
||||
|
@ -55,7 +54,7 @@ class AgentPersistence:
|
|||
)
|
||||
return session_id
|
||||
|
||||
async def get_session_info(self, session_id: str) -> Optional[AgentSessionInfo]:
|
||||
async def get_session_info(self, session_id: str) -> AgentSessionInfo | None:
|
||||
value = await self.kvstore.get(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
)
|
||||
|
@ -78,7 +77,7 @@ class AgentPersistence:
|
|||
|
||||
return check_access(session_info.session_id, session_info.access_attributes, get_auth_attributes())
|
||||
|
||||
async def get_session_if_accessible(self, session_id: str) -> Optional[AgentSessionInfo]:
|
||||
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
|
||||
"""Get session info if the user has access to it. For internal use by sub-session methods."""
|
||||
session_info = await self.get_session_info(session_id)
|
||||
if not session_info:
|
||||
|
@ -106,7 +105,7 @@ class AgentPersistence:
|
|||
value=turn.model_dump_json(),
|
||||
)
|
||||
|
||||
async def get_session_turns(self, session_id: str) -> List[Turn]:
|
||||
async def get_session_turns(self, session_id: str) -> list[Turn]:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise ValueError(f"Session {session_id} not found or access denied")
|
||||
|
||||
|
@ -125,7 +124,7 @@ class AgentPersistence:
|
|||
turns.sort(key=lambda x: (x.completed_at or datetime.min))
|
||||
return turns
|
||||
|
||||
async def get_session_turn(self, session_id: str, turn_id: str) -> Optional[Turn]:
|
||||
async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise ValueError(f"Session {session_id} not found or access denied")
|
||||
|
||||
|
@ -145,7 +144,7 @@ class AgentPersistence:
|
|||
value=step.model_dump_json(),
|
||||
)
|
||||
|
||||
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]:
|
||||
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> ToolExecutionStep | None:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
return None
|
||||
|
||||
|
@ -163,7 +162,7 @@ class AgentPersistence:
|
|||
value=str(num_infer_iters),
|
||||
)
|
||||
|
||||
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]:
|
||||
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> int | None:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
return None
|
||||
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||
|
@ -25,14 +24,14 @@ class ShieldRunnerMixin:
|
|||
def __init__(
|
||||
self,
|
||||
safety_api: Safety,
|
||||
input_shields: List[str] = None,
|
||||
output_shields: List[str] = None,
|
||||
input_shields: list[str] = None,
|
||||
output_shields: list[str] = None,
|
||||
):
|
||||
self.safety_api = safety_api
|
||||
self.input_shields = input_shields
|
||||
self.output_shields = output_shields
|
||||
|
||||
async def run_multiple_shields(self, messages: List[Message], identifiers: List[str]) -> None:
|
||||
async def run_multiple_shields(self, messages: list[Message], identifiers: list[str]) -> None:
|
||||
async def run_shield_with_span(identifier: str):
|
||||
async with tracing.span(f"run_shield_{identifier}"):
|
||||
return await self.safety_api.run_shield(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue