diff --git a/docs/openapi_generator/pyopenapi/generator.py b/docs/openapi_generator/pyopenapi/generator.py index 6d5e48a46..cc594d8d7 100644 --- a/docs/openapi_generator/pyopenapi/generator.py +++ b/docs/openapi_generator/pyopenapi/generator.py @@ -6,6 +6,7 @@ import hashlib import ipaddress +import types import typing from dataclasses import make_dataclass from typing import Any, Dict, Set, Union @@ -189,7 +190,7 @@ class ContentBuilder: else: return "application/json" - if typing.get_origin(payload_type) is typing.Union: + if typing.get_origin(payload_type) in (typing.Union, types.UnionType): media_types = [] item_types = [] for x in typing.get_args(payload_type): diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 929e7e4bc..30b37e98f 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -4,20 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from collections.abc import AsyncIterator from datetime import datetime from enum import Enum -from typing import ( - Annotated, - Any, - AsyncIterator, - Dict, - List, - Literal, - Optional, - Protocol, - Union, - runtime_checkable, -) +from typing import Annotated, Any, Literal, Protocol, runtime_checkable from pydantic import BaseModel, ConfigDict, Field @@ -79,8 +69,8 @@ class StepCommon(BaseModel): turn_id: str step_id: str - started_at: Optional[datetime] = None - completed_at: Optional[datetime] = None + started_at: datetime | None = None + completed_at: datetime | None = None class StepType(Enum): @@ -120,8 +110,8 @@ class ToolExecutionStep(StepCommon): """ step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value - tool_calls: List[ToolCall] - tool_responses: List[ToolResponse] + tool_calls: list[ToolCall] + tool_responses: list[ToolResponse] @json_schema_type @@ -132,7 +122,7 @@ class ShieldCallStep(StepCommon): """ step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value - violation: Optional[SafetyViolation] + violation: SafetyViolation | None @json_schema_type @@ -150,12 +140,7 @@ class MemoryRetrievalStep(StepCommon): Step = Annotated[ - Union[ - InferenceStep, - ToolExecutionStep, - ShieldCallStep, - MemoryRetrievalStep, - ], + InferenceStep | ToolExecutionStep | ShieldCallStep | MemoryRetrievalStep, Field(discriminator="step_type"), ] @@ -166,18 +151,13 @@ class Turn(BaseModel): turn_id: str session_id: str - input_messages: List[ - Union[ - UserMessage, - ToolResponseMessage, - ] - ] - steps: List[Step] + input_messages: list[UserMessage | ToolResponseMessage] + steps: list[Step] output_message: CompletionMessage - output_attachments: Optional[List[Attachment]] = Field(default_factory=list) + output_attachments: list[Attachment] | None = Field(default_factory=list) started_at: datetime - completed_at: Optional[datetime] = None + completed_at: datetime | None = None @json_schema_type @@ -186,34 +166,31 @@ class Session(BaseModel): session_id: str session_name: str - turns: List[Turn] + turns: list[Turn] started_at: datetime class AgentToolGroupWithArgs(BaseModel): name: str - args: Dict[str, Any] + args: dict[str, Any] -AgentToolGroup = Union[ - str, - AgentToolGroupWithArgs, -] +AgentToolGroup = str | AgentToolGroupWithArgs register_schema(AgentToolGroup, name="AgentTool") class AgentConfigCommon(BaseModel): - sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams) + sampling_params: SamplingParams | None = Field(default_factory=SamplingParams) - input_shields: Optional[List[str]] = Field(default_factory=list) - output_shields: Optional[List[str]] = Field(default_factory=list) - toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list) - client_tools: Optional[List[ToolDef]] = Field(default_factory=list) - tool_choice: Optional[ToolChoice] = Field(default=None, deprecated="use tool_config instead") - tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None, deprecated="use tool_config instead") - tool_config: Optional[ToolConfig] = Field(default=None) + input_shields: list[str] | None = Field(default_factory=list) + output_shields: list[str] | None = Field(default_factory=list) + toolgroups: list[AgentToolGroup] | None = Field(default_factory=list) + client_tools: list[ToolDef] | None = Field(default_factory=list) + tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead") + tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead") + tool_config: ToolConfig | None = Field(default=None) - max_infer_iters: Optional[int] = 10 + max_infer_iters: int | None = 10 def model_post_init(self, __context): if self.tool_config: @@ -243,9 +220,9 @@ class AgentConfig(AgentConfigCommon): model: str instructions: str - name: Optional[str] = None - enable_session_persistence: Optional[bool] = False - response_format: Optional[ResponseFormat] = None + name: str | None = None + enable_session_persistence: bool | None = False + response_format: ResponseFormat | None = None @json_schema_type @@ -257,16 +234,16 @@ class Agent(BaseModel): @json_schema_type class ListAgentsResponse(BaseModel): - data: List[Agent] + data: list[Agent] @json_schema_type class ListAgentSessionsResponse(BaseModel): - data: List[Session] + data: list[Session] class AgentConfigOverridablePerTurn(AgentConfigCommon): - instructions: Optional[str] = None + instructions: str | None = None class AgentTurnResponseEventType(Enum): @@ -284,7 +261,7 @@ class AgentTurnResponseStepStartPayload(BaseModel): event_type: Literal[AgentTurnResponseEventType.step_start.value] = AgentTurnResponseEventType.step_start.value step_type: StepType step_id: str - metadata: Optional[Dict[str, Any]] = Field(default_factory=dict) + metadata: dict[str, Any] | None = Field(default_factory=dict) @json_schema_type @@ -327,14 +304,12 @@ class AgentTurnResponseTurnAwaitingInputPayload(BaseModel): AgentTurnResponseEventPayload = Annotated[ - Union[ - AgentTurnResponseStepStartPayload, - AgentTurnResponseStepProgressPayload, - AgentTurnResponseStepCompletePayload, - AgentTurnResponseTurnStartPayload, - AgentTurnResponseTurnCompletePayload, - AgentTurnResponseTurnAwaitingInputPayload, - ], + AgentTurnResponseStepStartPayload + | AgentTurnResponseStepProgressPayload + | AgentTurnResponseStepCompletePayload + | AgentTurnResponseTurnStartPayload + | AgentTurnResponseTurnCompletePayload + | AgentTurnResponseTurnAwaitingInputPayload, Field(discriminator="event_type"), ] register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload") @@ -363,18 +338,13 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn): # TODO: figure out how we can simplify this and make why # ToolResponseMessage needs to be here (it is function call # execution from outside the system) - messages: List[ - Union[ - UserMessage, - ToolResponseMessage, - ] - ] + messages: list[UserMessage | ToolResponseMessage] - documents: Optional[List[Document]] = None - toolgroups: Optional[List[AgentToolGroup]] = None + documents: list[Document] | None = None + toolgroups: list[AgentToolGroup] | None = None - stream: Optional[bool] = False - tool_config: Optional[ToolConfig] = None + stream: bool | None = False + tool_config: ToolConfig | None = None @json_schema_type @@ -382,8 +352,8 @@ class AgentTurnResumeRequest(BaseModel): agent_id: str session_id: str turn_id: str - tool_responses: List[ToolResponse] - stream: Optional[bool] = False + tool_responses: list[ToolResponse] + stream: bool | None = False @json_schema_type @@ -429,17 +399,12 @@ class Agents(Protocol): self, agent_id: str, session_id: str, - messages: List[ - Union[ - UserMessage, - ToolResponseMessage, - ] - ], - stream: Optional[bool] = False, - documents: Optional[List[Document]] = None, - toolgroups: Optional[List[AgentToolGroup]] = None, - tool_config: Optional[ToolConfig] = None, - ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: + messages: list[UserMessage | ToolResponseMessage], + stream: bool | None = False, + documents: list[Document] | None = None, + toolgroups: list[AgentToolGroup] | None = None, + tool_config: ToolConfig | None = None, + ) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]: """Create a new turn for an agent. :param agent_id: The ID of the agent to create the turn for. @@ -463,9 +428,9 @@ class Agents(Protocol): agent_id: str, session_id: str, turn_id: str, - tool_responses: List[ToolResponse], - stream: Optional[bool] = False, - ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: + tool_responses: list[ToolResponse], + stream: bool | None = False, + ) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]: """Resume an agent turn with executed tool call responses. When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready. @@ -538,7 +503,7 @@ class Agents(Protocol): self, session_id: str, agent_id: str, - turn_ids: Optional[List[str]] = None, + turn_ids: list[str] | None = None, ) -> Session: """Retrieve an agent session by its ID. @@ -623,14 +588,14 @@ class Agents(Protocol): @webmethod(route="/openai/v1/responses", method="POST") async def create_openai_response( self, - input: Union[str, List[OpenAIResponseInputMessage]], + input: str | list[OpenAIResponseInputMessage], model: str, - previous_response_id: Optional[str] = None, - store: Optional[bool] = True, - stream: Optional[bool] = False, - temperature: Optional[float] = None, - tools: Optional[List[OpenAIResponseInputTool]] = None, - ) -> Union[OpenAIResponseObject, AsyncIterator[OpenAIResponseObjectStream]]: + previous_response_id: str | None = None, + store: bool | None = True, + stream: bool | None = False, + temperature: float | None = None, + tools: list[OpenAIResponseInputTool] | None = None, + ) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]: """Create a new OpenAI response. :param input: Input message(s) to create the response. diff --git a/llama_stack/apis/agents/openai_responses.py b/llama_stack/apis/agents/openai_responses.py index 72f16e224..8e11b2123 100644 --- a/llama_stack/apis/agents/openai_responses.py +++ b/llama_stack/apis/agents/openai_responses.py @@ -4,10 +4,9 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List, Literal, Optional, Union +from typing import Annotated, Literal from pydantic import BaseModel, Field -from typing_extensions import Annotated from llama_stack.schema_utils import json_schema_type, register_schema @@ -25,7 +24,7 @@ class OpenAIResponseOutputMessageContentOutputText(BaseModel): OpenAIResponseOutputMessageContent = Annotated[ - Union[OpenAIResponseOutputMessageContentOutputText,], + OpenAIResponseOutputMessageContentOutputText, Field(discriminator="type"), ] register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMessageContent") @@ -34,7 +33,7 @@ register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMe @json_schema_type class OpenAIResponseOutputMessage(BaseModel): id: str - content: List[OpenAIResponseOutputMessageContent] + content: list[OpenAIResponseOutputMessageContent] role: Literal["assistant"] = "assistant" status: str type: Literal["message"] = "message" @@ -48,10 +47,7 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel): OpenAIResponseOutput = Annotated[ - Union[ - OpenAIResponseOutputMessage, - OpenAIResponseOutputMessageWebSearchToolCall, - ], + OpenAIResponseOutputMessage | OpenAIResponseOutputMessageWebSearchToolCall, Field(discriminator="type"), ] register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput") @@ -60,18 +56,18 @@ register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput") @json_schema_type class OpenAIResponseObject(BaseModel): created_at: int - error: Optional[OpenAIResponseError] = None + error: OpenAIResponseError | None = None id: str model: str object: Literal["response"] = "response" - output: List[OpenAIResponseOutput] + output: list[OpenAIResponseOutput] parallel_tool_calls: bool = False - previous_response_id: Optional[str] = None + previous_response_id: str | None = None status: str - temperature: Optional[float] = None - top_p: Optional[float] = None - truncation: Optional[str] = None - user: Optional[str] = None + temperature: float | None = None + top_p: float | None = None + truncation: str | None = None + user: str | None = None @json_schema_type @@ -87,10 +83,7 @@ class OpenAIResponseObjectStreamResponseCompleted(BaseModel): OpenAIResponseObjectStream = Annotated[ - Union[ - OpenAIResponseObjectStreamResponseCreated, - OpenAIResponseObjectStreamResponseCompleted, - ], + OpenAIResponseObjectStreamResponseCreated | OpenAIResponseObjectStreamResponseCompleted, Field(discriminator="type"), ] register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream") @@ -107,12 +100,12 @@ class OpenAIResponseInputMessageContentImage(BaseModel): detail: Literal["low"] | Literal["high"] | Literal["auto"] = "auto" type: Literal["input_image"] = "input_image" # TODO: handle file_id - image_url: Optional[str] = None + image_url: str | None = None # TODO: handle file content types OpenAIResponseInputMessageContent = Annotated[ - Union[OpenAIResponseInputMessageContentText, OpenAIResponseInputMessageContentImage], + OpenAIResponseInputMessageContentText | OpenAIResponseInputMessageContentImage, Field(discriminator="type"), ] register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent") @@ -120,21 +113,21 @@ register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMess @json_schema_type class OpenAIResponseInputMessage(BaseModel): - content: Union[str, List[OpenAIResponseInputMessageContent]] + content: str | list[OpenAIResponseInputMessageContent] role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"] - type: Optional[Literal["message"]] = "message" + type: Literal["message"] | None = "message" @json_schema_type class OpenAIResponseInputToolWebSearch(BaseModel): type: Literal["web_search"] | Literal["web_search_preview_2025_03_11"] = "web_search" # TODO: actually use search_context_size somewhere... - search_context_size: Optional[str] = Field(default="medium", pattern="^low|medium|high$") + search_context_size: str | None = Field(default="medium", pattern="^low|medium|high$") # TODO: add user_location OpenAIResponseInputTool = Annotated[ - Union[OpenAIResponseInputToolWebSearch,], + OpenAIResponseInputToolWebSearch, Field(discriminator="type"), ] register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool") diff --git a/llama_stack/apis/batch_inference/batch_inference.py b/llama_stack/apis/batch_inference/batch_inference.py index 7a324128d..79bc73e4c 100644 --- a/llama_stack/apis/batch_inference/batch_inference.py +++ b/llama_stack/apis/batch_inference/batch_inference.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List, Optional, Protocol, runtime_checkable +from typing import Protocol, runtime_checkable from llama_stack.apis.common.job_types import Job from llama_stack.apis.inference import ( @@ -34,22 +34,22 @@ class BatchInference(Protocol): async def completion( self, model: str, - content_batch: List[InterleavedContent], - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - logprobs: Optional[LogProbConfig] = None, + content_batch: list[InterleavedContent], + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + logprobs: LogProbConfig | None = None, ) -> Job: ... @webmethod(route="/batch-inference/chat-completion", method="POST") async def chat_completion( self, model: str, - messages_batch: List[List[Message]], - sampling_params: Optional[SamplingParams] = None, + messages_batch: list[list[Message]], + sampling_params: SamplingParams | None = None, # zero-shot tool definitions as input to the model - tools: Optional[List[ToolDefinition]] = None, - tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = None, - response_format: Optional[ResponseFormat] = None, - logprobs: Optional[LogProbConfig] = None, + tools: list[ToolDefinition] | None = None, + tool_choice: ToolChoice | None = ToolChoice.auto, + tool_prompt_format: ToolPromptFormat | None = None, + response_format: ResponseFormat | None = None, + logprobs: LogProbConfig | None = None, ) -> Job: ... diff --git a/llama_stack/apis/benchmarks/benchmarks.py b/llama_stack/apis/benchmarks/benchmarks.py index 809af8868..1bba42d20 100644 --- a/llama_stack/apis/benchmarks/benchmarks.py +++ b/llama_stack/apis/benchmarks/benchmarks.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable +from typing import Any, Literal, Protocol, runtime_checkable from pydantic import BaseModel, Field @@ -13,8 +13,8 @@ from llama_stack.schema_utils import json_schema_type, webmethod class CommonBenchmarkFields(BaseModel): dataset_id: str - scoring_functions: List[str] - metadata: Dict[str, Any] = Field( + scoring_functions: list[str] + metadata: dict[str, Any] = Field( default_factory=dict, description="Metadata for this evaluation task", ) @@ -35,12 +35,12 @@ class Benchmark(CommonBenchmarkFields, Resource): class BenchmarkInput(CommonBenchmarkFields, BaseModel): benchmark_id: str - provider_id: Optional[str] = None - provider_benchmark_id: Optional[str] = None + provider_id: str | None = None + provider_benchmark_id: str | None = None class ListBenchmarksResponse(BaseModel): - data: List[Benchmark] + data: list[Benchmark] @runtime_checkable @@ -59,8 +59,8 @@ class Benchmarks(Protocol): self, benchmark_id: str, dataset_id: str, - scoring_functions: List[str], - provider_benchmark_id: Optional[str] = None, - provider_id: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, + scoring_functions: list[str], + provider_benchmark_id: str | None = None, + provider_id: str | None = None, + metadata: dict[str, Any] | None = None, ) -> None: ... diff --git a/llama_stack/apis/common/content_types.py b/llama_stack/apis/common/content_types.py index 9d4e21308..b9ef033dd 100644 --- a/llama_stack/apis/common/content_types.py +++ b/llama_stack/apis/common/content_types.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from enum import Enum -from typing import Annotated, List, Literal, Optional, Union +from typing import Annotated, Literal from pydantic import BaseModel, Field, model_validator @@ -26,9 +26,9 @@ class _URLOrData(BaseModel): :param data: base64 encoded image data as string """ - url: Optional[URL] = None + url: URL | None = None # data is a base64 encoded string, hint with contentEncoding=base64 - data: Optional[str] = Field(contentEncoding="base64", default=None) + data: str | None = Field(contentEncoding="base64", default=None) @model_validator(mode="before") @classmethod @@ -64,13 +64,13 @@ class TextContentItem(BaseModel): # other modalities can be added here InterleavedContentItem = Annotated[ - Union[ImageContentItem, TextContentItem], + ImageContentItem | TextContentItem, Field(discriminator="type"), ] register_schema(InterleavedContentItem, name="InterleavedContentItem") # accept a single "str" as a special case since it is common -InterleavedContent = Union[str, InterleavedContentItem, List[InterleavedContentItem]] +InterleavedContent = str | InterleavedContentItem | list[InterleavedContentItem] register_schema(InterleavedContent, name="InterleavedContent") @@ -100,13 +100,13 @@ class ToolCallDelta(BaseModel): # you either send an in-progress tool call so the client can stream a long # code generation or you send the final parsed tool call at the end of the # stream - tool_call: Union[str, ToolCall] + tool_call: str | ToolCall parse_status: ToolCallParseStatus # streaming completions send a stream of ContentDeltas ContentDelta = Annotated[ - Union[TextDelta, ImageDelta, ToolCallDelta], + TextDelta | ImageDelta | ToolCallDelta, Field(discriminator="type"), ] register_schema(ContentDelta, name="ContentDelta") diff --git a/llama_stack/apis/common/deployment_types.py b/llama_stack/apis/common/deployment_types.py index 83eea28a2..4d01d7ad1 100644 --- a/llama_stack/apis/common/deployment_types.py +++ b/llama_stack/apis/common/deployment_types.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from enum import Enum -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel @@ -25,6 +25,6 @@ class RestAPIMethod(Enum): class RestAPIExecutionConfig(BaseModel): url: URL method: RestAPIMethod - params: Optional[Dict[str, Any]] = None - headers: Optional[Dict[str, Any]] = None - body: Optional[Dict[str, Any]] = None + params: dict[str, Any] | None = None + headers: dict[str, Any] | None = None + body: dict[str, Any] | None = None diff --git a/llama_stack/apis/common/responses.py b/llama_stack/apis/common/responses.py index f9e9a4c31..b3bb5cb6b 100644 --- a/llama_stack/apis/common/responses.py +++ b/llama_stack/apis/common/responses.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List +from typing import Any from pydantic import BaseModel @@ -19,5 +19,5 @@ class PaginatedResponse(BaseModel): :param has_more: Whether there are more items available after this set """ - data: List[Dict[str, Any]] + data: list[dict[str, Any]] has_more: bool diff --git a/llama_stack/apis/common/training_types.py b/llama_stack/apis/common/training_types.py index d6c6c6919..46cd101af 100644 --- a/llama_stack/apis/common/training_types.py +++ b/llama_stack/apis/common/training_types.py @@ -5,7 +5,6 @@ # the root directory of this source tree. from datetime import datetime -from typing import Optional from pydantic import BaseModel @@ -27,4 +26,4 @@ class Checkpoint(BaseModel): epoch: int post_training_job_id: str path: str - training_metrics: Optional[PostTrainingMetric] = None + training_metrics: PostTrainingMetric | None = None diff --git a/llama_stack/apis/common/type_system.py b/llama_stack/apis/common/type_system.py index 5d9f000be..db4aab4c5 100644 --- a/llama_stack/apis/common/type_system.py +++ b/llama_stack/apis/common/type_system.py @@ -4,10 +4,9 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Literal, Union +from typing import Annotated, Literal from pydantic import BaseModel, Field -from typing_extensions import Annotated from llama_stack.schema_utils import json_schema_type, register_schema @@ -73,18 +72,16 @@ class DialogType(BaseModel): ParamType = Annotated[ - Union[ - StringType, - NumberType, - BooleanType, - ArrayType, - ObjectType, - JsonType, - UnionType, - ChatCompletionInputType, - CompletionInputType, - AgentTurnInputType, - ], + StringType + | NumberType + | BooleanType + | ArrayType + | ObjectType + | JsonType + | UnionType + | ChatCompletionInputType + | CompletionInputType + | AgentTurnInputType, Field(discriminator="type"), ] register_schema(ParamType, name="ParamType") diff --git a/llama_stack/apis/datasetio/datasetio.py b/llama_stack/apis/datasetio/datasetio.py index 6331882fb..6d160a043 100644 --- a/llama_stack/apis/datasetio/datasetio.py +++ b/llama_stack/apis/datasetio/datasetio.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Optional, Protocol, runtime_checkable +from typing import Any, Protocol, runtime_checkable from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.apis.datasets import Dataset @@ -24,8 +24,8 @@ class DatasetIO(Protocol): async def iterrows( self, dataset_id: str, - start_index: Optional[int] = None, - limit: Optional[int] = None, + start_index: int | None = None, + limit: int | None = None, ) -> PaginatedResponse: """Get a paginated list of rows from a dataset. @@ -44,4 +44,4 @@ class DatasetIO(Protocol): ... @webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST") - async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ... + async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: ... diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index 32ccde144..796217557 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from enum import Enum -from typing import Annotated, Any, Dict, List, Literal, Optional, Protocol, Union +from typing import Annotated, Any, Literal, Protocol from pydantic import BaseModel, Field @@ -81,11 +81,11 @@ class RowsDataSource(BaseModel): """ type: Literal["rows"] = "rows" - rows: List[Dict[str, Any]] + rows: list[dict[str, Any]] DataSource = Annotated[ - Union[URIDataSource, RowsDataSource], + URIDataSource | RowsDataSource, Field(discriminator="type"), ] register_schema(DataSource, name="DataSource") @@ -98,7 +98,7 @@ class CommonDatasetFields(BaseModel): purpose: DatasetPurpose source: DataSource - metadata: Dict[str, Any] = Field( + metadata: dict[str, Any] = Field( default_factory=dict, description="Any additional metadata for this dataset", ) @@ -122,7 +122,7 @@ class DatasetInput(CommonDatasetFields, BaseModel): class ListDatasetsResponse(BaseModel): - data: List[Dataset] + data: list[Dataset] class Datasets(Protocol): @@ -131,8 +131,8 @@ class Datasets(Protocol): self, purpose: DatasetPurpose, source: DataSource, - metadata: Optional[Dict[str, Any]] = None, - dataset_id: Optional[str] = None, + metadata: dict[str, Any] | None = None, + dataset_id: str | None = None, ) -> Dataset: """ Register a new dataset. diff --git a/llama_stack/apis/datatypes.py b/llama_stack/apis/datatypes.py index 25f3ab1ab..63a764725 100644 --- a/llama_stack/apis/datatypes.py +++ b/llama_stack/apis/datatypes.py @@ -5,7 +5,6 @@ # the root directory of this source tree. from enum import Enum -from typing import Optional from pydantic import BaseModel @@ -54,4 +53,4 @@ class Error(BaseModel): status: int title: str detail: str - instance: Optional[str] = None + instance: str | None = None diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index 0e5959c37..23ca89a94 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -4,10 +4,9 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Literal, Optional, Protocol, Union +from typing import Annotated, Any, Literal, Protocol from pydantic import BaseModel, Field -from typing_extensions import Annotated from llama_stack.apis.agents import AgentConfig from llama_stack.apis.common.job_types import Job @@ -29,7 +28,7 @@ class ModelCandidate(BaseModel): type: Literal["model"] = "model" model: str sampling_params: SamplingParams - system_message: Optional[SystemMessage] = None + system_message: SystemMessage | None = None @json_schema_type @@ -43,7 +42,7 @@ class AgentCandidate(BaseModel): config: AgentConfig -EvalCandidate = Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")] +EvalCandidate = Annotated[ModelCandidate | AgentCandidate, Field(discriminator="type")] register_schema(EvalCandidate, name="EvalCandidate") @@ -57,11 +56,11 @@ class BenchmarkConfig(BaseModel): """ eval_candidate: EvalCandidate - scoring_params: Dict[str, ScoringFnParams] = Field( + scoring_params: dict[str, ScoringFnParams] = Field( description="Map between scoring function id and parameters for each scoring function you want to run", default_factory=dict, ) - num_examples: Optional[int] = Field( + num_examples: int | None = Field( description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated", default=None, ) @@ -76,9 +75,9 @@ class EvaluateResponse(BaseModel): :param scores: The scores from the evaluation. """ - generations: List[Dict[str, Any]] + generations: list[dict[str, Any]] # each key in the dict is a scoring function name - scores: Dict[str, ScoringResult] + scores: dict[str, ScoringResult] class Eval(Protocol): @@ -101,8 +100,8 @@ class Eval(Protocol): async def evaluate_rows( self, benchmark_id: str, - input_rows: List[Dict[str, Any]], - scoring_functions: List[str], + input_rows: list[dict[str, Any]], + scoring_functions: list[str], benchmark_config: BenchmarkConfig, ) -> EvaluateResponse: """Evaluate a list of rows on a benchmark. diff --git a/llama_stack/apis/files/files.py b/llama_stack/apis/files/files.py index ef8b65829..4a9b49978 100644 --- a/llama_stack/apis/files/files.py +++ b/llama_stack/apis/files/files.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List, Optional, Protocol, runtime_checkable +from typing import Protocol, runtime_checkable from pydantic import BaseModel @@ -42,7 +42,7 @@ class ListBucketResponse(BaseModel): :param data: List of FileResponse entries """ - data: List[BucketResponse] + data: list[BucketResponse] @json_schema_type @@ -74,7 +74,7 @@ class ListFileResponse(BaseModel): :param data: List of FileResponse entries """ - data: List[FileResponse] + data: list[FileResponse] @runtime_checkable @@ -102,7 +102,7 @@ class Files(Protocol): async def upload_content_to_session( self, upload_id: str, - ) -> Optional[FileResponse]: + ) -> FileResponse | None: """ Upload file content to an existing upload session. On the server, request body will have the raw bytes that are uploaded. diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 309171f20..dbcd1d019 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -4,21 +4,18 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from collections.abc import AsyncIterator from enum import Enum from typing import ( + Annotated, Any, - AsyncIterator, - Dict, - List, Literal, - Optional, Protocol, - Union, runtime_checkable, ) from pydantic import BaseModel, Field, field_validator -from typing_extensions import Annotated, TypedDict +from typing_extensions import TypedDict from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem from llama_stack.apis.models import Model @@ -47,8 +44,8 @@ class GreedySamplingStrategy(BaseModel): @json_schema_type class TopPSamplingStrategy(BaseModel): type: Literal["top_p"] = "top_p" - temperature: Optional[float] = Field(..., gt=0.0) - top_p: Optional[float] = 0.95 + temperature: float | None = Field(..., gt=0.0) + top_p: float | None = 0.95 @json_schema_type @@ -58,7 +55,7 @@ class TopKSamplingStrategy(BaseModel): SamplingStrategy = Annotated[ - Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy], + GreedySamplingStrategy | TopPSamplingStrategy | TopKSamplingStrategy, Field(discriminator="type"), ] register_schema(SamplingStrategy, name="SamplingStrategy") @@ -79,9 +76,9 @@ class SamplingParams(BaseModel): strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy) - max_tokens: Optional[int] = 0 - repetition_penalty: Optional[float] = 1.0 - stop: Optional[List[str]] = None + max_tokens: int | None = 0 + repetition_penalty: float | None = 1.0 + stop: list[str] | None = None class LogProbConfig(BaseModel): @@ -90,7 +87,7 @@ class LogProbConfig(BaseModel): :param top_k: How many tokens (for each position) to return log probabilities for. """ - top_k: Optional[int] = 0 + top_k: int | None = 0 class QuantizationType(Enum): @@ -125,11 +122,11 @@ class Int4QuantizationConfig(BaseModel): """ type: Literal["int4_mixed"] = "int4_mixed" - scheme: Optional[str] = "int4_weight_int8_dynamic_activation" + scheme: str | None = "int4_weight_int8_dynamic_activation" QuantizationConfig = Annotated[ - Union[Bf16QuantizationConfig, Fp8QuantizationConfig, Int4QuantizationConfig], + Bf16QuantizationConfig | Fp8QuantizationConfig | Int4QuantizationConfig, Field(discriminator="type"), ] @@ -145,7 +142,7 @@ class UserMessage(BaseModel): role: Literal["user"] = "user" content: InterleavedContent - context: Optional[InterleavedContent] = None + context: InterleavedContent | None = None @json_schema_type @@ -190,16 +187,11 @@ class CompletionMessage(BaseModel): role: Literal["assistant"] = "assistant" content: InterleavedContent stop_reason: StopReason - tool_calls: Optional[List[ToolCall]] = Field(default_factory=list) + tool_calls: list[ToolCall] | None = Field(default_factory=list) Message = Annotated[ - Union[ - UserMessage, - SystemMessage, - ToolResponseMessage, - CompletionMessage, - ], + UserMessage | SystemMessage | ToolResponseMessage | CompletionMessage, Field(discriminator="role"), ] register_schema(Message, name="Message") @@ -208,9 +200,9 @@ register_schema(Message, name="Message") @json_schema_type class ToolResponse(BaseModel): call_id: str - tool_name: Union[BuiltinTool, str] + tool_name: BuiltinTool | str content: InterleavedContent - metadata: Optional[Dict[str, Any]] = None + metadata: dict[str, Any] | None = None @field_validator("tool_name", mode="before") @classmethod @@ -243,7 +235,7 @@ class TokenLogProbs(BaseModel): :param logprobs_by_token: Dictionary mapping tokens to their log probabilities """ - logprobs_by_token: Dict[str, float] + logprobs_by_token: dict[str, float] class ChatCompletionResponseEventType(Enum): @@ -271,8 +263,8 @@ class ChatCompletionResponseEvent(BaseModel): event_type: ChatCompletionResponseEventType delta: ContentDelta - logprobs: Optional[List[TokenLogProbs]] = None - stop_reason: Optional[StopReason] = None + logprobs: list[TokenLogProbs] | None = None + stop_reason: StopReason | None = None class ResponseFormatType(Enum): @@ -295,7 +287,7 @@ class JsonSchemaResponseFormat(BaseModel): """ type: Literal[ResponseFormatType.json_schema.value] = ResponseFormatType.json_schema.value - json_schema: Dict[str, Any] + json_schema: dict[str, Any] @json_schema_type @@ -307,11 +299,11 @@ class GrammarResponseFormat(BaseModel): """ type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value - bnf: Dict[str, Any] + bnf: dict[str, Any] ResponseFormat = Annotated[ - Union[JsonSchemaResponseFormat, GrammarResponseFormat], + JsonSchemaResponseFormat | GrammarResponseFormat, Field(discriminator="type"), ] register_schema(ResponseFormat, name="ResponseFormat") @@ -321,10 +313,10 @@ register_schema(ResponseFormat, name="ResponseFormat") class CompletionRequest(BaseModel): model: str content: InterleavedContent - sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams) - response_format: Optional[ResponseFormat] = None - stream: Optional[bool] = False - logprobs: Optional[LogProbConfig] = None + sampling_params: SamplingParams | None = Field(default_factory=SamplingParams) + response_format: ResponseFormat | None = None + stream: bool | None = False + logprobs: LogProbConfig | None = None @json_schema_type @@ -338,7 +330,7 @@ class CompletionResponse(MetricResponseMixin): content: str stop_reason: StopReason - logprobs: Optional[List[TokenLogProbs]] = None + logprobs: list[TokenLogProbs] | None = None @json_schema_type @@ -351,8 +343,8 @@ class CompletionResponseStreamChunk(MetricResponseMixin): """ delta: str - stop_reason: Optional[StopReason] = None - logprobs: Optional[List[TokenLogProbs]] = None + stop_reason: StopReason | None = None + logprobs: list[TokenLogProbs] | None = None class SystemMessageBehavior(Enum): @@ -383,9 +375,9 @@ class ToolConfig(BaseModel): '{{function_definitions}}' to indicate where the function definitions should be inserted. """ - tool_choice: Optional[ToolChoice | str] = Field(default=ToolChoice.auto) - tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None) - system_message_behavior: Optional[SystemMessageBehavior] = Field(default=SystemMessageBehavior.append) + tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto) + tool_prompt_format: ToolPromptFormat | None = Field(default=None) + system_message_behavior: SystemMessageBehavior | None = Field(default=SystemMessageBehavior.append) def model_post_init(self, __context: Any) -> None: if isinstance(self.tool_choice, str): @@ -399,15 +391,15 @@ class ToolConfig(BaseModel): @json_schema_type class ChatCompletionRequest(BaseModel): model: str - messages: List[Message] - sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams) + messages: list[Message] + sampling_params: SamplingParams | None = Field(default_factory=SamplingParams) - tools: Optional[List[ToolDefinition]] = Field(default_factory=list) - tool_config: Optional[ToolConfig] = Field(default_factory=ToolConfig) + tools: list[ToolDefinition] | None = Field(default_factory=list) + tool_config: ToolConfig | None = Field(default_factory=ToolConfig) - response_format: Optional[ResponseFormat] = None - stream: Optional[bool] = False - logprobs: Optional[LogProbConfig] = None + response_format: ResponseFormat | None = None + stream: bool | None = False + logprobs: LogProbConfig | None = None @json_schema_type @@ -429,7 +421,7 @@ class ChatCompletionResponse(MetricResponseMixin): """ completion_message: CompletionMessage - logprobs: Optional[List[TokenLogProbs]] = None + logprobs: list[TokenLogProbs] | None = None @json_schema_type @@ -439,7 +431,7 @@ class EmbeddingsResponse(BaseModel): :param embeddings: List of embedding vectors, one per input content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id} """ - embeddings: List[List[float]] + embeddings: list[list[float]] @json_schema_type @@ -451,7 +443,7 @@ class OpenAIChatCompletionContentPartTextParam(BaseModel): @json_schema_type class OpenAIImageURL(BaseModel): url: str - detail: Optional[str] = None + detail: str | None = None @json_schema_type @@ -461,16 +453,13 @@ class OpenAIChatCompletionContentPartImageParam(BaseModel): OpenAIChatCompletionContentPartParam = Annotated[ - Union[ - OpenAIChatCompletionContentPartTextParam, - OpenAIChatCompletionContentPartImageParam, - ], + OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam, Field(discriminator="type"), ] register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam") -OpenAIChatCompletionMessageContent = Union[str, List[OpenAIChatCompletionContentPartParam]] +OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam] @json_schema_type @@ -484,7 +473,7 @@ class OpenAIUserMessageParam(BaseModel): role: Literal["user"] = "user" content: OpenAIChatCompletionMessageContent - name: Optional[str] = None + name: str | None = None @json_schema_type @@ -498,21 +487,21 @@ class OpenAISystemMessageParam(BaseModel): role: Literal["system"] = "system" content: OpenAIChatCompletionMessageContent - name: Optional[str] = None + name: str | None = None @json_schema_type class OpenAIChatCompletionToolCallFunction(BaseModel): - name: Optional[str] = None - arguments: Optional[str] = None + name: str | None = None + arguments: str | None = None @json_schema_type class OpenAIChatCompletionToolCall(BaseModel): - index: Optional[int] = None - id: Optional[str] = None + index: int | None = None + id: str | None = None type: Literal["function"] = "function" - function: Optional[OpenAIChatCompletionToolCallFunction] = None + function: OpenAIChatCompletionToolCallFunction | None = None @json_schema_type @@ -526,9 +515,9 @@ class OpenAIAssistantMessageParam(BaseModel): """ role: Literal["assistant"] = "assistant" - content: Optional[OpenAIChatCompletionMessageContent] = None - name: Optional[str] = None - tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = None + content: OpenAIChatCompletionMessageContent | None = None + name: str | None = None + tool_calls: list[OpenAIChatCompletionToolCall] | None = None @json_schema_type @@ -556,17 +545,15 @@ class OpenAIDeveloperMessageParam(BaseModel): role: Literal["developer"] = "developer" content: OpenAIChatCompletionMessageContent - name: Optional[str] = None + name: str | None = None OpenAIMessageParam = Annotated[ - Union[ - OpenAIUserMessageParam, - OpenAISystemMessageParam, - OpenAIAssistantMessageParam, - OpenAIToolMessageParam, - OpenAIDeveloperMessageParam, - ], + OpenAIUserMessageParam + | OpenAISystemMessageParam + | OpenAIAssistantMessageParam + | OpenAIToolMessageParam + | OpenAIDeveloperMessageParam, Field(discriminator="role"), ] register_schema(OpenAIMessageParam, name="OpenAIMessageParam") @@ -580,14 +567,14 @@ class OpenAIResponseFormatText(BaseModel): @json_schema_type class OpenAIJSONSchema(TypedDict, total=False): name: str - description: Optional[str] = None - strict: Optional[bool] = None + description: str | None = None + strict: bool | None = None # Pydantic BaseModel cannot be used with a schema param, since it already # has one. And, we don't want to alias here because then have to handle # that alias when converting to OpenAI params. So, to support schema, # we use a TypedDict. - schema: Optional[Dict[str, Any]] = None + schema: dict[str, Any] | None = None @json_schema_type @@ -602,11 +589,7 @@ class OpenAIResponseFormatJSONObject(BaseModel): OpenAIResponseFormatParam = Annotated[ - Union[ - OpenAIResponseFormatText, - OpenAIResponseFormatJSONSchema, - OpenAIResponseFormatJSONObject, - ], + OpenAIResponseFormatText | OpenAIResponseFormatJSONSchema | OpenAIResponseFormatJSONObject, Field(discriminator="type"), ] register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam") @@ -622,7 +605,7 @@ class OpenAITopLogProb(BaseModel): """ token: str - bytes: Optional[List[int]] = None + bytes: list[int] | None = None logprob: float @@ -637,9 +620,9 @@ class OpenAITokenLogProb(BaseModel): """ token: str - bytes: Optional[List[int]] = None + bytes: list[int] | None = None logprob: float - top_logprobs: List[OpenAITopLogProb] + top_logprobs: list[OpenAITopLogProb] @json_schema_type @@ -650,8 +633,8 @@ class OpenAIChoiceLogprobs(BaseModel): :param refusal: (Optional) The log probabilities for the tokens in the message """ - content: Optional[List[OpenAITokenLogProb]] = None - refusal: Optional[List[OpenAITokenLogProb]] = None + content: list[OpenAITokenLogProb] | None = None + refusal: list[OpenAITokenLogProb] | None = None @json_schema_type @@ -664,10 +647,10 @@ class OpenAIChoiceDelta(BaseModel): :param tool_calls: (Optional) The tool calls of the delta """ - content: Optional[str] = None - refusal: Optional[str] = None - role: Optional[str] = None - tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = None + content: str | None = None + refusal: str | None = None + role: str | None = None + tool_calls: list[OpenAIChatCompletionToolCall] | None = None @json_schema_type @@ -683,7 +666,7 @@ class OpenAIChunkChoice(BaseModel): delta: OpenAIChoiceDelta finish_reason: str index: int - logprobs: Optional[OpenAIChoiceLogprobs] = None + logprobs: OpenAIChoiceLogprobs | None = None @json_schema_type @@ -699,7 +682,7 @@ class OpenAIChoice(BaseModel): message: OpenAIMessageParam finish_reason: str index: int - logprobs: Optional[OpenAIChoiceLogprobs] = None + logprobs: OpenAIChoiceLogprobs | None = None @json_schema_type @@ -714,7 +697,7 @@ class OpenAIChatCompletion(BaseModel): """ id: str - choices: List[OpenAIChoice] + choices: list[OpenAIChoice] object: Literal["chat.completion"] = "chat.completion" created: int model: str @@ -732,7 +715,7 @@ class OpenAIChatCompletionChunk(BaseModel): """ id: str - choices: List[OpenAIChunkChoice] + choices: list[OpenAIChunkChoice] object: Literal["chat.completion.chunk"] = "chat.completion.chunk" created: int model: str @@ -748,10 +731,10 @@ class OpenAICompletionLogprobs(BaseModel): :top_logprobs: (Optional) The top log probabilities for the tokens """ - text_offset: Optional[List[int]] = None - token_logprobs: Optional[List[float]] = None - tokens: Optional[List[str]] = None - top_logprobs: Optional[List[Dict[str, float]]] = None + text_offset: list[int] | None = None + token_logprobs: list[float] | None = None + tokens: list[str] | None = None + top_logprobs: list[dict[str, float]] | None = None @json_schema_type @@ -767,7 +750,7 @@ class OpenAICompletionChoice(BaseModel): finish_reason: str text: str index: int - logprobs: Optional[OpenAIChoiceLogprobs] = None + logprobs: OpenAIChoiceLogprobs | None = None @json_schema_type @@ -782,7 +765,7 @@ class OpenAICompletion(BaseModel): """ id: str - choices: List[OpenAICompletionChoice] + choices: list[OpenAICompletionChoice] created: int model: str object: Literal["text_completion"] = "text_completion" @@ -818,12 +801,12 @@ class EmbeddingTaskType(Enum): @json_schema_type class BatchCompletionResponse(BaseModel): - batch: List[CompletionResponse] + batch: list[CompletionResponse] @json_schema_type class BatchChatCompletionResponse(BaseModel): - batch: List[ChatCompletionResponse] + batch: list[ChatCompletionResponse] @runtime_checkable @@ -843,11 +826,11 @@ class Inference(Protocol): self, model_id: str, content: InterleavedContent, - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, - ) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, + ) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]: """Generate a completion for the given content using the specified model. :param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint. @@ -865,10 +848,10 @@ class Inference(Protocol): async def batch_completion( self, model_id: str, - content_batch: List[InterleavedContent], - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - logprobs: Optional[LogProbConfig] = None, + content_batch: list[InterleavedContent], + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + logprobs: LogProbConfig | None = None, ) -> BatchCompletionResponse: raise NotImplementedError("Batch completion is not implemented") @@ -876,16 +859,16 @@ class Inference(Protocol): async def chat_completion( self, model_id: str, - messages: List[Message], - sampling_params: Optional[SamplingParams] = None, - tools: Optional[List[ToolDefinition]] = None, - tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = None, - response_format: Optional[ResponseFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, - tool_config: Optional[ToolConfig] = None, - ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: + messages: list[Message], + sampling_params: SamplingParams | None = None, + tools: list[ToolDefinition] | None = None, + tool_choice: ToolChoice | None = ToolChoice.auto, + tool_prompt_format: ToolPromptFormat | None = None, + response_format: ResponseFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, + tool_config: ToolConfig | None = None, + ) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]: """Generate a chat completion for the given messages using the specified model. :param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint. @@ -916,12 +899,12 @@ class Inference(Protocol): async def batch_chat_completion( self, model_id: str, - messages_batch: List[List[Message]], - sampling_params: Optional[SamplingParams] = None, - tools: Optional[List[ToolDefinition]] = None, - tool_config: Optional[ToolConfig] = None, - response_format: Optional[ResponseFormat] = None, - logprobs: Optional[LogProbConfig] = None, + messages_batch: list[list[Message]], + sampling_params: SamplingParams | None = None, + tools: list[ToolDefinition] | None = None, + tool_config: ToolConfig | None = None, + response_format: ResponseFormat | None = None, + logprobs: LogProbConfig | None = None, ) -> BatchChatCompletionResponse: raise NotImplementedError("Batch chat completion is not implemented") @@ -929,10 +912,10 @@ class Inference(Protocol): async def embeddings( self, model_id: str, - contents: List[str] | List[InterleavedContentItem], - text_truncation: Optional[TextTruncation] = TextTruncation.none, - output_dimension: Optional[int] = None, - task_type: Optional[EmbeddingTaskType] = None, + contents: list[str] | list[InterleavedContentItem], + text_truncation: TextTruncation | None = TextTruncation.none, + output_dimension: int | None = None, + task_type: EmbeddingTaskType | None = None, ) -> EmbeddingsResponse: """Generate embeddings for content pieces using the specified model. @@ -950,25 +933,25 @@ class Inference(Protocol): self, # Standard OpenAI completion parameters model: str, - prompt: Union[str, List[str], List[int], List[List[int]]], - best_of: Optional[int] = None, - echo: Optional[bool] = None, - frequency_penalty: Optional[float] = None, - logit_bias: Optional[Dict[str, float]] = None, - logprobs: Optional[bool] = None, - max_tokens: Optional[int] = None, - n: Optional[int] = None, - presence_penalty: Optional[float] = None, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - stream: Optional[bool] = None, - stream_options: Optional[Dict[str, Any]] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - user: Optional[str] = None, + prompt: str | list[str] | list[int] | list[list[int]], + best_of: int | None = None, + echo: bool | None = None, + frequency_penalty: float | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_tokens: int | None = None, + n: int | None = None, + presence_penalty: float | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + top_p: float | None = None, + user: str | None = None, # vLLM-specific parameters - guided_choice: Optional[List[str]] = None, - prompt_logprobs: Optional[int] = None, + guided_choice: list[str] | None = None, + prompt_logprobs: int | None = None, ) -> OpenAICompletion: """Generate an OpenAI-compatible completion for the given prompt using the specified model. @@ -996,29 +979,29 @@ class Inference(Protocol): async def openai_chat_completion( self, model: str, - messages: List[OpenAIMessageParam], - frequency_penalty: Optional[float] = None, - function_call: Optional[Union[str, Dict[str, Any]]] = None, - functions: Optional[List[Dict[str, Any]]] = None, - logit_bias: Optional[Dict[str, float]] = None, - logprobs: Optional[bool] = None, - max_completion_tokens: Optional[int] = None, - max_tokens: Optional[int] = None, - n: Optional[int] = None, - parallel_tool_calls: Optional[bool] = None, - presence_penalty: Optional[float] = None, - response_format: Optional[OpenAIResponseFormatParam] = None, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - stream: Optional[bool] = None, - stream_options: Optional[Dict[str, Any]] = None, - temperature: Optional[float] = None, - tool_choice: Optional[Union[str, Dict[str, Any]]] = None, - tools: Optional[List[Dict[str, Any]]] = None, - top_logprobs: Optional[int] = None, - top_p: Optional[float] = None, - user: Optional[str] = None, - ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: + messages: list[OpenAIMessageParam], + frequency_penalty: float | None = None, + function_call: str | dict[str, Any] | None = None, + functions: list[dict[str, Any]] | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_completion_tokens: int | None = None, + max_tokens: int | None = None, + n: int | None = None, + parallel_tool_calls: bool | None = None, + presence_penalty: float | None = None, + response_format: OpenAIResponseFormatParam | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + tool_choice: str | dict[str, Any] | None = None, + tools: list[dict[str, Any]] | None = None, + top_logprobs: int | None = None, + top_p: float | None = None, + user: str | None = None, + ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: """Generate an OpenAI-compatible chat completion for the given messages using the specified model. :param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint. diff --git a/llama_stack/apis/inspect/inspect.py b/llama_stack/apis/inspect/inspect.py index 863f90e14..fb3167635 100644 --- a/llama_stack/apis/inspect/inspect.py +++ b/llama_stack/apis/inspect/inspect.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List, Protocol, runtime_checkable +from typing import Protocol, runtime_checkable from pydantic import BaseModel @@ -16,7 +16,7 @@ from llama_stack.schema_utils import json_schema_type, webmethod class RouteInfo(BaseModel): route: str method: str - provider_types: List[str] + provider_types: list[str] @json_schema_type @@ -30,7 +30,7 @@ class VersionInfo(BaseModel): class ListRoutesResponse(BaseModel): - data: List[RouteInfo] + data: list[RouteInfo] @runtime_checkable diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index 97398ce75..5d7b5aac6 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable +from typing import Any, Literal, Protocol, runtime_checkable from pydantic import BaseModel, ConfigDict, Field @@ -15,7 +15,7 @@ from llama_stack.schema_utils import json_schema_type, webmethod class CommonModelFields(BaseModel): - metadata: Dict[str, Any] = Field( + metadata: dict[str, Any] = Field( default_factory=dict, description="Any additional metadata for this model", ) @@ -46,14 +46,14 @@ class Model(CommonModelFields, Resource): class ModelInput(CommonModelFields): model_id: str - provider_id: Optional[str] = None - provider_model_id: Optional[str] = None - model_type: Optional[ModelType] = ModelType.llm + provider_id: str | None = None + provider_model_id: str | None = None + model_type: ModelType | None = ModelType.llm model_config = ConfigDict(protected_namespaces=()) class ListModelsResponse(BaseModel): - data: List[Model] + data: list[Model] @json_schema_type @@ -73,7 +73,7 @@ class OpenAIModel(BaseModel): class OpenAIListModelsResponse(BaseModel): - data: List[OpenAIModel] + data: list[OpenAIModel] @runtime_checkable @@ -95,10 +95,10 @@ class Models(Protocol): async def register_model( self, model_id: str, - provider_model_id: Optional[str] = None, - provider_id: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, - model_type: Optional[ModelType] = None, + provider_model_id: str | None = None, + provider_id: str | None = None, + metadata: dict[str, Any] | None = None, + model_type: ModelType | None = None, ) -> Model: ... @webmethod(route="/models/{model_id:path}", method="DELETE") diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index e5f1bcb65..016f79fce 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -6,10 +6,9 @@ from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Protocol, Union +from typing import Annotated, Any, Literal, Protocol from pydantic import BaseModel, Field -from typing_extensions import Annotated from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.job_types import JobStatus @@ -36,9 +35,9 @@ class DataConfig(BaseModel): batch_size: int shuffle: bool data_format: DatasetFormat - validation_dataset_id: Optional[str] = None - packed: Optional[bool] = False - train_on_input: Optional[bool] = False + validation_dataset_id: str | None = None + packed: bool | None = False + train_on_input: bool | None = False @json_schema_type @@ -51,10 +50,10 @@ class OptimizerConfig(BaseModel): @json_schema_type class EfficiencyConfig(BaseModel): - enable_activation_checkpointing: Optional[bool] = False - enable_activation_offloading: Optional[bool] = False - memory_efficient_fsdp_wrap: Optional[bool] = False - fsdp_cpu_offload: Optional[bool] = False + enable_activation_checkpointing: bool | None = False + enable_activation_offloading: bool | None = False + memory_efficient_fsdp_wrap: bool | None = False + fsdp_cpu_offload: bool | None = False @json_schema_type @@ -62,23 +61,23 @@ class TrainingConfig(BaseModel): n_epochs: int max_steps_per_epoch: int = 1 gradient_accumulation_steps: int = 1 - max_validation_steps: Optional[int] = 1 - data_config: Optional[DataConfig] = None - optimizer_config: Optional[OptimizerConfig] = None - efficiency_config: Optional[EfficiencyConfig] = None - dtype: Optional[str] = "bf16" + max_validation_steps: int | None = 1 + data_config: DataConfig | None = None + optimizer_config: OptimizerConfig | None = None + efficiency_config: EfficiencyConfig | None = None + dtype: str | None = "bf16" @json_schema_type class LoraFinetuningConfig(BaseModel): type: Literal["LoRA"] = "LoRA" - lora_attn_modules: List[str] + lora_attn_modules: list[str] apply_lora_to_mlp: bool apply_lora_to_output: bool rank: int alpha: int - use_dora: Optional[bool] = False - quantize_base: Optional[bool] = False + use_dora: bool | None = False + quantize_base: bool | None = False @json_schema_type @@ -88,7 +87,7 @@ class QATFinetuningConfig(BaseModel): group_size: int -AlgorithmConfig = Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")] +AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")] register_schema(AlgorithmConfig, name="AlgorithmConfig") @@ -97,7 +96,7 @@ class PostTrainingJobLogStream(BaseModel): """Stream of logs from a finetuning job.""" job_uuid: str - log_lines: List[str] + log_lines: list[str] @json_schema_type @@ -131,8 +130,8 @@ class PostTrainingRLHFRequest(BaseModel): training_config: TrainingConfig # TODO: define these - hyperparam_search_config: Dict[str, Any] - logger_config: Dict[str, Any] + hyperparam_search_config: dict[str, Any] + logger_config: dict[str, Any] class PostTrainingJob(BaseModel): @@ -146,17 +145,17 @@ class PostTrainingJobStatusResponse(BaseModel): job_uuid: str status: JobStatus - scheduled_at: Optional[datetime] = None - started_at: Optional[datetime] = None - completed_at: Optional[datetime] = None + scheduled_at: datetime | None = None + started_at: datetime | None = None + completed_at: datetime | None = None - resources_allocated: Optional[Dict[str, Any]] = None + resources_allocated: dict[str, Any] | None = None - checkpoints: List[Checkpoint] = Field(default_factory=list) + checkpoints: list[Checkpoint] = Field(default_factory=list) class ListPostTrainingJobsResponse(BaseModel): - data: List[PostTrainingJob] + data: list[PostTrainingJob] @json_schema_type @@ -164,7 +163,7 @@ class PostTrainingJobArtifactsResponse(BaseModel): """Artifacts of a finetuning job.""" job_uuid: str - checkpoints: List[Checkpoint] = Field(default_factory=list) + checkpoints: list[Checkpoint] = Field(default_factory=list) # TODO(ashwin): metrics, evals @@ -175,14 +174,14 @@ class PostTraining(Protocol): self, job_uuid: str, training_config: TrainingConfig, - hyperparam_search_config: Dict[str, Any], - logger_config: Dict[str, Any], - model: Optional[str] = Field( + hyperparam_search_config: dict[str, Any], + logger_config: dict[str, Any], + model: str | None = Field( default=None, description="Model descriptor for training if not in provider config`", ), - checkpoint_dir: Optional[str] = None, - algorithm_config: Optional[AlgorithmConfig] = None, + checkpoint_dir: str | None = None, + algorithm_config: AlgorithmConfig | None = None, ) -> PostTrainingJob: ... @webmethod(route="/post-training/preference-optimize", method="POST") @@ -192,8 +191,8 @@ class PostTraining(Protocol): finetuned_model: str, algorithm_config: DPOAlignmentConfig, training_config: TrainingConfig, - hyperparam_search_config: Dict[str, Any], - logger_config: Dict[str, Any], + hyperparam_search_config: dict[str, Any], + logger_config: dict[str, Any], ) -> PostTrainingJob: ... @webmethod(route="/post-training/jobs", method="GET") diff --git a/llama_stack/apis/providers/providers.py b/llama_stack/apis/providers/providers.py index ea5f968ec..751c9263b 100644 --- a/llama_stack/apis/providers/providers.py +++ b/llama_stack/apis/providers/providers.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Protocol, runtime_checkable +from typing import Any, Protocol, runtime_checkable from pydantic import BaseModel @@ -17,12 +17,12 @@ class ProviderInfo(BaseModel): api: str provider_id: str provider_type: str - config: Dict[str, Any] + config: dict[str, Any] health: HealthResponse class ListProvidersResponse(BaseModel): - data: List[ProviderInfo] + data: list[ProviderInfo] @runtime_checkable diff --git a/llama_stack/apis/safety/safety.py b/llama_stack/apis/safety/safety.py index fd2f0292c..e139f5ffc 100644 --- a/llama_stack/apis/safety/safety.py +++ b/llama_stack/apis/safety/safety.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from enum import Enum -from typing import Any, Dict, List, Optional, Protocol, runtime_checkable +from typing import Any, Protocol, runtime_checkable from pydantic import BaseModel, Field @@ -27,16 +27,16 @@ class SafetyViolation(BaseModel): violation_level: ViolationLevel # what message should you convey to the user - user_message: Optional[str] = None + user_message: str | None = None # additional metadata (including specific violation codes) more for # debugging, telemetry - metadata: Dict[str, Any] = Field(default_factory=dict) + metadata: dict[str, Any] = Field(default_factory=dict) @json_schema_type class RunShieldResponse(BaseModel): - violation: Optional[SafetyViolation] = None + violation: SafetyViolation | None = None class ShieldStore(Protocol): @@ -52,6 +52,6 @@ class Safety(Protocol): async def run_shield( self, shield_id: str, - messages: List[Message], - params: Dict[str, Any] = None, + messages: list[Message], + params: dict[str, Any] = None, ) -> RunShieldResponse: ... diff --git a/llama_stack/apis/scoring/scoring.py b/llama_stack/apis/scoring/scoring.py index 54a9ac2aa..414f3d5e2 100644 --- a/llama_stack/apis/scoring/scoring.py +++ b/llama_stack/apis/scoring/scoring.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Optional, Protocol, runtime_checkable +from typing import Any, Protocol, runtime_checkable from pydantic import BaseModel @@ -12,7 +12,7 @@ from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams from llama_stack.schema_utils import json_schema_type, webmethod # mapping of metric to value -ScoringResultRow = Dict[str, Any] +ScoringResultRow = dict[str, Any] @json_schema_type @@ -24,15 +24,15 @@ class ScoringResult(BaseModel): :param aggregated_results: Map of metric name to aggregated value """ - score_rows: List[ScoringResultRow] + score_rows: list[ScoringResultRow] # aggregated metrics to value - aggregated_results: Dict[str, Any] + aggregated_results: dict[str, Any] @json_schema_type class ScoreBatchResponse(BaseModel): - dataset_id: Optional[str] = None - results: Dict[str, ScoringResult] + dataset_id: str | None = None + results: dict[str, ScoringResult] @json_schema_type @@ -44,7 +44,7 @@ class ScoreResponse(BaseModel): """ # each key in the dict is a scoring function name - results: Dict[str, ScoringResult] + results: dict[str, ScoringResult] class ScoringFunctionStore(Protocol): @@ -59,15 +59,15 @@ class Scoring(Protocol): async def score_batch( self, dataset_id: str, - scoring_functions: Dict[str, Optional[ScoringFnParams]], + scoring_functions: dict[str, ScoringFnParams | None], save_results_dataset: bool = False, ) -> ScoreBatchResponse: ... @webmethod(route="/scoring/score", method="POST") async def score( self, - input_rows: List[Dict[str, Any]], - scoring_functions: Dict[str, Optional[ScoringFnParams]], + input_rows: list[dict[str, Any]], + scoring_functions: dict[str, ScoringFnParams | None], ) -> ScoreResponse: """Score a list of rows. diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index 4f85947dd..6c7819965 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -6,18 +6,14 @@ from enum import Enum from typing import ( + Annotated, Any, - Dict, - List, Literal, - Optional, Protocol, - Union, runtime_checkable, ) from pydantic import BaseModel, Field -from typing_extensions import Annotated from llama_stack.apis.common.type_system import ParamType from llama_stack.apis.resource import Resource, ResourceType @@ -46,12 +42,12 @@ class AggregationFunctionType(Enum): class LLMAsJudgeScoringFnParams(BaseModel): type: Literal[ScoringFnParamsType.llm_as_judge.value] = ScoringFnParamsType.llm_as_judge.value judge_model: str - prompt_template: Optional[str] = None - judge_score_regexes: Optional[List[str]] = Field( + prompt_template: str | None = None + judge_score_regexes: list[str] | None = Field( description="Regexes to extract the answer from generated response", default_factory=list, ) - aggregation_functions: Optional[List[AggregationFunctionType]] = Field( + aggregation_functions: list[AggregationFunctionType] | None = Field( description="Aggregation functions to apply to the scores of each row", default_factory=list, ) @@ -60,11 +56,11 @@ class LLMAsJudgeScoringFnParams(BaseModel): @json_schema_type class RegexParserScoringFnParams(BaseModel): type: Literal[ScoringFnParamsType.regex_parser.value] = ScoringFnParamsType.regex_parser.value - parsing_regexes: Optional[List[str]] = Field( + parsing_regexes: list[str] | None = Field( description="Regex to extract the answer from generated response", default_factory=list, ) - aggregation_functions: Optional[List[AggregationFunctionType]] = Field( + aggregation_functions: list[AggregationFunctionType] | None = Field( description="Aggregation functions to apply to the scores of each row", default_factory=list, ) @@ -73,33 +69,29 @@ class RegexParserScoringFnParams(BaseModel): @json_schema_type class BasicScoringFnParams(BaseModel): type: Literal[ScoringFnParamsType.basic.value] = ScoringFnParamsType.basic.value - aggregation_functions: Optional[List[AggregationFunctionType]] = Field( + aggregation_functions: list[AggregationFunctionType] | None = Field( description="Aggregation functions to apply to the scores of each row", default_factory=list, ) ScoringFnParams = Annotated[ - Union[ - LLMAsJudgeScoringFnParams, - RegexParserScoringFnParams, - BasicScoringFnParams, - ], + LLMAsJudgeScoringFnParams | RegexParserScoringFnParams | BasicScoringFnParams, Field(discriminator="type"), ] register_schema(ScoringFnParams, name="ScoringFnParams") class CommonScoringFnFields(BaseModel): - description: Optional[str] = None - metadata: Dict[str, Any] = Field( + description: str | None = None + metadata: dict[str, Any] = Field( default_factory=dict, description="Any additional metadata for this definition", ) return_type: ParamType = Field( description="The return type of the deterministic function", ) - params: Optional[ScoringFnParams] = Field( + params: ScoringFnParams | None = Field( description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval", default=None, ) @@ -120,12 +112,12 @@ class ScoringFn(CommonScoringFnFields, Resource): class ScoringFnInput(CommonScoringFnFields, BaseModel): scoring_fn_id: str - provider_id: Optional[str] = None - provider_scoring_fn_id: Optional[str] = None + provider_id: str | None = None + provider_scoring_fn_id: str | None = None class ListScoringFunctionsResponse(BaseModel): - data: List[ScoringFn] + data: list[ScoringFn] @runtime_checkable @@ -142,7 +134,7 @@ class ScoringFunctions(Protocol): scoring_fn_id: str, description: str, return_type: ParamType, - provider_scoring_fn_id: Optional[str] = None, - provider_id: Optional[str] = None, - params: Optional[ScoringFnParams] = None, + provider_scoring_fn_id: str | None = None, + provider_id: str | None = None, + params: ScoringFnParams | None = None, ) -> None: ... diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index 67f3bd27b..4172fcbd1 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable +from typing import Any, Literal, Protocol, runtime_checkable from pydantic import BaseModel @@ -14,7 +14,7 @@ from llama_stack.schema_utils import json_schema_type, webmethod class CommonShieldFields(BaseModel): - params: Optional[Dict[str, Any]] = None + params: dict[str, Any] | None = None @json_schema_type @@ -34,12 +34,12 @@ class Shield(CommonShieldFields, Resource): class ShieldInput(CommonShieldFields): shield_id: str - provider_id: Optional[str] = None - provider_shield_id: Optional[str] = None + provider_id: str | None = None + provider_shield_id: str | None = None class ListShieldsResponse(BaseModel): - data: List[Shield] + data: list[Shield] @runtime_checkable @@ -55,7 +55,7 @@ class Shields(Protocol): async def register_shield( self, shield_id: str, - provider_shield_id: Optional[str] = None, - provider_id: Optional[str] = None, - params: Optional[Dict[str, Any]] = None, + provider_shield_id: str | None = None, + provider_id: str | None = None, + params: dict[str, Any] | None = None, ) -> Shield: ... diff --git a/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py b/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py index 7b41192af..91e550da9 100644 --- a/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py +++ b/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from enum import Enum -from typing import Any, Dict, List, Optional, Protocol, Union +from typing import Any, Protocol from pydantic import BaseModel @@ -28,24 +28,24 @@ class FilteringFunction(Enum): class SyntheticDataGenerationRequest(BaseModel): """Request to generate synthetic data. A small batch of prompts and a filtering function""" - dialogs: List[Message] + dialogs: list[Message] filtering_function: FilteringFunction = FilteringFunction.none - model: Optional[str] = None + model: str | None = None @json_schema_type class SyntheticDataGenerationResponse(BaseModel): """Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold.""" - synthetic_data: List[Dict[str, Any]] - statistics: Optional[Dict[str, Any]] = None + synthetic_data: list[dict[str, Any]] + statistics: dict[str, Any] | None = None class SyntheticDataGeneration(Protocol): @webmethod(route="/synthetic-data-generation/generate") def synthetic_data_generate( self, - dialogs: List[Message], + dialogs: list[Message], filtering_function: FilteringFunction = FilteringFunction.none, - model: Optional[str] = None, - ) -> Union[SyntheticDataGenerationResponse]: ... + model: str | None = None, + ) -> SyntheticDataGenerationResponse: ... diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index d57c311b2..af0469d2b 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -7,18 +7,14 @@ from datetime import datetime from enum import Enum from typing import ( + Annotated, Any, - Dict, - List, Literal, - Optional, Protocol, - Union, runtime_checkable, ) from pydantic import BaseModel, Field -from typing_extensions import Annotated from llama_stack.models.llama.datatypes import Primitive from llama_stack.schema_utils import json_schema_type, register_schema, webmethod @@ -37,11 +33,11 @@ class SpanStatus(Enum): class Span(BaseModel): span_id: str trace_id: str - parent_span_id: Optional[str] = None + parent_span_id: str | None = None name: str start_time: datetime - end_time: Optional[datetime] = None - attributes: Optional[Dict[str, Any]] = Field(default_factory=dict) + end_time: datetime | None = None + attributes: dict[str, Any] | None = Field(default_factory=dict) def set_attribute(self, key: str, value: Any): if self.attributes is None: @@ -54,7 +50,7 @@ class Trace(BaseModel): trace_id: str root_span_id: str start_time: datetime - end_time: Optional[datetime] = None + end_time: datetime | None = None @json_schema_type @@ -78,7 +74,7 @@ class EventCommon(BaseModel): trace_id: str span_id: str timestamp: datetime - attributes: Optional[Dict[str, Primitive]] = Field(default_factory=dict) + attributes: dict[str, Primitive] | None = Field(default_factory=dict) @json_schema_type @@ -92,15 +88,15 @@ class UnstructuredLogEvent(EventCommon): class MetricEvent(EventCommon): type: Literal[EventType.METRIC.value] = EventType.METRIC.value metric: str # this would be an enum - value: Union[int, float] + value: int | float unit: str @json_schema_type class MetricInResponse(BaseModel): metric: str - value: Union[int, float] - unit: Optional[str] = None + value: int | float + unit: str | None = None # This is a short term solution to allow inference API to return metrics @@ -124,7 +120,7 @@ class MetricInResponse(BaseModel): class MetricResponseMixin(BaseModel): - metrics: Optional[List[MetricInResponse]] = None + metrics: list[MetricInResponse] | None = None @json_schema_type @@ -137,7 +133,7 @@ class StructuredLogType(Enum): class SpanStartPayload(BaseModel): type: Literal[StructuredLogType.SPAN_START.value] = StructuredLogType.SPAN_START.value name: str - parent_span_id: Optional[str] = None + parent_span_id: str | None = None @json_schema_type @@ -147,10 +143,7 @@ class SpanEndPayload(BaseModel): StructuredLogPayload = Annotated[ - Union[ - SpanStartPayload, - SpanEndPayload, - ], + SpanStartPayload | SpanEndPayload, Field(discriminator="type"), ] register_schema(StructuredLogPayload, name="StructuredLogPayload") @@ -163,11 +156,7 @@ class StructuredLogEvent(EventCommon): Event = Annotated[ - Union[ - UnstructuredLogEvent, - MetricEvent, - StructuredLogEvent, - ], + UnstructuredLogEvent | MetricEvent | StructuredLogEvent, Field(discriminator="type"), ] register_schema(Event, name="Event") @@ -184,7 +173,7 @@ class EvalTrace(BaseModel): @json_schema_type class SpanWithStatus(Span): - status: Optional[SpanStatus] = None + status: SpanStatus | None = None @json_schema_type @@ -203,15 +192,15 @@ class QueryCondition(BaseModel): class QueryTracesResponse(BaseModel): - data: List[Trace] + data: list[Trace] class QuerySpansResponse(BaseModel): - data: List[Span] + data: list[Span] class QuerySpanTreeResponse(BaseModel): - data: Dict[str, SpanWithStatus] + data: dict[str, SpanWithStatus] @runtime_checkable @@ -222,10 +211,10 @@ class Telemetry(Protocol): @webmethod(route="/telemetry/traces", method="POST") async def query_traces( self, - attribute_filters: Optional[List[QueryCondition]] = None, - limit: Optional[int] = 100, - offset: Optional[int] = 0, - order_by: Optional[List[str]] = None, + attribute_filters: list[QueryCondition] | None = None, + limit: int | None = 100, + offset: int | None = 0, + order_by: list[str] | None = None, ) -> QueryTracesResponse: ... @webmethod(route="/telemetry/traces/{trace_id:path}", method="GET") @@ -238,23 +227,23 @@ class Telemetry(Protocol): async def get_span_tree( self, span_id: str, - attributes_to_return: Optional[List[str]] = None, - max_depth: Optional[int] = None, + attributes_to_return: list[str] | None = None, + max_depth: int | None = None, ) -> QuerySpanTreeResponse: ... @webmethod(route="/telemetry/spans", method="POST") async def query_spans( self, - attribute_filters: List[QueryCondition], - attributes_to_return: List[str], - max_depth: Optional[int] = None, + attribute_filters: list[QueryCondition], + attributes_to_return: list[str], + max_depth: int | None = None, ) -> QuerySpansResponse: ... @webmethod(route="/telemetry/spans/export", method="POST") async def save_spans_to_dataset( self, - attribute_filters: List[QueryCondition], - attributes_to_save: List[str], + attribute_filters: list[QueryCondition], + attributes_to_save: list[str], dataset_id: str, - max_depth: Optional[int] = None, + max_depth: int | None = None, ) -> None: ... diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py index 73b36e050..fdf199b1a 100644 --- a/llama_stack/apis/tools/rag_tool.py +++ b/llama_stack/apis/tools/rag_tool.py @@ -5,10 +5,10 @@ # the root directory of this source tree. from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Annotated, Any, Literal from pydantic import BaseModel, Field -from typing_extensions import Annotated, Protocol, runtime_checkable +from typing_extensions import Protocol, runtime_checkable from llama_stack.apis.common.content_types import URL, InterleavedContent from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol @@ -29,13 +29,13 @@ class RAGDocument(BaseModel): document_id: str content: InterleavedContent | URL mime_type: str | None = None - metadata: Dict[str, Any] = Field(default_factory=dict) + metadata: dict[str, Any] = Field(default_factory=dict) @json_schema_type class RAGQueryResult(BaseModel): - content: Optional[InterleavedContent] = None - metadata: Dict[str, Any] = Field(default_factory=dict) + content: InterleavedContent | None = None + metadata: dict[str, Any] = Field(default_factory=dict) @json_schema_type @@ -59,10 +59,7 @@ class LLMRAGQueryGeneratorConfig(BaseModel): RAGQueryGeneratorConfig = Annotated[ - Union[ - DefaultRAGQueryGeneratorConfig, - LLMRAGQueryGeneratorConfig, - ], + DefaultRAGQueryGeneratorConfig | LLMRAGQueryGeneratorConfig, Field(discriminator="type"), ] register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig") @@ -83,7 +80,7 @@ class RAGToolRuntime(Protocol): @webmethod(route="/tool-runtime/rag-tool/insert", method="POST") async def insert( self, - documents: List[RAGDocument], + documents: list[RAGDocument], vector_db_id: str, chunk_size_in_tokens: int = 512, ) -> None: @@ -94,8 +91,8 @@ class RAGToolRuntime(Protocol): async def query( self, content: InterleavedContent, - vector_db_ids: List[str], - query_config: Optional[RAGQueryConfig] = None, + vector_db_ids: list[str], + query_config: RAGQueryConfig | None = None, ) -> RAGQueryResult: """Query the RAG system for context; typically invoked by the agent""" ... diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index 4ca72f71d..eda627932 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from enum import Enum -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Literal from pydantic import BaseModel, Field from typing_extensions import Protocol, runtime_checkable @@ -24,7 +24,7 @@ class ToolParameter(BaseModel): parameter_type: str description: str required: bool = Field(default=True) - default: Optional[Any] = None + default: Any | None = None @json_schema_type @@ -40,39 +40,39 @@ class Tool(Resource): toolgroup_id: str tool_host: ToolHost description: str - parameters: List[ToolParameter] - metadata: Optional[Dict[str, Any]] = None + parameters: list[ToolParameter] + metadata: dict[str, Any] | None = None @json_schema_type class ToolDef(BaseModel): name: str - description: Optional[str] = None - parameters: Optional[List[ToolParameter]] = None - metadata: Optional[Dict[str, Any]] = None + description: str | None = None + parameters: list[ToolParameter] | None = None + metadata: dict[str, Any] | None = None @json_schema_type class ToolGroupInput(BaseModel): toolgroup_id: str provider_id: str - args: Optional[Dict[str, Any]] = None - mcp_endpoint: Optional[URL] = None + args: dict[str, Any] | None = None + mcp_endpoint: URL | None = None @json_schema_type class ToolGroup(Resource): type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value - mcp_endpoint: Optional[URL] = None - args: Optional[Dict[str, Any]] = None + mcp_endpoint: URL | None = None + args: dict[str, Any] | None = None @json_schema_type class ToolInvocationResult(BaseModel): - content: Optional[InterleavedContent] = None - error_message: Optional[str] = None - error_code: Optional[int] = None - metadata: Optional[Dict[str, Any]] = None + content: InterleavedContent | None = None + error_message: str | None = None + error_code: int | None = None + metadata: dict[str, Any] | None = None class ToolStore(Protocol): @@ -81,11 +81,11 @@ class ToolStore(Protocol): class ListToolGroupsResponse(BaseModel): - data: List[ToolGroup] + data: list[ToolGroup] class ListToolsResponse(BaseModel): - data: List[Tool] + data: list[Tool] class ListToolDefsResponse(BaseModel): @@ -100,8 +100,8 @@ class ToolGroups(Protocol): self, toolgroup_id: str, provider_id: str, - mcp_endpoint: Optional[URL] = None, - args: Optional[Dict[str, Any]] = None, + mcp_endpoint: URL | None = None, + args: dict[str, Any] | None = None, ) -> None: """Register a tool group""" ... @@ -118,7 +118,7 @@ class ToolGroups(Protocol): ... @webmethod(route="/tools", method="GET") - async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse: + async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse: """List tools with optional tool group""" ... @@ -151,10 +151,10 @@ class ToolRuntime(Protocol): # TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed. @webmethod(route="/tool-runtime/list-tools", method="GET") async def list_runtime_tools( - self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None ) -> ListToolDefsResponse: ... @webmethod(route="/tool-runtime/invoke", method="POST") - async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult: + async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult: """Run a tool with the given arguments""" ... diff --git a/llama_stack/apis/vector_dbs/vector_dbs.py b/llama_stack/apis/vector_dbs/vector_dbs.py index fe6c33919..6224566cd 100644 --- a/llama_stack/apis/vector_dbs/vector_dbs.py +++ b/llama_stack/apis/vector_dbs/vector_dbs.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List, Literal, Optional, Protocol, runtime_checkable +from typing import Literal, Protocol, runtime_checkable from pydantic import BaseModel @@ -33,11 +33,11 @@ class VectorDBInput(BaseModel): vector_db_id: str embedding_model: str embedding_dimension: int - provider_vector_db_id: Optional[str] = None + provider_vector_db_id: str | None = None class ListVectorDBsResponse(BaseModel): - data: List[VectorDB] + data: list[VectorDB] @runtime_checkable @@ -57,9 +57,9 @@ class VectorDBs(Protocol): self, vector_db_id: str, embedding_model: str, - embedding_dimension: Optional[int] = 384, - provider_id: Optional[str] = None, - provider_vector_db_id: Optional[str] = None, + embedding_dimension: int | None = 384, + provider_id: str | None = None, + provider_vector_db_id: str | None = None, ) -> VectorDB: ... @webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE") diff --git a/llama_stack/apis/vector_io/vector_io.py b/llama_stack/apis/vector_io/vector_io.py index ab0a4a20a..bfae0f802 100644 --- a/llama_stack/apis/vector_io/vector_io.py +++ b/llama_stack/apis/vector_io/vector_io.py @@ -8,7 +8,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Optional, Protocol, runtime_checkable +from typing import Any, Protocol, runtime_checkable from pydantic import BaseModel, Field @@ -20,17 +20,17 @@ from llama_stack.schema_utils import json_schema_type, webmethod class Chunk(BaseModel): content: InterleavedContent - metadata: Dict[str, Any] = Field(default_factory=dict) + metadata: dict[str, Any] = Field(default_factory=dict) @json_schema_type class QueryChunksResponse(BaseModel): - chunks: List[Chunk] - scores: List[float] + chunks: list[Chunk] + scores: list[float] class VectorDBStore(Protocol): - def get_vector_db(self, vector_db_id: str) -> Optional[VectorDB]: ... + def get_vector_db(self, vector_db_id: str) -> VectorDB | None: ... @runtime_checkable @@ -44,8 +44,8 @@ class VectorIO(Protocol): async def insert_chunks( self, vector_db_id: str, - chunks: List[Chunk], - ttl_seconds: Optional[int] = None, + chunks: list[Chunk], + ttl_seconds: int | None = None, ) -> None: ... @webmethod(route="/vector-io/query", method="POST") @@ -53,5 +53,5 @@ class VectorIO(Protocol): self, vector_db_id: str, query: InterleavedContent, - params: Optional[Dict[str, Any]] = None, + params: dict[str, Any] | None = None, ) -> QueryChunksResponse: ... diff --git a/llama_stack/cli/download.py b/llama_stack/cli/download.py index deccc4508..09c753776 100644 --- a/llama_stack/cli/download.py +++ b/llama_stack/cli/download.py @@ -13,7 +13,6 @@ from dataclasses import dataclass from datetime import datetime, timezone from functools import partial from pathlib import Path -from typing import Dict, List, Optional import httpx from pydantic import BaseModel, ConfigDict @@ -102,7 +101,7 @@ class DownloadTask: output_file: str total_size: int = 0 downloaded_size: int = 0 - task_id: Optional[int] = None + task_id: int | None = None retries: int = 0 max_retries: int = 3 @@ -262,7 +261,7 @@ class ParallelDownloader: self.progress.update(task.task_id, description=f"[red]Failed: {task.output_file}[/red]") raise DownloadError(f"Download failed for {task.output_file}: {str(e)}") from e - def has_disk_space(self, tasks: List[DownloadTask]) -> bool: + def has_disk_space(self, tasks: list[DownloadTask]) -> bool: try: total_remaining_size = sum(task.total_size - task.downloaded_size for task in tasks) dir_path = os.path.dirname(os.path.abspath(tasks[0].output_file)) @@ -282,7 +281,7 @@ class ParallelDownloader: except Exception as e: raise DownloadError(f"Failed to check disk space: {str(e)}") from e - async def download_all(self, tasks: List[DownloadTask]) -> None: + async def download_all(self, tasks: list[DownloadTask]) -> None: if not tasks: raise ValueError("No download tasks provided") @@ -391,20 +390,20 @@ def _meta_download( class ModelEntry(BaseModel): model_id: str - files: Dict[str, str] + files: dict[str, str] model_config = ConfigDict(protected_namespaces=()) class Manifest(BaseModel): - models: List[ModelEntry] + models: list[ModelEntry] expires_on: datetime def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int): from llama_stack.distribution.utils.model_utils import model_local_dir - with open(manifest_file, "r") as f: + with open(manifest_file) as f: d = json.load(f) manifest = Manifest(**d) diff --git a/llama_stack/cli/model/safety_models.py b/llama_stack/cli/model/safety_models.py index aaee1adcb..e31767f13 100644 --- a/llama_stack/cli/model/safety_models.py +++ b/llama_stack/cli/model/safety_models.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from pydantic import BaseModel, ConfigDict, Field @@ -22,7 +22,7 @@ class PromptGuardModel(BaseModel): max_seq_length: int = 512 is_instruct_model: bool = False quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16 - arch_args: Dict[str, Any] = Field(default_factory=dict) + arch_args: dict[str, Any] = Field(default_factory=dict) def descriptor(self) -> str: return self.model_id @@ -44,11 +44,11 @@ def prompt_guard_model_skus(): ] -def prompt_guard_model_sku_map() -> Dict[str, Any]: +def prompt_guard_model_sku_map() -> dict[str, Any]: return {model.model_id: model for model in prompt_guard_model_skus()} -def prompt_guard_download_info_map() -> Dict[str, LlamaDownloadInfo]: +def prompt_guard_download_info_map() -> dict[str, LlamaDownloadInfo]: return { model.model_id: LlamaDownloadInfo( folder="Prompt-Guard" if model.model_id == "Prompt-Guard-86M" else model.model_id, diff --git a/llama_stack/cli/stack/_build.py b/llama_stack/cli/stack/_build.py index f3a29b947..ae4a39ce2 100644 --- a/llama_stack/cli/stack/_build.py +++ b/llama_stack/cli/stack/_build.py @@ -13,7 +13,6 @@ import sys import textwrap from functools import lru_cache from pathlib import Path -from typing import Dict, Optional import yaml from prompt_toolkit import prompt @@ -46,14 +45,14 @@ from llama_stack.providers.datatypes import Api TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates" -@lru_cache() -def available_templates_specs() -> Dict[str, BuildConfig]: +@lru_cache +def available_templates_specs() -> dict[str, BuildConfig]: import yaml template_specs = {} for p in TEMPLATES_PATH.rglob("*build.yaml"): template_name = p.parent.name - with open(p, "r") as f: + with open(p) as f: build_config = BuildConfig(**yaml.safe_load(f)) template_specs[template_name] = build_config return template_specs @@ -178,7 +177,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None: if not available_providers: continue api_provider = prompt( - "> Enter provider for API {}: ".format(api.value), + f"> Enter provider for API {api.value}: ", completer=WordCompleter(available_providers), complete_while_typing=True, validator=Validator.from_callable( @@ -201,7 +200,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None: build_config = BuildConfig(image_type=image_type, distribution_spec=distribution_spec) else: - with open(args.config, "r") as f: + with open(args.config) as f: try: build_config = BuildConfig(**yaml.safe_load(f)) except Exception as e: @@ -332,9 +331,9 @@ def _generate_run_config( def _run_stack_build_command_from_build_config( build_config: BuildConfig, - image_name: Optional[str] = None, - template_name: Optional[str] = None, - config_path: Optional[str] = None, + image_name: str | None = None, + template_name: str | None = None, + config_path: str | None = None, ) -> str: image_name = image_name or build_config.image_name if build_config.image_type == LlamaStackImageType.CONTAINER.value: diff --git a/llama_stack/cli/table.py b/llama_stack/cli/table.py index bf59e6103..86c3adff2 100644 --- a/llama_stack/cli/table.py +++ b/llama_stack/cli/table.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Iterable +from collections.abc import Iterable from rich.console import Console from rich.table import Table diff --git a/llama_stack/cli/verify_download.py b/llama_stack/cli/verify_download.py index 1229e8601..3a1af3cbc 100644 --- a/llama_stack/cli/verify_download.py +++ b/llama_stack/cli/verify_download.py @@ -9,7 +9,6 @@ import hashlib from dataclasses import dataclass from functools import partial from pathlib import Path -from typing import Dict, List, Optional from rich.console import Console from rich.progress import Progress, SpinnerColumn, TextColumn @@ -21,7 +20,7 @@ from llama_stack.cli.subcommand import Subcommand class VerificationResult: filename: str expected_hash: str - actual_hash: Optional[str] + actual_hash: str | None exists: bool matches: bool @@ -60,9 +59,9 @@ def calculate_md5(filepath: Path, chunk_size: int = 8192) -> str: return md5_hash.hexdigest() -def load_checksums(checklist_path: Path) -> Dict[str, str]: +def load_checksums(checklist_path: Path) -> dict[str, str]: checksums = {} - with open(checklist_path, "r") as f: + with open(checklist_path) as f: for line in f: if line.strip(): md5sum, filepath = line.strip().split(" ", 1) @@ -72,7 +71,7 @@ def load_checksums(checklist_path: Path) -> Dict[str, str]: return checksums -def verify_files(model_dir: Path, checksums: Dict[str, str], console: Console) -> List[VerificationResult]: +def verify_files(model_dir: Path, checksums: dict[str, str], console: Console) -> list[VerificationResult]: results = [] with Progress( diff --git a/llama_stack/distribution/access_control.py b/llama_stack/distribution/access_control.py index 0651ab6eb..d560ec80f 100644 --- a/llama_stack/distribution/access_control.py +++ b/llama_stack/distribution/access_control.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Any from llama_stack.distribution.datatypes import AccessAttributes from llama_stack.log import get_logger @@ -14,8 +14,8 @@ logger = get_logger(__name__, category="core") def check_access( obj_identifier: str, - obj_attributes: Optional[AccessAttributes], - user_attributes: Optional[Dict[str, Any]] = None, + obj_attributes: AccessAttributes | None, + user_attributes: dict[str, Any] | None = None, ) -> bool: """Check if the current user has access to the given object, based on access attributes. diff --git a/llama_stack/distribution/client.py b/llama_stack/distribution/client.py index 1925b864f..9fde8a157 100644 --- a/llama_stack/distribution/client.py +++ b/llama_stack/distribution/client.py @@ -8,7 +8,7 @@ import inspect import json from collections.abc import AsyncIterator from enum import Enum -from typing import Any, Type, Union, get_args, get_origin +from typing import Any, Union, get_args, get_origin import httpx from pydantic import BaseModel, parse_obj_as @@ -27,7 +27,7 @@ async def get_client_impl(protocol, config: RemoteProviderConfig, _deps: Any): return impl -def create_api_client_class(protocol) -> Type: +def create_api_client_class(protocol) -> type: if protocol in _CLIENT_CLASSES: return _CLIENT_CLASSES[protocol] diff --git a/llama_stack/distribution/configure.py b/llama_stack/distribution/configure.py index 2a3bf7053..76167258a 100644 --- a/llama_stack/distribution/configure.py +++ b/llama_stack/distribution/configure.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import logging import textwrap -from typing import Any, Dict +from typing import Any from llama_stack.distribution.datatypes import ( LLAMA_STACK_RUN_CONFIG_VERSION, @@ -24,7 +24,7 @@ from llama_stack.providers.datatypes import Api, ProviderSpec logger = logging.getLogger(__name__) -def configure_single_provider(registry: Dict[str, ProviderSpec], provider: Provider) -> Provider: +def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provider) -> Provider: provider_spec = registry[provider.provider_type] config_type = instantiate_class_type(provider_spec.config_class) try: @@ -120,8 +120,8 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec def upgrade_from_routing_table( - config_dict: Dict[str, Any], -) -> Dict[str, Any]: + config_dict: dict[str, Any], +) -> dict[str, Any]: def get_providers(entries): return [ Provider( @@ -163,7 +163,7 @@ def upgrade_from_routing_table( return config_dict -def parse_and_maybe_upgrade_config(config_dict: Dict[str, Any]) -> StackRunConfig: +def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfig: version = config_dict.get("version", None) if version == LLAMA_STACK_RUN_CONFIG_VERSION: return StackRunConfig(**config_dict) diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 8e0e2f1f7..c127136c9 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from enum import Enum -from typing import Annotated, Any, Dict, List, Optional, Union +from typing import Annotated, Any from pydantic import BaseModel, Field @@ -30,7 +30,7 @@ LLAMA_STACK_BUILD_CONFIG_VERSION = "2" LLAMA_STACK_RUN_CONFIG_VERSION = "2" -RoutingKey = Union[str, List[str]] +RoutingKey = str | list[str] class AccessAttributes(BaseModel): @@ -47,17 +47,17 @@ class AccessAttributes(BaseModel): """ # Standard attribute categories - the minimal set we need now - roles: Optional[List[str]] = Field( + roles: list[str] | None = Field( default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')" ) - teams: Optional[List[str]] = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')") + teams: list[str] | None = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')") - projects: Optional[List[str]] = Field( + projects: list[str] | None = Field( default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')" ) - namespaces: Optional[List[str]] = Field( + namespaces: list[str] | None = Field( default=None, description="Namespace-based access control for resource isolation" ) @@ -106,7 +106,7 @@ class ResourceWithACL(Resource): # ^ User must have access to the customer-insights project AND have confidential namespace """ - access_attributes: Optional[AccessAttributes] = None + access_attributes: AccessAttributes | None = None # Use the extended Resource for all routable objects @@ -142,41 +142,21 @@ class ToolGroupWithACL(ToolGroup, ResourceWithACL): pass -RoutableObject = Union[ - Model, - Shield, - VectorDB, - Dataset, - ScoringFn, - Benchmark, - Tool, - ToolGroup, -] - +RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | Tool | ToolGroup RoutableObjectWithProvider = Annotated[ - Union[ - ModelWithACL, - ShieldWithACL, - VectorDBWithACL, - DatasetWithACL, - ScoringFnWithACL, - BenchmarkWithACL, - ToolWithACL, - ToolGroupWithACL, - ], + ModelWithACL + | ShieldWithACL + | VectorDBWithACL + | DatasetWithACL + | ScoringFnWithACL + | BenchmarkWithACL + | ToolWithACL + | ToolGroupWithACL, Field(discriminator="type"), ] -RoutedProtocol = Union[ - Inference, - Safety, - VectorIO, - DatasetIO, - Scoring, - Eval, - ToolRuntime, -] +RoutedProtocol = Inference | Safety | VectorIO | DatasetIO | Scoring | Eval | ToolRuntime # Example: /inference, /safety @@ -184,15 +164,15 @@ class AutoRoutedProviderSpec(ProviderSpec): provider_type: str = "router" config_class: str = "" - container_image: Optional[str] = None + container_image: str | None = None routing_table_api: Api module: str - provider_data_validator: Optional[str] = Field( + provider_data_validator: str | None = Field( default=None, ) @property - def pip_packages(self) -> List[str]: + def pip_packages(self) -> list[str]: raise AssertionError("Should not be called on AutoRoutedProviderSpec") @@ -200,20 +180,20 @@ class AutoRoutedProviderSpec(ProviderSpec): class RoutingTableProviderSpec(ProviderSpec): provider_type: str = "routing_table" config_class: str = "" - container_image: Optional[str] = None + container_image: str | None = None router_api: Api module: str - pip_packages: List[str] = Field(default_factory=list) + pip_packages: list[str] = Field(default_factory=list) class DistributionSpec(BaseModel): - description: Optional[str] = Field( + description: str | None = Field( default="", description="Description of the distribution", ) - container_image: Optional[str] = None - providers: Dict[str, Union[str, List[str]]] = Field( + container_image: str | None = None + providers: dict[str, str | list[str]] = Field( default_factory=dict, description=""" Provider Types for each of the APIs provided by this distribution. If you @@ -225,12 +205,12 @@ in the runtime configuration to help route to the correct provider.""", class Provider(BaseModel): provider_id: str provider_type: str - config: Dict[str, Any] + config: dict[str, Any] class LoggingConfig(BaseModel): - category_levels: Dict[str, str] = Field( - default_factory=Dict, + category_levels: dict[str, str] = Field( + default_factory=dict, description=""" Dictionary of different logging configurations for different portions (ex: core, server) of llama stack""", ) @@ -248,7 +228,7 @@ class AuthenticationConfig(BaseModel): ..., description="Type of authentication provider (e.g., 'kubernetes', 'custom')", ) - config: Dict[str, str] = Field( + config: dict[str, str] = Field( ..., description="Provider-specific configuration", ) @@ -261,15 +241,15 @@ class ServerConfig(BaseModel): ge=1024, le=65535, ) - tls_certfile: Optional[str] = Field( + tls_certfile: str | None = Field( default=None, description="Path to TLS certificate file for HTTPS", ) - tls_keyfile: Optional[str] = Field( + tls_keyfile: str | None = Field( default=None, description="Path to TLS key file for HTTPS", ) - auth: Optional[AuthenticationConfig] = Field( + auth: AuthenticationConfig | None = Field( default=None, description="Authentication configuration for the server", ) @@ -285,23 +265,23 @@ Reference to the distribution this package refers to. For unregistered (adhoc) p this could be just a hash """, ) - container_image: Optional[str] = Field( + container_image: str | None = Field( default=None, description="Reference to the container image if this package refers to a container", ) - apis: List[str] = Field( + apis: list[str] = Field( default_factory=list, description=""" The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""", ) - providers: Dict[str, List[Provider]] = Field( + providers: dict[str, list[Provider]] = Field( description=""" One or more providers to use for each API. The same provider_type (e.g., meta-reference) can be instantiated multiple times (with different configs) if necessary. """, ) - metadata_store: Optional[KVStoreConfig] = Field( + metadata_store: KVStoreConfig | None = Field( default=None, description=""" Configuration for the persistence store used by the distribution registry. If not specified, @@ -309,22 +289,22 @@ a default SQLite store will be used.""", ) # registry of "resources" in the distribution - models: List[ModelInput] = Field(default_factory=list) - shields: List[ShieldInput] = Field(default_factory=list) - vector_dbs: List[VectorDBInput] = Field(default_factory=list) - datasets: List[DatasetInput] = Field(default_factory=list) - scoring_fns: List[ScoringFnInput] = Field(default_factory=list) - benchmarks: List[BenchmarkInput] = Field(default_factory=list) - tool_groups: List[ToolGroupInput] = Field(default_factory=list) + models: list[ModelInput] = Field(default_factory=list) + shields: list[ShieldInput] = Field(default_factory=list) + vector_dbs: list[VectorDBInput] = Field(default_factory=list) + datasets: list[DatasetInput] = Field(default_factory=list) + scoring_fns: list[ScoringFnInput] = Field(default_factory=list) + benchmarks: list[BenchmarkInput] = Field(default_factory=list) + tool_groups: list[ToolGroupInput] = Field(default_factory=list) - logging: Optional[LoggingConfig] = Field(default=None, description="Configuration for Llama Stack Logging") + logging: LoggingConfig | None = Field(default=None, description="Configuration for Llama Stack Logging") server: ServerConfig = Field( default_factory=ServerConfig, description="Configuration for the HTTP(S) server", ) - external_providers_dir: Optional[str] = Field( + external_providers_dir: str | None = Field( default=None, description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.", ) @@ -338,11 +318,11 @@ class BuildConfig(BaseModel): default="conda", description="Type of package to build (conda | container | venv)", ) - image_name: Optional[str] = Field( + image_name: str | None = Field( default=None, description="Name of the distribution to build", ) - external_providers_dir: Optional[str] = Field( + external_providers_dir: str | None = Field( default=None, description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. " "pip_packages MUST contain the provider package name.", diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index f948ddf1c..07a91478a 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -7,7 +7,7 @@ import glob import importlib import os -from typing import Any, Dict, List +from typing import Any import yaml from pydantic import BaseModel @@ -24,7 +24,7 @@ from llama_stack.providers.datatypes import ( logger = get_logger(name=__name__, category="core") -def stack_apis() -> List[Api]: +def stack_apis() -> list[Api]: return list(Api) @@ -33,7 +33,7 @@ class AutoRoutedApiInfo(BaseModel): router_api: Api -def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]: +def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]: return [ AutoRoutedApiInfo( routing_table_api=Api.models, @@ -66,12 +66,12 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]: ] -def providable_apis() -> List[Api]: +def providable_apis() -> list[Api]: routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()} return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers] -def _load_remote_provider_spec(spec_data: Dict[str, Any], api: Api) -> ProviderSpec: +def _load_remote_provider_spec(spec_data: dict[str, Any], api: Api) -> ProviderSpec: adapter = AdapterSpec(**spec_data["adapter"]) spec = remote_provider_spec( api=api, @@ -81,7 +81,7 @@ def _load_remote_provider_spec(spec_data: Dict[str, Any], api: Api) -> ProviderS return spec -def _load_inline_provider_spec(spec_data: Dict[str, Any], api: Api, provider_name: str) -> ProviderSpec: +def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_name: str) -> ProviderSpec: spec = InlineProviderSpec( api=api, provider_type=f"inline::{provider_name}", @@ -98,7 +98,7 @@ def _load_inline_provider_spec(spec_data: Dict[str, Any], api: Api, provider_nam def get_provider_registry( config=None, -) -> Dict[Api, Dict[str, ProviderSpec]]: +) -> dict[Api, dict[str, ProviderSpec]]: """Get the provider registry, optionally including external providers. This function loads both built-in providers and external providers from YAML files. @@ -133,7 +133,7 @@ def get_provider_registry( ValueError: If any provider spec is invalid """ - ret: Dict[Api, Dict[str, ProviderSpec]] = {} + ret: dict[Api, dict[str, ProviderSpec]] = {} for api in providable_apis(): name = api.name.lower() logger.debug(f"Importing module {name}") diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index f426bcafe..b2d16d74c 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -12,7 +12,7 @@ import os from concurrent.futures import ThreadPoolExecutor from enum import Enum from pathlib import Path -from typing import Any, Optional, TypeVar, Union, get_args, get_origin +from typing import Any, TypeVar, Union, get_args, get_origin import httpx import yaml @@ -119,8 +119,8 @@ class LlamaStackAsLibraryClient(LlamaStackClient): self, config_path_or_template_name: str, skip_logger_removal: bool = False, - custom_provider_registry: Optional[ProviderRegistry] = None, - provider_data: Optional[dict[str, Any]] = None, + custom_provider_registry: ProviderRegistry | None = None, + provider_data: dict[str, Any] | None = None, ): super().__init__() self.async_client = AsyncLlamaStackAsLibraryClient( @@ -181,8 +181,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): def __init__( self, config_path_or_template_name: str, - custom_provider_registry: Optional[ProviderRegistry] = None, - provider_data: Optional[dict[str, Any]] = None, + custom_provider_registry: ProviderRegistry | None = None, + provider_data: dict[str, Any] | None = None, ): super().__init__() # when using the library client, we should not log to console since many @@ -371,7 +371,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): ) return await response.parse() - def _convert_body(self, path: str, method: str, body: Optional[dict] = None) -> dict: + def _convert_body(self, path: str, method: str, body: dict | None = None) -> dict: if not body: return {} diff --git a/llama_stack/distribution/providers.py b/llama_stack/distribution/providers.py index 1c00ce264..157fd1e0e 100644 --- a/llama_stack/distribution/providers.py +++ b/llama_stack/distribution/providers.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import asyncio -from typing import Any, Dict +from typing import Any from pydantic import BaseModel @@ -73,14 +73,14 @@ class ProviderImpl(Providers): raise ValueError(f"Provider {provider_id} not found") - async def get_providers_health(self) -> Dict[str, Dict[str, HealthResponse]]: + async def get_providers_health(self) -> dict[str, dict[str, HealthResponse]]: """Get health status for all providers. Returns: Dict[str, Dict[str, HealthResponse]]: A dictionary mapping API names to provider health statuses. Each API maps to a dictionary of provider IDs to their health responses. """ - providers_health: Dict[str, Dict[str, HealthResponse]] = {} + providers_health: dict[str, dict[str, HealthResponse]] = {} timeout = 1.0 async def check_provider_health(impl: Any) -> tuple[str, HealthResponse] | None: diff --git a/llama_stack/distribution/request_headers.py b/llama_stack/distribution/request_headers.py index f9cde2cdf..bc15776ec 100644 --- a/llama_stack/distribution/request_headers.py +++ b/llama_stack/distribution/request_headers.py @@ -7,7 +7,8 @@ import contextvars import json import logging -from typing import Any, ContextManager, Dict, List, Optional +from contextlib import AbstractContextManager +from typing import Any from .utils.dynamic import instantiate_class_type @@ -17,11 +18,11 @@ log = logging.getLogger(__name__) PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None) -class RequestProviderDataContext(ContextManager): +class RequestProviderDataContext(AbstractContextManager): """Context manager for request provider data""" def __init__( - self, provider_data: Optional[Dict[str, Any]] = None, auth_attributes: Optional[Dict[str, List[str]]] = None + self, provider_data: dict[str, Any] | None = None, auth_attributes: dict[str, list[str]] | None = None ): self.provider_data = provider_data or {} if auth_attributes: @@ -63,7 +64,7 @@ class NeedsRequestProviderData: return None -def parse_request_provider_data(headers: Dict[str, str]) -> Optional[Dict[str, Any]]: +def parse_request_provider_data(headers: dict[str, str]) -> dict[str, Any] | None: """Parse provider data from request headers""" keys = [ "X-LlamaStack-Provider-Data", @@ -86,14 +87,14 @@ def parse_request_provider_data(headers: Dict[str, str]) -> Optional[Dict[str, A def request_provider_data_context( - headers: Dict[str, str], auth_attributes: Optional[Dict[str, List[str]]] = None -) -> ContextManager: + headers: dict[str, str], auth_attributes: dict[str, list[str]] | None = None +) -> AbstractContextManager: """Context manager that sets request provider data from headers and auth attributes for the duration of the context""" provider_data = parse_request_provider_data(headers) return RequestProviderDataContext(provider_data, auth_attributes) -def get_auth_attributes() -> Optional[Dict[str, List[str]]]: +def get_auth_attributes() -> dict[str, list[str]] | None: """Helper to retrieve auth attributes from the provider data context""" provider_data = PROVIDER_DATA_VAR.get() if not provider_data: diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index e9a594eba..37588ea64 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import importlib import inspect -from typing import Any, Dict, List, Set, Tuple +from typing import Any from llama_stack.apis.agents import Agents from llama_stack.apis.benchmarks import Benchmarks @@ -58,7 +58,7 @@ class InvalidProviderError(Exception): pass -def api_protocol_map() -> Dict[Api, Any]: +def api_protocol_map() -> dict[Api, Any]: return { Api.providers: ProvidersAPI, Api.agents: Agents, @@ -83,7 +83,7 @@ def api_protocol_map() -> Dict[Api, Any]: } -def additional_protocols_map() -> Dict[Api, Any]: +def additional_protocols_map() -> dict[Api, Any]: return { Api.inference: (ModelsProtocolPrivate, Models, Api.models), Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups), @@ -104,14 +104,14 @@ class ProviderWithSpec(Provider): spec: ProviderSpec -ProviderRegistry = Dict[Api, Dict[str, ProviderSpec]] +ProviderRegistry = dict[Api, dict[str, ProviderSpec]] async def resolve_impls( run_config: StackRunConfig, provider_registry: ProviderRegistry, dist_registry: DistributionRegistry, -) -> Dict[Api, Any]: +) -> dict[Api, Any]: """ Resolves provider implementations by: 1. Validating and organizing providers. @@ -136,7 +136,7 @@ async def resolve_impls( return await instantiate_providers(sorted_providers, router_apis, dist_registry) -def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str, Dict[str, ProviderWithSpec]]: +def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]: """Generates specifications for automatically routed APIs.""" specs = {} for info in builtin_automatically_routed_apis(): @@ -178,10 +178,10 @@ def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str, def validate_and_prepare_providers( - run_config: StackRunConfig, provider_registry: ProviderRegistry, routing_table_apis: Set[Api], router_apis: Set[Api] -) -> Dict[str, Dict[str, ProviderWithSpec]]: + run_config: StackRunConfig, provider_registry: ProviderRegistry, routing_table_apis: set[Api], router_apis: set[Api] +) -> dict[str, dict[str, ProviderWithSpec]]: """Validates providers, handles deprecations, and organizes them into a spec dictionary.""" - providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]] = {} + providers_with_specs: dict[str, dict[str, ProviderWithSpec]] = {} for api_str, providers in run_config.providers.items(): api = Api(api_str) @@ -222,10 +222,10 @@ def validate_provider(provider: Provider, api: Api, provider_registry: ProviderR def sort_providers_by_deps( - providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]], run_config: StackRunConfig -) -> List[Tuple[str, ProviderWithSpec]]: + providers_with_specs: dict[str, dict[str, ProviderWithSpec]], run_config: StackRunConfig +) -> list[tuple[str, ProviderWithSpec]]: """Sorts providers based on their dependencies.""" - sorted_providers: List[Tuple[str, ProviderWithSpec]] = topological_sort( + sorted_providers: list[tuple[str, ProviderWithSpec]] = topological_sort( {k: list(v.values()) for k, v in providers_with_specs.items()} ) @@ -236,11 +236,11 @@ def sort_providers_by_deps( async def instantiate_providers( - sorted_providers: List[Tuple[str, ProviderWithSpec]], router_apis: Set[Api], dist_registry: DistributionRegistry -) -> Dict: + sorted_providers: list[tuple[str, ProviderWithSpec]], router_apis: set[Api], dist_registry: DistributionRegistry +) -> dict: """Instantiates providers asynchronously while managing dependencies.""" - impls: Dict[Api, Any] = {} - inner_impls_by_provider_id: Dict[str, Dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis} + impls: dict[Api, Any] = {} + inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis} for api_str, provider in sorted_providers: deps = {a: impls[a] for a in provider.spec.api_dependencies} for a in provider.spec.optional_api_dependencies: @@ -263,9 +263,9 @@ async def instantiate_providers( def topological_sort( - providers_with_specs: Dict[str, List[ProviderWithSpec]], -) -> List[Tuple[str, ProviderWithSpec]]: - def dfs(kv, visited: Set[str], stack: List[str]): + providers_with_specs: dict[str, list[ProviderWithSpec]], +) -> list[tuple[str, ProviderWithSpec]]: + def dfs(kv, visited: set[str], stack: list[str]): api_str, providers = kv visited.add(api_str) @@ -280,8 +280,8 @@ def topological_sort( stack.append(api_str) - visited: Set[str] = set() - stack: List[str] = [] + visited: set[str] = set() + stack: list[str] = [] for api_str, providers in providers_with_specs.items(): if api_str not in visited: @@ -298,8 +298,8 @@ def topological_sort( # returns a class implementing the protocol corresponding to the Api async def instantiate_provider( provider: ProviderWithSpec, - deps: Dict[Api, Any], - inner_impls: Dict[str, Any], + deps: dict[Api, Any], + inner_impls: dict[str, Any], dist_registry: DistributionRegistry, ): protocols = api_protocol_map() @@ -391,8 +391,8 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None: async def resolve_remote_stack_impls( config: RemoteProviderConfig, - apis: List[str], -) -> Dict[Api, Any]: + apis: list[str], +) -> dict[Api, Any]: protocols = api_protocol_map() additional_protocols = additional_protocols_map() diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index d0fca8771..cd2a296f2 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from llama_stack.distribution.datatypes import RoutedProtocol from llama_stack.distribution.store import DistributionRegistry @@ -23,7 +23,7 @@ from .routing_tables import ( async def get_routing_table_impl( api: Api, - impls_by_provider_id: Dict[str, RoutedProtocol], + impls_by_provider_id: dict[str, RoutedProtocol], _deps, dist_registry: DistributionRegistry, ) -> Any: @@ -45,7 +45,7 @@ async def get_routing_table_impl( return impl -async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any: +async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: dict[str, Any]) -> Any: from .routers import ( DatasetIORouter, EvalRouter, diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index d88df00bd..737a384d8 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -6,12 +6,12 @@ import asyncio import time -from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union +from collections.abc import AsyncGenerator, AsyncIterator +from typing import Annotated, Any from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam from pydantic import Field, TypeAdapter -from typing_extensions import Annotated from llama_stack.apis.common.content_types import ( URL, @@ -100,9 +100,9 @@ class VectorIORouter(VectorIO): self, vector_db_id: str, embedding_model: str, - embedding_dimension: Optional[int] = 384, - provider_id: Optional[str] = None, - provider_vector_db_id: Optional[str] = None, + embedding_dimension: int | None = 384, + provider_id: str | None = None, + provider_vector_db_id: str | None = None, ) -> None: logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}") await self.routing_table.register_vector_db( @@ -116,8 +116,8 @@ class VectorIORouter(VectorIO): async def insert_chunks( self, vector_db_id: str, - chunks: List[Chunk], - ttl_seconds: Optional[int] = None, + chunks: list[Chunk], + ttl_seconds: int | None = None, ) -> None: logger.debug( f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", @@ -128,7 +128,7 @@ class VectorIORouter(VectorIO): self, vector_db_id: str, query: InterleavedContent, - params: Optional[Dict[str, Any]] = None, + params: dict[str, Any] | None = None, ) -> QueryChunksResponse: logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}") return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params) @@ -140,7 +140,7 @@ class InferenceRouter(Inference): def __init__( self, routing_table: RoutingTable, - telemetry: Optional[Telemetry] = None, + telemetry: Telemetry | None = None, ) -> None: logger.debug("Initializing InferenceRouter") self.routing_table = routing_table @@ -160,10 +160,10 @@ class InferenceRouter(Inference): async def register_model( self, model_id: str, - provider_model_id: Optional[str] = None, - provider_id: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, - model_type: Optional[ModelType] = None, + provider_model_id: str | None = None, + provider_id: str | None = None, + metadata: dict[str, Any] | None = None, + model_type: ModelType | None = None, ) -> None: logger.debug( f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}", @@ -176,7 +176,7 @@ class InferenceRouter(Inference): completion_tokens: int, total_tokens: int, model: Model, - ) -> List[MetricEvent]: + ) -> list[MetricEvent]: """Constructs a list of MetricEvent objects containing token usage metrics. Args: @@ -221,7 +221,7 @@ class InferenceRouter(Inference): completion_tokens: int, total_tokens: int, model: Model, - ) -> List[MetricInResponse]: + ) -> list[MetricInResponse]: metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model) if self.telemetry: for metric in metrics: @@ -230,9 +230,9 @@ class InferenceRouter(Inference): async def _count_tokens( self, - messages: List[Message] | InterleavedContent, - tool_prompt_format: Optional[ToolPromptFormat] = None, - ) -> Optional[int]: + messages: list[Message] | InterleavedContent, + tool_prompt_format: ToolPromptFormat | None = None, + ) -> int | None: if isinstance(messages, list): encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format) else: @@ -242,16 +242,16 @@ class InferenceRouter(Inference): async def chat_completion( self, model_id: str, - messages: List[Message], - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - tools: Optional[List[ToolDefinition]] = None, - tool_choice: Optional[ToolChoice] = None, - tool_prompt_format: Optional[ToolPromptFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, - tool_config: Optional[ToolConfig] = None, - ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: + messages: list[Message], + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + tools: list[ToolDefinition] | None = None, + tool_choice: ToolChoice | None = None, + tool_prompt_format: ToolPromptFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, + tool_config: ToolConfig | None = None, + ) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]: logger.debug( f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", ) @@ -351,12 +351,12 @@ class InferenceRouter(Inference): async def batch_chat_completion( self, model_id: str, - messages_batch: List[List[Message]], - tools: Optional[List[ToolDefinition]] = None, - tool_config: Optional[ToolConfig] = None, - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - logprobs: Optional[LogProbConfig] = None, + messages_batch: list[list[Message]], + tools: list[ToolDefinition] | None = None, + tool_config: ToolConfig | None = None, + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + logprobs: LogProbConfig | None = None, ) -> BatchChatCompletionResponse: logger.debug( f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}", @@ -376,10 +376,10 @@ class InferenceRouter(Inference): self, model_id: str, content: InterleavedContent, - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, ) -> AsyncGenerator: if sampling_params is None: sampling_params = SamplingParams() @@ -439,10 +439,10 @@ class InferenceRouter(Inference): async def batch_completion( self, model_id: str, - content_batch: List[InterleavedContent], - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - logprobs: Optional[LogProbConfig] = None, + content_batch: list[InterleavedContent], + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + logprobs: LogProbConfig | None = None, ) -> BatchCompletionResponse: logger.debug( f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}", @@ -453,10 +453,10 @@ class InferenceRouter(Inference): async def embeddings( self, model_id: str, - contents: List[str] | List[InterleavedContentItem], - text_truncation: Optional[TextTruncation] = TextTruncation.none, - output_dimension: Optional[int] = None, - task_type: Optional[EmbeddingTaskType] = None, + contents: list[str] | list[InterleavedContentItem], + text_truncation: TextTruncation | None = TextTruncation.none, + output_dimension: int | None = None, + task_type: EmbeddingTaskType | None = None, ) -> EmbeddingsResponse: logger.debug(f"InferenceRouter.embeddings: {model_id}") model = await self.routing_table.get_model(model_id) @@ -475,24 +475,24 @@ class InferenceRouter(Inference): async def openai_completion( self, model: str, - prompt: Union[str, List[str], List[int], List[List[int]]], - best_of: Optional[int] = None, - echo: Optional[bool] = None, - frequency_penalty: Optional[float] = None, - logit_bias: Optional[Dict[str, float]] = None, - logprobs: Optional[bool] = None, - max_tokens: Optional[int] = None, - n: Optional[int] = None, - presence_penalty: Optional[float] = None, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - stream: Optional[bool] = None, - stream_options: Optional[Dict[str, Any]] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - user: Optional[str] = None, - guided_choice: Optional[List[str]] = None, - prompt_logprobs: Optional[int] = None, + prompt: str | list[str] | list[int] | list[list[int]], + best_of: int | None = None, + echo: bool | None = None, + frequency_penalty: float | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_tokens: int | None = None, + n: int | None = None, + presence_penalty: float | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + top_p: float | None = None, + user: str | None = None, + guided_choice: list[str] | None = None, + prompt_logprobs: int | None = None, ) -> OpenAICompletion: logger.debug( f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}", @@ -531,29 +531,29 @@ class InferenceRouter(Inference): async def openai_chat_completion( self, model: str, - messages: Annotated[List[OpenAIMessageParam], Field(..., min_length=1)], - frequency_penalty: Optional[float] = None, - function_call: Optional[Union[str, Dict[str, Any]]] = None, - functions: Optional[List[Dict[str, Any]]] = None, - logit_bias: Optional[Dict[str, float]] = None, - logprobs: Optional[bool] = None, - max_completion_tokens: Optional[int] = None, - max_tokens: Optional[int] = None, - n: Optional[int] = None, - parallel_tool_calls: Optional[bool] = None, - presence_penalty: Optional[float] = None, - response_format: Optional[OpenAIResponseFormatParam] = None, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - stream: Optional[bool] = None, - stream_options: Optional[Dict[str, Any]] = None, - temperature: Optional[float] = None, - tool_choice: Optional[Union[str, Dict[str, Any]]] = None, - tools: Optional[List[Dict[str, Any]]] = None, - top_logprobs: Optional[int] = None, - top_p: Optional[float] = None, - user: Optional[str] = None, - ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: + messages: Annotated[list[OpenAIMessageParam], Field(..., min_length=1)], + frequency_penalty: float | None = None, + function_call: str | dict[str, Any] | None = None, + functions: list[dict[str, Any]] | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_completion_tokens: int | None = None, + max_tokens: int | None = None, + n: int | None = None, + parallel_tool_calls: bool | None = None, + presence_penalty: float | None = None, + response_format: OpenAIResponseFormatParam | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + tool_choice: str | dict[str, Any] | None = None, + tools: list[dict[str, Any]] | None = None, + top_logprobs: int | None = None, + top_p: float | None = None, + user: str | None = None, + ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: logger.debug( f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}", ) @@ -602,7 +602,7 @@ class InferenceRouter(Inference): provider = self.routing_table.get_provider_impl(model_obj.identifier) return await provider.openai_chat_completion(**params) - async def health(self) -> Dict[str, HealthResponse]: + async def health(self) -> dict[str, HealthResponse]: health_statuses = {} timeout = 0.5 for provider_id, impl in self.routing_table.impls_by_provider_id.items(): @@ -645,9 +645,9 @@ class SafetyRouter(Safety): async def register_shield( self, shield_id: str, - provider_shield_id: Optional[str] = None, - provider_id: Optional[str] = None, - params: Optional[Dict[str, Any]] = None, + provider_shield_id: str | None = None, + provider_id: str | None = None, + params: dict[str, Any] | None = None, ) -> Shield: logger.debug(f"SafetyRouter.register_shield: {shield_id}") return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) @@ -655,8 +655,8 @@ class SafetyRouter(Safety): async def run_shield( self, shield_id: str, - messages: List[Message], - params: Dict[str, Any] = None, + messages: list[Message], + params: dict[str, Any] = None, ) -> RunShieldResponse: logger.debug(f"SafetyRouter.run_shield: {shield_id}") return await self.routing_table.get_provider_impl(shield_id).run_shield( @@ -686,8 +686,8 @@ class DatasetIORouter(DatasetIO): self, purpose: DatasetPurpose, source: DataSource, - metadata: Optional[Dict[str, Any]] = None, - dataset_id: Optional[str] = None, + metadata: dict[str, Any] | None = None, + dataset_id: str | None = None, ) -> None: logger.debug( f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}", @@ -702,8 +702,8 @@ class DatasetIORouter(DatasetIO): async def iterrows( self, dataset_id: str, - start_index: Optional[int] = None, - limit: Optional[int] = None, + start_index: int | None = None, + limit: int | None = None, ) -> PaginatedResponse: logger.debug( f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}", @@ -714,7 +714,7 @@ class DatasetIORouter(DatasetIO): limit=limit, ) - async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: + async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows") return await self.routing_table.get_provider_impl(dataset_id).append_rows( dataset_id=dataset_id, @@ -741,7 +741,7 @@ class ScoringRouter(Scoring): async def score_batch( self, dataset_id: str, - scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, + scoring_functions: dict[str, ScoringFnParams | None] = None, save_results_dataset: bool = False, ) -> ScoreBatchResponse: logger.debug(f"ScoringRouter.score_batch: {dataset_id}") @@ -762,8 +762,8 @@ class ScoringRouter(Scoring): async def score( self, - input_rows: List[Dict[str, Any]], - scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, + input_rows: list[dict[str, Any]], + scoring_functions: dict[str, ScoringFnParams | None] = None, ) -> ScoreResponse: logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions") res = {} @@ -808,8 +808,8 @@ class EvalRouter(Eval): async def evaluate_rows( self, benchmark_id: str, - input_rows: List[Dict[str, Any]], - scoring_functions: List[str], + input_rows: list[dict[str, Any]], + scoring_functions: list[str], benchmark_config: BenchmarkConfig, ) -> EvaluateResponse: logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows") @@ -863,8 +863,8 @@ class ToolRuntimeRouter(ToolRuntime): async def query( self, content: InterleavedContent, - vector_db_ids: List[str], - query_config: Optional[RAGQueryConfig] = None, + vector_db_ids: list[str], + query_config: RAGQueryConfig | None = None, ) -> RAGQueryResult: logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}") return await self.routing_table.get_provider_impl("knowledge_search").query( @@ -873,7 +873,7 @@ class ToolRuntimeRouter(ToolRuntime): async def insert( self, - documents: List[RAGDocument], + documents: list[RAGDocument], vector_db_id: str, chunk_size_in_tokens: int = 512, ) -> None: @@ -904,7 +904,7 @@ class ToolRuntimeRouter(ToolRuntime): logger.debug("ToolRuntimeRouter.shutdown") pass - async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any: + async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any: logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}") return await self.routing_table.get_provider_impl(tool_name).invoke_tool( tool_name=tool_name, @@ -912,7 +912,7 @@ class ToolRuntimeRouter(ToolRuntime): ) async def list_runtime_tools( - self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None ) -> ListToolDefsResponse: logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 68ee837bf..c04562197 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -7,7 +7,7 @@ import logging import time import uuid -from typing import Any, Dict, List, Optional +from typing import Any from pydantic import TypeAdapter @@ -106,20 +106,20 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None: raise ValueError(f"Unregister not supported for {api}") -Registry = Dict[str, List[RoutableObjectWithProvider]] +Registry = dict[str, list[RoutableObjectWithProvider]] class CommonRoutingTableImpl(RoutingTable): def __init__( self, - impls_by_provider_id: Dict[str, RoutedProtocol], + impls_by_provider_id: dict[str, RoutedProtocol], dist_registry: DistributionRegistry, ) -> None: self.impls_by_provider_id = impls_by_provider_id self.dist_registry = dist_registry async def initialize(self) -> None: - async def add_objects(objs: List[RoutableObjectWithProvider], provider_id: str, cls) -> None: + async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None: for obj in objs: if cls is None: obj.provider_id = provider_id @@ -154,7 +154,7 @@ class CommonRoutingTableImpl(RoutingTable): for p in self.impls_by_provider_id.values(): await p.shutdown() - def get_provider_impl(self, routing_key: str, provider_id: Optional[str] = None) -> Any: + def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: def apiname_object(): if isinstance(self, ModelsRoutingTable): return ("Inference", "model") @@ -192,7 +192,7 @@ class CommonRoutingTableImpl(RoutingTable): raise ValueError(f"Provider not found for `{routing_key}`") - async def get_object_by_identifier(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]: + async def get_object_by_identifier(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: # Get from disk registry obj = await self.dist_registry.get(type, identifier) if not obj: @@ -236,7 +236,7 @@ class CommonRoutingTableImpl(RoutingTable): await self.dist_registry.register(obj) return obj - async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]: + async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]: objs = await self.dist_registry.get_all() filtered_objs = [obj for obj in objs if obj.type == type] @@ -277,10 +277,10 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): async def register_model( self, model_id: str, - provider_model_id: Optional[str] = None, - provider_id: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, - model_type: Optional[ModelType] = None, + provider_model_id: str | None = None, + provider_id: str | None = None, + metadata: dict[str, Any] | None = None, + model_type: ModelType | None = None, ) -> Model: if provider_model_id is None: provider_model_id = model_id @@ -328,9 +328,9 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): async def register_shield( self, shield_id: str, - provider_shield_id: Optional[str] = None, - provider_id: Optional[str] = None, - params: Optional[Dict[str, Any]] = None, + provider_shield_id: str | None = None, + provider_id: str | None = None, + params: dict[str, Any] | None = None, ) -> Shield: if provider_shield_id is None: provider_shield_id = shield_id @@ -368,9 +368,9 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): self, vector_db_id: str, embedding_model: str, - embedding_dimension: Optional[int] = 384, - provider_id: Optional[str] = None, - provider_vector_db_id: Optional[str] = None, + embedding_dimension: int | None = 384, + provider_id: str | None = None, + provider_vector_db_id: str | None = None, ) -> VectorDB: if provider_vector_db_id is None: provider_vector_db_id = vector_db_id @@ -423,8 +423,8 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): self, purpose: DatasetPurpose, source: DataSource, - metadata: Optional[Dict[str, Any]] = None, - dataset_id: Optional[str] = None, + metadata: dict[str, Any] | None = None, + dataset_id: str | None = None, ) -> Dataset: if isinstance(source, dict): if source["type"] == "uri": @@ -489,9 +489,9 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): scoring_fn_id: str, description: str, return_type: ParamType, - provider_scoring_fn_id: Optional[str] = None, - provider_id: Optional[str] = None, - params: Optional[ScoringFnParams] = None, + provider_scoring_fn_id: str | None = None, + provider_id: str | None = None, + params: ScoringFnParams | None = None, ) -> None: if provider_scoring_fn_id is None: provider_scoring_fn_id = scoring_fn_id @@ -528,10 +528,10 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks): self, benchmark_id: str, dataset_id: str, - scoring_functions: List[str], - metadata: Optional[Dict[str, Any]] = None, - provider_benchmark_id: Optional[str] = None, - provider_id: Optional[str] = None, + scoring_functions: list[str], + metadata: dict[str, Any] | None = None, + provider_benchmark_id: str | None = None, + provider_id: str | None = None, ) -> None: if metadata is None: metadata = {} @@ -556,7 +556,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks): class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): - async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse: + async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse: tools = await self.get_all_with_type("tool") if toolgroup_id: tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id] @@ -578,8 +578,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): self, toolgroup_id: str, provider_id: str, - mcp_endpoint: Optional[URL] = None, - args: Optional[Dict[str, Any]] = None, + mcp_endpoint: URL | None = None, + args: dict[str, Any] | None = None, ) -> None: tools = [] tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint) diff --git a/llama_stack/distribution/server/auth_providers.py b/llama_stack/distribution/server/auth_providers.py index d94b16858..1b19f8923 100644 --- a/llama_stack/distribution/server/auth_providers.py +++ b/llama_stack/distribution/server/auth_providers.py @@ -7,7 +7,6 @@ import json from abc import ABC, abstractmethod from enum import Enum -from typing import Dict, List, Optional from urllib.parse import parse_qs import httpx @@ -22,7 +21,7 @@ logger = get_logger(name=__name__, category="auth") class AuthResponse(BaseModel): """The format of the authentication response from the auth endpoint.""" - access_attributes: Optional[AccessAttributes] = Field( + access_attributes: AccessAttributes | None = Field( default=None, description=""" Structured user attributes for attribute-based access control. @@ -44,7 +43,7 @@ class AuthResponse(BaseModel): """, ) - message: Optional[str] = Field( + message: str | None = Field( default=None, description="Optional message providing additional context about the authentication result." ) @@ -52,9 +51,9 @@ class AuthResponse(BaseModel): class AuthRequestContext(BaseModel): path: str = Field(description="The path of the request being authenticated") - headers: Dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)") + headers: dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)") - params: Dict[str, List[str]] = Field( + params: dict[str, list[str]] = Field( description="Query parameters from the original request, parsed as dictionary of lists" ) @@ -76,14 +75,14 @@ class AuthProviderConfig(BaseModel): """Base configuration for authentication providers.""" provider_type: AuthProviderType = Field(..., description="Type of authentication provider") - config: Dict[str, str] = Field(..., description="Provider-specific configuration") + config: dict[str, str] = Field(..., description="Provider-specific configuration") class AuthProvider(ABC): """Abstract base class for authentication providers.""" @abstractmethod - async def validate_token(self, token: str, scope: Optional[Dict] = None) -> Optional[AccessAttributes]: + async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None: """Validate a token and return access attributes.""" pass @@ -96,7 +95,7 @@ class AuthProvider(ABC): class KubernetesAuthProvider(AuthProvider): """Kubernetes authentication provider that validates tokens against the Kubernetes API server.""" - def __init__(self, config: Dict[str, str]): + def __init__(self, config: dict[str, str]): self.api_server_url = config["api_server_url"] self.ca_cert_path = config.get("ca_cert_path") self._client = None @@ -120,7 +119,7 @@ class KubernetesAuthProvider(AuthProvider): self._client = ApiClient(configuration) return self._client - async def validate_token(self, token: str, scope: Optional[Dict] = None) -> Optional[AccessAttributes]: + async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None: """Validate a Kubernetes token and return access attributes.""" try: client = await self._get_client() @@ -166,11 +165,11 @@ class KubernetesAuthProvider(AuthProvider): class CustomAuthProvider(AuthProvider): """Custom authentication provider that uses an external endpoint.""" - def __init__(self, config: Dict[str, str]): + def __init__(self, config: dict[str, str]): self.endpoint = config["endpoint"] self._client = None - async def validate_token(self, token: str, scope: Optional[Dict] = None) -> Optional[AccessAttributes]: + async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None: """Validate a token using the custom authentication endpoint.""" if not self.endpoint: raise ValueError("Authentication endpoint not configured") diff --git a/llama_stack/distribution/server/endpoints.py b/llama_stack/distribution/server/endpoints.py index 98f01c067..ec1f7e083 100644 --- a/llama_stack/distribution/server/endpoints.py +++ b/llama_stack/distribution/server/endpoints.py @@ -6,7 +6,6 @@ import inspect import re -from typing import Dict, List from pydantic import BaseModel @@ -29,7 +28,7 @@ def toolgroup_protocol_map(): } -def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]: +def get_all_api_endpoints() -> dict[Api, list[ApiEndpoint]]: apis = {} protocols = api_protocol_map() diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 0b21936c2..0c8c70306 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -15,7 +15,7 @@ import warnings from contextlib import asynccontextmanager from importlib.metadata import version as parse_version from pathlib import Path -from typing import Any, List, Optional, Union +from typing import Annotated, Any import yaml from fastapi import Body, FastAPI, HTTPException, Request @@ -24,7 +24,6 @@ from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, StreamingResponse from openai import BadRequestError from pydantic import BaseModel, ValidationError -from typing_extensions import Annotated from llama_stack.distribution.datatypes import LoggingConfig, StackRunConfig from llama_stack.distribution.distribution import builtin_automatically_routed_apis @@ -91,7 +90,7 @@ async def global_exception_handler(request: Request, exc: Exception): return JSONResponse(status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}}) -def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidationError]: +def translate_exception(exc: Exception) -> HTTPException | RequestValidationError: if isinstance(exc, ValidationError): exc = RequestValidationError(exc.errors()) @@ -315,7 +314,7 @@ class ClientVersionMiddleware: return await self.app(scope, receive, send) -def main(args: Optional[argparse.Namespace] = None): +def main(args: argparse.Namespace | None = None): """Start the LlamaStack server.""" parser = argparse.ArgumentParser(description="Start the LlamaStack server.") parser.add_argument( @@ -385,7 +384,7 @@ def main(args: Optional[argparse.Namespace] = None): raise ValueError("Either --yaml-config or --template must be provided") logger_config = None - with open(config_file, "r") as fp: + with open(config_file) as fp: config_contents = yaml.safe_load(fp) if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")): logger_config = LoggingConfig(**cfg) @@ -517,7 +516,7 @@ def main(args: Optional[argparse.Namespace] = None): uvicorn.run(**uvicorn_config) -def extract_path_params(route: str) -> List[str]: +def extract_path_params(route: str) -> list[str]: segments = route.split("/") params = [seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")] # to handle path params like {param:path} diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index a6dc3d2a0..fc68dc016 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -8,7 +8,7 @@ import importlib.resources import os import re import tempfile -from typing import Any, Dict, Optional +from typing import Any import yaml @@ -90,7 +90,7 @@ RESOURCES = [ ] -async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]): +async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]): for rsrc, api, register_method, list_method in RESOURCES: objects = getattr(run_config, rsrc) if api not in impls: @@ -197,7 +197,7 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]: ) from e -def add_internal_implementations(impls: Dict[Api, Any], run_config: StackRunConfig) -> None: +def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConfig) -> None: """Add internal implementations (inspect and providers) to the implementations dictionary. Args: @@ -220,8 +220,8 @@ def add_internal_implementations(impls: Dict[Api, Any], run_config: StackRunConf # Produces a stack of providers for the given run config. Not all APIs may be # asked for in the run config. async def construct_stack( - run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None -) -> Dict[Api, Any]: + run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None +) -> dict[Api, Any]: dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name) impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry) @@ -244,7 +244,7 @@ def get_stack_run_config_from_template(template: str) -> StackRunConfig: def run_config_from_adhoc_config_spec( - adhoc_config_spec: str, provider_registry: Optional[ProviderRegistry] = None + adhoc_config_spec: str, provider_registry: ProviderRegistry | None = None ) -> StackRunConfig: """ Create an adhoc distribution from a list of API providers. diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index 76b66cc7a..ae97f600c 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -6,7 +6,7 @@ import asyncio from contextlib import asynccontextmanager -from typing import Dict, List, Optional, Protocol, Tuple +from typing import Protocol import pydantic @@ -20,13 +20,13 @@ logger = get_logger(__name__, category="core") class DistributionRegistry(Protocol): - async def get_all(self) -> List[RoutableObjectWithProvider]: ... + async def get_all(self) -> list[RoutableObjectWithProvider]: ... async def initialize(self) -> None: ... - async def get(self, identifier: str) -> Optional[RoutableObjectWithProvider]: ... + async def get(self, identifier: str) -> RoutableObjectWithProvider | None: ... - def get_cached(self, identifier: str) -> Optional[RoutableObjectWithProvider]: ... + def get_cached(self, identifier: str) -> RoutableObjectWithProvider | None: ... async def update(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: ... @@ -40,13 +40,13 @@ KEY_VERSION = "v8" KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}" -def _get_registry_key_range() -> Tuple[str, str]: +def _get_registry_key_range() -> tuple[str, str]: """Returns the start and end keys for the registry range query.""" start_key = f"{REGISTER_PREFIX}:{KEY_VERSION}" return start_key, f"{start_key}\xff" -def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider]: +def _parse_registry_values(values: list[str]) -> list[RoutableObjectWithProvider]: """Utility function to parse registry values into RoutableObjectWithProvider objects.""" all_objects = [] for value in values: @@ -67,16 +67,16 @@ class DiskDistributionRegistry(DistributionRegistry): async def initialize(self) -> None: pass - def get_cached(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]: + def get_cached(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: # Disk registry does not have a cache raise NotImplementedError("Disk registry does not have a cache") - async def get_all(self) -> List[RoutableObjectWithProvider]: + async def get_all(self) -> list[RoutableObjectWithProvider]: start_key, end_key = _get_registry_key_range() values = await self.kvstore.range(start_key, end_key) return _parse_registry_values(values) - async def get(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]: + async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: json_str = await self.kvstore.get(KEY_FORMAT.format(type=type, identifier=identifier)) if not json_str: return None @@ -113,7 +113,7 @@ class DiskDistributionRegistry(DistributionRegistry): class CachedDiskDistributionRegistry(DiskDistributionRegistry): def __init__(self, kvstore: KVStore): super().__init__(kvstore) - self.cache: Dict[Tuple[str, str], RoutableObjectWithProvider] = {} + self.cache: dict[tuple[str, str], RoutableObjectWithProvider] = {} self._initialized = False self._initialize_lock = asyncio.Lock() self._cache_lock = asyncio.Lock() @@ -147,15 +147,15 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry): async def initialize(self) -> None: await self._ensure_initialized() - def get_cached(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]: + def get_cached(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: return self.cache.get((type, identifier), None) - async def get_all(self) -> List[RoutableObjectWithProvider]: + async def get_all(self) -> list[RoutableObjectWithProvider]: await self._ensure_initialized() async with self._locked_cache() as cache: return list(cache.values()) - async def get(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]: + async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: await self._ensure_initialized() cache_key = (type, identifier) @@ -189,7 +189,7 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry): async def create_dist_registry( - metadata_store: Optional[KVStoreConfig], + metadata_store: KVStoreConfig | None, image_name: str, ) -> tuple[CachedDiskDistributionRegistry, KVStore]: # instantiate kvstore for storing and retrieving distribution metadata diff --git a/llama_stack/distribution/ui/modules/api.py b/llama_stack/distribution/ui/modules/api.py index d5395c5b9..11455ed46 100644 --- a/llama_stack/distribution/ui/modules/api.py +++ b/llama_stack/distribution/ui/modules/api.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import os -from typing import Optional from llama_stack_client import LlamaStackClient @@ -23,7 +22,7 @@ class LlamaStackApi: }, ) - def run_scoring(self, row, scoring_function_ids: list[str], scoring_params: Optional[dict]): + def run_scoring(self, row, scoring_function_ids: list[str], scoring_params: dict | None): """Run scoring on a single row""" if not scoring_params: scoring_params = {fn_id: None for fn_id in scoring_function_ids} diff --git a/llama_stack/distribution/utils/config.py b/llama_stack/distribution/utils/config.py index 5e78289b7..dece52460 100644 --- a/llama_stack/distribution/utils/config.py +++ b/llama_stack/distribution/utils/config.py @@ -4,10 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any -def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]: +def redact_sensitive_fields(data: dict[str, Any]) -> dict[str, Any]: """Redact sensitive information from config before printing.""" sensitive_patterns = ["api_key", "api_token", "password", "secret"] @@ -18,7 +18,7 @@ def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]: return [_redact_value(i) for i in v] return v - def _redact_dict(d: Dict[str, Any]) -> Dict[str, Any]: + def _redact_dict(d: dict[str, Any]) -> dict[str, Any]: result = {} for k, v in d.items(): if any(pattern in k.lower() for pattern in sensitive_patterns): diff --git a/llama_stack/distribution/utils/context.py b/llama_stack/distribution/utils/context.py index c34079ac6..3fcd3315f 100644 --- a/llama_stack/distribution/utils/context.py +++ b/llama_stack/distribution/utils/context.py @@ -4,14 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from collections.abc import AsyncGenerator from contextvars import ContextVar -from typing import AsyncGenerator, List, TypeVar +from typing import TypeVar T = TypeVar("T") def preserve_contexts_async_generator( - gen: AsyncGenerator[T, None], context_vars: List[ContextVar] + gen: AsyncGenerator[T, None], context_vars: list[ContextVar] ) -> AsyncGenerator[T, None]: """ Wraps an async generator to preserve context variables across iterations. diff --git a/llama_stack/distribution/utils/prompt_for_config.py b/llama_stack/distribution/utils/prompt_for_config.py index 9b2b99022..26f6920e0 100644 --- a/llama_stack/distribution/utils/prompt_for_config.py +++ b/llama_stack/distribution/utils/prompt_for_config.py @@ -8,12 +8,11 @@ import inspect import json import logging from enum import Enum -from typing import Any, List, Literal, Optional, Type, Union, get_args, get_origin +from typing import Annotated, Any, Literal, Union, get_args, get_origin from pydantic import BaseModel from pydantic.fields import FieldInfo from pydantic_core import PydanticUndefinedType -from typing_extensions import Annotated log = logging.getLogger(__name__) @@ -21,7 +20,7 @@ log = logging.getLogger(__name__) def is_list_of_primitives(field_type): """Check if a field type is a List of primitive types.""" origin = get_origin(field_type) - if origin is List or origin is list: + if origin is list or origin is list: args = get_args(field_type) if len(args) == 1 and args[0] in (int, float, str, bool): return True @@ -53,7 +52,7 @@ def get_non_none_type(field_type): return next(arg for arg in get_args(field_type) if arg is not type(None)) -def manually_validate_field(model: Type[BaseModel], field_name: str, value: Any): +def manually_validate_field(model: type[BaseModel], field_name: str, value: Any): validators = model.__pydantic_decorators__.field_validators for _name, validator in validators.items(): if field_name in validator.info.fields: @@ -126,7 +125,7 @@ def prompt_for_discriminated_union( # # doesn't support List[nested_class] yet or Dicts of any kind. needs a bunch of # unit tests for coverage. -def prompt_for_config(config_type: type[BaseModel], existing_config: Optional[BaseModel] = None) -> BaseModel: +def prompt_for_config(config_type: type[BaseModel], existing_config: BaseModel | None = None) -> BaseModel: """ Recursively prompt the user for configuration values based on a Pydantic BaseModel. diff --git a/llama_stack/log.py b/llama_stack/log.py index 3835b74a1..98858d208 100644 --- a/llama_stack/log.py +++ b/llama_stack/log.py @@ -7,7 +7,6 @@ import logging import os from logging.config import dictConfig -from typing import Dict, Optional from rich.console import Console from rich.errors import MarkupError @@ -33,7 +32,7 @@ CATEGORIES = [ ] # Initialize category levels with default level -_category_levels: Dict[str, int] = {category: DEFAULT_LOG_LEVEL for category in CATEGORIES} +_category_levels: dict[str, int] = {category: DEFAULT_LOG_LEVEL for category in CATEGORIES} def config_to_category_levels(category: str, level: str): @@ -49,7 +48,7 @@ def config_to_category_levels(category: str, level: str): Dict[str, int]: A dictionary mapping categories to their log levels. """ - category_levels: Dict[str, int] = {} + category_levels: dict[str, int] = {} level_value = logging._nameToLevel.get(str(level).upper()) if level_value is None: logging.warning(f"Unknown log level '{level}' for category '{category}'. Falling back to default 'INFO'.") @@ -69,7 +68,7 @@ def config_to_category_levels(category: str, level: str): return category_levels -def parse_yaml_config(yaml_config: LoggingConfig) -> Dict[str, int]: +def parse_yaml_config(yaml_config: LoggingConfig) -> dict[str, int]: """ Helper function to parse a yaml logging configuration found in the run.yaml @@ -86,7 +85,7 @@ def parse_yaml_config(yaml_config: LoggingConfig) -> Dict[str, int]: return category_levels -def parse_environment_config(env_config: str) -> Dict[str, int]: +def parse_environment_config(env_config: str) -> dict[str, int]: """ Parse the LLAMA_STACK_LOGGING environment variable and return a dictionary of category log levels. @@ -131,7 +130,7 @@ class CustomRichHandler(RichHandler): self.markup = original_markup -def setup_logging(category_levels: Dict[str, int], log_file: str | None) -> None: +def setup_logging(category_levels: dict[str, int], log_file: str | None) -> None: """ Configure logging based on the provided category log levels and an optional log file. @@ -211,7 +210,7 @@ def setup_logging(category_levels: Dict[str, int], log_file: str | None) -> None def get_logger( - name: str, category: str = "uncategorized", config: Optional[LoggingConfig] | None = None + name: str, category: str = "uncategorized", config: LoggingConfig | None | None = None ) -> logging.LoggerAdapter: """ Returns a logger with the specified name and category. diff --git a/llama_stack/models/llama/checkpoint.py b/llama_stack/models/llama/checkpoint.py index 2bae08a69..c9e0030e3 100644 --- a/llama_stack/models/llama/checkpoint.py +++ b/llama_stack/models/llama/checkpoint.py @@ -7,14 +7,14 @@ import concurrent.futures import re from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any import numpy as np import torch from fairscale.nn.model_parallel.initialize import get_model_parallel_rank, get_model_parallel_world_size -def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> List[int]: +def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> list[int]: """Map a new MP rank to a list of old MP ranks given a change in MP size.""" if new_mp_size % old_mp_size == 0: # Read old MP shard and split it into smaller ones @@ -31,12 +31,12 @@ def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> List[in def maybe_reshard_state_dict( - ckpt_paths: List[Path], + ckpt_paths: list[Path], n_kv_heads: int, - moe_num_experts: Optional[int] = None, - map_location: Union[str, torch.device] = "cpu", + moe_num_experts: int | None = None, + map_location: str | torch.device = "cpu", mmap: bool = True, -) -> Dict[str, torch.Tensor]: +) -> dict[str, torch.Tensor]: if str(map_location) == "cpu": torch.set_default_tensor_type(torch.BFloat16Tensor) else: @@ -97,18 +97,18 @@ _MOE_WEIGHT_COLUMN_KEY = {"feed_forward.experts.moe_w_out_eF_D"} def reshard_mp( - state_dicts: List[Dict[str, torch.Tensor]], + state_dicts: list[dict[str, torch.Tensor]], size: int, rank: int, repeat_qk_qv: int = 1, -) -> Dict[str, torch.Tensor]: +) -> dict[str, torch.Tensor]: """ Reshard a list of state dicts into a single state dict given a change in MP size. If the list has more than one state dict, we concatenate the values of the same key across all state dicts. Otherwise, we just slice it for the current MP rank. """ - def concat_or_chunk(tensors: List[torch.Tensor], dim: int) -> torch.Tensor: + def concat_or_chunk(tensors: list[torch.Tensor], dim: int) -> torch.Tensor: if len(tensors) > 1: return torch.cat(tensors, dim=dim) return tensors[0].chunk(size, dim=dim)[rank].clone() @@ -144,7 +144,7 @@ def reshard_mp( column_regex = re.compile("|".join(column_keys)) row_regex = re.compile("|".join(row_keys)) - output: Dict[str, torch.Tensor] = {} + output: dict[str, torch.Tensor] = {} with concurrent.futures.ThreadPoolExecutor() as executor: # Note: only processes keys in the first state dict. # Assumes keys are the same across all state dicts. @@ -154,7 +154,7 @@ def reshard_mp( return output -def convert_moe_weights(state_dict: Dict[str, Any], num_experts: int) -> Dict[str, Any]: +def convert_moe_weights(state_dict: dict[str, Any], num_experts: int) -> dict[str, Any]: routed_keys = _MOE_WEIGHT_ROW_KEY | _MOE_WEIGHT_COLUMN_KEY routed_regex = re.compile("|".join(routed_keys)) keys = list(state_dict.keys()) diff --git a/llama_stack/models/llama/datatypes.py b/llama_stack/models/llama/datatypes.py index 48cb51005..f9f094c3d 100644 --- a/llama_stack/models/llama/datatypes.py +++ b/llama_stack/models/llama/datatypes.py @@ -7,10 +7,9 @@ import base64 from enum import Enum from io import BytesIO -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Annotated, Any, Literal from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator -from typing_extensions import Annotated # The goal is that these set of types are relevant for all Llama models. # That isn't the current state yet -- e.g., BuiltinTool is somewhat specific to @@ -31,21 +30,21 @@ class BuiltinTool(Enum): code_interpreter = "code_interpreter" -Primitive = Union[str, int, float, bool, None] -RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]] +Primitive = str | int | float | bool | None +RecursiveType = Primitive | list[Primitive] | dict[str, Primitive] class ToolCall(BaseModel): call_id: str - tool_name: Union[BuiltinTool, str] + tool_name: BuiltinTool | str # Plan is to deprecate the Dict in favor of a JSON string # that is parsed on the client side instead of trying to manage # the recursive type here. # Making this a union so that client side can start prepping for this change. # Eventually, we will remove both the Dict and arguments_json field, # and arguments will just be a str - arguments: Union[str, Dict[str, RecursiveType]] - arguments_json: Optional[str] = None + arguments: str | dict[str, RecursiveType] + arguments_json: str | None = None @field_validator("tool_name", mode="before") @classmethod @@ -91,15 +90,15 @@ class StopReason(Enum): class ToolParamDefinition(BaseModel): param_type: str - description: Optional[str] = None - required: Optional[bool] = True - default: Optional[Any] = None + description: str | None = None + required: bool | None = True + default: Any | None = None class ToolDefinition(BaseModel): - tool_name: Union[BuiltinTool, str] - description: Optional[str] = None - parameters: Optional[Dict[str, ToolParamDefinition]] = None + tool_name: BuiltinTool | str + description: str | None = None + parameters: dict[str, ToolParamDefinition] | None = None @field_validator("tool_name", mode="before") @classmethod @@ -119,7 +118,7 @@ class RawMediaItem(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) @field_serializer("data") - def serialize_data(self, data: Optional[bytes], _info): + def serialize_data(self, data: bytes | None, _info): if data is None: return None return base64.b64encode(data).decode("utf-8") @@ -137,9 +136,9 @@ class RawTextItem(BaseModel): text: str -RawContentItem = Annotated[Union[RawTextItem, RawMediaItem], Field(discriminator="type")] +RawContentItem = Annotated[RawTextItem | RawMediaItem, Field(discriminator="type")] -RawContent = str | RawContentItem | List[RawContentItem] +RawContent = str | RawContentItem | list[RawContentItem] class RawMessage(BaseModel): @@ -147,17 +146,17 @@ class RawMessage(BaseModel): content: RawContent # This is for RAG but likely should be absorbed into content - context: Optional[RawContent] = None + context: RawContent | None = None # These are for the output message coming from the assistant - stop_reason: Optional[StopReason] = None - tool_calls: List[ToolCall] = Field(default_factory=list) + stop_reason: StopReason | None = None + tool_calls: list[ToolCall] = Field(default_factory=list) class GenerationResult(BaseModel): token: int text: str - logprobs: Optional[List[float]] = None + logprobs: list[float] | None = None source: Literal["input"] | Literal["output"] diff --git a/llama_stack/models/llama/llama3/args.py b/llama_stack/models/llama/llama3/args.py index f7e4b4557..4f92874f5 100644 --- a/llama_stack/models/llama/llama3/args.py +++ b/llama_stack/models/llama/llama3/args.py @@ -6,7 +6,6 @@ from dataclasses import dataclass from enum import Enum -from typing import Optional class QuantizationScheme(Enum): @@ -15,8 +14,8 @@ class QuantizationScheme(Enum): @dataclass class QuantizationArgs: - scheme: Optional[QuantizationScheme] = None - group_size: Optional[int] = None + scheme: QuantizationScheme | None = None + group_size: int | None = None spinquant: bool = False def __init__(self, **kwargs): @@ -39,10 +38,10 @@ class ModelArgs: dim: int = 4096 n_layers: int = 32 n_heads: int = 32 - n_kv_heads: Optional[int] = None + n_kv_heads: int | None = None vocab_size: int = -1 multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 - ffn_dim_multiplier: Optional[float] = None + ffn_dim_multiplier: float | None = None norm_eps: float = 1e-5 rope_theta: float = 500000 use_scaled_rope: bool = False @@ -55,8 +54,8 @@ class ModelArgs: vision_max_num_chunks: int = 4 vision_num_cross_attention_layers: int = -1 - quantization_args: Optional[QuantizationArgs] = None - lora_args: Optional[LoRAArgs] = None + quantization_args: QuantizationArgs | None = None + lora_args: LoRAArgs | None = None def __init__(self, **kwargs): for k, v in kwargs.items(): diff --git a/llama_stack/models/llama/llama3/chat_format.py b/llama_stack/models/llama/llama3/chat_format.py index fe7a7a898..7bb05d8db 100644 --- a/llama_stack/models/llama/llama3/chat_format.py +++ b/llama_stack/models/llama/llama3/chat_format.py @@ -8,7 +8,6 @@ import io import json import uuid from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple from PIL import Image as PIL_Image @@ -29,14 +28,14 @@ from .tool_utils import ToolUtils @dataclass class VisionInput: - mask: List[List[int]] - images: List[PIL_Image.Image] + mask: list[list[int]] + images: list[PIL_Image.Image] @dataclass class LLMInput: - tokens: List[int] - vision: Optional[VisionInput] = None + tokens: list[int] + vision: VisionInput | None = None def role_str(role: Role) -> str: @@ -50,7 +49,7 @@ def role_str(role: Role) -> str: class ChatFormat: - possible_headers: Dict[Role, str] + possible_headers: dict[Role, str] def __init__(self, tokenizer: Tokenizer): self.tokenizer = tokenizer @@ -58,7 +57,7 @@ class ChatFormat: self.possible_headers = {role: f"<|start_header_id|>{role_str(role)}<|end_header_id|>\n\n" for role in Role} self.vision_token = self.tokenizer.special_tokens["<|image|>"] - def _encode_header(self, role: str) -> List[int]: + def _encode_header(self, role: str) -> list[int]: tokens = [] tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"]) tokens.extend(self.tokenizer.encode("ipython" if role == "tool" else role, bos=False, eos=False)) @@ -70,7 +69,7 @@ class ChatFormat: tokens, images = self._encode_content(content, bos=True) return self._model_input_from_tokens_images(tokens, images) - def _encode_content(self, content: RawContent, bos: bool = False) -> Tuple[List[int], List[PIL_Image.Image]]: + def _encode_content(self, content: RawContent, bos: bool = False) -> tuple[list[int], list[PIL_Image.Image]]: tokens = [] images = [] @@ -107,7 +106,7 @@ class ChatFormat: def encode_message( self, message: RawMessage, tool_prompt_format: ToolPromptFormat - ) -> Tuple[List[int], List[PIL_Image.Image]]: + ) -> tuple[list[int], list[PIL_Image.Image]]: tokens = self._encode_header(message.role) images = [] @@ -145,8 +144,8 @@ class ChatFormat: def encode_dialog_prompt( self, - messages: List[RawMessage], - tool_prompt_format: Optional[ToolPromptFormat] = None, + messages: list[RawMessage], + tool_prompt_format: ToolPromptFormat | None = None, ) -> LLMInput: tool_prompt_format = tool_prompt_format or ToolPromptFormat.json tokens = [] @@ -163,7 +162,7 @@ class ChatFormat: return self._model_input_from_tokens_images(tokens, images) # TODO(this should be generic, not only for assistant messages) - def decode_assistant_message(self, tokens: List[int], stop_reason: StopReason) -> RawMessage: + def decode_assistant_message(self, tokens: list[int], stop_reason: StopReason) -> RawMessage: content = self.tokenizer.decode(tokens) return self.decode_assistant_message_from_content(content, stop_reason) @@ -234,7 +233,7 @@ class ChatFormat: tool_calls=tool_calls, ) - def _model_input_from_tokens_images(self, tokens: List[int], images: List[PIL_Image.Image]) -> LLMInput: + def _model_input_from_tokens_images(self, tokens: list[int], images: list[PIL_Image.Image]) -> LLMInput: vision_input = None if len(images) > 0: vision_input = VisionInput( @@ -249,9 +248,9 @@ class ChatFormat: def create_vision_mask( - tokens: List[int], + tokens: list[int], vision_token: int, -) -> List[List[int]]: +) -> list[list[int]]: vision_token_locations = [i for i, token in enumerate(tokens) if token == vision_token] if len(vision_token_locations) == 0: return [] diff --git a/llama_stack/models/llama/llama3/generation.py b/llama_stack/models/llama/llama3/generation.py index 35c140707..c6d618818 100644 --- a/llama_stack/models/llama/llama3/generation.py +++ b/llama_stack/models/llama/llama3/generation.py @@ -15,8 +15,8 @@ import json import os import sys import time +from collections.abc import Callable, Generator from pathlib import Path -from typing import Callable, Generator, List, Optional import torch import torch.nn.functional as F @@ -41,8 +41,8 @@ class Llama3: ckpt_dir: str, max_seq_len: int, max_batch_size: int, - world_size: Optional[int] = None, - quantization_mode: Optional[QuantizationMode] = None, + world_size: int | None = None, + quantization_mode: QuantizationMode | None = None, seed: int = 1, device: str = "cuda", ): @@ -82,7 +82,7 @@ class Llama3: ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth")) assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}" print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})") - with open(Path(ckpt_dir) / "params.json", "r") as f: + with open(Path(ckpt_dir) / "params.json") as f: params = json.loads(f.read()) model_args: ModelArgs = ModelArgs( @@ -154,15 +154,15 @@ class Llama3: @torch.inference_mode() def generate( self, - llm_inputs: List[LLMInput], + llm_inputs: list[LLMInput], temperature: float = 0.6, top_p: float = 0.9, - max_gen_len: Optional[int] = None, + max_gen_len: int | None = None, logprobs: bool = False, echo: bool = False, print_model_input: bool = False, - logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, - ) -> Generator[List[GenerationResult], None, None]: + logits_processor: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, + ) -> Generator[list[GenerationResult], None, None]: if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len: max_gen_len = self.args.max_seq_len - 1 params = self.model.params @@ -302,13 +302,13 @@ class Llama3: def completion( self, - contents: List[RawContent], + contents: list[RawContent], temperature: float = 0.6, top_p: float = 0.9, - max_gen_len: Optional[int] = None, + max_gen_len: int | None = None, logprobs: bool = False, echo: bool = False, - ) -> Generator[List[GenerationResult], None, None]: + ) -> Generator[list[GenerationResult], None, None]: model_inputs = [self.formatter.encode_content(c) for c in contents] for result in self.generate( model_inputs=model_inputs, @@ -324,14 +324,14 @@ class Llama3: def chat_completion( self, - messages_batch: List[List[RawMessage]], + messages_batch: list[list[RawMessage]], temperature: float = 0.6, top_p: float = 0.9, - max_gen_len: Optional[int] = None, + max_gen_len: int | None = None, logprobs: bool = False, tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json, echo: bool = False, - ) -> Generator[List[GenerationResult], None, None]: + ) -> Generator[list[GenerationResult], None, None]: model_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch] for result in self.generate( model_inputs=model_inputs, diff --git a/llama_stack/models/llama/llama3/interface.py b/llama_stack/models/llama/llama3/interface.py index 8684237df..b63ba4847 100644 --- a/llama_stack/models/llama/llama3/interface.py +++ b/llama_stack/models/llama/llama3/interface.py @@ -12,7 +12,6 @@ # the top-level of this source tree. from pathlib import Path -from typing import List, Optional from termcolor import colored @@ -131,7 +130,7 @@ class LLama31Interface: self.formatter = ChatFormat(self.tokenizer) self.tool_prompt_format = tool_prompt_format - def get_tokens(self, messages: List[RawMessage]) -> List[int]: + def get_tokens(self, messages: list[RawMessage]) -> list[int]: model_input = self.formatter.encode_dialog_prompt( messages, self.tool_prompt_format, @@ -149,10 +148,10 @@ class LLama31Interface: def system_messages( self, - builtin_tools: List[BuiltinTool], - custom_tools: List[ToolDefinition], - instruction: Optional[str] = None, - ) -> List[RawMessage]: + builtin_tools: list[BuiltinTool], + custom_tools: list[ToolDefinition], + instruction: str | None = None, + ) -> list[RawMessage]: messages = [] default_gen = SystemDefaultGenerator() @@ -194,8 +193,8 @@ class LLama31Interface: self, content: str, stop_reason: StopReason, - tool_call: Optional[ToolCall] = None, - ) -> List[RawMessage]: + tool_call: ToolCall | None = None, + ) -> list[RawMessage]: tool_calls = [] if tool_call: tool_calls.append(tool_call) @@ -208,7 +207,7 @@ class LLama31Interface: ) ] - def user_message(self, content: str) -> List[RawMessage]: + def user_message(self, content: str) -> list[RawMessage]: return [RawMessage(role="user", content=content)] def display_message_as_tokens(self, message: RawMessage) -> None: @@ -228,7 +227,7 @@ class LLama31Interface: print("\n", end="") -def list_jinja_templates() -> List[Template]: +def list_jinja_templates() -> list[Template]: return TEMPLATES diff --git a/llama_stack/models/llama/llama3/model.py b/llama_stack/models/llama/llama3/model.py index 2562673e2..88f748c1d 100644 --- a/llama_stack/models/llama/llama3/model.py +++ b/llama_stack/models/llama/llama3/model.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import math -from typing import Optional, Tuple import fairscale.nn.model_parallel.initialize as fs_init import torch @@ -80,7 +79,7 @@ def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) freqs_cis = reshape_for_broadcast(freqs_cis, xq_) @@ -162,7 +161,7 @@ class Attention(nn.Module): x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, - mask: Optional[torch.Tensor], + mask: torch.Tensor | None, ): bsz, seqlen, _ = x.shape xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) @@ -204,7 +203,7 @@ class FeedForward(nn.Module): dim: int, hidden_dim: int, multiple_of: int, - ffn_dim_multiplier: Optional[float], + ffn_dim_multiplier: float | None, ): super().__init__() hidden_dim = int(2 * hidden_dim / 3) @@ -243,7 +242,7 @@ class TransformerBlock(nn.Module): x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, - mask: Optional[torch.Tensor], + mask: torch.Tensor | None, ): h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask) out = h + self.feed_forward(self.ffn_norm(h)) diff --git a/llama_stack/models/llama/llama3/multimodal/image_transform.py b/llama_stack/models/llama/llama3/multimodal/image_transform.py index c156d6d2e..f2761ee47 100644 --- a/llama_stack/models/llama/llama3/multimodal/image_transform.py +++ b/llama_stack/models/llama/llama3/multimodal/image_transform.py @@ -14,7 +14,7 @@ import math from collections import defaultdict from logging import getLogger -from typing import Any, Optional, Set, Tuple +from typing import Any import torch import torchvision.transforms as tv @@ -26,7 +26,7 @@ IMAGE_RES = 224 logger = getLogger() -class VariableSizeImageTransform(object): +class VariableSizeImageTransform: """ This class accepts images of any size and dynamically resize, pads and chunks it based on the image aspect ratio and the number of image chunks we allow. @@ -75,7 +75,7 @@ class VariableSizeImageTransform(object): self.resample = tv.InterpolationMode.BILINEAR @staticmethod - def get_factors(n: int) -> Set[int]: + def get_factors(n: int) -> set[int]: """ Calculate all factors of a given number, i.e. a dividor that leaves no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}. @@ -145,9 +145,9 @@ class VariableSizeImageTransform(object): @staticmethod def get_max_res_without_distortion( - image_size: Tuple[int, int], - target_size: Tuple[int, int], - ) -> Tuple[int, int]: + image_size: tuple[int, int], + target_size: tuple[int, int], + ) -> tuple[int, int]: """ Determines the maximum resolution to which an image can be resized to without distorting its aspect ratio, based on the target resolution. @@ -198,8 +198,8 @@ class VariableSizeImageTransform(object): def resize_without_distortion( self, image: torch.Tensor, - target_size: Tuple[int, int], - max_upscaling_size: Optional[int], + target_size: tuple[int, int], + max_upscaling_size: int | None, ) -> torch.Tensor: """ Used to resize an image to target_resolution, without distortion. @@ -261,10 +261,10 @@ class VariableSizeImageTransform(object): def get_best_fit( self, - image_size: Tuple[int, int], + image_size: tuple[int, int], possible_resolutions: torch.Tensor, resize_to_max_canvas: bool = False, - ) -> Tuple[int, int]: + ) -> tuple[int, int]: """ Determines the best canvas possible from a list of possible resolutions to, without distortion, resize an image to. @@ -364,7 +364,7 @@ class VariableSizeImageTransform(object): max_num_chunks: int, normalize_img: bool = True, resize_to_max_canvas: bool = False, - ) -> Tuple[Any, Any]: + ) -> tuple[Any, Any]: """ Args: image (PIL.Image): Image to be resized. diff --git a/llama_stack/models/llama/llama3/multimodal/model.py b/llama_stack/models/llama/llama3/multimodal/model.py index 0cb18b948..5f1c3605c 100644 --- a/llama_stack/models/llama/llama3/multimodal/model.py +++ b/llama_stack/models/llama/llama3/multimodal/model.py @@ -6,8 +6,9 @@ import logging import math +from collections.abc import Callable from functools import partial -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any import fairscale.nn.model_parallel.initialize as fs_init import torch @@ -104,9 +105,9 @@ class ColumnParallelConv2dPatch(torch.nn.Module): self, in_channels: int, out_channels: int, - kernel_size: Union[int, Tuple[int, int]], - stride: Union[int, Tuple[int, int]], - bias: Optional[bool] = False, + kernel_size: int | tuple[int, int], + stride: int | tuple[int, int], + bias: bool | None = False, ) -> None: super().__init__() if isinstance(kernel_size, int): @@ -390,13 +391,13 @@ class VisionEncoder(nn.Module): def load_hook( self, - state_dict: Dict[str, Any], + state_dict: dict[str, Any], prefix: str, - local_metadata: Dict[str, Any], + local_metadata: dict[str, Any], strict: bool = True, - missing_keys: List[str] = None, - unexpected_keys: List[str] = None, - error_msgs: List[str] = None, + missing_keys: list[str] = None, + unexpected_keys: list[str] = None, + error_msgs: list[str] = None, return_state_dict: bool = False, ) -> None: orig_pos_embed = state_dict.get(prefix + "positional_embedding") @@ -641,7 +642,7 @@ class FeedForward(nn.Module): dim: int, hidden_dim: int, multiple_of: int, - ffn_dim_multiplier: Optional[float], + ffn_dim_multiplier: float | None, ): """ Initialize the FeedForward module. @@ -983,7 +984,7 @@ class CrossAttentionTransformerBlock(torch.nn.Module): self, x: torch.Tensor, xattn_mask: torch.Tensor, - full_text_row_masked_out_mask: Tuple[torch.Tensor, torch.Tensor], + full_text_row_masked_out_mask: tuple[torch.Tensor, torch.Tensor], xattn_cache: torch.Tensor, ) -> torch.Tensor: _attn_out = self.attention( @@ -1144,7 +1145,7 @@ class CrossAttentionTransformerText(torch.nn.Module): def _init_fusion_schedule( self, num_layers: int, - ) -> List[int]: + ) -> list[int]: llama_layers = list(range(self.n_llama_layers)) # uniformly spread the layers @@ -1231,7 +1232,7 @@ class CrossAttentionTransformerText(torch.nn.Module): text_dtype, vision_tokens, cross_attention_masks, - ) -> Tuple[Tensor, Tensor]: + ) -> tuple[Tensor, Tensor]: assert vision_tokens is not None, "Vision tokens must be provided" vision_seqlen = vision_tokens.shape[3] assert vision_tokens.shape[1] == cross_attention_masks.shape[2], ( @@ -1280,11 +1281,11 @@ class CrossAttentionTransformer(torch.nn.Module): def compute_vision_tokens_masks( self, - batch_images: List[List[PIL_Image.Image]], - batch_masks: List[List[List[int]]], + batch_images: list[list[PIL_Image.Image]], + batch_masks: list[list[list[int]]], total_len: int, device: torch.device, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: skip_vision_encoder = False assert len(batch_images) == len(batch_masks), "Images and masks must have the same length" @@ -1371,11 +1372,11 @@ class CrossAttentionTransformer(torch.nn.Module): def _stack_images( - images: List[List[PIL_Image.Image]], + images: list[list[PIL_Image.Image]], max_num_chunks: int, image_res: int, max_num_images: int, -) -> Tuple[torch.Tensor, List[int]]: +) -> tuple[torch.Tensor, list[int]]: """ Takes a list of list of images and stacks them into a tensor. This function is needed since images can be of completely @@ -1400,8 +1401,8 @@ def _stack_images( def _pad_masks( - all_masks: List[List[List[int]]], - all_num_chunks: List[List[int]], + all_masks: list[list[list[int]]], + all_num_chunks: list[list[int]], total_len: int, max_num_chunks: int, ) -> torch.Tensor: diff --git a/llama_stack/models/llama/llama3/prompt_templates/base.py b/llama_stack/models/llama/llama3/prompt_templates/base.py index bff2a21e1..0081443be 100644 --- a/llama_stack/models/llama/llama3/prompt_templates/base.py +++ b/llama_stack/models/llama/llama3/prompt_templates/base.py @@ -12,7 +12,7 @@ # the top-level of this source tree. from dataclasses import dataclass -from typing import Any, Dict, List +from typing import Any from jinja2 import Template @@ -20,7 +20,7 @@ from jinja2 import Template @dataclass class PromptTemplate: template: str - data: Dict[str, Any] + data: dict[str, Any] def render(self): template = Template(self.template) @@ -35,5 +35,5 @@ class PromptTemplateGeneratorBase: def gen(self, *args, **kwargs) -> PromptTemplate: raise NotImplementedError() - def data_examples(self) -> List[Any]: + def data_examples(self) -> list[Any]: raise NotImplementedError() diff --git a/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py b/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py index fbc0127fd..8e6f97012 100644 --- a/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py +++ b/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py @@ -13,7 +13,7 @@ import textwrap from datetime import datetime -from typing import Any, List, Optional +from typing import Any from llama_stack.apis.inference import ( BuiltinTool, @@ -39,12 +39,12 @@ class SystemDefaultGenerator(PromptTemplateGeneratorBase): }, ) - def data_examples(self) -> List[Any]: + def data_examples(self) -> list[Any]: return [None] class BuiltinToolGenerator(PromptTemplateGeneratorBase): - def _tool_breakdown(self, tools: List[ToolDefinition]): + def _tool_breakdown(self, tools: list[ToolDefinition]): builtin_tools, custom_tools = [], [] for dfn in tools: if isinstance(dfn.tool_name, BuiltinTool): @@ -54,7 +54,7 @@ class BuiltinToolGenerator(PromptTemplateGeneratorBase): return builtin_tools, custom_tools - def gen(self, tools: List[ToolDefinition]) -> PromptTemplate: + def gen(self, tools: list[ToolDefinition]) -> PromptTemplate: builtin_tools, custom_tools = self._tool_breakdown(tools) template_str = textwrap.dedent( """ @@ -75,7 +75,7 @@ class BuiltinToolGenerator(PromptTemplateGeneratorBase): }, ) - def data_examples(self) -> List[List[ToolDefinition]]: + def data_examples(self) -> list[list[ToolDefinition]]: return [ # builtin tools [ @@ -91,7 +91,7 @@ class BuiltinToolGenerator(PromptTemplateGeneratorBase): class JsonCustomToolGenerator(PromptTemplateGeneratorBase): - def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate: + def gen(self, custom_tools: list[ToolDefinition]) -> PromptTemplate: template_str = textwrap.dedent( """ Answer the user's question by making use of the following functions if needed. @@ -137,7 +137,7 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase): {"custom_tools": [t.model_dump() for t in custom_tools]}, ) - def data_examples(self) -> List[List[ToolDefinition]]: + def data_examples(self) -> list[list[ToolDefinition]]: return [ [ ToolDefinition( @@ -161,7 +161,7 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase): class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase): - def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate: + def gen(self, custom_tools: list[ToolDefinition]) -> PromptTemplate: template_str = textwrap.dedent( """ You have access to the following functions: @@ -199,7 +199,7 @@ class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase): {"custom_tools": [t.model_dump() for t in custom_tools]}, ) - def data_examples(self) -> List[List[ToolDefinition]]: + def data_examples(self) -> list[list[ToolDefinition]]: return [ [ ToolDefinition( @@ -238,14 +238,14 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801 """.strip("\n") ) - def gen(self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None) -> PromptTemplate: + def gen(self, custom_tools: list[ToolDefinition], system_prompt: str | None = None) -> PromptTemplate: system_prompt = system_prompt or self.DEFAULT_PROMPT return PromptTemplate( system_prompt, {"function_description": self._gen_function_description(custom_tools)}, ) - def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate: + def _gen_function_description(self, custom_tools: list[ToolDefinition]) -> PromptTemplate: template_str = textwrap.dedent( """ Here is a list of functions in JSON format that you can invoke. @@ -291,7 +291,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801 {"tools": [t.model_dump() for t in custom_tools]}, ).render() - def data_examples(self) -> List[List[ToolDefinition]]: + def data_examples(self) -> list[list[ToolDefinition]]: return [ [ ToolDefinition( diff --git a/llama_stack/models/llama/llama3/prompt_templates/tool_response.py b/llama_stack/models/llama/llama3/prompt_templates/tool_response.py index 3df4dac14..4da171279 100644 --- a/llama_stack/models/llama/llama3/prompt_templates/tool_response.py +++ b/llama_stack/models/llama/llama3/prompt_templates/tool_response.py @@ -12,7 +12,6 @@ # the top-level of this source tree. import textwrap -from typing import Optional from .base import PromptTemplate, PromptTemplateGeneratorBase @@ -21,8 +20,8 @@ class ToolResponseGenerator(PromptTemplateGeneratorBase): def gen( self, status: str, - stdout: Optional[str] = None, - stderr: Optional[str] = None, + stdout: str | None = None, + stderr: str | None = None, ): assert status in [ "success", diff --git a/llama_stack/models/llama/llama3/quantization/loader.py b/llama_stack/models/llama/llama3/quantization/loader.py index 771fd02be..436cfa6fa 100644 --- a/llama_stack/models/llama/llama3/quantization/loader.py +++ b/llama_stack/models/llama/llama3/quantization/loader.py @@ -6,7 +6,7 @@ # type: ignore import os -from typing import Any, Dict, List, Optional, cast +from typing import Any, cast import torch from fairscale.nn.model_parallel.initialize import get_model_parallel_rank @@ -37,9 +37,9 @@ def swiglu_wrapper( def convert_to_quantized_model( model: Transformer | CrossAttentionTransformer, checkpoint_dir: str, - quantization_mode: Optional[str] = None, - fp8_activation_scale_ub: Optional[float] = 1200.0, - device: Optional[torch.device] = None, + quantization_mode: str | None = None, + fp8_activation_scale_ub: float | None = 1200.0, + device: torch.device | None = None, ) -> Transformer | CrossAttentionTransformer: if quantization_mode == QuantizationMode.fp8_mixed: return convert_to_fp8_quantized_model(model, checkpoint_dir, fp8_activation_scale_ub, device) @@ -52,8 +52,8 @@ def convert_to_quantized_model( def convert_to_fp8_quantized_model( model: Transformer, checkpoint_dir: str, - fp8_activation_scale_ub: Optional[float] = 1200.0, - device: Optional[torch.device] = None, + fp8_activation_scale_ub: float | None = 1200.0, + device: torch.device | None = None, ) -> Transformer: # Move weights to GPU with quantization fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt") @@ -122,8 +122,8 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear): precision: torch.dtype = torch.float32, scales_precision: torch.dtype = torch.float32, # LoRA parameters - lora_rank: Optional[int] = None, - lora_scale: Optional[float] = None, + lora_rank: int | None = None, + lora_scale: float | None = None, ) -> None: super().__init__( in_features, @@ -134,8 +134,8 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear): precision=precision, scales_precision=scales_precision, ) - self.lora_scale: Optional[float] = None - self.adaptor: Optional[nn.Sequential] = None + self.lora_scale: float | None = None + self.adaptor: nn.Sequential | None = None if lora_rank is not None: assert lora_scale is not None, "Please specify lora scale for LoRA." # Low-rank adaptation. See paper for more details: https://arxiv.org/abs/2106.09685 @@ -147,13 +147,13 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear): def load_hook( self, - state_dict: Dict[str, Any], + state_dict: dict[str, Any], prefix: str, - local_metadata: Dict[str, Any], + local_metadata: dict[str, Any], strict: bool, - missing_keys: List[str], - unexpected_keys: List[str], - error_msgs: List[str], + missing_keys: list[str], + unexpected_keys: list[str], + error_msgs: list[str], ) -> None: """A hook to load the quantized weights from the state dict.""" if prefix + "zeros" not in state_dict: @@ -191,13 +191,13 @@ class Int8WeightEmbedding(torch.nn.Embedding): def load_hook( self, - state_dict: Dict[str, Any], + state_dict: dict[str, Any], prefix: str, - local_metadata: Dict[str, Any], + local_metadata: dict[str, Any], strict: bool, - missing_keys: List[str], - unexpected_keys: List[str], - error_msgs: List[str], + missing_keys: list[str], + unexpected_keys: list[str], + error_msgs: list[str], ) -> None: """A hook to load the quantized embedding weight and scales from the state dict.""" weights = state_dict.pop(prefix + "weight") @@ -221,13 +221,13 @@ class Int8WeightLinear(torch.nn.Linear): def load_hook( self, - state_dict: Dict[str, Any], + state_dict: dict[str, Any], prefix: str, - local_metadata: Dict[str, Any], + local_metadata: dict[str, Any], strict: bool, - missing_keys: List[str], - unexpected_keys: List[str], - error_msgs: List[str], + missing_keys: list[str], + unexpected_keys: list[str], + error_msgs: list[str], ) -> None: """A hook to load the quantized linear weight and scales from the state dict.""" weights = state_dict.pop(prefix + "weight") @@ -238,8 +238,8 @@ class Int8WeightLinear(torch.nn.Linear): def _prepare_model_int4_weight_int8_dynamic_activation( model: torch.nn.Module, group_size: int, - lora_rank: Optional[int], - lora_scale: Optional[float], + lora_rank: int | None, + lora_scale: float | None, ): """Prepare the model for int4 weight and int8 dynamic activation quantization. @@ -265,7 +265,7 @@ def _prepare_model_int4_weight_int8_dynamic_activation( ) del module setattr(model, module_name, quantized_module) - elif isinstance(module, (ColumnParallelLinear, RowParallelLinear, nn.Linear)): + elif isinstance(module, ColumnParallelLinear | RowParallelLinear | nn.Linear): quantized_module = Int8DynActInt4WeightLinearLoRA( in_features=module.in_features, out_features=module.out_features, @@ -286,7 +286,7 @@ def _prepare_model_int4_weight_int8_dynamic_activation( def convert_to_int4_quantized_model( model: Transformer | CrossAttentionTransformer, checkpoint_dir: str, - device: Optional[torch.device] = None, + device: torch.device | None = None, ) -> Transformer | CrossAttentionTransformer: """Convert the model to int4 quantized model.""" model_args = model.params diff --git a/llama_stack/models/llama/llama3/tokenizer.py b/llama_stack/models/llama/llama3/tokenizer.py index d3cc4fc07..e5ada3599 100644 --- a/llama_stack/models/llama/llama3/tokenizer.py +++ b/llama_stack/models/llama/llama3/tokenizer.py @@ -5,18 +5,11 @@ # the root directory of this source tree. import os +from collections.abc import Collection, Iterator, Sequence, Set from logging import getLogger from pathlib import Path from typing import ( - AbstractSet, - Collection, - Dict, - Iterator, - List, Literal, - Optional, - Sequence, - Union, cast, ) @@ -44,7 +37,7 @@ class Tokenizer: Tokenizing and encoding/decoding text using the Tiktoken tokenizer. """ - special_tokens: Dict[str, int] + special_tokens: dict[str, int] num_reserved_special_tokens = 256 @@ -116,9 +109,9 @@ class Tokenizer: *, bos: bool, eos: bool, - allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None, - disallowed_special: Union[Literal["all"], Collection[str]] = (), - ) -> List[int]: + allowed_special: Literal["all"] | Set[str] | None = None, + disallowed_special: Literal["all"] | Collection[str] = (), + ) -> list[int]: """ Encodes a string into a list of token IDs. @@ -151,7 +144,7 @@ class Tokenizer: s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS ) ) - t: List[int] = [] + t: list[int] = [] for substr in substrs: t.extend( self.model.encode( @@ -177,7 +170,7 @@ class Tokenizer: str: The decoded string. """ # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. - return self.model.decode(cast(List[int], t)) + return self.model.decode(cast(list[int], t)) @staticmethod def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]: diff --git a/llama_stack/models/llama/llama3/tool_utils.py b/llama_stack/models/llama/llama3/tool_utils.py index 91b46ec98..574080184 100644 --- a/llama_stack/models/llama/llama3/tool_utils.py +++ b/llama_stack/models/llama/llama3/tool_utils.py @@ -6,7 +6,6 @@ import json import re -from typing import Optional, Tuple from llama_stack.log import get_logger @@ -172,7 +171,7 @@ class ToolUtils: return match is not None @staticmethod - def maybe_extract_builtin_tool_call(message_body: str) -> Optional[Tuple[str, str]]: + def maybe_extract_builtin_tool_call(message_body: str) -> tuple[str, str] | None: # Find the first match in the text match = re.search(BUILTIN_TOOL_PATTERN, message_body) @@ -185,7 +184,7 @@ class ToolUtils: return None @staticmethod - def maybe_extract_custom_tool_call(message_body: str) -> Optional[Tuple[str, str]]: + def maybe_extract_custom_tool_call(message_body: str) -> tuple[str, str] | None: # NOTE: Custom function too calls are still experimental # Sometimes, response is of the form # {"type": "function", "name": "function_name", "parameters": {...} @@ -252,7 +251,7 @@ class ToolUtils: def format_value(value: RecursiveType) -> str: if isinstance(value, str): return f'"{value}"' - elif isinstance(value, (int, float, bool)) or value is None: + elif isinstance(value, int | float | bool) or value is None: return str(value) elif isinstance(value, list): return f"[{', '.join(format_value(v) for v in value)}]" diff --git a/llama_stack/models/llama/llama3_1/prompts.py b/llama_stack/models/llama/llama3_1/prompts.py index 9dcc51dc8..579a5ee02 100644 --- a/llama_stack/models/llama/llama3_1/prompts.py +++ b/llama_stack/models/llama/llama3_1/prompts.py @@ -12,7 +12,6 @@ # the top-level of this source tree. import textwrap -from typing import List from llama_stack.models.llama.datatypes import ( BuiltinTool, @@ -73,7 +72,7 @@ def wolfram_alpha_response(): ) -def usecases() -> List[UseCase | str]: +def usecases() -> list[UseCase | str]: return [ textwrap.dedent( """ diff --git a/llama_stack/models/llama/llama3_3/prompts.py b/llama_stack/models/llama/llama3_3/prompts.py index 194e4fa26..60349e578 100644 --- a/llama_stack/models/llama/llama3_3/prompts.py +++ b/llama_stack/models/llama/llama3_3/prompts.py @@ -12,7 +12,6 @@ # the top-level of this source tree. import textwrap -from typing import List from llama_stack.models.llama.datatypes import ( BuiltinTool, @@ -74,7 +73,7 @@ def wolfram_alpha_response(): ) -def usecases() -> List[UseCase | str]: +def usecases() -> list[UseCase | str]: return [ textwrap.dedent( """ diff --git a/llama_stack/models/llama/llama4/args.py b/llama_stack/models/llama/llama4/args.py index dd5f7cbde..523d6ed10 100644 --- a/llama_stack/models/llama/llama4/args.py +++ b/llama_stack/models/llama/llama4/args.py @@ -5,7 +5,6 @@ # the root directory of this source tree. from enum import Enum -from typing import Optional from pydantic import BaseModel, model_validator @@ -15,8 +14,8 @@ class QuantizationScheme(Enum): class QuantizationArgs(BaseModel): - scheme: Optional[QuantizationScheme] = None - group_size: Optional[int] = None + scheme: QuantizationScheme | None = None + group_size: int | None = None spinquant: bool = False @@ -58,32 +57,32 @@ class ModelArgs(BaseModel): dim: int = -1 n_layers: int = -1 n_heads: int = -1 - n_kv_heads: Optional[int] = None - head_dim: Optional[int] = None + n_kv_heads: int | None = None + head_dim: int | None = None vocab_size: int = -1 multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 - ffn_dim_multiplier: Optional[float] = None - ffn_exp: Optional[float] = None + ffn_dim_multiplier: float | None = None + ffn_exp: float | None = None norm_eps: float = 1e-5 - attention_chunk_size: Optional[int] = None + attention_chunk_size: int | None = None rope_theta: float = 500000 use_scaled_rope: bool = False - rope_scaling_factor: Optional[float] = None - rope_high_freq_factor: Optional[float] = None + rope_scaling_factor: float | None = None + rope_high_freq_factor: float | None = None - nope_layer_interval: Optional[int] = None # No position encoding in every n layers + nope_layer_interval: int | None = None # No position encoding in every n layers use_qk_norm: bool = False # Set to True to enable inference-time temperature tuning (useful for very long context) attn_temperature_tuning: bool = False floor_scale: float = 8192.0 attn_scale: float = 0.1 - vision_args: Optional[VisionArgs] = None - moe_args: Optional[MoEArgs] = None - quantization_args: Optional[QuantizationArgs] = None - lora_args: Optional[LoRAArgs] = None + vision_args: VisionArgs | None = None + moe_args: MoEArgs | None = None + quantization_args: QuantizationArgs | None = None + lora_args: LoRAArgs | None = None max_batch_size: int = 32 max_seq_len: int = 2048 diff --git a/llama_stack/models/llama/llama4/chat_format.py b/llama_stack/models/llama/llama4/chat_format.py index 1574eeb5e..96ebd0881 100644 --- a/llama_stack/models/llama/llama4/chat_format.py +++ b/llama_stack/models/llama/llama4/chat_format.py @@ -8,7 +8,6 @@ import io import json import uuid from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple import torch from PIL import Image as PIL_Image @@ -46,10 +45,10 @@ def role_str(role: Role) -> str: class TransformedImage: image_tiles: torch.Tensor # is the aspect ratio needed anywhere? - aspect_ratio: Tuple[int, int] + aspect_ratio: tuple[int, int] -def convert_image_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image: +def convert_image_to_rgb(image: PIL_Image.Image, bg: tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image: if image.mode == "RGBA": image.load() # for png.split() new_img = PIL_Image.new("RGB", image.size, bg) @@ -59,12 +58,12 @@ def convert_image_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255 class ChatFormat: - possible_headers: Dict[Role, str] + possible_headers: dict[Role, str] def __init__( self, tokenizer: Tokenizer, - vision_args: Optional[VisionArgs] = None, + vision_args: VisionArgs | None = None, max_num_chunks: int = 16, ): self.tokenizer = tokenizer @@ -81,7 +80,7 @@ class ChatFormat: vision_args.image_size.width, vision_args.image_size.height ) - def _encode_header(self, role: str) -> List[int]: + def _encode_header(self, role: str) -> list[int]: tokens = [] tokens.append(self.tokenizer.special_tokens["<|header_start|>"]) @@ -98,7 +97,7 @@ class ChatFormat: def _encode_image( self, transformed_image: TransformedImage, - ) -> List[int]: + ) -> list[int]: assert self.vision_args is not None, "The model is not vision-enabled" image_tensor = transformed_image.image_tiles @@ -140,7 +139,7 @@ class ChatFormat: return tokens - def _encode_content(self, content: RawContent, bos: bool = False) -> Tuple[List[int], List[TransformedImage]]: + def _encode_content(self, content: RawContent, bos: bool = False) -> tuple[list[int], list[TransformedImage]]: tokens = [] tranformed_images = [] @@ -189,7 +188,7 @@ class ChatFormat: def encode_message( self, message: RawMessage, tool_prompt_format: ToolPromptFormat - ) -> Tuple[List[int], List[TransformedImage]]: + ) -> tuple[list[int], list[TransformedImage]]: tokens = self._encode_header(message.role) images = [] @@ -223,7 +222,7 @@ class ChatFormat: def encode_dialog_prompt( self, - messages: List[RawMessage], + messages: list[RawMessage], tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json, ) -> LLMInput: tokens = [] @@ -240,7 +239,7 @@ class ChatFormat: return self._model_input_from_tokens_images(tokens, images) # TODO(this should be generic, not only for assistant messages) - def decode_assistant_message(self, tokens: List[int], stop_reason: StopReason) -> RawMessage: + def decode_assistant_message(self, tokens: list[int], stop_reason: StopReason) -> RawMessage: content = self.tokenizer.decode(tokens) return self.decode_assistant_message_from_content(content, stop_reason) @@ -312,7 +311,7 @@ class ChatFormat: tool_calls=tool_calls, ) - def _model_input_from_tokens_images(self, tokens: List[int], images: List[TransformedImage]) -> LLMInput: + def _model_input_from_tokens_images(self, tokens: list[int], images: list[TransformedImage]) -> LLMInput: return LLMInput( tokens=tokens, images=[x.image_tiles for x in images] if len(images) > 0 else None, diff --git a/llama_stack/models/llama/llama4/datatypes.py b/llama_stack/models/llama/llama4/datatypes.py index 27174db63..24d8ae948 100644 --- a/llama_stack/models/llama/llama4/datatypes.py +++ b/llama_stack/models/llama/llama4/datatypes.py @@ -5,7 +5,6 @@ # the root directory of this source tree. from dataclasses import dataclass -from typing import List, Optional, Union import torch @@ -30,7 +29,7 @@ class LLMInput: tokens: torch.Tensor # images are already pre-processed (resized, tiled, etc.) - images: Optional[List[torch.Tensor]] = None + images: list[torch.Tensor] | None = None @dataclass @@ -45,8 +44,8 @@ class TransformerInput: # tokens_position defines the position of the tokens in each batch, # - when it is a tensor ([batch_size,]), it is the start position of the tokens in each batch # - when it is an int, the start position are the same for all batches - tokens_position: Union[torch.Tensor, int] - image_embedding: Optional[MaskedEmbedding] = None + tokens_position: torch.Tensor | int + image_embedding: MaskedEmbedding | None = None @dataclass diff --git a/llama_stack/models/llama/llama4/ffn.py b/llama_stack/models/llama/llama4/ffn.py index 9c9fca5fc..6584f1a2a 100644 --- a/llama_stack/models/llama/llama4/ffn.py +++ b/llama_stack/models/llama/llama4/ffn.py @@ -11,7 +11,7 @@ # top-level folder for each specific model found within the models/ directory at # the top-level of this source tree. -from typing import Any, Dict, List +from typing import Any from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region @@ -36,13 +36,13 @@ class FeedForward(nn.Module): def load_hook( self, - state_dict: Dict[str, Any], + state_dict: dict[str, Any], prefix: str, - local_metadata: Dict[str, Any], + local_metadata: dict[str, Any], strict: bool, - missing_keys: List[str], - unexpected_keys: List[str], - error_msgs: List[str], + missing_keys: list[str], + unexpected_keys: list[str], + error_msgs: list[str], ) -> None: if prefix + "mlp.fc1_weight" in state_dict: w1, w3 = state_dict.pop(prefix + "mlp.fc1_weight").chunk(2, dim=0) diff --git a/llama_stack/models/llama/llama4/generation.py b/llama_stack/models/llama/llama4/generation.py index 8e94bb33a..476761209 100644 --- a/llama_stack/models/llama/llama4/generation.py +++ b/llama_stack/models/llama/llama4/generation.py @@ -10,8 +10,8 @@ import json import os import sys import time +from collections.abc import Callable, Generator from pathlib import Path -from typing import Callable, Generator, List, Optional import torch import torch.nn.functional as F @@ -38,8 +38,8 @@ class Llama4: ckpt_dir: str, max_seq_len: int, max_batch_size: int, - world_size: Optional[int] = None, - quantization_mode: Optional[QuantizationMode] = None, + world_size: int | None = None, + quantization_mode: QuantizationMode | None = None, seed: int = 1, ): if not torch.distributed.is_initialized(): @@ -63,7 +63,7 @@ class Llama4: ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth")) assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}" print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})") - with open(Path(ckpt_dir) / "params.json", "r") as f: + with open(Path(ckpt_dir) / "params.json") as f: params = json.loads(f.read()) model_args: ModelArgs = ModelArgs( @@ -117,15 +117,15 @@ class Llama4: @torch.inference_mode() def generate( self, - llm_inputs: List[LLMInput], + llm_inputs: list[LLMInput], temperature: float = 0.6, top_p: float = 0.9, - max_gen_len: Optional[int] = None, + max_gen_len: int | None = None, logprobs: bool = False, echo: bool = False, print_model_input: bool = False, - logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, - ) -> Generator[List[GenerationResult], None, None]: + logits_processor: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, + ) -> Generator[list[GenerationResult], None, None]: if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.args.max_seq_len: max_gen_len = self.model.args.max_seq_len - 1 @@ -245,13 +245,13 @@ class Llama4: def completion( self, - contents: List[RawContent], + contents: list[RawContent], temperature: float = 0.6, top_p: float = 0.9, - max_gen_len: Optional[int] = None, + max_gen_len: int | None = None, logprobs: bool = False, echo: bool = False, - ) -> Generator[List[GenerationResult], None, None]: + ) -> Generator[list[GenerationResult], None, None]: llm_inputs = [self.formatter.encode_content(c) for c in contents] for result in self.generate( llm_inputs=llm_inputs, @@ -267,13 +267,13 @@ class Llama4: def chat_completion( self, - messages_batch: List[List[RawMessage]], + messages_batch: list[list[RawMessage]], temperature: float = 0.6, top_p: float = 0.9, - max_gen_len: Optional[int] = None, + max_gen_len: int | None = None, logprobs: bool = False, echo: bool = False, - ) -> Generator[List[GenerationResult], None, None]: + ) -> Generator[list[GenerationResult], None, None]: llm_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch] for result in self.generate( llm_inputs=llm_inputs, diff --git a/llama_stack/models/llama/llama4/model.py b/llama_stack/models/llama/llama4/model.py index 2272b868d..4fb1181f7 100644 --- a/llama_stack/models/llama/llama4/model.py +++ b/llama_stack/models/llama/llama4/model.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import math -from typing import Any, Dict, List, Optional, Tuple +from typing import Any import fairscale.nn.model_parallel.initialize as fs_init import torch @@ -89,7 +89,7 @@ def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) freqs_cis = reshape_for_broadcast(freqs_cis, xq_) @@ -174,13 +174,13 @@ class Attention(nn.Module): def load_hook( self, - state_dict: Dict[str, Any], + state_dict: dict[str, Any], prefix: str, - local_metadata: Dict[str, Any], + local_metadata: dict[str, Any], strict: bool, - missing_keys: List[str], - unexpected_keys: List[str], - error_msgs: List[str], + missing_keys: list[str], + unexpected_keys: list[str], + error_msgs: list[str], ) -> None: if prefix + "wqkv.weight" in state_dict: wqkv = state_dict.pop(prefix + "wqkv.weight") @@ -200,7 +200,7 @@ class Attention(nn.Module): x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, - mask: Optional[torch.Tensor] = None, + mask: torch.Tensor | None = None, ): bsz, seqlen, _ = x.shape xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) @@ -288,13 +288,13 @@ class TransformerBlock(nn.Module): def load_hook( self, - state_dict: Dict[str, Any], + state_dict: dict[str, Any], prefix: str, - local_metadata: Dict[str, Any], + local_metadata: dict[str, Any], strict: bool, - missing_keys: List[str], - unexpected_keys: List[str], - error_msgs: List[str], + missing_keys: list[str], + unexpected_keys: list[str], + error_msgs: list[str], ) -> None: if prefix + "attention.wqkv.layer_norm_weight" in state_dict: state_dict[prefix + "attention_norm.weight"] = state_dict.pop(prefix + "attention.wqkv.layer_norm_weight") @@ -318,8 +318,8 @@ class TransformerBlock(nn.Module): x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, - global_attn_mask: Optional[torch.Tensor], - local_attn_mask: Optional[torch.Tensor], + global_attn_mask: torch.Tensor | None, + local_attn_mask: torch.Tensor | None, ): # The iRoPE architecture uses global attention mask for NoPE layers or # if chunked local attention is not used @@ -374,13 +374,13 @@ class Transformer(nn.Module): def load_hook( self, - state_dict: Dict[str, Any], + state_dict: dict[str, Any], prefix: str, - local_metadata: Dict[str, Any], + local_metadata: dict[str, Any], strict: bool, - missing_keys: List[str], - unexpected_keys: List[str], - error_msgs: List[str], + missing_keys: list[str], + unexpected_keys: list[str], + error_msgs: list[str], ) -> None: if prefix + "rope.freqs" in state_dict: state_dict.pop(prefix + "rope.freqs") diff --git a/llama_stack/models/llama/llama4/moe.py b/llama_stack/models/llama/llama4/moe.py index 2ce49e915..7475963d3 100644 --- a/llama_stack/models/llama/llama4/moe.py +++ b/llama_stack/models/llama/llama4/moe.py @@ -6,7 +6,7 @@ # ruff: noqa: N806 # pyre-strict -from typing import Any, Dict, List +from typing import Any import fairscale.nn.model_parallel.initialize as fs_init import torch @@ -63,13 +63,13 @@ class Experts(nn.Module): def load_hook( self, - state_dict: Dict[str, Any], + state_dict: dict[str, Any], prefix: str, - local_metadata: Dict[str, Any], + local_metadata: dict[str, Any], strict: bool, - missing_keys: List[str], - unexpected_keys: List[str], - error_msgs: List[str], + missing_keys: list[str], + unexpected_keys: list[str], + error_msgs: list[str], ) -> None: self.prefix = prefix if prefix + "moe_w_in_eD_F" in state_dict: @@ -158,13 +158,13 @@ class MoE(torch.nn.Module): def load_hook( self, - state_dict: Dict[str, Any], + state_dict: dict[str, Any], prefix: str, - local_metadata: Dict[str, Any], + local_metadata: dict[str, Any], strict: bool, - missing_keys: List[str], - unexpected_keys: List[str], - error_msgs: List[str], + missing_keys: list[str], + unexpected_keys: list[str], + error_msgs: list[str], ) -> None: if prefix + "w_in_shared_FD.weight" in state_dict: state_dict[prefix + "shared_expert.w1.weight"] = state_dict.pop(prefix + "w_in_shared_FD.weight") @@ -210,5 +210,5 @@ class MoE(torch.nn.Module): def divide_exact(numerator: int, denominator: int) -> int: - assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator) + assert numerator % denominator == 0, f"{numerator} is not divisible by {denominator}" return numerator // denominator diff --git a/llama_stack/models/llama/llama4/preprocess.py b/llama_stack/models/llama/llama4/preprocess.py index 689680779..7527a9987 100644 --- a/llama_stack/models/llama/llama4/preprocess.py +++ b/llama_stack/models/llama/llama4/preprocess.py @@ -13,7 +13,6 @@ import math from collections import defaultdict -from typing import Optional, Set, Tuple import torch import torchvision.transforms as tv @@ -52,7 +51,7 @@ class ResizeNormalizeImageTransform: return self.tv_transform(image) -class VariableSizeImageTransform(object): +class VariableSizeImageTransform: """ This class accepts images of any size and dynamically resize, pads and chunks it based on the image aspect ratio and the number of image chunks we allow. @@ -100,7 +99,7 @@ class VariableSizeImageTransform(object): self.resample = tv.InterpolationMode.BILINEAR @staticmethod - def get_factors(n: int) -> Set[int]: + def get_factors(n: int) -> set[int]: """ Calculate all factors of a given number, i.e. a dividor that leaves no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}. @@ -170,9 +169,9 @@ class VariableSizeImageTransform(object): @staticmethod def get_max_res_without_distortion( - image_size: Tuple[int, int], - target_size: Tuple[int, int], - ) -> Tuple[int, int]: + image_size: tuple[int, int], + target_size: tuple[int, int], + ) -> tuple[int, int]: """ Determines the maximum resolution to which an image can be resized to without distorting its aspect ratio, based on the target resolution. @@ -223,8 +222,8 @@ class VariableSizeImageTransform(object): def resize_without_distortion( self, image: torch.Tensor, - target_size: Tuple[int, int], - max_upscaling_size: Optional[int], + target_size: tuple[int, int], + max_upscaling_size: int | None, ) -> torch.Tensor: """ Used to resize an image to target_resolution, without distortion. @@ -289,10 +288,10 @@ class VariableSizeImageTransform(object): def get_best_fit( self, - image_size: Tuple[int, int], + image_size: tuple[int, int], possible_resolutions: torch.Tensor, resize_to_max_canvas: bool = False, - ) -> Tuple[int, int]: + ) -> tuple[int, int]: """ Determines the best canvas possible from a list of possible resolutions to, without distortion, resize an image to. @@ -392,7 +391,7 @@ class VariableSizeImageTransform(object): max_num_chunks: int, normalize_img: bool = True, resize_to_max_canvas: bool = False, - ) -> Tuple[torch.Tensor, Tuple[int, int]]: + ) -> tuple[torch.Tensor, tuple[int, int]]: """ Args: image (PIL.Image): Image to be resized. diff --git a/llama_stack/models/llama/llama4/prompt_templates/system_prompts.py b/llama_stack/models/llama/llama4/prompt_templates/system_prompts.py index 139e204ad..ceb28300f 100644 --- a/llama_stack/models/llama/llama4/prompt_templates/system_prompts.py +++ b/llama_stack/models/llama/llama4/prompt_templates/system_prompts.py @@ -12,7 +12,6 @@ # the top-level of this source tree. import textwrap -from typing import List, Optional from llama_stack.apis.inference import ToolDefinition, ToolParamDefinition from llama_stack.models.llama.llama3.prompt_templates.base import ( @@ -67,14 +66,14 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801 """.strip("\n") ) - def gen(self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None) -> PromptTemplate: + def gen(self, custom_tools: list[ToolDefinition], system_prompt: str | None = None) -> PromptTemplate: system_prompt = system_prompt or self.DEFAULT_PROMPT return PromptTemplate( system_prompt, {"function_description": self._gen_function_description(custom_tools)}, ) - def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate: + def _gen_function_description(self, custom_tools: list[ToolDefinition]) -> PromptTemplate: template_str = textwrap.dedent( """ Here is a list of functions in JSON format that you can invoke. @@ -120,7 +119,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801 {"tools": [t.model_dump() for t in custom_tools]}, ).render() - def data_examples(self) -> List[List[ToolDefinition]]: + def data_examples(self) -> list[list[ToolDefinition]]: return [ [ ToolDefinition( diff --git a/llama_stack/models/llama/llama4/prompts.py b/llama_stack/models/llama/llama4/prompts.py index fe9a59130..2da94db7b 100644 --- a/llama_stack/models/llama/llama4/prompts.py +++ b/llama_stack/models/llama/llama4/prompts.py @@ -7,7 +7,6 @@ import textwrap from io import BytesIO from pathlib import Path -from typing import List from llama_stack.models.llama.llama4.prompt_templates.system_prompts import ( PythonListCustomToolGenerator, @@ -23,7 +22,7 @@ from ..prompt_format import ( THIS_DIR = Path(__file__).parent -def usecases(base_model: bool = False) -> List[UseCase | str]: +def usecases(base_model: bool = False) -> list[UseCase | str]: with open(THIS_DIR.parent / "resources/small_dog.jpg", "rb") as f: img_small_dog = f.read() with open(THIS_DIR.parent / "resources/dog.jpg", "rb") as f: diff --git a/llama_stack/models/llama/llama4/quantization/loader.py b/llama_stack/models/llama/llama4/quantization/loader.py index f11d83c60..223744a5f 100644 --- a/llama_stack/models/llama/llama4/quantization/loader.py +++ b/llama_stack/models/llama/llama4/quantization/loader.py @@ -6,7 +6,7 @@ import logging import os -from typing import Callable, Optional +from collections.abc import Callable import torch from fairscale.nn.model_parallel.initialize import get_model_parallel_rank @@ -45,8 +45,8 @@ def experts_batched_swiglu_wrapper( def convert_to_quantized_model( model: Transformer, checkpoint_dir: str, - quantization_mode: Optional[str] = None, - fp8_activation_scale_ub: Optional[float] = 1200.0, + quantization_mode: str | None = None, + fp8_activation_scale_ub: float | None = 1200.0, use_rich_progress: bool = True, ) -> Transformer: from ...quantize_impls import ( @@ -213,7 +213,7 @@ def logging_callbacks( ) task_id = progress.add_task("[blue]Converting layers...", total=total_blocks, status="Starting") - def update_status(message: Optional[str], completed: Optional[int] = None) -> None: + def update_status(message: str | None, completed: int | None = None) -> None: if use_rich_progress: if message is not None: progress.update(task_id, status=message) diff --git a/llama_stack/models/llama/llama4/tokenizer.py b/llama_stack/models/llama/llama4/tokenizer.py index 0d2cc7ce5..74070d43e 100644 --- a/llama_stack/models/llama/llama4/tokenizer.py +++ b/llama_stack/models/llama/llama4/tokenizer.py @@ -5,18 +5,11 @@ # the root directory of this source tree. import os +from collections.abc import Collection, Iterator, Sequence, Set from logging import getLogger from pathlib import Path from typing import ( - AbstractSet, - Collection, - Dict, - Iterator, - List, Literal, - Optional, - Sequence, - Union, cast, ) @@ -114,7 +107,7 @@ class Tokenizer: Tokenizing and encoding/decoding text using the Tiktoken tokenizer. """ - special_tokens: Dict[str, int] + special_tokens: dict[str, int] num_reserved_special_tokens = 2048 @@ -182,9 +175,9 @@ class Tokenizer: *, bos: bool, eos: bool, - allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None, - disallowed_special: Union[Literal["all"], Collection[str]] = (), - ) -> List[int]: + allowed_special: Literal["all"] | Set[str] | None = None, + disallowed_special: Literal["all"] | Collection[str] = (), + ) -> list[int]: """ Encodes a string into a list of token IDs. @@ -217,7 +210,7 @@ class Tokenizer: s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS ) ) - t: List[int] = [] + t: list[int] = [] for substr in substrs: t.extend( self.model.encode( @@ -243,7 +236,7 @@ class Tokenizer: str: The decoded string. """ # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. - return self.model.decode(cast(List[int], t)) + return self.model.decode(cast(list[int], t)) @staticmethod def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]: diff --git a/llama_stack/models/llama/llama4/vision/embedding.py b/llama_stack/models/llama/llama4/vision/embedding.py index ed7659a73..c7dd81965 100644 --- a/llama_stack/models/llama/llama4/vision/embedding.py +++ b/llama_stack/models/llama/llama4/vision/embedding.py @@ -5,7 +5,8 @@ # the root directory of this source tree. import math -from typing import Any, Callable, Dict, List +from collections.abc import Callable +from typing import Any import torch import torch.nn as nn @@ -136,13 +137,13 @@ class VisionEmbeddings(torch.nn.Module): def load_hook( self, - state_dict: Dict[str, Any], + state_dict: dict[str, Any], prefix: str, - local_metadata: Dict[str, Any], + local_metadata: dict[str, Any], strict: bool = True, - missing_keys: List[str] = None, - unexpected_keys: List[str] = None, - error_msgs: List[str] = None, + missing_keys: list[str] = None, + unexpected_keys: list[str] = None, + error_msgs: list[str] = None, return_state_dict: bool = False, ) -> None: original_sd = self.state_dict() @@ -163,7 +164,7 @@ class VisionEmbeddings(torch.nn.Module): # each image is a tensor of shape [num_tiles, C, H, W] def forward( self, - image_batch: List[List[torch.Tensor]], + image_batch: list[list[torch.Tensor]], image_mask: torch.Tensor, h_ref: torch.Tensor, ) -> torch.Tensor: diff --git a/llama_stack/models/llama/llama4/vision/encoder.py b/llama_stack/models/llama/llama4/vision/encoder.py index 4baf03d8d..4b66f1411 100644 --- a/llama_stack/models/llama/llama4/vision/encoder.py +++ b/llama_stack/models/llama/llama4/vision/encoder.py @@ -4,7 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from collections.abc import Callable +from typing import Any import fairscale.nn.model_parallel.initialize as fs_init import torch @@ -42,9 +43,9 @@ class ColumnParallelConv2dPatch(torch.nn.Module): self, in_channels: int, out_channels: int, - kernel_size: Union[int, Tuple[int, int]], - stride: Union[int, Tuple[int, int]], - bias: Optional[bool] = False, + kernel_size: int | tuple[int, int], + stride: int | tuple[int, int], + bias: bool | None = False, ) -> None: super().__init__() if isinstance(kernel_size, int): @@ -134,15 +135,15 @@ class _TransformerBlock(nn.Module): def attention( self, x: torch.Tensor, - freq_cis: Optional[torch.Tensor] = None, + freq_cis: torch.Tensor | None = None, ): return self.attn(x=x, start_pos=0, freqs_cis=freq_cis) def forward( self, x: torch.Tensor, - mask: Optional[torch.Tensor] = None, - freq_cis: Optional[torch.Tensor] = None, + mask: torch.Tensor | None = None, + freq_cis: torch.Tensor | None = None, ): _gate_attn = 1 if not self.gated else self.gate_attn.tanh() _gate_ffn = 1 if not self.gated else self.gate_ffn.tanh() @@ -210,8 +211,8 @@ class PackingIndex: class VisionEncoder(nn.Module): def __init__( self, - image_size: Tuple[int, int], - patch_size: Tuple[int, int], + image_size: tuple[int, int], + patch_size: tuple[int, int], dim: int, layers: int, heads: int, @@ -299,13 +300,13 @@ class VisionEncoder(nn.Module): def load_hook( self, - state_dict: Dict[str, Any], + state_dict: dict[str, Any], prefix: str, - local_metadata: Dict[str, Any], + local_metadata: dict[str, Any], strict: bool = True, - missing_keys: List[str] = None, - unexpected_keys: List[str] = None, - error_msgs: List[str] = None, + missing_keys: list[str] = None, + unexpected_keys: list[str] = None, + error_msgs: list[str] = None, return_state_dict: bool = False, ) -> None: orig_pos_embed = state_dict.get(prefix + "positional_embedding") diff --git a/llama_stack/models/llama/prompt_format.py b/llama_stack/models/llama/prompt_format.py index edb34620c..6191df61a 100644 --- a/llama_stack/models/llama/prompt_format.py +++ b/llama_stack/models/llama/prompt_format.py @@ -14,7 +14,6 @@ import json import textwrap from pathlib import Path -from typing import List from pydantic import BaseModel, Field @@ -44,7 +43,7 @@ class TextCompletionContent(BaseModel): class UseCase(BaseModel): title: str = "" description: str = "" - dialogs: List[List[RawMessage] | TextCompletionContent | str] = Field(default_factory=list) + dialogs: list[list[RawMessage] | TextCompletionContent | str] = Field(default_factory=list) notes: str = "" tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json max_gen_len: int = 512 diff --git a/llama_stack/models/llama/quantize_impls.py b/llama_stack/models/llama/quantize_impls.py index a5da01588..a6400c5c9 100644 --- a/llama_stack/models/llama/quantize_impls.py +++ b/llama_stack/models/llama/quantize_impls.py @@ -7,7 +7,6 @@ # type: ignore import collections import logging -from typing import Optional, Tuple, Type, Union log = logging.getLogger(__name__) @@ -27,7 +26,7 @@ class Fp8ScaledWeights: # TODO: Ugly trick so torch allows us to replace parameters # with our custom Fp8Weights instance. Do this properly. @property - def __class__(self) -> Type[nn.parameter.Parameter]: + def __class__(self) -> type[nn.parameter.Parameter]: return nn.Parameter @property @@ -51,7 +50,7 @@ class Int4ScaledWeights: # TODO: Ugly trick so torch allows us to replace parameters # with our custom Int4Weights instance. Do this properly. @property - def __class__(self) -> Type[nn.parameter.Parameter]: + def __class__(self) -> type[nn.parameter.Parameter]: return nn.Parameter @property @@ -74,7 +73,7 @@ class Int4Weights( def int4_row_quantize( x: torch.Tensor, group_size: int = 128, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: n_bit = 4 # Number of target bits. to_quant = x.reshape(-1, group_size).to(torch.float) @@ -115,8 +114,8 @@ def pack_int4(x: torch.Tensor) -> torch.Tensor: def bmm_nt( x: Tensor, - w: Union[Fp8RowwiseWeights, Int4Weights], - num_tokens: Optional[Tensor] = None, + w: Fp8RowwiseWeights | Int4Weights, + num_tokens: Tensor | None = None, ) -> Tensor: if isinstance(w, Fp8ScaledWeights): xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, w.activation_scale_ub) @@ -129,10 +128,10 @@ def bmm_nt( def ffn_swiglu( x: Tensor, - w1: Union[Fp8RowwiseWeights, Int4Weights], - w3: Union[Fp8RowwiseWeights, Int4Weights], - w2: Union[Fp8RowwiseWeights, Int4Weights], - num_tokens: Optional[Tensor] = None, + w1: Fp8RowwiseWeights | Int4Weights, + w3: Fp8RowwiseWeights | Int4Weights, + w2: Fp8RowwiseWeights | Int4Weights, + num_tokens: Tensor | None = None, is_memory_bounded: bool = False, ) -> Tensor: if (isinstance(w1, Fp8ScaledWeights) and isinstance(w3, Fp8ScaledWeights) and isinstance(w2, Fp8ScaledWeights)) or ( @@ -158,7 +157,7 @@ def ffn_swiglu( def quantize_fp8( w: Tensor, fp8_activation_scale_ub: float, - output_device: Optional[torch.device] = None, + output_device: torch.device | None = None, ) -> Fp8RowwiseWeights: """Quantize [n, k] weight tensor. @@ -184,7 +183,7 @@ def quantize_fp8( @torch.inference_mode() def quantize_int4( w: Tensor, - output_device: Optional[torch.device] = None, + output_device: torch.device | None = None, ) -> Int4Weights: """Quantize [n, k/2] weight tensor. @@ -213,7 +212,7 @@ def load_fp8( w: Tensor, w_scale: Tensor, fp8_activation_scale_ub: float, - output_device: Optional[torch.device] = None, + output_device: torch.device | None = None, ) -> Fp8RowwiseWeights: """Load FP8 [n, k] weight tensor. @@ -239,7 +238,7 @@ def load_int4( w: Tensor, scale: Tensor, zero_point: Tensor, - output_device: Optional[torch.device] = None, + output_device: torch.device | None = None, ) -> Int4Weights: """Load INT4 [n, k/2] weight tensor. @@ -256,9 +255,9 @@ def load_int4( def fc_dynamic( x: Tensor, - w: Union[Fp8RowwiseWeights, Int4Weights], - activation_scale_ub: Optional[Tensor] = None, - num_tokens: Optional[Tensor] = None, + w: Fp8RowwiseWeights | Int4Weights, + activation_scale_ub: Tensor | None = None, + num_tokens: Tensor | None = None, is_memory_bounded: bool = False, ) -> Tensor: """ @@ -275,11 +274,11 @@ def fc_dynamic( def ffn_swiglu_dynamic( x: Tensor, - w1: Union[Fp8RowwiseWeights, Int4Weights], - w3: Union[Fp8RowwiseWeights, Int4Weights], - w2: Union[Fp8RowwiseWeights, Int4Weights], - activation_scale_ub: Optional[Tensor] = None, - num_tokens: Optional[Tensor] = None, + w1: Fp8RowwiseWeights | Int4Weights, + w3: Fp8RowwiseWeights | Int4Weights, + w2: Fp8RowwiseWeights | Int4Weights, + activation_scale_ub: Tensor | None = None, + num_tokens: Tensor | None = None, is_memory_bounded: bool = False, ) -> Tensor: assert x.dim() == 3 or x.dim() == 2 diff --git a/llama_stack/models/llama/sku_list.py b/llama_stack/models/llama/sku_list.py index 1e3218d04..a82cbf708 100644 --- a/llama_stack/models/llama/sku_list.py +++ b/llama_stack/models/llama/sku_list.py @@ -6,7 +6,6 @@ from dataclasses import dataclass from functools import lru_cache -from typing import List, Optional from .sku_types import ( CheckpointQuantizationFormat, @@ -19,14 +18,14 @@ LLAMA2_VOCAB_SIZE = 32000 LLAMA3_VOCAB_SIZE = 128256 -def resolve_model(descriptor: str) -> Optional[Model]: +def resolve_model(descriptor: str) -> Model | None: for m in all_registered_models(): if descriptor in (m.descriptor(), m.huggingface_repo): return m return None -def all_registered_models() -> List[Model]: +def all_registered_models() -> list[Model]: return ( llama2_family() + llama3_family() @@ -38,48 +37,48 @@ def all_registered_models() -> List[Model]: ) -def llama2_family() -> List[Model]: +def llama2_family() -> list[Model]: return [ *llama2_base_models(), *llama2_instruct_models(), ] -def llama3_family() -> List[Model]: +def llama3_family() -> list[Model]: return [ *llama3_base_models(), *llama3_instruct_models(), ] -def llama3_1_family() -> List[Model]: +def llama3_1_family() -> list[Model]: return [ *llama3_1_base_models(), *llama3_1_instruct_models(), ] -def llama3_2_family() -> List[Model]: +def llama3_2_family() -> list[Model]: return [ *llama3_2_base_models(), *llama3_2_instruct_models(), ] -def llama3_3_family() -> List[Model]: +def llama3_3_family() -> list[Model]: return [ *llama3_3_instruct_models(), ] -def llama4_family() -> List[Model]: +def llama4_family() -> list[Model]: return [ *llama4_base_models(), *llama4_instruct_models(), ] -def llama4_base_models() -> List[Model]: +def llama4_base_models() -> list[Model]: return [ Model( core_model_id=CoreModelId.llama4_scout_17b_16e, @@ -98,7 +97,7 @@ def llama4_base_models() -> List[Model]: ] -def llama4_instruct_models() -> List[Model]: +def llama4_instruct_models() -> list[Model]: return [ Model( core_model_id=CoreModelId.llama4_scout_17b_16e_instruct, @@ -126,7 +125,7 @@ def llama4_instruct_models() -> List[Model]: ] -def llama2_base_models() -> List[Model]: +def llama2_base_models() -> list[Model]: return [ Model( core_model_id=CoreModelId.llama2_7b, @@ -185,7 +184,7 @@ def llama2_base_models() -> List[Model]: ] -def llama3_base_models() -> List[Model]: +def llama3_base_models() -> list[Model]: return [ Model( core_model_id=CoreModelId.llama3_8b, @@ -226,7 +225,7 @@ def llama3_base_models() -> List[Model]: ] -def llama3_1_base_models() -> List[Model]: +def llama3_1_base_models() -> list[Model]: return [ Model( core_model_id=CoreModelId.llama3_1_8b, @@ -324,7 +323,7 @@ def llama3_1_base_models() -> List[Model]: ] -def llama3_2_base_models() -> List[Model]: +def llama3_2_base_models() -> list[Model]: return [ Model( core_model_id=CoreModelId.llama3_2_1b, @@ -407,7 +406,7 @@ def llama3_2_base_models() -> List[Model]: ] -def llama2_instruct_models() -> List[Model]: +def llama2_instruct_models() -> list[Model]: return [ Model( core_model_id=CoreModelId.llama2_7b_chat, @@ -466,7 +465,7 @@ def llama2_instruct_models() -> List[Model]: ] -def llama3_instruct_models() -> List[Model]: +def llama3_instruct_models() -> list[Model]: return [ Model( core_model_id=CoreModelId.llama3_8b_instruct, @@ -507,7 +506,7 @@ def llama3_instruct_models() -> List[Model]: ] -def llama3_1_instruct_models() -> List[Model]: +def llama3_1_instruct_models() -> list[Model]: return [ Model( core_model_id=CoreModelId.llama3_1_8b_instruct, @@ -635,7 +634,7 @@ def arch_args_3b() -> dict: } -def llama3_2_quantized_models() -> List[Model]: +def llama3_2_quantized_models() -> list[Model]: return [ Model( core_model_id=CoreModelId.llama3_2_1b_instruct, @@ -704,7 +703,7 @@ def llama3_2_quantized_models() -> List[Model]: ] -def llama3_2_instruct_models() -> List[Model]: +def llama3_2_instruct_models() -> list[Model]: return [ Model( core_model_id=CoreModelId.llama3_2_1b_instruct, @@ -766,7 +765,7 @@ def llama3_2_instruct_models() -> List[Model]: ] -def llama3_3_instruct_models() -> List[Model]: +def llama3_3_instruct_models() -> list[Model]: return [ Model( core_model_id=CoreModelId.llama3_3_70b_instruct, @@ -790,7 +789,7 @@ def llama3_3_instruct_models() -> List[Model]: @lru_cache -def safety_models() -> List[Model]: +def safety_models() -> list[Model]: return [ Model( core_model_id=CoreModelId.llama_guard_4_12b, @@ -919,7 +918,7 @@ def safety_models() -> List[Model]: @dataclass class LlamaDownloadInfo: folder: str - files: List[str] + files: list[str] pth_size: int diff --git a/llama_stack/models/llama/sku_types.py b/llama_stack/models/llama/sku_types.py index 2796c4c4b..4147707d5 100644 --- a/llama_stack/models/llama/sku_types.py +++ b/llama_stack/models/llama/sku_types.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from enum import Enum -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, ConfigDict, Field @@ -159,13 +159,13 @@ def model_family(model_id) -> ModelFamily: class Model(BaseModel): core_model_id: CoreModelId description: str - huggingface_repo: Optional[str] = None - arch_args: Dict[str, Any] + huggingface_repo: str | None = None + arch_args: dict[str, Any] variant: str = "" quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16 pth_file_count: int - metadata: Dict[str, Any] = Field(default_factory=dict) + metadata: dict[str, Any] = Field(default_factory=dict) # silence pydantic until we remove the `model_` fields model_config = ConfigDict(protected_namespaces=()) diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 3482acb31..3e9806f23 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from enum import Enum -from typing import Any, List, Optional, Protocol +from typing import Any, Protocol from urllib.parse import urlparse from pydantic import BaseModel, Field @@ -65,7 +65,7 @@ class DatasetsProtocolPrivate(Protocol): class ScoringFunctionsProtocolPrivate(Protocol): - async def list_scoring_functions(self) -> List[ScoringFn]: ... + async def list_scoring_functions(self) -> list[ScoringFn]: ... async def register_scoring_function(self, scoring_fn: ScoringFn) -> None: ... @@ -88,24 +88,24 @@ class ProviderSpec(BaseModel): ..., description="Fully-qualified classname of the config for this provider", ) - api_dependencies: List[Api] = Field( + api_dependencies: list[Api] = Field( default_factory=list, description="Higher-level API surfaces may depend on other providers to provide their functionality", ) - optional_api_dependencies: List[Api] = Field( + optional_api_dependencies: list[Api] = Field( default_factory=list, ) - deprecation_warning: Optional[str] = Field( + deprecation_warning: str | None = Field( default=None, description="If this provider is deprecated, specify the warning message here", ) - deprecation_error: Optional[str] = Field( + deprecation_error: str | None = Field( default=None, description="If this provider is deprecated and does NOT work, specify the error message here", ) # used internally by the resolver; this is a hack for now - deps__: List[str] = Field(default_factory=list) + deps__: list[str] = Field(default_factory=list) @property def is_sample(self) -> bool: @@ -131,25 +131,25 @@ Fully-qualified name of the module to import. The module is expected to have: - `get_adapter_impl(config, deps)`: returns the adapter implementation """, ) - pip_packages: List[str] = Field( + pip_packages: list[str] = Field( default_factory=list, description="The pip dependencies needed for this implementation", ) config_class: str = Field( description="Fully-qualified classname of the config for this provider", ) - provider_data_validator: Optional[str] = Field( + provider_data_validator: str | None = Field( default=None, ) @json_schema_type class InlineProviderSpec(ProviderSpec): - pip_packages: List[str] = Field( + pip_packages: list[str] = Field( default_factory=list, description="The pip dependencies needed for this implementation", ) - container_image: Optional[str] = Field( + container_image: str | None = Field( default=None, description=""" The container image to use for this implementation. If one is provided, pip_packages will be ignored. @@ -164,14 +164,14 @@ Fully-qualified name of the module to import. The module is expected to have: - `get_provider_impl(config, deps)`: returns the local implementation """, ) - provider_data_validator: Optional[str] = Field( + provider_data_validator: str | None = Field( default=None, ) class RemoteProviderConfig(BaseModel): host: str = "localhost" - port: Optional[int] = None + port: int | None = None protocol: str = "http" @property @@ -197,7 +197,7 @@ API responses, specify the adapter here. ) @property - def container_image(self) -> Optional[str]: + def container_image(self) -> str | None: return None @property @@ -205,16 +205,16 @@ API responses, specify the adapter here. return self.adapter.module @property - def pip_packages(self) -> List[str]: + def pip_packages(self) -> list[str]: return self.adapter.pip_packages @property - def provider_data_validator(self) -> Optional[str]: + def provider_data_validator(self) -> str | None: return self.adapter.provider_data_validator def remote_provider_spec( - api: Api, adapter: AdapterSpec, api_dependencies: Optional[List[Api]] = None + api: Api, adapter: AdapterSpec, api_dependencies: list[Api] | None = None ) -> RemoteProviderSpec: return RemoteProviderSpec( api=api, diff --git a/llama_stack/providers/inline/agents/meta_reference/__init__.py b/llama_stack/providers/inline/agents/meta_reference/__init__.py index 4be064f1d..7503b8c90 100644 --- a/llama_stack/providers/inline/agents/meta_reference/__init__.py +++ b/llama_stack/providers/inline/agents/meta_reference/__init__.py @@ -4,14 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from llama_stack.distribution.datatypes import Api from .config import MetaReferenceAgentsImplConfig -async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, Any]): +async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any]): from .agents import MetaReferenceAgentsImpl impl = MetaReferenceAgentsImpl( diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index b5714b438..3807ddb86 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -10,8 +10,8 @@ import re import secrets import string import uuid +from collections.abc import AsyncGenerator from datetime import datetime, timezone -from typing import AsyncGenerator, List, Optional, Union import httpx @@ -112,7 +112,7 @@ class ChatAgent(ShieldRunnerMixin): output_shields=agent_config.output_shields, ) - def turn_to_messages(self, turn: Turn) -> List[Message]: + def turn_to_messages(self, turn: Turn) -> list[Message]: messages = [] # NOTE: if a toolcall response is in a step, we do not add it when processing the input messages @@ -161,7 +161,7 @@ class ChatAgent(ShieldRunnerMixin): async def create_session(self, name: str) -> str: return await self.storage.create_session(name) - async def get_messages_from_turns(self, turns: List[Turn]) -> List[Message]: + async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]: messages = [] if self.agent_config.instructions != "": messages.append(SystemMessage(content=self.agent_config.instructions)) @@ -201,8 +201,8 @@ class ChatAgent(ShieldRunnerMixin): async def _run_turn( self, - request: Union[AgentTurnCreateRequest, AgentTurnResumeRequest], - turn_id: Optional[str] = None, + request: AgentTurnCreateRequest | AgentTurnResumeRequest, + turn_id: str | None = None, ) -> AsyncGenerator: assert request.stream is True, "Non-streaming not supported" @@ -321,10 +321,10 @@ class ChatAgent(ShieldRunnerMixin): self, session_id: str, turn_id: str, - input_messages: List[Message], + input_messages: list[Message], sampling_params: SamplingParams, stream: bool = False, - documents: Optional[List[Document]] = None, + documents: list[Document] | None = None, ) -> AsyncGenerator: # Doing async generators makes downstream code much simpler and everything amenable to # streaming. However, it also makes things complicated here because AsyncGenerators cannot @@ -374,8 +374,8 @@ class ChatAgent(ShieldRunnerMixin): async def run_multiple_shields_wrapper( self, turn_id: str, - messages: List[Message], - shields: List[str], + messages: list[Message], + shields: list[str], touchpoint: str, ) -> AsyncGenerator: async with tracing.span("run_shields") as span: @@ -443,10 +443,10 @@ class ChatAgent(ShieldRunnerMixin): self, session_id: str, turn_id: str, - input_messages: List[Message], + input_messages: list[Message], sampling_params: SamplingParams, stream: bool = False, - documents: Optional[List[Document]] = None, + documents: list[Document] | None = None, ) -> AsyncGenerator: # if document is passed in a turn, we parse the raw text of the document # and sent it as a user message @@ -760,7 +760,7 @@ class ChatAgent(ShieldRunnerMixin): async def _initialize_tools( self, - toolgroups_for_turn: Optional[List[AgentToolGroup]] = None, + toolgroups_for_turn: list[AgentToolGroup] | None = None, ) -> None: toolgroup_to_args = {} for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []): @@ -847,7 +847,7 @@ class ChatAgent(ShieldRunnerMixin): tool_name_to_args, ) - def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]: + def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, str | None]: """Parse a toolgroup name into its components. Args: @@ -921,7 +921,7 @@ async def get_raw_document_text(document: Document) -> str: def _interpret_content_as_attachment( content: str, -) -> Optional[Attachment]: +) -> Attachment | None: match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content) if match: snippet = match.group(1) diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index b9e36f519..e0cfa5b25 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -8,7 +8,7 @@ import json import logging import shutil import uuid -from typing import AsyncGenerator, List, Optional, Union +from collections.abc import AsyncGenerator from llama_stack.apis.agents import ( Agent, @@ -142,16 +142,11 @@ class MetaReferenceAgentsImpl(Agents): self, agent_id: str, session_id: str, - messages: List[ - Union[ - UserMessage, - ToolResponseMessage, - ] - ], - toolgroups: Optional[List[AgentToolGroup]] = None, - documents: Optional[List[Document]] = None, - stream: Optional[bool] = False, - tool_config: Optional[ToolConfig] = None, + messages: list[UserMessage | ToolResponseMessage], + toolgroups: list[AgentToolGroup] | None = None, + documents: list[Document] | None = None, + stream: bool | None = False, + tool_config: ToolConfig | None = None, ) -> AsyncGenerator: request = AgentTurnCreateRequest( agent_id=agent_id, @@ -180,8 +175,8 @@ class MetaReferenceAgentsImpl(Agents): agent_id: str, session_id: str, turn_id: str, - tool_responses: List[ToolResponse], - stream: Optional[bool] = False, + tool_responses: list[ToolResponse], + stream: bool | None = False, ) -> AsyncGenerator: request = AgentTurnResumeRequest( agent_id=agent_id, @@ -219,7 +214,7 @@ class MetaReferenceAgentsImpl(Agents): self, agent_id: str, session_id: str, - turn_ids: Optional[List[str]] = None, + turn_ids: list[str] | None = None, ) -> Session: agent = await self._get_agent_impl(agent_id) session_info = await agent.storage.get_session_info(session_id) @@ -265,13 +260,13 @@ class MetaReferenceAgentsImpl(Agents): async def create_openai_response( self, - input: Union[str, List[OpenAIResponseInputMessage]], + input: str | list[OpenAIResponseInputMessage], model: str, - previous_response_id: Optional[str] = None, - store: Optional[bool] = True, - stream: Optional[bool] = False, - temperature: Optional[float] = None, - tools: Optional[List[OpenAIResponseInputTool]] = None, + previous_response_id: str | None = None, + store: bool | None = True, + stream: bool | None = False, + temperature: float | None = None, + tools: list[OpenAIResponseInputTool] | None = None, ) -> OpenAIResponseObject: return await self.openai_responses_impl.create_openai_response( input, model, previous_response_id, store, stream, temperature, tools diff --git a/llama_stack/providers/inline/agents/meta_reference/config.py b/llama_stack/providers/inline/agents/meta_reference/config.py index ff34e5d5f..c860e6df1 100644 --- a/llama_stack/providers/inline/agents/meta_reference/config.py +++ b/llama_stack/providers/inline/agents/meta_reference/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from pydantic import BaseModel @@ -16,7 +16,7 @@ class MetaReferenceAgentsImplConfig(BaseModel): persistence_store: KVStoreConfig @classmethod - def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]: + def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]: return { "persistence_store": SqliteKVStoreConfig.sample_run_config( __distro_dir__=__distro_dir__, diff --git a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py index 251acd853..24a99dd6e 100644 --- a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py @@ -6,7 +6,8 @@ import json import uuid -from typing import AsyncIterator, List, Optional, Union, cast +from collections.abc import AsyncIterator +from typing import cast from openai.types.chat import ChatCompletionToolParam @@ -49,15 +50,15 @@ logger = get_logger(name=__name__, category="openai_responses") OPENAI_RESPONSES_PREFIX = "openai_responses:" -async def _previous_response_to_messages(previous_response: OpenAIResponseObject) -> List[OpenAIMessageParam]: - messages: List[OpenAIMessageParam] = [] +async def _previous_response_to_messages(previous_response: OpenAIResponseObject) -> list[OpenAIMessageParam]: + messages: list[OpenAIMessageParam] = [] for output_message in previous_response.output: if isinstance(output_message, OpenAIResponseOutputMessage): messages.append(OpenAIAssistantMessageParam(content=output_message.content[0].text)) return messages -async def _openai_choices_to_output_messages(choices: List[OpenAIChoice]) -> List[OpenAIResponseOutputMessage]: +async def _openai_choices_to_output_messages(choices: list[OpenAIChoice]) -> list[OpenAIResponseOutputMessage]: output_messages = [] for choice in choices: output_content = "" @@ -101,22 +102,22 @@ class OpenAIResponsesImpl: async def create_openai_response( self, - input: Union[str, List[OpenAIResponseInputMessage]], + input: str | list[OpenAIResponseInputMessage], model: str, - previous_response_id: Optional[str] = None, - store: Optional[bool] = True, - stream: Optional[bool] = False, - temperature: Optional[float] = None, - tools: Optional[List[OpenAIResponseInputTool]] = None, + previous_response_id: str | None = None, + store: bool | None = True, + stream: bool | None = False, + temperature: float | None = None, + tools: list[OpenAIResponseInputTool] | None = None, ): stream = False if stream is None else stream - messages: List[OpenAIMessageParam] = [] + messages: list[OpenAIMessageParam] = [] if previous_response_id: previous_response = await self.get_openai_response(previous_response_id) messages.extend(await _previous_response_to_messages(previous_response)) # TODO: refactor this user_content parsing out into a separate method - user_content: Union[str, List[OpenAIChatCompletionContentPartParam]] = "" + user_content: str | list[OpenAIChatCompletionContentPartParam] = "" if isinstance(input, list): user_content = [] for user_input in input: @@ -179,7 +180,7 @@ class OpenAIResponsesImpl: # dump and reload to map to our pydantic types chat_response = OpenAIChatCompletion(**chat_response.model_dump()) - output_messages: List[OpenAIResponseOutput] = [] + output_messages: list[OpenAIResponseOutput] = [] if chat_response.choices[0].message.tool_calls: output_messages.extend( await self._execute_tool_and_return_final_output(model, stream, chat_response, messages, temperature) @@ -215,9 +216,9 @@ class OpenAIResponsesImpl: return response async def _convert_response_tools_to_chat_tools( - self, tools: List[OpenAIResponseInputTool] - ) -> List[ChatCompletionToolParam]: - chat_tools: List[ChatCompletionToolParam] = [] + self, tools: list[OpenAIResponseInputTool] + ) -> list[ChatCompletionToolParam]: + chat_tools: list[ChatCompletionToolParam] = [] for input_tool in tools: # TODO: Handle other tool types if input_tool.type == "web_search": @@ -247,10 +248,10 @@ class OpenAIResponsesImpl: model_id: str, stream: bool, chat_response: OpenAIChatCompletion, - messages: List[OpenAIMessageParam], + messages: list[OpenAIMessageParam], temperature: float, - ) -> List[OpenAIResponseOutput]: - output_messages: List[OpenAIResponseOutput] = [] + ) -> list[OpenAIResponseOutput]: + output_messages: list[OpenAIResponseOutput] = [] choice = chat_response.choices[0] # If the choice is not an assistant message, we don't need to execute any tools @@ -314,7 +315,7 @@ class OpenAIResponsesImpl: async def _execute_tool_call( self, function: OpenAIChatCompletionToolCallFunction, - ) -> Optional[ToolInvocationResult]: + ) -> ToolInvocationResult | None: if not function.name: return None function_args = json.loads(function.arguments) if function.arguments else {} diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index 202d43609..60ce91395 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -8,7 +8,6 @@ import json import logging import uuid from datetime import datetime, timezone -from typing import List, Optional from pydantic import BaseModel @@ -25,9 +24,9 @@ class AgentSessionInfo(BaseModel): session_id: str session_name: str # TODO: is this used anywhere? - vector_db_id: Optional[str] = None + vector_db_id: str | None = None started_at: datetime - access_attributes: Optional[AccessAttributes] = None + access_attributes: AccessAttributes | None = None class AgentPersistence: @@ -55,7 +54,7 @@ class AgentPersistence: ) return session_id - async def get_session_info(self, session_id: str) -> Optional[AgentSessionInfo]: + async def get_session_info(self, session_id: str) -> AgentSessionInfo | None: value = await self.kvstore.get( key=f"session:{self.agent_id}:{session_id}", ) @@ -78,7 +77,7 @@ class AgentPersistence: return check_access(session_info.session_id, session_info.access_attributes, get_auth_attributes()) - async def get_session_if_accessible(self, session_id: str) -> Optional[AgentSessionInfo]: + async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None: """Get session info if the user has access to it. For internal use by sub-session methods.""" session_info = await self.get_session_info(session_id) if not session_info: @@ -106,7 +105,7 @@ class AgentPersistence: value=turn.model_dump_json(), ) - async def get_session_turns(self, session_id: str) -> List[Turn]: + async def get_session_turns(self, session_id: str) -> list[Turn]: if not await self.get_session_if_accessible(session_id): raise ValueError(f"Session {session_id} not found or access denied") @@ -125,7 +124,7 @@ class AgentPersistence: turns.sort(key=lambda x: (x.completed_at or datetime.min)) return turns - async def get_session_turn(self, session_id: str, turn_id: str) -> Optional[Turn]: + async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None: if not await self.get_session_if_accessible(session_id): raise ValueError(f"Session {session_id} not found or access denied") @@ -145,7 +144,7 @@ class AgentPersistence: value=step.model_dump_json(), ) - async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]: + async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> ToolExecutionStep | None: if not await self.get_session_if_accessible(session_id): return None @@ -163,7 +162,7 @@ class AgentPersistence: value=str(num_infer_iters), ) - async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]: + async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> int | None: if not await self.get_session_if_accessible(session_id): return None diff --git a/llama_stack/providers/inline/agents/meta_reference/safety.py b/llama_stack/providers/inline/agents/meta_reference/safety.py index bef16eaba..6b3573d8c 100644 --- a/llama_stack/providers/inline/agents/meta_reference/safety.py +++ b/llama_stack/providers/inline/agents/meta_reference/safety.py @@ -6,7 +6,6 @@ import asyncio import logging -from typing import List from llama_stack.apis.inference import Message from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel @@ -25,14 +24,14 @@ class ShieldRunnerMixin: def __init__( self, safety_api: Safety, - input_shields: List[str] = None, - output_shields: List[str] = None, + input_shields: list[str] = None, + output_shields: list[str] = None, ): self.safety_api = safety_api self.input_shields = input_shields self.output_shields = output_shields - async def run_multiple_shields(self, messages: List[Message], identifiers: List[str]) -> None: + async def run_multiple_shields(self, messages: list[Message], identifiers: list[str]) -> None: async def run_shield_with_span(identifier: str): async with tracing.span(f"run_shield_{identifier}"): return await self.safety_api.run_shield( diff --git a/llama_stack/providers/inline/datasetio/localfs/__init__.py b/llama_stack/providers/inline/datasetio/localfs/__init__.py index 5a0876d79..58aa6ffaf 100644 --- a/llama_stack/providers/inline/datasetio/localfs/__init__.py +++ b/llama_stack/providers/inline/datasetio/localfs/__init__.py @@ -4,14 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from .config import LocalFSDatasetIOConfig async def get_provider_impl( config: LocalFSDatasetIOConfig, - _deps: Dict[str, Any], + _deps: dict[str, Any], ): from .datasetio import LocalFSDatasetIOImpl diff --git a/llama_stack/providers/inline/datasetio/localfs/config.py b/llama_stack/providers/inline/datasetio/localfs/config.py index d74521f1f..b450e8777 100644 --- a/llama_stack/providers/inline/datasetio/localfs/config.py +++ b/llama_stack/providers/inline/datasetio/localfs/config.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from pydantic import BaseModel @@ -17,7 +17,7 @@ class LocalFSDatasetIOConfig(BaseModel): kvstore: KVStoreConfig @classmethod - def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]: + def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: return { "kvstore": SqliteKVStoreConfig.sample_run_config( __distro_dir__=__distro_dir__, diff --git a/llama_stack/providers/inline/datasetio/localfs/datasetio.py b/llama_stack/providers/inline/datasetio/localfs/datasetio.py index e71107d61..260a640bd 100644 --- a/llama_stack/providers/inline/datasetio/localfs/datasetio.py +++ b/llama_stack/providers/inline/datasetio/localfs/datasetio.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Optional +from typing import Any import pandas @@ -92,8 +92,8 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): async def iterrows( self, dataset_id: str, - start_index: Optional[int] = None, - limit: Optional[int] = None, + start_index: int | None = None, + limit: int | None = None, ) -> PaginatedResponse: dataset_def = self.dataset_infos[dataset_id] dataset_impl = PandasDataframeDataset(dataset_def) @@ -102,7 +102,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): records = dataset_impl.df.to_dict("records") return paginate_records(records, start_index, limit) - async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: + async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: dataset_def = self.dataset_infos[dataset_id] dataset_impl = PandasDataframeDataset(dataset_def) await dataset_impl.load() diff --git a/llama_stack/providers/inline/eval/meta_reference/__init__.py b/llama_stack/providers/inline/eval/meta_reference/__init__.py index e2a7fc2cd..7afe7f33b 100644 --- a/llama_stack/providers/inline/eval/meta_reference/__init__.py +++ b/llama_stack/providers/inline/eval/meta_reference/__init__.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from llama_stack.distribution.datatypes import Api @@ -12,7 +12,7 @@ from .config import MetaReferenceEvalConfig async def get_provider_impl( config: MetaReferenceEvalConfig, - deps: Dict[Api, Any], + deps: dict[Api, Any], ): from .eval import MetaReferenceEvalImpl diff --git a/llama_stack/providers/inline/eval/meta_reference/config.py b/llama_stack/providers/inline/eval/meta_reference/config.py index 5b2bec259..2a4a29998 100644 --- a/llama_stack/providers/inline/eval/meta_reference/config.py +++ b/llama_stack/providers/inline/eval/meta_reference/config.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from pydantic import BaseModel @@ -17,7 +17,7 @@ class MetaReferenceEvalConfig(BaseModel): kvstore: KVStoreConfig @classmethod - def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]: + def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: return { "kvstore": SqliteKVStoreConfig.sample_run_config( __distro_dir__=__distro_dir__, diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index 7c28f1bb7..12dde66b5 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import json -from typing import Any, Dict, List +from typing import Any from tqdm import tqdm @@ -105,8 +105,8 @@ class MetaReferenceEvalImpl( return Job(job_id=job_id, status=JobStatus.completed) async def _run_agent_generation( - self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig - ) -> List[Dict[str, Any]]: + self, input_rows: list[dict[str, Any]], benchmark_config: BenchmarkConfig + ) -> list[dict[str, Any]]: candidate = benchmark_config.eval_candidate create_response = await self.agents_api.create_agent(candidate.config) agent_id = create_response.agent_id @@ -148,8 +148,8 @@ class MetaReferenceEvalImpl( return generations async def _run_model_generation( - self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig - ) -> List[Dict[str, Any]]: + self, input_rows: list[dict[str, Any]], benchmark_config: BenchmarkConfig + ) -> list[dict[str, Any]]: candidate = benchmark_config.eval_candidate assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided" @@ -185,8 +185,8 @@ class MetaReferenceEvalImpl( async def evaluate_rows( self, benchmark_id: str, - input_rows: List[Dict[str, Any]], - scoring_functions: List[str], + input_rows: list[dict[str, Any]], + scoring_functions: list[str], benchmark_config: BenchmarkConfig, ) -> EvaluateResponse: candidate = benchmark_config.eval_candidate diff --git a/llama_stack/providers/inline/inference/meta_reference/__init__.py b/llama_stack/providers/inline/inference/meta_reference/__init__.py index 3710766e2..5eb822429 100644 --- a/llama_stack/providers/inline/inference/meta_reference/__init__.py +++ b/llama_stack/providers/inline/inference/meta_reference/__init__.py @@ -4,14 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from .config import MetaReferenceInferenceConfig async def get_provider_impl( config: MetaReferenceInferenceConfig, - _deps: Dict[str, Any], + _deps: dict[str, Any], ): from .inference import MetaReferenceInferenceImpl diff --git a/llama_stack/providers/inline/inference/meta_reference/config.py b/llama_stack/providers/inline/inference/meta_reference/config.py index 6f796d0d4..7bc961443 100644 --- a/llama_stack/providers/inline/inference/meta_reference/config.py +++ b/llama_stack/providers/inline/inference/meta_reference/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, field_validator @@ -17,11 +17,11 @@ class MetaReferenceInferenceConfig(BaseModel): # the actual inference model id is dtermined by the moddel id in the request # Note: you need to register the model before using it for inference # models in the resouce list in the run.yaml config will be registered automatically - model: Optional[str] = None - torch_seed: Optional[int] = None + model: str | None = None + torch_seed: int | None = None max_seq_len: int = 4096 max_batch_size: int = 1 - model_parallel_size: Optional[int] = None + model_parallel_size: int | None = None # when this is False, we assume that the distributed process group is setup by someone # outside of this code (e.g., when run inside `torchrun`). that is useful for clients @@ -30,9 +30,9 @@ class MetaReferenceInferenceConfig(BaseModel): # By default, the implementation will look at ~/.llama/checkpoints/ but you # can override by specifying the directory explicitly - checkpoint_dir: Optional[str] = None + checkpoint_dir: str | None = None - quantization: Optional[QuantizationConfig] = None + quantization: QuantizationConfig | None = None @field_validator("model") @classmethod @@ -55,7 +55,7 @@ class MetaReferenceInferenceConfig(BaseModel): max_batch_size: str = "${env.MAX_BATCH_SIZE:1}", max_seq_len: str = "${env.MAX_SEQ_LEN:4096}", **kwargs, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: return { "model": model, "checkpoint_dir": checkpoint_dir, diff --git a/llama_stack/providers/inline/inference/meta_reference/generators.py b/llama_stack/providers/inline/inference/meta_reference/generators.py index 0a928ce73..cb926f529 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generators.py +++ b/llama_stack/providers/inline/inference/meta_reference/generators.py @@ -5,7 +5,8 @@ # the root directory of this source tree. import math -from typing import Generator, List, Optional, Tuple +from collections.abc import Generator +from typing import Optional import torch from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData @@ -39,7 +40,7 @@ Tokenizer = Llama4Tokenizer | Llama3Tokenizer class LogitsProcessor: def __init__(self, token_enforcer: TokenEnforcer): self.token_enforcer = token_enforcer - self.mask: Optional[torch.Tensor] = None + self.mask: torch.Tensor | None = None def __call__(self, tokens: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: token_sequence = tokens[0, :].tolist() @@ -58,7 +59,7 @@ class LogitsProcessor: def get_logits_processor( tokenizer: Tokenizer, vocab_size: int, - response_format: Optional[ResponseFormat], + response_format: ResponseFormat | None, ) -> Optional["LogitsProcessor"]: if response_format is None: return None @@ -76,7 +77,7 @@ def get_logits_processor( return LogitsProcessor(token_enforcer) -def _build_regular_tokens_list(tokenizer: Tokenizer, vocab_size: int) -> List[Tuple[int, str, bool]]: +def _build_regular_tokens_list(tokenizer: Tokenizer, vocab_size: int) -> list[tuple[int, str, bool]]: token_0 = tokenizer.encode("0", bos=False, eos=False)[-1] regular_tokens = [] @@ -158,7 +159,7 @@ class LlamaGenerator: def completion( self, - request_batch: List[CompletionRequestWithRawContent], + request_batch: list[CompletionRequestWithRawContent], ) -> Generator: first_request = request_batch[0] sampling_params = first_request.sampling_params or SamplingParams() @@ -167,7 +168,7 @@ class LlamaGenerator: max_gen_len = self.args.max_seq_len - 1 temperature, top_p = _infer_sampling_params(sampling_params) - for result in self.inner_generator.generate( + yield from self.inner_generator.generate( llm_inputs=[self.formatter.encode_content(request.content) for request in request_batch], max_gen_len=max_gen_len, temperature=temperature, @@ -179,12 +180,11 @@ class LlamaGenerator: self.args.vocab_size, first_request.response_format, ), - ): - yield result + ) def chat_completion( self, - request_batch: List[ChatCompletionRequestWithRawContent], + request_batch: list[ChatCompletionRequestWithRawContent], ) -> Generator: first_request = request_batch[0] sampling_params = first_request.sampling_params or SamplingParams() @@ -193,7 +193,7 @@ class LlamaGenerator: max_gen_len = self.args.max_seq_len - 1 temperature, top_p = _infer_sampling_params(sampling_params) - for result in self.inner_generator.generate( + yield from self.inner_generator.generate( llm_inputs=[ self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request)) for request in request_batch @@ -208,5 +208,4 @@ class LlamaGenerator: self.args.vocab_size, first_request.response_format, ), - ): - yield result + ) diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 1bc098fab..8dd594869 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -6,7 +6,7 @@ import asyncio import os -from typing import AsyncGenerator, List, Optional, Union +from collections.abc import AsyncGenerator from pydantic import BaseModel from termcolor import cprint @@ -184,11 +184,11 @@ class MetaReferenceInferenceImpl( self, model_id: str, content: InterleavedContent, - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, - ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, + ) -> CompletionResponse | CompletionResponseStreamChunk: if sampling_params is None: sampling_params = SamplingParams() if logprobs: @@ -215,11 +215,11 @@ class MetaReferenceInferenceImpl( async def batch_completion( self, model_id: str, - content_batch: List[InterleavedContent], - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, + content_batch: list[InterleavedContent], + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, ) -> BatchCompletionResponse: if sampling_params is None: sampling_params = SamplingParams() @@ -291,14 +291,14 @@ class MetaReferenceInferenceImpl( for x in impl(): yield x - async def _nonstream_completion(self, request_batch: List[CompletionRequest]) -> List[CompletionResponse]: + async def _nonstream_completion(self, request_batch: list[CompletionRequest]) -> list[CompletionResponse]: tokenizer = self.generator.formatter.tokenizer first_request = request_batch[0] class ItemState(BaseModel): - tokens: List[int] = [] - logprobs: List[TokenLogProbs] = [] + tokens: list[int] = [] + logprobs: list[TokenLogProbs] = [] stop_reason: StopReason | None = None finished: bool = False @@ -349,15 +349,15 @@ class MetaReferenceInferenceImpl( async def chat_completion( self, model_id: str, - messages: List[Message], - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - tools: Optional[List[ToolDefinition]] = None, - tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, - tool_config: Optional[ToolConfig] = None, + messages: list[Message], + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + tools: list[ToolDefinition] | None = None, + tool_choice: ToolChoice | None = ToolChoice.auto, + tool_prompt_format: ToolPromptFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, + tool_config: ToolConfig | None = None, ) -> AsyncGenerator: if sampling_params is None: sampling_params = SamplingParams() @@ -395,13 +395,13 @@ class MetaReferenceInferenceImpl( async def batch_chat_completion( self, model_id: str, - messages_batch: List[List[Message]], - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - tools: Optional[List[ToolDefinition]] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, - tool_config: Optional[ToolConfig] = None, + messages_batch: list[list[Message]], + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + tools: list[ToolDefinition] | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, + tool_config: ToolConfig | None = None, ) -> BatchChatCompletionResponse: if sampling_params is None: sampling_params = SamplingParams() @@ -436,15 +436,15 @@ class MetaReferenceInferenceImpl( return BatchChatCompletionResponse(batch=results) async def _nonstream_chat_completion( - self, request_batch: List[ChatCompletionRequest] - ) -> List[ChatCompletionResponse]: + self, request_batch: list[ChatCompletionRequest] + ) -> list[ChatCompletionResponse]: tokenizer = self.generator.formatter.tokenizer first_request = request_batch[0] class ItemState(BaseModel): - tokens: List[int] = [] - logprobs: List[TokenLogProbs] = [] + tokens: list[int] = [] + logprobs: list[TokenLogProbs] = [] stop_reason: StopReason | None = None finished: bool = False diff --git a/llama_stack/providers/inline/inference/meta_reference/model_parallel.py b/llama_stack/providers/inline/inference/meta_reference/model_parallel.py index 50640c6d1..9031d36b3 100644 --- a/llama_stack/providers/inline/inference/meta_reference/model_parallel.py +++ b/llama_stack/providers/inline/inference/meta_reference/model_parallel.py @@ -4,9 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from collections.abc import Callable, Generator from copy import deepcopy from functools import partial -from typing import Any, Callable, Generator, List +from typing import Any from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat @@ -82,7 +83,7 @@ class LlamaModelParallelGenerator: def completion( self, - request_batch: List[CompletionRequestWithRawContent], + request_batch: list[CompletionRequestWithRawContent], ) -> Generator: req_obj = deepcopy(request_batch) gen = self.group.run_inference(("completion", req_obj)) @@ -90,7 +91,7 @@ class LlamaModelParallelGenerator: def chat_completion( self, - request_batch: List[ChatCompletionRequestWithRawContent], + request_batch: list[ChatCompletionRequestWithRawContent], ) -> Generator: req_obj = deepcopy(request_batch) gen = self.group.run_inference(("chat_completion", req_obj)) diff --git a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py index 8c0ffc632..97e96b929 100644 --- a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +++ b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py @@ -18,8 +18,9 @@ import os import tempfile import time import uuid +from collections.abc import Callable, Generator from enum import Enum -from typing import Callable, Generator, List, Literal, Optional, Tuple, Union +from typing import Annotated, Literal import torch import zmq @@ -30,7 +31,6 @@ from fairscale.nn.model_parallel.initialize import ( ) from pydantic import BaseModel, Field from torch.distributed.launcher.api import LaunchConfig, elastic_launch -from typing_extensions import Annotated from llama_stack.models.llama.datatypes import GenerationResult from llama_stack.providers.utils.inference.prompt_adapter import ( @@ -69,15 +69,15 @@ class CancelSentinel(BaseModel): class TaskRequest(BaseModel): type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request - task: Tuple[ + task: tuple[ str, - List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent], + list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent], ] class TaskResponse(BaseModel): type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response - result: List[GenerationResult] + result: list[GenerationResult] class ExceptionResponse(BaseModel): @@ -85,15 +85,9 @@ class ExceptionResponse(BaseModel): error: str -ProcessingMessage = Union[ - ReadyRequest, - ReadyResponse, - EndSentinel, - CancelSentinel, - TaskRequest, - TaskResponse, - ExceptionResponse, -] +ProcessingMessage = ( + ReadyRequest | ReadyResponse | EndSentinel | CancelSentinel | TaskRequest | TaskResponse | ExceptionResponse +) class ProcessingMessageWrapper(BaseModel): @@ -203,7 +197,7 @@ def maybe_get_work(sock: zmq.Socket): return client_id, message -def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage]: +def maybe_parse_message(maybe_json: str | None) -> ProcessingMessage | None: if maybe_json is None: return None try: @@ -334,9 +328,9 @@ class ModelParallelProcessGroup: def run_inference( self, - req: Tuple[ + req: tuple[ str, - List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent], + list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent], ], ) -> Generator: assert not self.running, "inference already running" diff --git a/llama_stack/providers/inline/inference/sentence_transformers/__init__.py b/llama_stack/providers/inline/inference/sentence_transformers/__init__.py index c1d65d10c..1719cbacc 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/__init__.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/__init__.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from llama_stack.providers.inline.inference.sentence_transformers.config import ( SentenceTransformersInferenceConfig, @@ -13,7 +13,7 @@ from llama_stack.providers.inline.inference.sentence_transformers.config import async def get_provider_impl( config: SentenceTransformersInferenceConfig, - _deps: Dict[str, Any], + _deps: dict[str, Any], ): from .sentence_transformers import SentenceTransformersInferenceImpl diff --git a/llama_stack/providers/inline/inference/sentence_transformers/config.py b/llama_stack/providers/inline/inference/sentence_transformers/config.py index 93e0afe11..b03010b10 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/config.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/config.py @@ -4,12 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from pydantic import BaseModel class SentenceTransformersInferenceConfig(BaseModel): @classmethod - def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + def sample_run_config(cls, **kwargs) -> dict[str, Any]: return {} diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index d717d055f..7b36b0997 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import logging -from typing import AsyncGenerator, List, Optional, Union +from collections.abc import AsyncGenerator from llama_stack.apis.inference import ( CompletionResponse, @@ -60,46 +60,46 @@ class SentenceTransformersInferenceImpl( self, model_id: str, content: str, - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, - ) -> Union[CompletionResponse, AsyncGenerator]: + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, + ) -> CompletionResponse | AsyncGenerator: raise ValueError("Sentence transformers don't support completion") async def chat_completion( self, model_id: str, - messages: List[Message], - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - tools: Optional[List[ToolDefinition]] = None, - tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, - tool_config: Optional[ToolConfig] = None, + messages: list[Message], + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + tools: list[ToolDefinition] | None = None, + tool_choice: ToolChoice | None = ToolChoice.auto, + tool_prompt_format: ToolPromptFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, + tool_config: ToolConfig | None = None, ) -> AsyncGenerator: raise ValueError("Sentence transformers don't support chat completion") async def batch_completion( self, model_id: str, - content_batch: List[InterleavedContent], - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - logprobs: Optional[LogProbConfig] = None, + content_batch: list[InterleavedContent], + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + logprobs: LogProbConfig | None = None, ): raise NotImplementedError("Batch completion is not supported for Sentence Transformers") async def batch_chat_completion( self, model_id: str, - messages_batch: List[List[Message]], - sampling_params: Optional[SamplingParams] = None, - tools: Optional[List[ToolDefinition]] = None, - tool_config: Optional[ToolConfig] = None, - response_format: Optional[ResponseFormat] = None, - logprobs: Optional[LogProbConfig] = None, + messages_batch: list[list[Message]], + sampling_params: SamplingParams | None = None, + tools: list[ToolDefinition] | None = None, + tool_config: ToolConfig | None = None, + response_format: ResponseFormat | None = None, + logprobs: LogProbConfig | None = None, ): raise NotImplementedError("Batch chat completion is not supported for Sentence Transformers") diff --git a/llama_stack/providers/inline/inference/vllm/__init__.py b/llama_stack/providers/inline/inference/vllm/__init__.py index bd0551e57..d0ec3e084 100644 --- a/llama_stack/providers/inline/inference/vllm/__init__.py +++ b/llama_stack/providers/inline/inference/vllm/__init__.py @@ -4,12 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from .config import VLLMConfig -async def get_provider_impl(config: VLLMConfig, _deps: Dict[str, Any]): +async def get_provider_impl(config: VLLMConfig, _deps: dict[str, Any]): from .vllm import VLLMInferenceImpl impl = VLLMInferenceImpl(config) diff --git a/llama_stack/providers/inline/inference/vllm/config.py b/llama_stack/providers/inline/inference/vllm/config.py index 51d48e6d5..ce8743c74 100644 --- a/llama_stack/providers/inline/inference/vllm/config.py +++ b/llama_stack/providers/inline/inference/vllm/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from pydantic import BaseModel, Field @@ -42,7 +42,7 @@ class VLLMConfig(BaseModel): ) @classmethod - def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]: + def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]: return { "tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:1}", "max_tokens": "${env.MAX_TOKENS:4096}", diff --git a/llama_stack/providers/inline/inference/vllm/openai_utils.py b/llama_stack/providers/inline/inference/vllm/openai_utils.py index d34f5ad5f..77cbf0403 100644 --- a/llama_stack/providers/inline/inference/vllm/openai_utils.py +++ b/llama_stack/providers/inline/inference/vllm/openai_utils.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List, Optional import vllm @@ -55,8 +54,8 @@ def _merge_context_into_content(message: Message) -> Message: # type: ignore def _llama_stack_tools_to_openai_tools( - tools: Optional[List[ToolDefinition]] = None, -) -> List[vllm.entrypoints.openai.protocol.ChatCompletionToolsParam]: + tools: list[ToolDefinition] | None = None, +) -> list[vllm.entrypoints.openai.protocol.ChatCompletionToolsParam]: """ Convert the list of available tools from Llama Stack's format to vLLM's version of OpenAI's format. diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index 9d742c39c..438cb14a0 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -7,7 +7,7 @@ import json import re import uuid -from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union +from collections.abc import AsyncGenerator, AsyncIterator # These vLLM modules contain names that overlap with Llama Stack names, so we import # fully-qualified names @@ -100,7 +100,7 @@ def _random_uuid_str() -> str: def _response_format_to_guided_decoding_params( - response_format: Optional[ResponseFormat], # type: ignore + response_format: ResponseFormat | None, # type: ignore ) -> vllm.sampling_params.GuidedDecodingParams: """ Translate constrained decoding parameters from Llama Stack's format to vLLM's format. @@ -131,9 +131,9 @@ def _response_format_to_guided_decoding_params( def _convert_sampling_params( - sampling_params: Optional[SamplingParams], - response_format: Optional[ResponseFormat], # type: ignore - log_prob_config: Optional[LogProbConfig], + sampling_params: SamplingParams | None, + response_format: ResponseFormat | None, # type: ignore + log_prob_config: LogProbConfig | None, ) -> vllm.SamplingParams: """Convert sampling and constrained decoding configuration from Llama Stack's format to vLLM's format.""" @@ -370,11 +370,11 @@ class VLLMInferenceImpl( self, model_id: str, content: InterleavedContent, - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, - ) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, + ) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]: if model_id not in self.model_ids: raise ValueError( f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}" @@ -403,25 +403,25 @@ class VLLMInferenceImpl( async def embeddings( self, model_id: str, - contents: List[str] | List[InterleavedContentItem], - text_truncation: Optional[TextTruncation] = TextTruncation.none, - output_dimension: Optional[int] = None, - task_type: Optional[EmbeddingTaskType] = None, + contents: list[str] | list[InterleavedContentItem], + text_truncation: TextTruncation | None = TextTruncation.none, + output_dimension: int | None = None, + task_type: EmbeddingTaskType | None = None, ) -> EmbeddingsResponse: raise NotImplementedError() async def chat_completion( self, model_id: str, - messages: List[Message], # type: ignore - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, # type: ignore - tools: Optional[List[ToolDefinition]] = None, - tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, - tool_config: Optional[ToolConfig] = None, + messages: list[Message], # type: ignore + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, # type: ignore + tools: list[ToolDefinition] | None = None, + tool_choice: ToolChoice | None = ToolChoice.auto, + tool_prompt_format: ToolPromptFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, + tool_config: ToolConfig | None = None, ) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk: sampling_params = sampling_params or SamplingParams() if model_id not in self.model_ids: @@ -605,7 +605,7 @@ class VLLMInferenceImpl( async def _chat_completion_for_meta_llama( self, request: ChatCompletionRequest - ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: + ) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]: """ Subroutine that routes chat completions for Meta Llama models through Llama Stack's chat template instead of using vLLM's version of that template. The Llama Stack version @@ -701,7 +701,7 @@ class VLLMInferenceImpl( # Tool calls come in pieces, but Llama Stack expects them in bigger chunks. We build up # those chunks and output them at the end. # This data structure holds the current set of partial tool calls. - index_to_tool_call: Dict[int, Dict] = dict() + index_to_tool_call: dict[int, dict] = dict() # The Llama Stack event stream must always start with a start event. Use an empty one to # simplify logic below diff --git a/llama_stack/providers/inline/post_training/torchtune/__init__.py b/llama_stack/providers/inline/post_training/torchtune/__init__.py index ca7801be7..7a2f9eba2 100644 --- a/llama_stack/providers/inline/post_training/torchtune/__init__.py +++ b/llama_stack/providers/inline/post_training/torchtune/__init__.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from llama_stack.distribution.datatypes import Api @@ -15,7 +15,7 @@ from .config import TorchtunePostTrainingConfig async def get_provider_impl( config: TorchtunePostTrainingConfig, - deps: Dict[Api, Any], + deps: dict[Api, Any], ): from .post_training import TorchtunePostTrainingImpl diff --git a/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py b/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py index fcadd0884..af8bd2765 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py @@ -8,7 +8,7 @@ import json import os import shutil from pathlib import Path -from typing import Any, Dict, List +from typing import Any import torch from safetensors.torch import save_file @@ -34,7 +34,7 @@ class TorchtuneCheckpointer: model_id: str, training_algorithm: str, checkpoint_dir: str, - checkpoint_files: List[str], + checkpoint_files: list[str], output_dir: str, model_type: str, ): @@ -54,11 +54,11 @@ class TorchtuneCheckpointer: # get ckpt paths self._checkpoint_path = Path.joinpath(self._checkpoint_dir, self._checkpoint_file) - def load_checkpoint(self) -> Dict[str, Any]: + def load_checkpoint(self) -> dict[str, Any]: """ Load Meta checkpoint from file. Currently only loading from a single file is supported. """ - state_dict: Dict[str, Any] = {} + state_dict: dict[str, Any] = {} model_state_dict = safe_torch_load(self._checkpoint_path) if self._model_type == ModelType.LLAMA3_VISION: from torchtune.models.llama3_2_vision._convert_weights import ( @@ -82,7 +82,7 @@ class TorchtuneCheckpointer: def save_checkpoint( self, - state_dict: Dict[str, Any], + state_dict: dict[str, Any], epoch: int, adapter_only: bool = False, checkpoint_format: str | None = None, @@ -100,7 +100,7 @@ class TorchtuneCheckpointer: def _save_meta_format_checkpoint( self, model_file_path: Path, - state_dict: Dict[str, Any], + state_dict: dict[str, Any], adapter_only: bool = False, ) -> None: model_file_path.mkdir(parents=True, exist_ok=True) @@ -168,7 +168,7 @@ class TorchtuneCheckpointer: def _save_hf_format_checkpoint( self, model_file_path: Path, - state_dict: Dict[str, Any], + state_dict: dict[str, Any], ) -> None: # the config.json file contains model params needed for state dict conversion config = json.loads(Path.joinpath(self._checkpoint_dir.parent, "config.json").read_text()) @@ -179,7 +179,7 @@ class TorchtuneCheckpointer: repo_id_path = Path.joinpath(self._checkpoint_dir.parent, REPO_ID_FNAME).with_suffix(".json") self.repo_id = None if repo_id_path.exists(): - with open(repo_id_path, "r") as json_file: + with open(repo_id_path) as json_file: data = json.load(json_file) self.repo_id = data.get("repo_id") diff --git a/llama_stack/providers/inline/post_training/torchtune/common/utils.py b/llama_stack/providers/inline/post_training/torchtune/common/utils.py index a040ca1b0..f0fa052a2 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/utils.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/utils.py @@ -10,7 +10,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Callable, Dict +from collections.abc import Callable import torch from pydantic import BaseModel @@ -35,7 +35,7 @@ class ModelConfig(BaseModel): checkpoint_type: str -MODEL_CONFIGS: Dict[str, ModelConfig] = { +MODEL_CONFIGS: dict[str, ModelConfig] = { "Llama3.2-3B-Instruct": ModelConfig( model_definition=lora_llama3_2_3b, tokenizer_type=llama3_tokenizer, @@ -48,7 +48,7 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = { ), } -DATA_FORMATS: Dict[str, Transform] = { +DATA_FORMATS: dict[str, Transform] = { "instruct": InputOutputToMessages, "dialog": ShareGPTToMessages, } diff --git a/llama_stack/providers/inline/post_training/torchtune/config.py b/llama_stack/providers/inline/post_training/torchtune/config.py index ee3504f9e..f3ce874aa 100644 --- a/llama_stack/providers/inline/post_training/torchtune/config.py +++ b/llama_stack/providers/inline/post_training/torchtune/config.py @@ -4,17 +4,17 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Literal, Optional +from typing import Any, Literal from pydantic import BaseModel class TorchtunePostTrainingConfig(BaseModel): - torch_seed: Optional[int] = None - checkpoint_format: Optional[Literal["meta", "huggingface"]] = "meta" + torch_seed: int | None = None + checkpoint_format: Literal["meta", "huggingface"] | None = "meta" @classmethod - def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]: + def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: return { "checkpoint_format": "meta", } diff --git a/llama_stack/providers/inline/post_training/torchtune/datasets/format_adapter.py b/llama_stack/providers/inline/post_training/torchtune/datasets/format_adapter.py index 6b607f1c7..96dd8b8dd 100644 --- a/llama_stack/providers/inline/post_training/torchtune/datasets/format_adapter.py +++ b/llama_stack/providers/inline/post_training/torchtune/datasets/format_adapter.py @@ -11,7 +11,8 @@ # LICENSE file in the root directory of this source tree. import json -from typing import Any, Mapping +from collections.abc import Mapping +from typing import Any from llama_stack.providers.utils.common.data_schema_validator import ColumnName diff --git a/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py b/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py index 050996860..ae7faf31e 100644 --- a/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py +++ b/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py @@ -10,7 +10,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Dict, List, Mapping +from collections.abc import Mapping +from typing import Any import numpy as np from torch.utils.data import Dataset @@ -27,7 +28,7 @@ from llama_stack.providers.inline.post_training.torchtune.datasets.format_adapte class SFTDataset(Dataset): def __init__( self, - rows: List[Dict[str, Any]], + rows: list[dict[str, Any]], message_transform: Transform, model_transform: Transform, dataset_type: str, @@ -40,11 +41,11 @@ class SFTDataset(Dataset): def __len__(self): return len(self._rows) - def __getitem__(self, index: int) -> Dict[str, Any]: + def __getitem__(self, index: int) -> dict[str, Any]: sample = self._rows[index] return self._prepare_sample(sample) - def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]: + def _prepare_sample(self, sample: Mapping[str, Any]) -> dict[str, Any]: if self._dataset_type == "instruct": sample = llama_stack_instruct_to_torchtune_instruct(sample) elif self._dataset_type == "dialog": diff --git a/llama_stack/providers/inline/post_training/torchtune/post_training.py b/llama_stack/providers/inline/post_training/torchtune/post_training.py index cc1a6a5fe..c7d8d6758 100644 --- a/llama_stack/providers/inline/post_training/torchtune/post_training.py +++ b/llama_stack/providers/inline/post_training/torchtune/post_training.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. from enum import Enum -from typing import Any, Dict, Optional +from typing import Any from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets @@ -64,7 +64,7 @@ class TorchtunePostTrainingImpl: ) @staticmethod - def _resources_stats_to_artifact(resources_stats: Dict[str, Any]) -> JobArtifact: + def _resources_stats_to_artifact(resources_stats: dict[str, Any]) -> JobArtifact: return JobArtifact( type=TrainingArtifactType.RESOURCES_STATS.value, name=TrainingArtifactType.RESOURCES_STATS.value, @@ -75,11 +75,11 @@ class TorchtunePostTrainingImpl: self, job_uuid: str, training_config: TrainingConfig, - hyperparam_search_config: Dict[str, Any], - logger_config: Dict[str, Any], + hyperparam_search_config: dict[str, Any], + logger_config: dict[str, Any], model: str, - checkpoint_dir: Optional[str], - algorithm_config: Optional[AlgorithmConfig], + checkpoint_dir: str | None, + algorithm_config: AlgorithmConfig | None, ) -> PostTrainingJob: if isinstance(algorithm_config, LoraFinetuningConfig): @@ -121,8 +121,8 @@ class TorchtunePostTrainingImpl: finetuned_model: str, algorithm_config: DPOAlignmentConfig, training_config: TrainingConfig, - hyperparam_search_config: Dict[str, Any], - logger_config: Dict[str, Any], + hyperparam_search_config: dict[str, Any], + logger_config: dict[str, Any], ) -> PostTrainingJob: ... async def get_training_jobs(self) -> ListPostTrainingJobsResponse: @@ -144,7 +144,7 @@ class TorchtunePostTrainingImpl: return data[0] if data else None @webmethod(route="/post-training/job/status") - async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]: + async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None: job = self._scheduler.get_job(job_uuid) match job.status: @@ -175,6 +175,6 @@ class TorchtunePostTrainingImpl: self._scheduler.cancel(job_uuid) @webmethod(route="/post-training/job/artifacts") - async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]: + async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None: job = self._scheduler.get_job(job_uuid) return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job)) diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 5cf15824d..1239523cd 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -11,7 +11,7 @@ import time from datetime import datetime, timezone from functools import partial from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple +from typing import Any import torch from torch import nn @@ -80,10 +80,10 @@ class LoraFinetuningSingleDevice: config: TorchtunePostTrainingConfig, job_uuid: str, training_config: TrainingConfig, - hyperparam_search_config: Dict[str, Any], - logger_config: Dict[str, Any], + hyperparam_search_config: dict[str, Any], + logger_config: dict[str, Any], model: str, - checkpoint_dir: Optional[str], + checkpoint_dir: str | None, algorithm_config: LoraFinetuningConfig | QATFinetuningConfig | None, datasetio_api: DatasetIO, datasets_api: Datasets, @@ -156,7 +156,7 @@ class LoraFinetuningSingleDevice: self.datasets_api = datasets_api async def load_checkpoint(self): - def get_checkpoint_files(checkpoint_dir: str) -> List[str]: + def get_checkpoint_files(checkpoint_dir: str) -> list[str]: try: # List all files in the given directory files = os.listdir(checkpoint_dir) @@ -250,8 +250,8 @@ class LoraFinetuningSingleDevice: self, enable_activation_checkpointing: bool, enable_activation_offloading: bool, - base_model_state_dict: Dict[str, Any], - lora_weights_state_dict: Optional[Dict[str, Any]] = None, + base_model_state_dict: dict[str, Any], + lora_weights_state_dict: dict[str, Any] | None = None, ) -> nn.Module: self._lora_rank = self.algorithm_config.rank self._lora_alpha = self.algorithm_config.alpha @@ -335,7 +335,7 @@ class LoraFinetuningSingleDevice: tokenizer: Llama3Tokenizer, shuffle: bool, batch_size: int, - ) -> Tuple[DistributedSampler, DataLoader]: + ) -> tuple[DistributedSampler, DataLoader]: async def fetch_rows(dataset_id: str): return await self.datasetio_api.iterrows( dataset_id=dataset_id, @@ -430,7 +430,7 @@ class LoraFinetuningSingleDevice: checkpoint_format=self._checkpoint_format, ) - async def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + async def _loss_step(self, batch: dict[str, torch.Tensor]) -> torch.Tensor: # Shape [b, s], needed for the loss not the model labels = batch.pop("labels") # run model @@ -452,7 +452,7 @@ class LoraFinetuningSingleDevice: return loss - async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]: + async def train(self) -> tuple[dict[str, Any], list[Checkpoint]]: """ The core training loop. """ @@ -464,7 +464,7 @@ class LoraFinetuningSingleDevice: # training artifacts checkpoints = [] - memory_stats: Dict[str, Any] = {} + memory_stats: dict[str, Any] = {} # self.epochs_run should be non-zero when we're resuming from a checkpoint for curr_epoch in range(self.epochs_run, self.total_epochs): @@ -565,7 +565,7 @@ class LoraFinetuningSingleDevice: return (memory_stats, checkpoints) - async def validation(self) -> Tuple[float, float]: + async def validation(self) -> tuple[float, float]: total_loss = 0.0 total_tokens = 0 log.info("Starting validation...") diff --git a/llama_stack/providers/inline/safety/code_scanner/__init__.py b/llama_stack/providers/inline/safety/code_scanner/__init__.py index 62975a963..68e32b747 100644 --- a/llama_stack/providers/inline/safety/code_scanner/__init__.py +++ b/llama_stack/providers/inline/safety/code_scanner/__init__.py @@ -4,12 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from .config import CodeScannerConfig -async def get_provider_impl(config: CodeScannerConfig, deps: Dict[str, Any]): +async def get_provider_impl(config: CodeScannerConfig, deps: dict[str, Any]): from .code_scanner import MetaReferenceCodeScannerSafetyImpl impl = MetaReferenceCodeScannerSafetyImpl(config, deps) diff --git a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py index 606d11d2c..be05ee436 100644 --- a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py +++ b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import logging -from typing import Any, Dict, List +from typing import Any from llama_stack.apis.inference import Message from llama_stack.apis.safety import ( @@ -48,8 +48,8 @@ class MetaReferenceCodeScannerSafetyImpl(Safety): async def run_shield( self, shield_id: str, - messages: List[Message], - params: Dict[str, Any] = None, + messages: list[Message], + params: dict[str, Any] = None, ) -> RunShieldResponse: shield = await self.shield_store.get_shield(shield_id) if not shield: diff --git a/llama_stack/providers/inline/safety/code_scanner/config.py b/llama_stack/providers/inline/safety/code_scanner/config.py index 1d880ee9c..66eb8e368 100644 --- a/llama_stack/providers/inline/safety/code_scanner/config.py +++ b/llama_stack/providers/inline/safety/code_scanner/config.py @@ -4,12 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from pydantic import BaseModel class CodeScannerConfig(BaseModel): @classmethod - def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]: + def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: return {} diff --git a/llama_stack/providers/inline/safety/llama_guard/__init__.py b/llama_stack/providers/inline/safety/llama_guard/__init__.py index a4263b169..8865cc344 100644 --- a/llama_stack/providers/inline/safety/llama_guard/__init__.py +++ b/llama_stack/providers/inline/safety/llama_guard/__init__.py @@ -4,12 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from .config import LlamaGuardConfig -async def get_provider_impl(config: LlamaGuardConfig, deps: Dict[str, Any]): +async def get_provider_impl(config: LlamaGuardConfig, deps: dict[str, Any]): from .llama_guard import LlamaGuardSafetyImpl assert isinstance(config, LlamaGuardConfig), f"Unexpected config type: {type(config)}" diff --git a/llama_stack/providers/inline/safety/llama_guard/config.py b/llama_stack/providers/inline/safety/llama_guard/config.py index 53849ab33..412e7218d 100644 --- a/llama_stack/providers/inline/safety/llama_guard/config.py +++ b/llama_stack/providers/inline/safety/llama_guard/config.py @@ -4,16 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List +from typing import Any from pydantic import BaseModel class LlamaGuardConfig(BaseModel): - excluded_categories: List[str] = [] + excluded_categories: list[str] = [] @classmethod - def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]: + def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: return { "excluded_categories": [], } diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index 2ab16f986..937301c2e 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -6,7 +6,7 @@ import re from string import Template -from typing import Any, Dict, List, Optional +from typing import Any from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem from llama_stack.apis.inference import ( @@ -149,8 +149,8 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): async def run_shield( self, shield_id: str, - messages: List[Message], - params: Dict[str, Any] = None, + messages: list[Message], + params: dict[str, Any] = None, ) -> RunShieldResponse: shield = await self.shield_store.get_shield(shield_id) if not shield: @@ -177,7 +177,7 @@ class LlamaGuardShield: self, model: str, inference_api: Inference, - excluded_categories: Optional[List[str]] = None, + excluded_categories: list[str] | None = None, ): if excluded_categories is None: excluded_categories = [] @@ -193,7 +193,7 @@ class LlamaGuardShield: self.inference_api = inference_api self.excluded_categories = excluded_categories - def check_unsafe_response(self, response: str) -> Optional[str]: + def check_unsafe_response(self, response: str) -> str | None: match = re.match(r"^unsafe\n(.*)$", response) if match: # extracts the unsafe code @@ -202,7 +202,7 @@ class LlamaGuardShield: return None - def get_safety_categories(self) -> List[str]: + def get_safety_categories(self) -> list[str]: excluded_categories = self.excluded_categories if set(excluded_categories) == set(SAFETY_CATEGORIES_TO_CODE_MAP.values()): excluded_categories = [] @@ -218,7 +218,7 @@ class LlamaGuardShield: return final_categories - def validate_messages(self, messages: List[Message]) -> None: + def validate_messages(self, messages: list[Message]) -> None: if len(messages) == 0: raise ValueError("Messages must not be empty") if messages[0].role != Role.user.value: @@ -229,7 +229,7 @@ class LlamaGuardShield: return messages - async def run(self, messages: List[Message]) -> RunShieldResponse: + async def run(self, messages: list[Message]) -> RunShieldResponse: messages = self.validate_messages(messages) if self.model == CoreModelId.llama_guard_3_11b_vision.value: @@ -247,10 +247,10 @@ class LlamaGuardShield: content = content.strip() return self.get_shield_response(content) - def build_text_shield_input(self, messages: List[Message]) -> UserMessage: + def build_text_shield_input(self, messages: list[Message]) -> UserMessage: return UserMessage(content=self.build_prompt(messages)) - def build_vision_shield_input(self, messages: List[Message]) -> UserMessage: + def build_vision_shield_input(self, messages: list[Message]) -> UserMessage: conversation = [] most_recent_img = None @@ -284,7 +284,7 @@ class LlamaGuardShield: return UserMessage(content=prompt) - def build_prompt(self, messages: List[Message]) -> str: + def build_prompt(self, messages: list[Message]) -> str: categories = self.get_safety_categories() categories_str = "\n".join(categories) conversations_str = "\n\n".join( diff --git a/llama_stack/providers/inline/safety/prompt_guard/__init__.py b/llama_stack/providers/inline/safety/prompt_guard/__init__.py index 747f34421..1761c9138 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/__init__.py +++ b/llama_stack/providers/inline/safety/prompt_guard/__init__.py @@ -4,12 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any -from .config import PromptGuardConfig # noqa: F401 +from .config import PromptGuardConfig -async def get_provider_impl(config: PromptGuardConfig, deps: Dict[str, Any]): +async def get_provider_impl(config: PromptGuardConfig, deps: dict[str, Any]): from .prompt_guard import PromptGuardSafetyImpl impl = PromptGuardSafetyImpl(config, deps) diff --git a/llama_stack/providers/inline/safety/prompt_guard/config.py b/llama_stack/providers/inline/safety/prompt_guard/config.py index 76bd5978d..69ea512c5 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/config.py +++ b/llama_stack/providers/inline/safety/prompt_guard/config.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from enum import Enum -from typing import Any, Dict +from typing import Any from pydantic import BaseModel, field_validator @@ -26,7 +26,7 @@ class PromptGuardConfig(BaseModel): return v @classmethod - def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]: + def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: return { "guard_type": "injection", } diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index fce3e3d14..56ce8285f 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import logging -from typing import Any, Dict, List +from typing import Any import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer @@ -49,8 +49,8 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate): async def run_shield( self, shield_id: str, - messages: List[Message], - params: Dict[str, Any] = None, + messages: list[Message], + params: dict[str, Any] = None, ) -> RunShieldResponse: shield = await self.shield_store.get_shield(shield_id) if not shield: @@ -81,7 +81,7 @@ class PromptGuardShield: self.tokenizer = AutoTokenizer.from_pretrained(model_dir) self.model = AutoModelForSequenceClassification.from_pretrained(model_dir, device_map=self.device) - async def run(self, messages: List[Message]) -> RunShieldResponse: + async def run(self, messages: list[Message]) -> RunShieldResponse: message = messages[-1] text = interleaved_content_as_str(message.content) diff --git a/llama_stack/providers/inline/scoring/basic/__init__.py b/llama_stack/providers/inline/scoring/basic/__init__.py index 4898b973a..d9d150b1a 100644 --- a/llama_stack/providers/inline/scoring/basic/__init__.py +++ b/llama_stack/providers/inline/scoring/basic/__init__.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from llama_stack.distribution.datatypes import Api @@ -12,7 +12,7 @@ from .config import BasicScoringConfig async def get_provider_impl( config: BasicScoringConfig, - deps: Dict[Api, Any], + deps: dict[Api, Any], ): from .scoring import BasicScoringImpl diff --git a/llama_stack/providers/inline/scoring/basic/config.py b/llama_stack/providers/inline/scoring/basic/config.py index 5866be359..e9c7fb451 100644 --- a/llama_stack/providers/inline/scoring/basic/config.py +++ b/llama_stack/providers/inline/scoring/basic/config.py @@ -3,12 +3,12 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from pydantic import BaseModel class BasicScoringConfig(BaseModel): @classmethod - def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]: + def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: return {} diff --git a/llama_stack/providers/inline/scoring/basic/scoring.py b/llama_stack/providers/inline/scoring/basic/scoring.py index 9a45f7139..09f89be5e 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring.py +++ b/llama_stack/providers/inline/scoring/basic/scoring.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Optional +from typing import Any from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets @@ -66,7 +66,7 @@ class BasicScoringImpl( async def shutdown(self) -> None: ... - async def list_scoring_functions(self) -> List[ScoringFn]: + async def list_scoring_functions(self) -> list[ScoringFn]: scoring_fn_defs_list = [ fn_def for impl in self.scoring_fn_id_impls.values() for fn_def in impl.get_supported_scoring_fn_defs() ] @@ -82,7 +82,7 @@ class BasicScoringImpl( async def score_batch( self, dataset_id: str, - scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, + scoring_functions: dict[str, ScoringFnParams | None] = None, save_results_dataset: bool = False, ) -> ScoreBatchResponse: dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) @@ -107,8 +107,8 @@ class BasicScoringImpl( async def score( self, - input_rows: List[Dict[str, Any]], - scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, + input_rows: list[dict[str, Any]], + scoring_functions: dict[str, ScoringFnParams | None] = None, ) -> ScoreResponse: res = {} for scoring_fn_id in scoring_functions.keys(): diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/bfcl_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/bfcl_scoring_fn.py index f37780f3e..b29620be2 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/bfcl_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/bfcl_scoring_fn.py @@ -6,7 +6,7 @@ import json import re -from typing import Any, Dict, Optional +from typing import Any from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring_functions import ScoringFnParams @@ -17,7 +17,7 @@ from ..utils.bfcl.checker import ast_checker, is_empty_output from .fn_defs.bfcl import bfcl -def postprocess(x: Dict[str, Any], test_category: str) -> Dict[str, Any]: +def postprocess(x: dict[str, Any], test_category: str) -> dict[str, Any]: contain_func_call = False error = None error_type = None @@ -52,11 +52,11 @@ def postprocess(x: Dict[str, Any], test_category: str) -> Dict[str, Any]: } -def gen_valid(x: Dict[str, Any]) -> Dict[str, float]: +def gen_valid(x: dict[str, Any]) -> dict[str, float]: return {"valid": x["valid"]} -def gen_relevance_acc(x: Dict[str, Any]) -> Dict[str, float]: +def gen_relevance_acc(x: dict[str, Any]) -> dict[str, float]: # This function serves for both relevance and irrelevance tests, which share the exact opposite logic. # If `test_category` is "irrelevance", the model is expected to output no function call. # No function call means either the AST decoding fails (a error message is generated) or the decoded AST does not contain any function call (such as a empty list, `[]`). @@ -78,9 +78,9 @@ class BFCLScoringFn(RegisteredBaseScoringFn): async def score_row( self, - input_row: Dict[str, Any], - scoring_fn_identifier: Optional[str] = "bfcl", - scoring_params: Optional[ScoringFnParams] = None, + input_row: dict[str, Any], + scoring_fn_identifier: str | None = "bfcl", + scoring_params: ScoringFnParams | None = None, ) -> ScoringResultRow: test_category = re.sub(r"_[0-9_-]+$", "", input_row["id"]) score_result = postprocess(input_row, test_category) diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/docvqa_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/docvqa_scoring_fn.py index 84ca55732..b87974d08 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/docvqa_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/docvqa_scoring_fn.py @@ -6,7 +6,7 @@ import json import re -from typing import Any, Dict, Optional +from typing import Any from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring_functions import ScoringFnParams @@ -228,9 +228,9 @@ class DocVQAScoringFn(RegisteredBaseScoringFn): async def score_row( self, - input_row: Dict[str, Any], - scoring_fn_identifier: Optional[str] = "docvqa", - scoring_params: Optional[ScoringFnParams] = None, + input_row: dict[str, Any], + scoring_fn_identifier: str | None = "docvqa", + scoring_params: ScoringFnParams | None = None, ) -> ScoringResultRow: expected_answers = json.loads(input_row["expected_answer"]) generated_answer = input_row["generated_answer"] diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py index 0bd6bdd48..60804330f 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Any from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring_functions import ScoringFnParams @@ -26,9 +26,9 @@ class EqualityScoringFn(RegisteredBaseScoringFn): async def score_row( self, - input_row: Dict[str, Any], - scoring_fn_identifier: Optional[str] = "equality", - scoring_params: Optional[ScoringFnParams] = None, + input_row: dict[str, Any], + scoring_fn_identifier: str | None = "equality", + scoring_params: ScoringFnParams | None = None, ) -> ScoringResultRow: assert "expected_answer" in input_row, "Expected answer not found in input row." assert "generated_answer" in input_row, "Generated answer not found in input row." diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/ifeval_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/ifeval_scoring_fn.py index 6ff856684..77f6176e6 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/ifeval_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/ifeval_scoring_fn.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Any from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring_functions import ScoringFnParams @@ -28,9 +28,9 @@ class IfEvalScoringFn(RegisteredBaseScoringFn): async def score_row( self, - input_row: Dict[str, Any], - scoring_fn_identifier: Optional[str] = None, - scoring_params: Optional[ScoringFnParams] = None, + input_row: dict[str, Any], + scoring_fn_identifier: str | None = None, + scoring_params: ScoringFnParams | None = None, ) -> ScoringResultRow: from ..utils.ifeval_utils import INSTRUCTION_DICT, INSTRUCTION_LIST diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_math_response_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_math_response_scoring_fn.py index d6c78a9ac..d765959a8 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_math_response_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_math_response_scoring_fn.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Any from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType @@ -28,9 +28,9 @@ class RegexParserMathResponseScoringFn(RegisteredBaseScoringFn): async def score_row( self, - input_row: Dict[str, Any], - scoring_fn_identifier: Optional[str] = None, - scoring_params: Optional[ScoringFnParams] = None, + input_row: dict[str, Any], + scoring_fn_identifier: str | None = None, + scoring_params: ScoringFnParams | None = None, ) -> ScoringResultRow: assert scoring_fn_identifier is not None, "Scoring function identifier not found." fn_def = self.supported_fn_defs_registry[scoring_fn_identifier] diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py index 0606a9581..cb336e303 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import re -from typing import Any, Dict, Optional +from typing import Any from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType @@ -28,9 +28,9 @@ class RegexParserScoringFn(RegisteredBaseScoringFn): async def score_row( self, - input_row: Dict[str, Any], - scoring_fn_identifier: Optional[str] = None, - scoring_params: Optional[ScoringFnParams] = None, + input_row: dict[str, Any], + scoring_fn_identifier: str | None = None, + scoring_params: ScoringFnParams | None = None, ) -> ScoringResultRow: assert scoring_fn_identifier is not None, "Scoring function identifier not found." fn_def = self.supported_fn_defs_registry[scoring_fn_identifier] diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py index 71defc433..d6e10e6c9 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Any from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring_functions import ScoringFnParams @@ -26,9 +26,9 @@ class SubsetOfScoringFn(RegisteredBaseScoringFn): async def score_row( self, - input_row: Dict[str, Any], - scoring_fn_identifier: Optional[str] = "subset_of", - scoring_params: Optional[ScoringFnParams] = None, + input_row: dict[str, Any], + scoring_fn_identifier: str | None = "subset_of", + scoring_params: ScoringFnParams | None = None, ) -> ScoringResultRow: expected_answer = input_row["expected_answer"] generated_answer = input_row["generated_answer"] diff --git a/llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py b/llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py index 28605159f..b74c3826e 100644 --- a/llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py +++ b/llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py @@ -11,8 +11,8 @@ import logging import random import re import string +from collections.abc import Iterable, Sequence from types import MappingProxyType -from typing import Dict, Iterable, List, Optional, Sequence, Union import emoji import langdetect @@ -1673,12 +1673,11 @@ def split_chinese_japanese_hindi(lines: str) -> Iterable[str]: The separator for hindi is '।' """ for line in lines.splitlines(): - for sent in re.findall( + yield from re.findall( r"[^!?。\.\!\?\!\?\.\n।]+[!?。\.\!\?\!\?\.\n।]?", line.strip(), flags=re.U, - ): - yield sent + ) def count_words_cjk(text: str) -> int: @@ -1707,7 +1706,7 @@ def count_words_cjk(text: str) -> int: return non_asian_words_cnt + asian_chars_cnt + emoji_cnt -@functools.lru_cache(maxsize=None) +@functools.cache def _get_sentence_tokenizer(): return nltk.data.load("nltk:tokenizers/punkt/english.pickle") @@ -1719,8 +1718,8 @@ def count_sentences(text): return len(tokenized_sentences) -def get_langid(text: str, lid_path: Optional[str] = None) -> str: - line_langs: List[str] = [] +def get_langid(text: str, lid_path: str | None = None) -> str: + line_langs: list[str] = [] lines = [line.strip() for line in text.split("\n") if len(line.strip()) >= 4] for line in lines: @@ -1741,7 +1740,7 @@ def generate_keywords(num_keywords): """Library of instructions""" -_InstructionArgsDtype = Optional[Dict[str, Union[int, str, Sequence[str]]]] +_InstructionArgsDtype = dict[str, int | str | Sequence[str]] | None _LANGUAGES = LANGUAGE_CODES diff --git a/llama_stack/providers/inline/scoring/basic/utils/math_utils.py b/llama_stack/providers/inline/scoring/basic/utils/math_utils.py index e11fc625b..6840aad14 100644 --- a/llama_stack/providers/inline/scoring/basic/utils/math_utils.py +++ b/llama_stack/providers/inline/scoring/basic/utils/math_utils.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import re -from typing import Sequence +from collections.abc import Sequence from llama_stack.providers.utils.scoring.basic_scoring_utils import time_limit @@ -323,7 +323,7 @@ def _fix_a_slash_b(string: str) -> str: try: ia = int(a) ib = int(b) - assert string == "{}/{}".format(ia, ib) + assert string == f"{ia}/{ib}" new_string = "\\frac{" + str(ia) + "}{" + str(ib) + "}" return new_string except (ValueError, AssertionError): diff --git a/llama_stack/providers/inline/scoring/braintrust/__init__.py b/llama_stack/providers/inline/scoring/braintrust/__init__.py index f1b0112d9..8ea6e9b96 100644 --- a/llama_stack/providers/inline/scoring/braintrust/__init__.py +++ b/llama_stack/providers/inline/scoring/braintrust/__init__.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from pydantic import BaseModel @@ -18,7 +18,7 @@ class BraintrustProviderDataValidator(BaseModel): async def get_provider_impl( config: BraintrustScoringConfig, - deps: Dict[Api, Any], + deps: dict[Api, Any], ): from .braintrust import BraintrustScoringImpl diff --git a/llama_stack/providers/inline/scoring/braintrust/braintrust.py b/llama_stack/providers/inline/scoring/braintrust/braintrust.py index 3fae83340..d6655d657 100644 --- a/llama_stack/providers/inline/scoring/braintrust/braintrust.py +++ b/llama_stack/providers/inline/scoring/braintrust/braintrust.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import os -from typing import Any, Dict, List, Optional +from typing import Any from autoevals.llm import Factuality from autoevals.ragas import ( @@ -132,7 +132,7 @@ class BraintrustScoringImpl( async def shutdown(self) -> None: ... - async def list_scoring_functions(self) -> List[ScoringFn]: + async def list_scoring_functions(self) -> list[ScoringFn]: scoring_fn_defs_list = list(self.supported_fn_defs_registry.values()) for f in scoring_fn_defs_list: assert f.identifier.startswith("braintrust"), ( @@ -159,7 +159,7 @@ class BraintrustScoringImpl( async def score_batch( self, dataset_id: str, - scoring_functions: Dict[str, Optional[ScoringFnParams]], + scoring_functions: dict[str, ScoringFnParams | None], save_results_dataset: bool = False, ) -> ScoreBatchResponse: await self.set_api_key() @@ -181,9 +181,7 @@ class BraintrustScoringImpl( results=res.results, ) - async def score_row( - self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None - ) -> ScoringResultRow: + async def score_row(self, input_row: dict[str, Any], scoring_fn_identifier: str | None = None) -> ScoringResultRow: validate_row_schema(input_row, get_valid_schemas(Api.scoring.value)) await self.set_api_key() assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None" @@ -203,8 +201,8 @@ class BraintrustScoringImpl( async def score( self, - input_rows: List[Dict[str, Any]], - scoring_functions: Dict[str, Optional[ScoringFnParams]], + input_rows: list[dict[str, Any]], + scoring_functions: dict[str, ScoringFnParams | None], ) -> ScoreResponse: await self.set_api_key() res = {} diff --git a/llama_stack/providers/inline/scoring/braintrust/config.py b/llama_stack/providers/inline/scoring/braintrust/config.py index d4e0d9bcd..4a80f1e4f 100644 --- a/llama_stack/providers/inline/scoring/braintrust/config.py +++ b/llama_stack/providers/inline/scoring/braintrust/config.py @@ -3,19 +3,19 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, Field class BraintrustScoringConfig(BaseModel): - openai_api_key: Optional[str] = Field( + openai_api_key: str | None = Field( default=None, description="The OpenAI API Key", ) @classmethod - def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + def sample_run_config(cls, **kwargs) -> dict[str, Any]: return { "openai_api_key": "${env.OPENAI_API_KEY:}", } diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/__init__.py b/llama_stack/providers/inline/scoring/llm_as_judge/__init__.py index 4a83bfe13..88bf10737 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/__init__.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/__init__.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from llama_stack.distribution.datatypes import Api @@ -12,7 +12,7 @@ from .config import LlmAsJudgeScoringConfig async def get_provider_impl( config: LlmAsJudgeScoringConfig, - deps: Dict[Api, Any], + deps: dict[Api, Any], ): from .scoring import LlmAsJudgeScoringImpl diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/config.py b/llama_stack/providers/inline/scoring/llm_as_judge/config.py index ff63fc5e7..b150ef54c 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/config.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/config.py @@ -3,12 +3,12 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from pydantic import BaseModel class LlmAsJudgeScoringConfig(BaseModel): @classmethod - def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]: + def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: return {} diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py index 7f004fbb6..b705cb9b3 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Optional +from typing import Any from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets @@ -50,7 +50,7 @@ class LlmAsJudgeScoringImpl( async def shutdown(self) -> None: ... - async def list_scoring_functions(self) -> List[ScoringFn]: + async def list_scoring_functions(self) -> list[ScoringFn]: scoring_fn_defs_list = self.llm_as_judge_fn.get_supported_scoring_fn_defs() for f in self.llm_as_judge_fn.get_supported_scoring_fn_defs(): @@ -66,7 +66,7 @@ class LlmAsJudgeScoringImpl( async def score_batch( self, dataset_id: str, - scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, + scoring_functions: dict[str, ScoringFnParams | None] = None, save_results_dataset: bool = False, ) -> ScoreBatchResponse: dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) @@ -91,8 +91,8 @@ class LlmAsJudgeScoringImpl( async def score( self, - input_rows: List[Dict[str, Any]], - scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, + input_rows: list[dict[str, Any]], + scoring_functions: dict[str, ScoringFnParams | None] = None, ) -> ScoreResponse: res = {} for scoring_fn_id in scoring_functions.keys(): diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py index f4e8ab0aa..51cdf6c3f 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import re -from typing import Any, Dict, Optional +from typing import Any from llama_stack.apis.inference.inference import Inference, UserMessage from llama_stack.apis.scoring import ScoringResultRow @@ -30,9 +30,9 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn): async def score_row( self, - input_row: Dict[str, Any], - scoring_fn_identifier: Optional[str] = None, - scoring_params: Optional[ScoringFnParams] = None, + input_row: dict[str, Any], + scoring_fn_identifier: str | None = None, + scoring_params: ScoringFnParams | None = None, ) -> ScoringResultRow: assert scoring_fn_identifier is not None, "Scoring function identifier not found." fn_def = self.supported_fn_defs_registry[scoring_fn_identifier] diff --git a/llama_stack/providers/inline/telemetry/meta_reference/__init__.py b/llama_stack/providers/inline/telemetry/meta_reference/__init__.py index 23468c5d0..09e97136a 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/__init__.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/__init__.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from llama_stack.distribution.datatypes import Api @@ -13,7 +13,7 @@ from .config import TelemetryConfig, TelemetrySink __all__ = ["TelemetryConfig", "TelemetrySink"] -async def get_provider_impl(config: TelemetryConfig, deps: Dict[Api, Any]): +async def get_provider_impl(config: TelemetryConfig, deps: dict[Api, Any]): from .telemetry import TelemetryAdapter impl = TelemetryAdapter(config, deps) diff --git a/llama_stack/providers/inline/telemetry/meta_reference/config.py b/llama_stack/providers/inline/telemetry/meta_reference/config.py index 54bdc083c..af53bfd9c 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/config.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/config.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from enum import Enum -from typing import Any, Dict, List +from typing import Any from pydantic import BaseModel, Field, field_validator @@ -33,7 +33,7 @@ class TelemetryConfig(BaseModel): default="", description="The service name to use for telemetry", ) - sinks: List[TelemetrySink] = Field( + sinks: list[TelemetrySink] = Field( default=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE], description="List of telemetry sinks to enable (possible values: otel, sqlite, console)", ) @@ -50,7 +50,7 @@ class TelemetryConfig(BaseModel): return v @classmethod - def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> Dict[str, Any]: + def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> dict[str, Any]: return { "service_name": "${env.OTEL_SERVICE_NAME:}", "sinks": "${env.TELEMETRY_SINKS:console,sqlite}", diff --git a/llama_stack/providers/inline/telemetry/meta_reference/console_span_processor.py b/llama_stack/providers/inline/telemetry/meta_reference/console_span_processor.py index b909d32ef..ff1914c15 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/console_span_processor.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/console_span_processor.py @@ -78,7 +78,7 @@ class ConsoleSpanProcessor(SpanProcessor): severity = event.attributes.get("severity", "info") message = event.attributes.get("message", event.name) - if isinstance(message, (dict, list)): + if isinstance(message, dict | list): message = json.dumps(message, indent=2) severity_colors = { diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py index 9b23c8229..9295d5cab 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import threading -from typing import Any, Dict, List, Optional +from typing import Any from opentelemetry import metrics, trace from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter @@ -60,7 +60,7 @@ def is_tracing_enabled(tracer): class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): - def __init__(self, config: TelemetryConfig, deps: Dict[Api, Any]) -> None: + def __init__(self, config: TelemetryConfig, deps: dict[Api, Any]) -> None: self.config = config self.datasetio_api = deps.get(Api.datasetio) self.meter = None @@ -231,10 +231,10 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): async def query_traces( self, - attribute_filters: Optional[List[QueryCondition]] = None, - limit: Optional[int] = 100, - offset: Optional[int] = 0, - order_by: Optional[List[str]] = None, + attribute_filters: list[QueryCondition] | None = None, + limit: int | None = 100, + offset: int | None = 0, + order_by: list[str] | None = None, ) -> QueryTracesResponse: return QueryTracesResponse( data=await self.trace_store.query_traces( @@ -254,8 +254,8 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): async def get_span_tree( self, span_id: str, - attributes_to_return: Optional[List[str]] = None, - max_depth: Optional[int] = None, + attributes_to_return: list[str] | None = None, + max_depth: int | None = None, ) -> QuerySpanTreeResponse: return QuerySpanTreeResponse( data=await self.trace_store.get_span_tree( diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/__init__.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/__init__.py index 8317ce793..d91005c6c 100644 --- a/llama_stack/providers/inline/tool_runtime/code_interpreter/__init__.py +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/__init__.py @@ -4,14 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from .config import CodeInterpreterToolConfig __all__ = ["CodeInterpreterToolConfig", "CodeInterpreterToolRuntimeImpl"] -async def get_provider_impl(config: CodeInterpreterToolConfig, _deps: Dict[str, Any]): +async def get_provider_impl(config: CodeInterpreterToolConfig, _deps: dict[str, Any]): from .code_interpreter import CodeInterpreterToolRuntimeImpl impl = CodeInterpreterToolRuntimeImpl(config) diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_execution.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_execution.py index 6106cf741..6c9765b55 100644 --- a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_execution.py +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_execution.py @@ -18,7 +18,6 @@ from dataclasses import dataclass from datetime import datetime from io import BytesIO from pathlib import Path -from typing import List from PIL import Image @@ -45,7 +44,7 @@ except: """ -def generate_bwrap_command(bind_dirs: List[str]) -> str: +def generate_bwrap_command(bind_dirs: list[str]) -> str: """ Generate the bwrap command string for binding all directories in the current directory read-only. @@ -71,7 +70,7 @@ class CodeExecutionContext: @dataclass class CodeExecutionRequest: - scripts: List[str] + scripts: list[str] only_last_cell_stdouterr: bool = True only_last_cell_fail: bool = True seed: int = 0 diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py index 10ac2fcc6..041104040 100644 --- a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py @@ -9,7 +9,7 @@ import asyncio import logging import os import tempfile -from typing import Any, Dict, Optional +from typing import Any from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( @@ -46,7 +46,7 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): return async def list_runtime_tools( - self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None ) -> ListToolDefsResponse: return ListToolDefsResponse( data=[ @@ -64,7 +64,7 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): ] ) - async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult: + async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult: script = kwargs["code"] # Use environment variable to control bwrap usage force_disable_bwrap = os.environ.get("DISABLE_CODE_SANDBOX", "").lower() in ("1", "true", "yes") diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/config.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/config.py index 7de1ec453..caf51d573 100644 --- a/llama_stack/providers/inline/tool_runtime/code_interpreter/config.py +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/config.py @@ -4,12 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from pydantic import BaseModel class CodeInterpreterToolConfig(BaseModel): @classmethod - def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]: + def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: return {} diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/utils.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/utils.py index d6f539a39..fabddbf0b 100644 --- a/llama_stack/providers/inline/tool_runtime/code_interpreter/utils.py +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/utils.py @@ -15,7 +15,7 @@ def get_code_env_prefix() -> str: global CODE_ENV_PREFIX if CODE_ENV_PREFIX is None: - with open(CODE_ENV_PREFIX_FILE, "r") as f: + with open(CODE_ENV_PREFIX_FILE) as f: CODE_ENV_PREFIX = f.read() return CODE_ENV_PREFIX diff --git a/llama_stack/providers/inline/tool_runtime/rag/__init__.py b/llama_stack/providers/inline/tool_runtime/rag/__init__.py index 0ef3c35e9..f9a6e5c55 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/__init__.py +++ b/llama_stack/providers/inline/tool_runtime/rag/__init__.py @@ -4,14 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from llama_stack.providers.datatypes import Api from .config import RagToolRuntimeConfig -async def get_provider_impl(config: RagToolRuntimeConfig, deps: Dict[Api, Any]): +async def get_provider_impl(config: RagToolRuntimeConfig, deps: dict[Api, Any]): from .memory import MemoryToolRuntimeImpl impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference]) diff --git a/llama_stack/providers/inline/tool_runtime/rag/config.py b/llama_stack/providers/inline/tool_runtime/rag/config.py index c75c3fc51..43ba78e65 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/config.py +++ b/llama_stack/providers/inline/tool_runtime/rag/config.py @@ -4,12 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from pydantic import BaseModel class RagToolRuntimeConfig(BaseModel): @classmethod - def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]: + def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: return {} diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index 8d4689e5d..df0257718 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -8,7 +8,7 @@ import asyncio import logging import secrets import string -from typing import Any, Dict, List, Optional +from typing import Any from pydantic import TypeAdapter @@ -74,7 +74,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): async def insert( self, - documents: List[RAGDocument], + documents: list[RAGDocument], vector_db_id: str, chunk_size_in_tokens: int = 512, ) -> None: @@ -101,8 +101,8 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): async def query( self, content: InterleavedContent, - vector_db_ids: List[str], - query_config: Optional[RAGQueryConfig] = None, + vector_db_ids: list[str], + query_config: RAGQueryConfig | None = None, ) -> RAGQueryResult: if not vector_db_ids: return RAGQueryResult(content=None) @@ -123,7 +123,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): ) for vector_db_id in vector_db_ids ] - results: List[QueryChunksResponse] = await asyncio.gather(*tasks) + results: list[QueryChunksResponse] = await asyncio.gather(*tasks) chunks = [c for r in results for c in r.chunks] scores = [s for r in results for s in r.scores] @@ -168,7 +168,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): ) async def list_runtime_tools( - self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None ) -> ListToolDefsResponse: # Parameters are not listed since these methods are not yet invoked automatically # by the LLM. The method is only implemented so things like /tools can list without @@ -193,7 +193,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): ] ) - async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult: + async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult: vector_db_ids = kwargs.get("vector_db_ids", []) query_config = kwargs.get("query_config") if query_config: diff --git a/llama_stack/providers/inline/vector_io/chroma/__init__.py b/llama_stack/providers/inline/vector_io/chroma/__init__.py index f39188b46..2e0efb8a1 100644 --- a/llama_stack/providers/inline/vector_io/chroma/__init__.py +++ b/llama_stack/providers/inline/vector_io/chroma/__init__.py @@ -4,14 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from llama_stack.providers.datatypes import Api from .config import ChromaVectorIOConfig -async def get_provider_impl(config: ChromaVectorIOConfig, deps: Dict[Api, Any]): +async def get_provider_impl(config: ChromaVectorIOConfig, deps: dict[Api, Any]): from llama_stack.providers.remote.vector_io.chroma.chroma import ( ChromaVectorIOAdapter, ) diff --git a/llama_stack/providers/inline/vector_io/chroma/config.py b/llama_stack/providers/inline/vector_io/chroma/config.py index 1e333fe92..81e2f289e 100644 --- a/llama_stack/providers/inline/vector_io/chroma/config.py +++ b/llama_stack/providers/inline/vector_io/chroma/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from pydantic import BaseModel @@ -13,5 +13,5 @@ class ChromaVectorIOConfig(BaseModel): db_path: str @classmethod - def sample_run_config(cls, db_path: str = "${env.CHROMADB_PATH}", **kwargs: Any) -> Dict[str, Any]: + def sample_run_config(cls, db_path: str = "${env.CHROMADB_PATH}", **kwargs: Any) -> dict[str, Any]: return {"db_path": db_path} diff --git a/llama_stack/providers/inline/vector_io/faiss/__init__.py b/llama_stack/providers/inline/vector_io/faiss/__init__.py index fc8ce70b4..68a1dee66 100644 --- a/llama_stack/providers/inline/vector_io/faiss/__init__.py +++ b/llama_stack/providers/inline/vector_io/faiss/__init__.py @@ -4,14 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from llama_stack.providers.datatypes import Api from .config import FaissVectorIOConfig -async def get_provider_impl(config: FaissVectorIOConfig, deps: Dict[Api, Any]): +async def get_provider_impl(config: FaissVectorIOConfig, deps: dict[Api, Any]): from .faiss import FaissVectorIOAdapter assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}" diff --git a/llama_stack/providers/inline/vector_io/faiss/config.py b/llama_stack/providers/inline/vector_io/faiss/config.py index fa6e5bede..cbcbb1762 100644 --- a/llama_stack/providers/inline/vector_io/faiss/config.py +++ b/llama_stack/providers/inline/vector_io/faiss/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from pydantic import BaseModel @@ -20,7 +20,7 @@ class FaissVectorIOConfig(BaseModel): kvstore: KVStoreConfig @classmethod - def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]: + def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: return { "kvstore": SqliteKVStoreConfig.sample_run_config( __distro_dir__=__distro_dir__, diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index 20c795650..5d5b1da35 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -9,7 +9,7 @@ import base64 import io import json import logging -from typing import Any, Dict, List, Optional +from typing import Any import faiss import numpy as np @@ -84,7 +84,7 @@ class FaissIndex(EmbeddingIndex): await self.kvstore.delete(f"{FAISS_INDEX_PREFIX}{self.bank_id}") - async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): + async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): # Add dimension check embedding_dim = embeddings.shape[1] if len(embeddings.shape) > 1 else embeddings.shape[0] if embedding_dim != self.index.d: @@ -159,7 +159,7 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): inference_api=self.inference_api, ) - async def list_vector_dbs(self) -> List[VectorDB]: + async def list_vector_dbs(self) -> list[VectorDB]: return [i.vector_db for i in self.cache.values()] async def unregister_vector_db(self, vector_db_id: str) -> None: @@ -176,8 +176,8 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): async def insert_chunks( self, vector_db_id: str, - chunks: List[Chunk], - ttl_seconds: Optional[int] = None, + chunks: list[Chunk], + ttl_seconds: int | None = None, ) -> None: index = self.cache.get(vector_db_id) if index is None: @@ -189,7 +189,7 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): self, vector_db_id: str, query: InterleavedContent, - params: Optional[Dict[str, Any]] = None, + params: dict[str, Any] | None = None, ) -> QueryChunksResponse: index = self.cache.get(vector_db_id) if index is None: diff --git a/llama_stack/providers/inline/vector_io/milvus/__init__.py b/llama_stack/providers/inline/vector_io/milvus/__init__.py index d88a3b005..fe3a1f7f9 100644 --- a/llama_stack/providers/inline/vector_io/milvus/__init__.py +++ b/llama_stack/providers/inline/vector_io/milvus/__init__.py @@ -4,14 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from llama_stack.providers.datatypes import Api from .config import MilvusVectorIOConfig -async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, Any]): +async def get_provider_impl(config: MilvusVectorIOConfig, deps: dict[Api, Any]): from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter impl = MilvusVectorIOAdapter(config, deps[Api.inference]) diff --git a/llama_stack/providers/inline/vector_io/milvus/config.py b/llama_stack/providers/inline/vector_io/milvus/config.py index 0e11d8c7c..eb22b5276 100644 --- a/llama_stack/providers/inline/vector_io/milvus/config.py +++ b/llama_stack/providers/inline/vector_io/milvus/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from pydantic import BaseModel @@ -16,5 +16,5 @@ class MilvusVectorIOConfig(BaseModel): db_path: str @classmethod - def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]: + def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: return {"db_path": "${env.MILVUS_DB_PATH}"} diff --git a/llama_stack/providers/inline/vector_io/qdrant/__init__.py b/llama_stack/providers/inline/vector_io/qdrant/__init__.py index 8f0b91c61..ee33b3797 100644 --- a/llama_stack/providers/inline/vector_io/qdrant/__init__.py +++ b/llama_stack/providers/inline/vector_io/qdrant/__init__.py @@ -4,14 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Dict - from llama_stack.providers.datatypes import Api, ProviderSpec from .config import QdrantVectorIOConfig -async def get_adapter_impl(config: QdrantVectorIOConfig, deps: Dict[Api, ProviderSpec]): +async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]): from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter impl = QdrantVectorIOAdapter(config, deps[Api.inference]) diff --git a/llama_stack/providers/inline/vector_io/qdrant/config.py b/llama_stack/providers/inline/vector_io/qdrant/config.py index 282e951b0..283724b41 100644 --- a/llama_stack/providers/inline/vector_io/qdrant/config.py +++ b/llama_stack/providers/inline/vector_io/qdrant/config.py @@ -5,7 +5,7 @@ # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from pydantic import BaseModel @@ -17,7 +17,7 @@ class QdrantVectorIOConfig(BaseModel): path: str @classmethod - def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]: + def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]: return { "path": "${env.QDRANT_PATH:~/.llama/" + __distro_dir__ + "}/" + "qdrant.db", } diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py b/llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py index 2380eb0ef..6db176eda 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py @@ -4,14 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from llama_stack.providers.datatypes import Api from .config import SQLiteVectorIOConfig -async def get_provider_impl(config: SQLiteVectorIOConfig, deps: Dict[Api, Any]): +async def get_provider_impl(config: SQLiteVectorIOConfig, deps: dict[Api, Any]): from .sqlite_vec import SQLiteVecVectorIOAdapter assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}" diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/config.py b/llama_stack/providers/inline/vector_io/sqlite_vec/config.py index 906c19689..cb806cb39 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/config.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from pydantic import BaseModel @@ -13,7 +13,7 @@ class SQLiteVectorIOConfig(BaseModel): db_path: str @classmethod - def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]: + def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]: return { "db_path": "${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + "sqlite_vec.db", } diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index 5f7671138..ab4384021 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -10,7 +10,7 @@ import logging import sqlite3 import struct import uuid -from typing import Any, Dict, List, Optional +from typing import Any import numpy as np import sqlite_vec @@ -25,7 +25,7 @@ from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, Vect logger = logging.getLogger(__name__) -def serialize_vector(vector: List[float]) -> bytes: +def serialize_vector(vector: list[float]) -> bytes: """Serialize a list of floats into a compact binary representation.""" return struct.pack(f"{len(vector)}f", *vector) @@ -98,7 +98,7 @@ class SQLiteVecIndex(EmbeddingIndex): await asyncio.to_thread(_drop_tables) - async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray, batch_size: int = 500): + async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray, batch_size: int = 500): """ Add new chunks along with their embeddings using batch inserts. For each chunk, we insert its JSON into the metadata table and then insert its @@ -209,7 +209,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): def __init__(self, config, inference_api: Inference) -> None: self.config = config self.inference_api = inference_api - self.cache: Dict[str, VectorDBWithIndex] = {} + self.cache: dict[str, VectorDBWithIndex] = {} async def initialize(self) -> None: def _setup_connection(): @@ -264,7 +264,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.config.db_path, vector_db.identifier) self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) - async def list_vector_dbs(self) -> List[VectorDB]: + async def list_vector_dbs(self) -> list[VectorDB]: return [v.vector_db for v in self.cache.values()] async def unregister_vector_db(self, vector_db_id: str) -> None: @@ -286,7 +286,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): await asyncio.to_thread(_delete_vector_db_from_registry) - async def insert_chunks(self, vector_db_id: str, chunks: List[Chunk], ttl_seconds: Optional[int] = None) -> None: + async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None: if vector_db_id not in self.cache: raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}") # The VectorDBWithIndex helper is expected to compute embeddings via the inference_api @@ -294,7 +294,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): await self.cache[vector_db_id].insert_chunks(chunks) async def query_chunks( - self, vector_db_id: str, query: Any, params: Optional[Dict[str, Any]] = None + self, vector_db_id: str, query: Any, params: dict[str, Any] | None = None ) -> QueryChunksResponse: if vector_db_id not in self.cache: raise ValueError(f"Vector DB {vector_db_id} not found") @@ -303,5 +303,5 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): def generate_chunk_id(document_id: str, chunk_text: str) -> str: """Generate a unique chunk ID using a hash of document ID and chunk text.""" - hash_input = f"{document_id}:{chunk_text}".encode("utf-8") + hash_input = f"{document_id}:{chunk_text}".encode() return str(uuid.UUID(hashlib.md5(hash_input).hexdigest())) diff --git a/llama_stack/providers/registry/agents.py b/llama_stack/providers/registry/agents.py index 3ed59304d..e0801a8d1 100644 --- a/llama_stack/providers/registry/agents.py +++ b/llama_stack/providers/registry/agents.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List from llama_stack.providers.datatypes import ( Api, @@ -14,7 +13,7 @@ from llama_stack.providers.datatypes import ( from llama_stack.providers.utils.kvstore import kvstore_dependencies -def available_providers() -> List[ProviderSpec]: +def available_providers() -> list[ProviderSpec]: return [ InlineProviderSpec( api=Api.agents, diff --git a/llama_stack/providers/registry/datasetio.py b/llama_stack/providers/registry/datasetio.py index 7db136136..152cc9cb9 100644 --- a/llama_stack/providers/registry/datasetio.py +++ b/llama_stack/providers/registry/datasetio.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List from llama_stack.providers.datatypes import ( AdapterSpec, @@ -15,7 +14,7 @@ from llama_stack.providers.datatypes import ( ) -def available_providers() -> List[ProviderSpec]: +def available_providers() -> list[ProviderSpec]: return [ InlineProviderSpec( api=Api.datasetio, diff --git a/llama_stack/providers/registry/eval.py b/llama_stack/providers/registry/eval.py index 9604d5da4..c9c29bbe0 100644 --- a/llama_stack/providers/registry/eval.py +++ b/llama_stack/providers/registry/eval.py @@ -4,12 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec -def available_providers() -> List[ProviderSpec]: +def available_providers() -> list[ProviderSpec]: return [ InlineProviderSpec( api=Api.eval, diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index a05ec25b1..b0abc1818 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List from llama_stack.providers.datatypes import ( AdapterSpec, @@ -29,7 +28,7 @@ META_REFERENCE_DEPS = [ ] -def available_providers() -> List[ProviderSpec]: +def available_providers() -> list[ProviderSpec]: return [ InlineProviderSpec( api=Api.inference, diff --git a/llama_stack/providers/registry/post_training.py b/llama_stack/providers/registry/post_training.py index 4d10fcf3b..35567c07d 100644 --- a/llama_stack/providers/registry/post_training.py +++ b/llama_stack/providers/registry/post_training.py @@ -4,12 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec -def available_providers() -> List[ProviderSpec]: +def available_providers() -> list[ProviderSpec]: return [ InlineProviderSpec( api=Api.post_training, diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index 54dc51034..c209da092 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List from llama_stack.providers.datatypes import ( AdapterSpec, @@ -15,7 +14,7 @@ from llama_stack.providers.datatypes import ( ) -def available_providers() -> List[ProviderSpec]: +def available_providers() -> list[ProviderSpec]: return [ InlineProviderSpec( api=Api.safety, diff --git a/llama_stack/providers/registry/scoring.py b/llama_stack/providers/registry/scoring.py index ca09be984..7980d6a13 100644 --- a/llama_stack/providers/registry/scoring.py +++ b/llama_stack/providers/registry/scoring.py @@ -4,12 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec -def available_providers() -> List[ProviderSpec]: +def available_providers() -> list[ProviderSpec]: return [ InlineProviderSpec( api=Api.scoring, diff --git a/llama_stack/providers/registry/telemetry.py b/llama_stack/providers/registry/telemetry.py index fc249f3e2..14da06126 100644 --- a/llama_stack/providers/registry/telemetry.py +++ b/llama_stack/providers/registry/telemetry.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List from llama_stack.providers.datatypes import ( Api, @@ -13,7 +12,7 @@ from llama_stack.providers.datatypes import ( ) -def available_providers() -> List[ProviderSpec]: +def available_providers() -> list[ProviderSpec]: return [ InlineProviderSpec( api=Api.telemetry, diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index 95ea2dcf9..3140626f9 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List from llama_stack.providers.datatypes import ( AdapterSpec, @@ -15,7 +14,7 @@ from llama_stack.providers.datatypes import ( ) -def available_providers() -> List[ProviderSpec]: +def available_providers() -> list[ProviderSpec]: return [ InlineProviderSpec( api=Api.tool_runtime, diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py index 93031763d..d888c8420 100644 --- a/llama_stack/providers/registry/vector_io.py +++ b/llama_stack/providers/registry/vector_io.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List from llama_stack.providers.datatypes import ( AdapterSpec, @@ -15,7 +14,7 @@ from llama_stack.providers.datatypes import ( ) -def available_providers() -> List[ProviderSpec]: +def available_providers() -> list[ProviderSpec]: return [ InlineProviderSpec( api=Api.vector_io, diff --git a/llama_stack/providers/remote/datasetio/huggingface/config.py b/llama_stack/providers/remote/datasetio/huggingface/config.py index c06996b6f..38f933728 100644 --- a/llama_stack/providers/remote/datasetio/huggingface/config.py +++ b/llama_stack/providers/remote/datasetio/huggingface/config.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from pydantic import BaseModel @@ -17,7 +17,7 @@ class HuggingfaceDatasetIOConfig(BaseModel): kvstore: KVStoreConfig @classmethod - def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]: + def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: return { "kvstore": SqliteKVStoreConfig.sample_run_config( __distro_dir__=__distro_dir__, diff --git a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py index 7a17e5e42..baecf45e9 100644 --- a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py +++ b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Optional +from typing import Any from urllib.parse import parse_qs, urlparse import datasets as hf_datasets @@ -70,8 +70,8 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): async def iterrows( self, dataset_id: str, - start_index: Optional[int] = None, - limit: Optional[int] = None, + start_index: int | None = None, + limit: int | None = None, ) -> PaginatedResponse: dataset_def = self.dataset_infos[dataset_id] path, params = parse_hf_params(dataset_def) @@ -80,7 +80,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): records = [loaded_dataset[i] for i in range(len(loaded_dataset))] return paginate_records(records, start_index, limit) - async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: + async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: dataset_def = self.dataset_infos[dataset_id] path, params = parse_hf_params(dataset_def) loaded_dataset = hf_datasets.load_dataset(path, **params) diff --git a/llama_stack/providers/remote/datasetio/nvidia/config.py b/llama_stack/providers/remote/datasetio/nvidia/config.py index 7f3dbdfbd..e616ce25c 100644 --- a/llama_stack/providers/remote/datasetio/nvidia/config.py +++ b/llama_stack/providers/remote/datasetio/nvidia/config.py @@ -6,7 +6,7 @@ import os import warnings -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, Field @@ -14,17 +14,17 @@ from pydantic import BaseModel, Field class NvidiaDatasetIOConfig(BaseModel): """Configuration for NVIDIA DatasetIO implementation.""" - api_key: Optional[str] = Field( + api_key: str | None = Field( default_factory=lambda: os.getenv("NVIDIA_API_KEY"), description="The NVIDIA API key.", ) - dataset_namespace: Optional[str] = Field( + dataset_namespace: str | None = Field( default_factory=lambda: os.getenv("NVIDIA_DATASET_NAMESPACE", "default"), description="The NVIDIA dataset namespace.", ) - project_id: Optional[str] = Field( + project_id: str | None = Field( default_factory=lambda: os.getenv("NVIDIA_PROJECT_ID", "test-project"), description="The NVIDIA project ID.", ) @@ -52,7 +52,7 @@ class NvidiaDatasetIOConfig(BaseModel): ) @classmethod - def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + def sample_run_config(cls, **kwargs) -> dict[str, Any]: return { "api_key": "${env.NVIDIA_API_KEY:}", "dataset_namespace": "${env.NVIDIA_DATASET_NAMESPACE:default}", diff --git a/llama_stack/providers/remote/datasetio/nvidia/datasetio.py b/llama_stack/providers/remote/datasetio/nvidia/datasetio.py index 83efe3991..6a9e2bb58 100644 --- a/llama_stack/providers/remote/datasetio/nvidia/datasetio.py +++ b/llama_stack/providers/remote/datasetio/nvidia/datasetio.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Optional +from typing import Any import aiohttp @@ -27,11 +27,11 @@ class NvidiaDatasetIOAdapter: self, method: str, path: str, - headers: Optional[Dict[str, Any]] = None, - params: Optional[Dict[str, Any]] = None, - json: Optional[Dict[str, Any]] = None, + headers: dict[str, Any] | None = None, + params: dict[str, Any] | None = None, + json: dict[str, Any] | None = None, **kwargs, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Helper method to make HTTP requests to the Customizer API.""" url = f"{self.config.datasets_url}{path}" request_headers = self.headers.copy() @@ -82,11 +82,11 @@ class NvidiaDatasetIOAdapter: async def update_dataset( self, dataset_id: str, - dataset_schema: Dict[str, ParamType], + dataset_schema: dict[str, ParamType], url: URL, - provider_dataset_id: Optional[str] = None, - provider_id: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, + provider_dataset_id: str | None = None, + provider_id: str | None = None, + metadata: dict[str, Any] | None = None, ) -> None: raise NotImplementedError("Not implemented") @@ -103,10 +103,10 @@ class NvidiaDatasetIOAdapter: async def iterrows( self, dataset_id: str, - start_index: Optional[int] = None, - limit: Optional[int] = None, + start_index: int | None = None, + limit: int | None = None, ) -> PaginatedResponse: raise NotImplementedError("Not implemented") - async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: + async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: raise NotImplementedError("Not implemented") diff --git a/llama_stack/providers/remote/eval/nvidia/__init__.py b/llama_stack/providers/remote/eval/nvidia/__init__.py index 8abbec9b2..55e3754f3 100644 --- a/llama_stack/providers/remote/eval/nvidia/__init__.py +++ b/llama_stack/providers/remote/eval/nvidia/__init__.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from llama_stack.distribution.datatypes import Api @@ -12,7 +12,7 @@ from .config import NVIDIAEvalConfig async def get_adapter_impl( config: NVIDIAEvalConfig, - deps: Dict[Api, Any], + deps: dict[Api, Any], ): from .eval import NVIDIAEvalImpl diff --git a/llama_stack/providers/remote/eval/nvidia/config.py b/llama_stack/providers/remote/eval/nvidia/config.py index b660fcd68..5c8f9ff76 100644 --- a/llama_stack/providers/remote/eval/nvidia/config.py +++ b/llama_stack/providers/remote/eval/nvidia/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import os -from typing import Any, Dict +from typing import Any from pydantic import BaseModel, Field @@ -23,7 +23,7 @@ class NVIDIAEvalConfig(BaseModel): ) @classmethod - def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + def sample_run_config(cls, **kwargs) -> dict[str, Any]: return { "evaluator_url": "${env.NVIDIA_EVALUATOR_URL:http://localhost:7331}", } diff --git a/llama_stack/providers/remote/eval/nvidia/eval.py b/llama_stack/providers/remote/eval/nvidia/eval.py index e1a3b5355..3572de0ef 100644 --- a/llama_stack/providers/remote/eval/nvidia/eval.py +++ b/llama_stack/providers/remote/eval/nvidia/eval.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List +from typing import Any import requests @@ -101,8 +101,8 @@ class NVIDIAEvalImpl( async def evaluate_rows( self, benchmark_id: str, - input_rows: List[Dict[str, Any]], - scoring_functions: List[str], + input_rows: list[dict[str, Any]], + scoring_functions: list[str], benchmark_config: BenchmarkConfig, ) -> EvaluateResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/anthropic/__init__.py b/llama_stack/providers/remote/inference/anthropic/__init__.py index 3075f856e..8b420a5a0 100644 --- a/llama_stack/providers/remote/inference/anthropic/__init__.py +++ b/llama_stack/providers/remote/inference/anthropic/__init__.py @@ -4,15 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Optional - from pydantic import BaseModel from .config import AnthropicConfig class AnthropicProviderDataValidator(BaseModel): - anthropic_api_key: Optional[str] = None + anthropic_api_key: str | None = None async def get_adapter_impl(config: AnthropicConfig, _deps): diff --git a/llama_stack/providers/remote/inference/anthropic/config.py b/llama_stack/providers/remote/inference/anthropic/config.py index 0e9469602..10da0025e 100644 --- a/llama_stack/providers/remote/inference/anthropic/config.py +++ b/llama_stack/providers/remote/inference/anthropic/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, Field @@ -12,7 +12,7 @@ from llama_stack.schema_utils import json_schema_type class AnthropicProviderDataValidator(BaseModel): - anthropic_api_key: Optional[str] = Field( + anthropic_api_key: str | None = Field( default=None, description="API key for Anthropic models", ) @@ -20,13 +20,13 @@ class AnthropicProviderDataValidator(BaseModel): @json_schema_type class AnthropicConfig(BaseModel): - api_key: Optional[str] = Field( + api_key: str | None = Field( default=None, description="API key for Anthropic models", ) @classmethod - def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY}", **kwargs) -> Dict[str, Any]: + def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY}", **kwargs) -> dict[str, Any]: return { "api_key": api_key, } diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index f8dbcf31a..0404a578f 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import json -from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union +from collections.abc import AsyncGenerator, AsyncIterator from botocore.client import BaseClient @@ -79,26 +79,26 @@ class BedrockInferenceAdapter( self, model_id: str, content: InterleavedContent, - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, ) -> AsyncGenerator: raise NotImplementedError() async def chat_completion( self, model_id: str, - messages: List[Message], - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - tools: Optional[List[ToolDefinition]] = None, - tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, - tool_config: Optional[ToolConfig] = None, - ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: + messages: list[Message], + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + tools: list[ToolDefinition] | None = None, + tool_choice: ToolChoice | None = ToolChoice.auto, + tool_prompt_format: ToolPromptFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, + tool_config: ToolConfig | None = None, + ) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]: if sampling_params is None: sampling_params = SamplingParams() model = await self.model_store.get_model(model_id) @@ -151,7 +151,7 @@ class BedrockInferenceAdapter( async for chunk in process_chat_completion_stream_response(stream, request): yield chunk - async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict: + async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> dict: bedrock_model = request.model sampling_params = request.sampling_params @@ -176,10 +176,10 @@ class BedrockInferenceAdapter( async def embeddings( self, model_id: str, - contents: List[str] | List[InterleavedContentItem], - text_truncation: Optional[TextTruncation] = TextTruncation.none, - output_dimension: Optional[int] = None, - task_type: Optional[EmbeddingTaskType] = None, + contents: list[str] | list[InterleavedContentItem], + text_truncation: TextTruncation | None = TextTruncation.none, + output_dimension: int | None = None, + task_type: EmbeddingTaskType | None = None, ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) embeddings = [] diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index 3156601be..685375346 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator, List, Optional, Union +from collections.abc import AsyncGenerator from cerebras.cloud.sdk import AsyncCerebras @@ -79,10 +79,10 @@ class CerebrasInferenceAdapter( self, model_id: str, content: InterleavedContent, - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, ) -> AsyncGenerator: if sampling_params is None: sampling_params = SamplingParams() @@ -120,15 +120,15 @@ class CerebrasInferenceAdapter( async def chat_completion( self, model_id: str, - messages: List[Message], - sampling_params: Optional[SamplingParams] = None, - tools: Optional[List[ToolDefinition]] = None, - tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = None, - response_format: Optional[ResponseFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, - tool_config: Optional[ToolConfig] = None, + messages: list[Message], + sampling_params: SamplingParams | None = None, + tools: list[ToolDefinition] | None = None, + tool_choice: ToolChoice | None = ToolChoice.auto, + tool_prompt_format: ToolPromptFormat | None = None, + response_format: ResponseFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, + tool_config: ToolConfig | None = None, ) -> AsyncGenerator: if sampling_params is None: sampling_params = SamplingParams() @@ -166,7 +166,7 @@ class CerebrasInferenceAdapter( async for chunk in process_chat_completion_stream_response(stream, request): yield chunk - async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: + async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict: if request.sampling_params and isinstance(request.sampling_params.strategy, TopKSamplingStrategy): raise ValueError("`top_k` not supported by Cerebras") @@ -188,9 +188,9 @@ class CerebrasInferenceAdapter( async def embeddings( self, model_id: str, - contents: List[str] | List[InterleavedContentItem], - text_truncation: Optional[TextTruncation] = TextTruncation.none, - output_dimension: Optional[int] = None, - task_type: Optional[EmbeddingTaskType] = None, + contents: list[str] | list[InterleavedContentItem], + text_truncation: TextTruncation | None = TextTruncation.none, + output_dimension: int | None = None, + task_type: EmbeddingTaskType | None = None, ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/cerebras/config.py b/llama_stack/providers/remote/inference/cerebras/config.py index 81682c980..81312ec76 100644 --- a/llama_stack/providers/remote/inference/cerebras/config.py +++ b/llama_stack/providers/remote/inference/cerebras/config.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import os -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, Field, SecretStr @@ -20,13 +20,13 @@ class CerebrasImplConfig(BaseModel): default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL), description="Base URL for the Cerebras API", ) - api_key: Optional[SecretStr] = Field( + api_key: SecretStr | None = Field( default=os.environ.get("CEREBRAS_API_KEY"), description="Cerebras API Key", ) @classmethod - def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + def sample_run_config(cls, **kwargs) -> dict[str, Any]: return { "base_url": DEFAULT_BASE_URL, "api_key": "${env.CEREBRAS_API_KEY}", diff --git a/llama_stack/providers/remote/inference/cerebras_openai_compat/config.py b/llama_stack/providers/remote/inference/cerebras_openai_compat/config.py index 149c0a202..cb8daff6a 100644 --- a/llama_stack/providers/remote/inference/cerebras_openai_compat/config.py +++ b/llama_stack/providers/remote/inference/cerebras_openai_compat/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, Field @@ -12,7 +12,7 @@ from llama_stack.schema_utils import json_schema_type class CerebrasProviderDataValidator(BaseModel): - cerebras_api_key: Optional[str] = Field( + cerebras_api_key: str | None = Field( default=None, description="API key for Cerebras models", ) @@ -20,7 +20,7 @@ class CerebrasProviderDataValidator(BaseModel): @json_schema_type class CerebrasCompatConfig(BaseModel): - api_key: Optional[str] = Field( + api_key: str | None = Field( default=None, description="The Cerebras API key", ) @@ -31,7 +31,7 @@ class CerebrasCompatConfig(BaseModel): ) @classmethod - def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY}", **kwargs) -> Dict[str, Any]: + def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY}", **kwargs) -> dict[str, Any]: return { "openai_compat_api_base": "https://api.cerebras.ai/v1", "api_key": api_key, diff --git a/llama_stack/providers/remote/inference/databricks/config.py b/llama_stack/providers/remote/inference/databricks/config.py index 1d51125cb..5710dcef3 100644 --- a/llama_stack/providers/remote/inference/databricks/config.py +++ b/llama_stack/providers/remote/inference/databricks/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from pydantic import BaseModel, Field @@ -28,7 +28,7 @@ class DatabricksImplConfig(BaseModel): url: str = "${env.DATABRICKS_URL}", api_token: str = "${env.DATABRICKS_API_TOKEN}", **kwargs: Any, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: return { "url": url, "api_token": api_token, diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index 27d96eb7d..5c36eac3e 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator, List, Optional +from collections.abc import AsyncGenerator from openai import OpenAI @@ -78,25 +78,25 @@ class DatabricksInferenceAdapter( self, model: str, content: InterleavedContent, - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, ) -> AsyncGenerator: raise NotImplementedError() async def chat_completion( self, model: str, - messages: List[Message], - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - tools: Optional[List[ToolDefinition]] = None, - tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, - tool_config: Optional[ToolConfig] = None, + messages: list[Message], + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + tools: list[ToolDefinition] | None = None, + tool_choice: ToolChoice | None = ToolChoice.auto, + tool_prompt_format: ToolPromptFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, + tool_config: ToolConfig | None = None, ) -> AsyncGenerator: if sampling_params is None: sampling_params = SamplingParams() @@ -146,9 +146,9 @@ class DatabricksInferenceAdapter( async def embeddings( self, model_id: str, - contents: List[str] | List[InterleavedContentItem], - text_truncation: Optional[TextTruncation] = TextTruncation.none, - output_dimension: Optional[int] = None, - task_type: Optional[EmbeddingTaskType] = None, + contents: list[str] | list[InterleavedContentItem], + text_truncation: TextTruncation | None = TextTruncation.none, + output_dimension: int | None = None, + task_type: EmbeddingTaskType | None = None, ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/fireworks/config.py b/llama_stack/providers/remote/inference/fireworks/config.py index c21ce4a40..072d558f4 100644 --- a/llama_stack/providers/remote/inference/fireworks/config.py +++ b/llama_stack/providers/remote/inference/fireworks/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, Field, SecretStr @@ -17,13 +17,13 @@ class FireworksImplConfig(BaseModel): default="https://api.fireworks.ai/inference/v1", description="The URL for the Fireworks server", ) - api_key: Optional[SecretStr] = Field( + api_key: SecretStr | None = Field( default=None, description="The Fireworks.ai API Key", ) @classmethod - def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY}", **kwargs) -> Dict[str, Any]: + def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY}", **kwargs) -> dict[str, Any]: return { "url": "https://api.fireworks.ai/inference/v1", "api_key": api_key, diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 58678a9cc..b6d3984c6 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -4,7 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union +from collections.abc import AsyncGenerator, AsyncIterator +from typing import Any from fireworks.client import Fireworks from openai import AsyncOpenAI @@ -105,10 +106,10 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv self, model_id: str, content: InterleavedContent, - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, ) -> AsyncGenerator: if sampling_params is None: sampling_params = SamplingParams() @@ -146,9 +147,9 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv def _build_options( self, - sampling_params: Optional[SamplingParams], + sampling_params: SamplingParams | None, fmt: ResponseFormat, - logprobs: Optional[LogProbConfig], + logprobs: LogProbConfig | None, ) -> dict: options = get_sampling_options(sampling_params) options.setdefault("max_tokens", 512) @@ -177,15 +178,15 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv async def chat_completion( self, model_id: str, - messages: List[Message], - sampling_params: Optional[SamplingParams] = None, - tools: Optional[List[ToolDefinition]] = None, - tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = None, - response_format: Optional[ResponseFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, - tool_config: Optional[ToolConfig] = None, + messages: list[Message], + sampling_params: SamplingParams | None = None, + tools: list[ToolDefinition] | None = None, + tool_choice: ToolChoice | None = ToolChoice.auto, + tool_prompt_format: ToolPromptFormat | None = None, + response_format: ResponseFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, + tool_config: ToolConfig | None = None, ) -> AsyncGenerator: if sampling_params is None: sampling_params = SamplingParams() @@ -229,7 +230,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv async for chunk in process_chat_completion_stream_response(stream, request): yield chunk - async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: + async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict: input_dict = {} media_present = request_has_media(request) @@ -263,10 +264,10 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv async def embeddings( self, model_id: str, - contents: List[str] | List[InterleavedContentItem], - text_truncation: Optional[TextTruncation] = TextTruncation.none, - output_dimension: Optional[int] = None, - task_type: Optional[EmbeddingTaskType] = None, + contents: list[str] | list[InterleavedContentItem], + text_truncation: TextTruncation | None = TextTruncation.none, + output_dimension: int | None = None, + task_type: EmbeddingTaskType | None = None, ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) @@ -288,24 +289,24 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv async def openai_completion( self, model: str, - prompt: Union[str, List[str], List[int], List[List[int]]], - best_of: Optional[int] = None, - echo: Optional[bool] = None, - frequency_penalty: Optional[float] = None, - logit_bias: Optional[Dict[str, float]] = None, - logprobs: Optional[bool] = None, - max_tokens: Optional[int] = None, - n: Optional[int] = None, - presence_penalty: Optional[float] = None, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - stream: Optional[bool] = None, - stream_options: Optional[Dict[str, Any]] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - user: Optional[str] = None, - guided_choice: Optional[List[str]] = None, - prompt_logprobs: Optional[int] = None, + prompt: str | list[str] | list[int] | list[list[int]], + best_of: int | None = None, + echo: bool | None = None, + frequency_penalty: float | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_tokens: int | None = None, + n: int | None = None, + presence_penalty: float | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + top_p: float | None = None, + user: str | None = None, + guided_choice: list[str] | None = None, + prompt_logprobs: int | None = None, ) -> OpenAICompletion: model_obj = await self.model_store.get_model(model) @@ -338,29 +339,29 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv async def openai_chat_completion( self, model: str, - messages: List[OpenAIMessageParam], - frequency_penalty: Optional[float] = None, - function_call: Optional[Union[str, Dict[str, Any]]] = None, - functions: Optional[List[Dict[str, Any]]] = None, - logit_bias: Optional[Dict[str, float]] = None, - logprobs: Optional[bool] = None, - max_completion_tokens: Optional[int] = None, - max_tokens: Optional[int] = None, - n: Optional[int] = None, - parallel_tool_calls: Optional[bool] = None, - presence_penalty: Optional[float] = None, - response_format: Optional[OpenAIResponseFormatParam] = None, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - stream: Optional[bool] = None, - stream_options: Optional[Dict[str, Any]] = None, - temperature: Optional[float] = None, - tool_choice: Optional[Union[str, Dict[str, Any]]] = None, - tools: Optional[List[Dict[str, Any]]] = None, - top_logprobs: Optional[int] = None, - top_p: Optional[float] = None, - user: Optional[str] = None, - ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: + messages: list[OpenAIMessageParam], + frequency_penalty: float | None = None, + function_call: str | dict[str, Any] | None = None, + functions: list[dict[str, Any]] | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_completion_tokens: int | None = None, + max_tokens: int | None = None, + n: int | None = None, + parallel_tool_calls: bool | None = None, + presence_penalty: float | None = None, + response_format: OpenAIResponseFormatParam | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + tool_choice: str | dict[str, Any] | None = None, + tools: list[dict[str, Any]] | None = None, + top_logprobs: int | None = None, + top_p: float | None = None, + user: str | None = None, + ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: model_obj = await self.model_store.get_model(model) # Divert Llama Models through Llama Stack inference APIs because diff --git a/llama_stack/providers/remote/inference/fireworks_openai_compat/config.py b/llama_stack/providers/remote/inference/fireworks_openai_compat/config.py index 0263d348a..bf38cdd2b 100644 --- a/llama_stack/providers/remote/inference/fireworks_openai_compat/config.py +++ b/llama_stack/providers/remote/inference/fireworks_openai_compat/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, Field @@ -12,7 +12,7 @@ from llama_stack.schema_utils import json_schema_type class FireworksProviderDataValidator(BaseModel): - fireworks_api_key: Optional[str] = Field( + fireworks_api_key: str | None = Field( default=None, description="API key for Fireworks models", ) @@ -20,7 +20,7 @@ class FireworksProviderDataValidator(BaseModel): @json_schema_type class FireworksCompatConfig(BaseModel): - api_key: Optional[str] = Field( + api_key: str | None = Field( default=None, description="The Fireworks API key", ) @@ -31,7 +31,7 @@ class FireworksCompatConfig(BaseModel): ) @classmethod - def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY}", **kwargs) -> Dict[str, Any]: + def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY}", **kwargs) -> dict[str, Any]: return { "openai_compat_api_base": "https://api.fireworks.ai/inference/v1", "api_key": api_key, diff --git a/llama_stack/providers/remote/inference/gemini/__init__.py b/llama_stack/providers/remote/inference/gemini/__init__.py index dd972f21c..9d35da893 100644 --- a/llama_stack/providers/remote/inference/gemini/__init__.py +++ b/llama_stack/providers/remote/inference/gemini/__init__.py @@ -4,15 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Optional - from pydantic import BaseModel from .config import GeminiConfig class GeminiProviderDataValidator(BaseModel): - gemini_api_key: Optional[str] = None + gemini_api_key: str | None = None async def get_adapter_impl(config: GeminiConfig, _deps): diff --git a/llama_stack/providers/remote/inference/gemini/config.py b/llama_stack/providers/remote/inference/gemini/config.py index 30c8d9913..63ef4de01 100644 --- a/llama_stack/providers/remote/inference/gemini/config.py +++ b/llama_stack/providers/remote/inference/gemini/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, Field @@ -12,7 +12,7 @@ from llama_stack.schema_utils import json_schema_type class GeminiProviderDataValidator(BaseModel): - gemini_api_key: Optional[str] = Field( + gemini_api_key: str | None = Field( default=None, description="API key for Gemini models", ) @@ -20,13 +20,13 @@ class GeminiProviderDataValidator(BaseModel): @json_schema_type class GeminiConfig(BaseModel): - api_key: Optional[str] = Field( + api_key: str | None = Field( default=None, description="API key for Gemini models", ) @classmethod - def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY}", **kwargs) -> Dict[str, Any]: + def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY}", **kwargs) -> dict[str, Any]: return { "api_key": api_key, } diff --git a/llama_stack/providers/remote/inference/groq/config.py b/llama_stack/providers/remote/inference/groq/config.py index 8a1204b0b..fe060507a 100644 --- a/llama_stack/providers/remote/inference/groq/config.py +++ b/llama_stack/providers/remote/inference/groq/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, Field @@ -12,7 +12,7 @@ from llama_stack.schema_utils import json_schema_type class GroqProviderDataValidator(BaseModel): - groq_api_key: Optional[str] = Field( + groq_api_key: str | None = Field( default=None, description="API key for Groq models", ) @@ -20,7 +20,7 @@ class GroqProviderDataValidator(BaseModel): @json_schema_type class GroqConfig(BaseModel): - api_key: Optional[str] = Field( + api_key: str | None = Field( # The Groq client library loads the GROQ_API_KEY environment variable by default default=None, description="The Groq API key", @@ -32,7 +32,7 @@ class GroqConfig(BaseModel): ) @classmethod - def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY}", **kwargs) -> Dict[str, Any]: + def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY}", **kwargs) -> dict[str, Any]: return { "url": "https://api.groq.com", "api_key": api_key, diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py index f3f14e9af..27d7d7961 100644 --- a/llama_stack/providers/remote/inference/groq/groq.py +++ b/llama_stack/providers/remote/inference/groq/groq.py @@ -4,7 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, AsyncIterator, Dict, List, Optional, Union +from collections.abc import AsyncIterator +from typing import Any from openai import AsyncOpenAI @@ -59,29 +60,29 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin): async def openai_chat_completion( self, model: str, - messages: List[OpenAIMessageParam], - frequency_penalty: Optional[float] = None, - function_call: Optional[Union[str, Dict[str, Any]]] = None, - functions: Optional[List[Dict[str, Any]]] = None, - logit_bias: Optional[Dict[str, float]] = None, - logprobs: Optional[bool] = None, - max_completion_tokens: Optional[int] = None, - max_tokens: Optional[int] = None, - n: Optional[int] = None, - parallel_tool_calls: Optional[bool] = None, - presence_penalty: Optional[float] = None, - response_format: Optional[OpenAIResponseFormatParam] = None, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - stream: Optional[bool] = None, - stream_options: Optional[Dict[str, Any]] = None, - temperature: Optional[float] = None, - tool_choice: Optional[Union[str, Dict[str, Any]]] = None, - tools: Optional[List[Dict[str, Any]]] = None, - top_logprobs: Optional[int] = None, - top_p: Optional[float] = None, - user: Optional[str] = None, - ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: + messages: list[OpenAIMessageParam], + frequency_penalty: float | None = None, + function_call: str | dict[str, Any] | None = None, + functions: list[dict[str, Any]] | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_completion_tokens: int | None = None, + max_tokens: int | None = None, + n: int | None = None, + parallel_tool_calls: bool | None = None, + presence_penalty: float | None = None, + response_format: OpenAIResponseFormatParam | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + tool_choice: str | dict[str, Any] | None = None, + tools: list[dict[str, Any]] | None = None, + top_logprobs: int | None = None, + top_p: float | None = None, + user: str | None = None, + ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: model_obj = await self.model_store.get_model(model) # Groq does not support json_schema response format, so we need to convert it to json_object diff --git a/llama_stack/providers/remote/inference/groq_openai_compat/config.py b/llama_stack/providers/remote/inference/groq_openai_compat/config.py index 4b90b4576..481f740f9 100644 --- a/llama_stack/providers/remote/inference/groq_openai_compat/config.py +++ b/llama_stack/providers/remote/inference/groq_openai_compat/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, Field @@ -12,7 +12,7 @@ from llama_stack.schema_utils import json_schema_type class GroqProviderDataValidator(BaseModel): - groq_api_key: Optional[str] = Field( + groq_api_key: str | None = Field( default=None, description="API key for Groq models", ) @@ -20,7 +20,7 @@ class GroqProviderDataValidator(BaseModel): @json_schema_type class GroqCompatConfig(BaseModel): - api_key: Optional[str] = Field( + api_key: str | None = Field( default=None, description="The Groq API key", ) @@ -31,7 +31,7 @@ class GroqCompatConfig(BaseModel): ) @classmethod - def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY}", **kwargs) -> Dict[str, Any]: + def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY}", **kwargs) -> dict[str, Any]: return { "openai_compat_api_base": "https://api.groq.com/openai/v1", "api_key": api_key, diff --git a/llama_stack/providers/remote/inference/llama_openai_compat/config.py b/llama_stack/providers/remote/inference/llama_openai_compat/config.py index e984ec803..57bc7240d 100644 --- a/llama_stack/providers/remote/inference/llama_openai_compat/config.py +++ b/llama_stack/providers/remote/inference/llama_openai_compat/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, Field @@ -12,7 +12,7 @@ from llama_stack.schema_utils import json_schema_type class LlamaProviderDataValidator(BaseModel): - llama_api_key: Optional[str] = Field( + llama_api_key: str | None = Field( default=None, description="API key for api.llama models", ) @@ -20,7 +20,7 @@ class LlamaProviderDataValidator(BaseModel): @json_schema_type class LlamaCompatConfig(BaseModel): - api_key: Optional[str] = Field( + api_key: str | None = Field( default=None, description="The Llama API key", ) @@ -31,7 +31,7 @@ class LlamaCompatConfig(BaseModel): ) @classmethod - def sample_run_config(cls, api_key: str = "${env.LLAMA_API_KEY}", **kwargs) -> Dict[str, Any]: + def sample_run_config(cls, api_key: str = "${env.LLAMA_API_KEY}", **kwargs) -> dict[str, Any]: return { "openai_compat_api_base": "https://api.llama.com/compat/v1/", "api_key": api_key, diff --git a/llama_stack/providers/remote/inference/nvidia/config.py b/llama_stack/providers/remote/inference/nvidia/config.py index 8f80408d4..4c449edc2 100644 --- a/llama_stack/providers/remote/inference/nvidia/config.py +++ b/llama_stack/providers/remote/inference/nvidia/config.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import os -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, Field, SecretStr @@ -39,7 +39,7 @@ class NVIDIAConfig(BaseModel): default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com"), description="A base url for accessing the NVIDIA NIM", ) - api_key: Optional[SecretStr] = Field( + api_key: SecretStr | None = Field( default_factory=lambda: os.getenv("NVIDIA_API_KEY"), description="The NVIDIA API key, only needed of using the hosted service", ) @@ -53,7 +53,7 @@ class NVIDIAConfig(BaseModel): ) @classmethod - def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + def sample_run_config(cls, **kwargs) -> dict[str, Any]: return { "url": "${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}", "api_key": "${env.NVIDIA_API_KEY:}", diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 4a62ad6cb..333486fe4 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -6,8 +6,9 @@ import logging import warnings +from collections.abc import AsyncIterator from functools import lru_cache -from typing import Any, AsyncIterator, Dict, List, Optional, Union +from typing import Any from openai import APIConnectionError, AsyncOpenAI, BadRequestError @@ -141,11 +142,11 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): self, model_id: str, content: InterleavedContent, - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, - ) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, + ) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]: if sampling_params is None: sampling_params = SamplingParams() if content_has_media(content): @@ -182,20 +183,20 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): async def embeddings( self, model_id: str, - contents: List[str] | List[InterleavedContentItem], - text_truncation: Optional[TextTruncation] = TextTruncation.none, - output_dimension: Optional[int] = None, - task_type: Optional[EmbeddingTaskType] = None, + contents: list[str] | list[InterleavedContentItem], + text_truncation: TextTruncation | None = TextTruncation.none, + output_dimension: int | None = None, + task_type: EmbeddingTaskType | None = None, ) -> EmbeddingsResponse: if any(content_has_media(content) for content in contents): raise NotImplementedError("Media is not supported") # - # Llama Stack: contents = List[str] | List[InterleavedContentItem] + # Llama Stack: contents = list[str] | list[InterleavedContentItem] # -> - # OpenAI: input = str | List[str] + # OpenAI: input = str | list[str] # - # we can ignore str and always pass List[str] to OpenAI + # we can ignore str and always pass list[str] to OpenAI # flat_contents = [content.text if isinstance(content, TextContentItem) else content for content in contents] input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents] @@ -231,25 +232,25 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): raise ValueError(f"Failed to get embeddings: {e}") from e # - # OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=List[float], ...)], ...) + # OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=list[float], ...)], ...) # -> - # Llama Stack: EmbeddingsResponse(embeddings=List[List[float]]) + # Llama Stack: EmbeddingsResponse(embeddings=list[list[float]]) # return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data]) async def chat_completion( self, model_id: str, - messages: List[Message], - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - tools: Optional[List[ToolDefinition]] = None, - tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, - tool_config: Optional[ToolConfig] = None, - ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: + messages: list[Message], + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + tools: list[ToolDefinition] | None = None, + tool_choice: ToolChoice | None = ToolChoice.auto, + tool_prompt_format: ToolPromptFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, + tool_config: ToolConfig | None = None, + ) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]: if sampling_params is None: sampling_params = SamplingParams() if tool_prompt_format: @@ -286,24 +287,24 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): async def openai_completion( self, model: str, - prompt: Union[str, List[str], List[int], List[List[int]]], - best_of: Optional[int] = None, - echo: Optional[bool] = None, - frequency_penalty: Optional[float] = None, - logit_bias: Optional[Dict[str, float]] = None, - logprobs: Optional[bool] = None, - max_tokens: Optional[int] = None, - n: Optional[int] = None, - presence_penalty: Optional[float] = None, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - stream: Optional[bool] = None, - stream_options: Optional[Dict[str, Any]] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - user: Optional[str] = None, - guided_choice: Optional[List[str]] = None, - prompt_logprobs: Optional[int] = None, + prompt: str | list[str] | list[int] | list[list[int]], + best_of: int | None = None, + echo: bool | None = None, + frequency_penalty: float | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_tokens: int | None = None, + n: int | None = None, + presence_penalty: float | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + top_p: float | None = None, + user: str | None = None, + guided_choice: list[str] | None = None, + prompt_logprobs: int | None = None, ) -> OpenAICompletion: provider_model_id = await self._get_provider_model_id(model) @@ -335,29 +336,29 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): async def openai_chat_completion( self, model: str, - messages: List[OpenAIMessageParam], - frequency_penalty: Optional[float] = None, - function_call: Optional[Union[str, Dict[str, Any]]] = None, - functions: Optional[List[Dict[str, Any]]] = None, - logit_bias: Optional[Dict[str, float]] = None, - logprobs: Optional[bool] = None, - max_completion_tokens: Optional[int] = None, - max_tokens: Optional[int] = None, - n: Optional[int] = None, - parallel_tool_calls: Optional[bool] = None, - presence_penalty: Optional[float] = None, - response_format: Optional[OpenAIResponseFormatParam] = None, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - stream: Optional[bool] = None, - stream_options: Optional[Dict[str, Any]] = None, - temperature: Optional[float] = None, - tool_choice: Optional[Union[str, Dict[str, Any]]] = None, - tools: Optional[List[Dict[str, Any]]] = None, - top_logprobs: Optional[int] = None, - top_p: Optional[float] = None, - user: Optional[str] = None, - ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: + messages: list[OpenAIMessageParam], + frequency_penalty: float | None = None, + function_call: str | dict[str, Any] | None = None, + functions: list[dict[str, Any]] | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_completion_tokens: int | None = None, + max_tokens: int | None = None, + n: int | None = None, + parallel_tool_calls: bool | None = None, + presence_penalty: float | None = None, + response_format: OpenAIResponseFormatParam | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + tool_choice: str | dict[str, Any] | None = None, + tools: list[dict[str, Any]] | None = None, + top_logprobs: int | None = None, + top_p: float | None = None, + user: str | None = None, + ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: provider_model_id = await self._get_provider_model_id(model) params = await prepare_openai_completion_params( diff --git a/llama_stack/providers/remote/inference/nvidia/openai_utils.py b/llama_stack/providers/remote/inference/nvidia/openai_utils.py index 3f2769b26..0b0d7fcf3 100644 --- a/llama_stack/providers/remote/inference/nvidia/openai_utils.py +++ b/llama_stack/providers/remote/inference/nvidia/openai_utils.py @@ -5,7 +5,8 @@ # the root directory of this source tree. import warnings -from typing import Any, AsyncGenerator, Dict, List, Optional +from collections.abc import AsyncGenerator +from typing import Any from openai import AsyncStream from openai.types.chat.chat_completion import ( @@ -64,7 +65,7 @@ async def convert_chat_completion_request( ) nvext = {} - payload: Dict[str, Any] = dict( + payload: dict[str, Any] = dict( model=request.model, messages=[await convert_message_to_openai_dict_new(message) for message in request.messages], stream=request.stream, @@ -137,7 +138,7 @@ def convert_completion_request( # logprobs.top_k -> logprobs nvext = {} - payload: Dict[str, Any] = dict( + payload: dict[str, Any] = dict( model=request.model, prompt=request.content, stream=request.stream, @@ -176,8 +177,8 @@ def convert_completion_request( def _convert_openai_completion_logprobs( - logprobs: Optional[OpenAICompletionLogprobs], -) -> Optional[List[TokenLogProbs]]: + logprobs: OpenAICompletionLogprobs | None, +) -> list[TokenLogProbs] | None: """ Convert an OpenAI CompletionLogprobs into a list of TokenLogProbs. """ diff --git a/llama_stack/providers/remote/inference/nvidia/utils.py b/llama_stack/providers/remote/inference/nvidia/utils.py index 7d3f3f27e..74019999e 100644 --- a/llama_stack/providers/remote/inference/nvidia/utils.py +++ b/llama_stack/providers/remote/inference/nvidia/utils.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import logging -from typing import Tuple import httpx @@ -18,7 +17,7 @@ def _is_nvidia_hosted(config: NVIDIAConfig) -> bool: return "integrate.api.nvidia.com" in config.url -async def _get_health(url: str) -> Tuple[bool, bool]: +async def _get_health(url: str) -> tuple[bool, bool]: """ Query {url}/v1/health/{live,ready} to check if the server is running and ready diff --git a/llama_stack/providers/remote/inference/ollama/config.py b/llama_stack/providers/remote/inference/ollama/config.py index a5a4d48ab..0e4aef0e1 100644 --- a/llama_stack/providers/remote/inference/ollama/config.py +++ b/llama_stack/providers/remote/inference/ollama/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from pydantic import BaseModel @@ -15,5 +15,5 @@ class OllamaImplConfig(BaseModel): url: str = DEFAULT_OLLAMA_URL @classmethod - def sample_run_config(cls, url: str = "${env.OLLAMA_URL:http://localhost:11434}", **kwargs) -> Dict[str, Any]: + def sample_run_config(cls, url: str = "${env.OLLAMA_URL:http://localhost:11434}", **kwargs) -> dict[str, Any]: return {"url": url} diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 37c187181..0cf63097b 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -5,7 +5,8 @@ # the root directory of this source tree. -from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union +from collections.abc import AsyncGenerator, AsyncIterator +from typing import Any import httpx from ollama import AsyncClient # type: ignore[attr-defined] @@ -130,10 +131,10 @@ class OllamaInferenceAdapter( self, model_id: str, content: InterleavedContent, - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, ) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]: if sampling_params is None: sampling_params = SamplingParams() @@ -188,15 +189,15 @@ class OllamaInferenceAdapter( async def chat_completion( self, model_id: str, - messages: List[Message], - sampling_params: Optional[SamplingParams] = None, - tools: Optional[List[ToolDefinition]] = None, - tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = None, - response_format: Optional[ResponseFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, - tool_config: Optional[ToolConfig] = None, + messages: list[Message], + sampling_params: SamplingParams | None = None, + tools: list[ToolDefinition] | None = None, + tool_choice: ToolChoice | None = ToolChoice.auto, + tool_prompt_format: ToolPromptFormat | None = None, + response_format: ResponseFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, + tool_config: ToolConfig | None = None, ) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]: if sampling_params is None: sampling_params = SamplingParams() @@ -216,7 +217,7 @@ class OllamaInferenceAdapter( else: return await self._nonstream_chat_completion(request) - async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: + async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict: sampling_options = get_sampling_options(request.sampling_params) # This is needed since the Ollama API expects num_predict to be set # for early truncation instead of max_tokens. @@ -314,10 +315,10 @@ class OllamaInferenceAdapter( async def embeddings( self, model_id: str, - contents: List[str] | List[InterleavedContentItem], - text_truncation: Optional[TextTruncation] = TextTruncation.none, - output_dimension: Optional[int] = None, - task_type: Optional[EmbeddingTaskType] = None, + contents: list[str] | list[InterleavedContentItem], + text_truncation: TextTruncation | None = TextTruncation.none, + output_dimension: int | None = None, + task_type: EmbeddingTaskType | None = None, ) -> EmbeddingsResponse: model = await self._get_model(model_id) @@ -365,24 +366,24 @@ class OllamaInferenceAdapter( async def openai_completion( self, model: str, - prompt: Union[str, List[str], List[int], List[List[int]]], - best_of: Optional[int] = None, - echo: Optional[bool] = None, - frequency_penalty: Optional[float] = None, - logit_bias: Optional[Dict[str, float]] = None, - logprobs: Optional[bool] = None, - max_tokens: Optional[int] = None, - n: Optional[int] = None, - presence_penalty: Optional[float] = None, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - stream: Optional[bool] = None, - stream_options: Optional[Dict[str, Any]] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - user: Optional[str] = None, - guided_choice: Optional[List[str]] = None, - prompt_logprobs: Optional[int] = None, + prompt: str | list[str] | list[int] | list[list[int]], + best_of: int | None = None, + echo: bool | None = None, + frequency_penalty: float | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_tokens: int | None = None, + n: int | None = None, + presence_penalty: float | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + top_p: float | None = None, + user: str | None = None, + guided_choice: list[str] | None = None, + prompt_logprobs: int | None = None, ) -> OpenAICompletion: if not isinstance(prompt, str): raise ValueError("Ollama does not support non-string prompts for completion") @@ -416,29 +417,29 @@ class OllamaInferenceAdapter( async def openai_chat_completion( self, model: str, - messages: List[OpenAIMessageParam], - frequency_penalty: Optional[float] = None, - function_call: Optional[Union[str, Dict[str, Any]]] = None, - functions: Optional[List[Dict[str, Any]]] = None, - logit_bias: Optional[Dict[str, float]] = None, - logprobs: Optional[bool] = None, - max_completion_tokens: Optional[int] = None, - max_tokens: Optional[int] = None, - n: Optional[int] = None, - parallel_tool_calls: Optional[bool] = None, - presence_penalty: Optional[float] = None, - response_format: Optional[OpenAIResponseFormatParam] = None, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - stream: Optional[bool] = None, - stream_options: Optional[Dict[str, Any]] = None, - temperature: Optional[float] = None, - tool_choice: Optional[Union[str, Dict[str, Any]]] = None, - tools: Optional[List[Dict[str, Any]]] = None, - top_logprobs: Optional[int] = None, - top_p: Optional[float] = None, - user: Optional[str] = None, - ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: + messages: list[OpenAIMessageParam], + frequency_penalty: float | None = None, + function_call: str | dict[str, Any] | None = None, + functions: list[dict[str, Any]] | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_completion_tokens: int | None = None, + max_tokens: int | None = None, + n: int | None = None, + parallel_tool_calls: bool | None = None, + presence_penalty: float | None = None, + response_format: OpenAIResponseFormatParam | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + tool_choice: str | dict[str, Any] | None = None, + tools: list[dict[str, Any]] | None = None, + top_logprobs: int | None = None, + top_p: float | None = None, + user: str | None = None, + ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: model_obj = await self._get_model(model) # ollama still makes tool calls even when tool_choice is "none" @@ -480,27 +481,27 @@ class OllamaInferenceAdapter( async def batch_completion( self, model_id: str, - content_batch: List[InterleavedContent], - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - logprobs: Optional[LogProbConfig] = None, + content_batch: list[InterleavedContent], + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + logprobs: LogProbConfig | None = None, ): raise NotImplementedError("Batch completion is not supported for Ollama") async def batch_chat_completion( self, model_id: str, - messages_batch: List[List[Message]], - sampling_params: Optional[SamplingParams] = None, - tools: Optional[List[ToolDefinition]] = None, - tool_config: Optional[ToolConfig] = None, - response_format: Optional[ResponseFormat] = None, - logprobs: Optional[LogProbConfig] = None, + messages_batch: list[list[Message]], + sampling_params: SamplingParams | None = None, + tools: list[ToolDefinition] | None = None, + tool_config: ToolConfig | None = None, + response_format: ResponseFormat | None = None, + logprobs: LogProbConfig | None = None, ): raise NotImplementedError("Batch chat completion is not supported for Ollama") -async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]: +async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]: async def _convert_content(content) -> dict: if isinstance(content, ImageContentItem): return { diff --git a/llama_stack/providers/remote/inference/openai/__init__.py b/llama_stack/providers/remote/inference/openai/__init__.py index 000a03d33..c245dbe10 100644 --- a/llama_stack/providers/remote/inference/openai/__init__.py +++ b/llama_stack/providers/remote/inference/openai/__init__.py @@ -4,15 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Optional - from pydantic import BaseModel from .config import OpenAIConfig class OpenAIProviderDataValidator(BaseModel): - openai_api_key: Optional[str] = None + openai_api_key: str | None = None async def get_adapter_impl(config: OpenAIConfig, _deps): diff --git a/llama_stack/providers/remote/inference/openai/config.py b/llama_stack/providers/remote/inference/openai/config.py index 2b0cc2c10..17fb98831 100644 --- a/llama_stack/providers/remote/inference/openai/config.py +++ b/llama_stack/providers/remote/inference/openai/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, Field @@ -12,7 +12,7 @@ from llama_stack.schema_utils import json_schema_type class OpenAIProviderDataValidator(BaseModel): - openai_api_key: Optional[str] = Field( + openai_api_key: str | None = Field( default=None, description="API key for OpenAI models", ) @@ -20,13 +20,13 @@ class OpenAIProviderDataValidator(BaseModel): @json_schema_type class OpenAIConfig(BaseModel): - api_key: Optional[str] = Field( + api_key: str | None = Field( default=None, description="API key for OpenAI models", ) @classmethod - def sample_run_config(cls, api_key: str = "${env.OPENAI_API_KEY}", **kwargs) -> Dict[str, Any]: + def sample_run_config(cls, api_key: str = "${env.OPENAI_API_KEY}", **kwargs) -> dict[str, Any]: return { "api_key": api_key, } diff --git a/llama_stack/providers/remote/inference/passthrough/config.py b/llama_stack/providers/remote/inference/passthrough/config.py index 46325e428..ce41495ce 100644 --- a/llama_stack/providers/remote/inference/passthrough/config.py +++ b/llama_stack/providers/remote/inference/passthrough/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, Field, SecretStr @@ -18,13 +18,13 @@ class PassthroughImplConfig(BaseModel): description="The URL for the passthrough endpoint", ) - api_key: Optional[SecretStr] = Field( + api_key: SecretStr | None = Field( default=None, description="API Key for the passthrouth endpoint", ) @classmethod - def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + def sample_run_config(cls, **kwargs) -> dict[str, Any]: return { "url": "${env.PASSTHROUGH_URL}", "api_key": "${env.PASSTHROUGH_API_KEY}", diff --git a/llama_stack/providers/remote/inference/passthrough/passthrough.py b/llama_stack/providers/remote/inference/passthrough/passthrough.py index af05320b0..78ee52641 100644 --- a/llama_stack/providers/remote/inference/passthrough/passthrough.py +++ b/llama_stack/providers/remote/inference/passthrough/passthrough.py @@ -4,7 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union +from collections.abc import AsyncGenerator, AsyncIterator +from typing import Any from llama_stack_client import AsyncLlamaStackClient @@ -93,10 +94,10 @@ class PassthroughInferenceAdapter(Inference): self, model_id: str, content: InterleavedContent, - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, ) -> AsyncGenerator: if sampling_params is None: sampling_params = SamplingParams() @@ -123,15 +124,15 @@ class PassthroughInferenceAdapter(Inference): async def chat_completion( self, model_id: str, - messages: List[Message], - sampling_params: Optional[SamplingParams] = None, - tools: Optional[List[ToolDefinition]] = None, - tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = None, - response_format: Optional[ResponseFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, - tool_config: Optional[ToolConfig] = None, + messages: list[Message], + sampling_params: SamplingParams | None = None, + tools: list[ToolDefinition] | None = None, + tool_choice: ToolChoice | None = ToolChoice.auto, + tool_prompt_format: ToolPromptFormat | None = None, + response_format: ResponseFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, + tool_config: ToolConfig | None = None, ) -> AsyncGenerator: if sampling_params is None: sampling_params = SamplingParams() @@ -165,7 +166,7 @@ class PassthroughInferenceAdapter(Inference): else: return await self._nonstream_chat_completion(json_params) - async def _nonstream_chat_completion(self, json_params: Dict[str, Any]) -> ChatCompletionResponse: + async def _nonstream_chat_completion(self, json_params: dict[str, Any]) -> ChatCompletionResponse: client = self._get_client() response = await client.inference.chat_completion(**json_params) @@ -178,7 +179,7 @@ class PassthroughInferenceAdapter(Inference): logprobs=response.logprobs, ) - async def _stream_chat_completion(self, json_params: Dict[str, Any]) -> AsyncGenerator: + async def _stream_chat_completion(self, json_params: dict[str, Any]) -> AsyncGenerator: client = self._get_client() stream_response = await client.inference.chat_completion(**json_params) @@ -193,10 +194,10 @@ class PassthroughInferenceAdapter(Inference): async def embeddings( self, model_id: str, - contents: List[InterleavedContent], - text_truncation: Optional[TextTruncation] = TextTruncation.none, - output_dimension: Optional[int] = None, - task_type: Optional[EmbeddingTaskType] = None, + contents: list[InterleavedContent], + text_truncation: TextTruncation | None = TextTruncation.none, + output_dimension: int | None = None, + task_type: EmbeddingTaskType | None = None, ) -> EmbeddingsResponse: client = self._get_client() model = await self.model_store.get_model(model_id) @@ -212,24 +213,24 @@ class PassthroughInferenceAdapter(Inference): async def openai_completion( self, model: str, - prompt: Union[str, List[str], List[int], List[List[int]]], - best_of: Optional[int] = None, - echo: Optional[bool] = None, - frequency_penalty: Optional[float] = None, - logit_bias: Optional[Dict[str, float]] = None, - logprobs: Optional[bool] = None, - max_tokens: Optional[int] = None, - n: Optional[int] = None, - presence_penalty: Optional[float] = None, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - stream: Optional[bool] = None, - stream_options: Optional[Dict[str, Any]] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - user: Optional[str] = None, - guided_choice: Optional[List[str]] = None, - prompt_logprobs: Optional[int] = None, + prompt: str | list[str] | list[int] | list[list[int]], + best_of: int | None = None, + echo: bool | None = None, + frequency_penalty: float | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_tokens: int | None = None, + n: int | None = None, + presence_penalty: float | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + top_p: float | None = None, + user: str | None = None, + guided_choice: list[str] | None = None, + prompt_logprobs: int | None = None, ) -> OpenAICompletion: client = self._get_client() model_obj = await self.model_store.get_model(model) @@ -261,29 +262,29 @@ class PassthroughInferenceAdapter(Inference): async def openai_chat_completion( self, model: str, - messages: List[OpenAIMessageParam], - frequency_penalty: Optional[float] = None, - function_call: Optional[Union[str, Dict[str, Any]]] = None, - functions: Optional[List[Dict[str, Any]]] = None, - logit_bias: Optional[Dict[str, float]] = None, - logprobs: Optional[bool] = None, - max_completion_tokens: Optional[int] = None, - max_tokens: Optional[int] = None, - n: Optional[int] = None, - parallel_tool_calls: Optional[bool] = None, - presence_penalty: Optional[float] = None, - response_format: Optional[OpenAIResponseFormatParam] = None, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - stream: Optional[bool] = None, - stream_options: Optional[Dict[str, Any]] = None, - temperature: Optional[float] = None, - tool_choice: Optional[Union[str, Dict[str, Any]]] = None, - tools: Optional[List[Dict[str, Any]]] = None, - top_logprobs: Optional[int] = None, - top_p: Optional[float] = None, - user: Optional[str] = None, - ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: + messages: list[OpenAIMessageParam], + frequency_penalty: float | None = None, + function_call: str | dict[str, Any] | None = None, + functions: list[dict[str, Any]] | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_completion_tokens: int | None = None, + max_tokens: int | None = None, + n: int | None = None, + parallel_tool_calls: bool | None = None, + presence_penalty: float | None = None, + response_format: OpenAIResponseFormatParam | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + tool_choice: str | dict[str, Any] | None = None, + tools: list[dict[str, Any]] | None = None, + top_logprobs: int | None = None, + top_p: float | None = None, + user: str | None = None, + ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: client = self._get_client() model_obj = await self.model_store.get_model(model) @@ -315,7 +316,7 @@ class PassthroughInferenceAdapter(Inference): return await client.inference.openai_chat_completion(**params) - def cast_value_to_json_dict(self, request_params: Dict[str, Any]) -> Dict[str, Any]: + def cast_value_to_json_dict(self, request_params: dict[str, Any]) -> dict[str, Any]: json_params = {} for key, value in request_params.items(): json_input = convert_pydantic_to_json_value(value) diff --git a/llama_stack/providers/remote/inference/runpod/config.py b/llama_stack/providers/remote/inference/runpod/config.py index 377a7fe6a..e3913dc35 100644 --- a/llama_stack/providers/remote/inference/runpod/config.py +++ b/llama_stack/providers/remote/inference/runpod/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, Field @@ -13,17 +13,17 @@ from llama_stack.schema_utils import json_schema_type @json_schema_type class RunpodImplConfig(BaseModel): - url: Optional[str] = Field( + url: str | None = Field( default=None, description="The URL for the Runpod model serving endpoint", ) - api_token: Optional[str] = Field( + api_token: str | None = Field( default=None, description="The API token", ) @classmethod - def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]: + def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]: return { "url": "${env.RUNPOD_URL:}", "api_token": "${env.RUNPOD_API_TOKEN:}", diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index 72cbead9b..2706aa15e 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator +from collections.abc import AsyncGenerator from openai import OpenAI diff --git a/llama_stack/providers/remote/inference/sambanova/config.py b/llama_stack/providers/remote/inference/sambanova/config.py index a30c29b74..8ca11de78 100644 --- a/llama_stack/providers/remote/inference/sambanova/config.py +++ b/llama_stack/providers/remote/inference/sambanova/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, Field @@ -17,13 +17,13 @@ class SambaNovaImplConfig(BaseModel): default="https://api.sambanova.ai/v1", description="The URL for the SambaNova AI server", ) - api_key: Optional[str] = Field( + api_key: str | None = Field( default=None, description="The SambaNova.ai API Key", ) @classmethod - def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + def sample_run_config(cls, **kwargs) -> dict[str, Any]: return { "url": "https://api.sambanova.ai/v1", "api_key": "${env.SAMBANOVA_API_KEY}", diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index 1665e72b8..3db95dcb4 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import json -from typing import AsyncGenerator, List, Optional +from collections.abc import AsyncGenerator from openai import OpenAI @@ -77,25 +77,25 @@ class SambaNovaInferenceAdapter( self, model_id: str, content: InterleavedContent, - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, ) -> AsyncGenerator: raise NotImplementedError() async def chat_completion( self, model_id: str, - messages: List[Message], - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - tools: Optional[List[ToolDefinition]] = None, - tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, - stream: Optional[bool] = False, - tool_config: Optional[ToolConfig] = None, - logprobs: Optional[LogProbConfig] = None, + messages: list[Message], + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + tools: list[ToolDefinition] | None = None, + tool_choice: ToolChoice | None = ToolChoice.auto, + tool_prompt_format: ToolPromptFormat | None = ToolPromptFormat.json, + stream: bool | None = False, + tool_config: ToolConfig | None = None, + logprobs: LogProbConfig | None = None, ) -> AsyncGenerator: if sampling_params is None: sampling_params = SamplingParams() @@ -146,10 +146,10 @@ class SambaNovaInferenceAdapter( async def embeddings( self, model_id: str, - contents: List[str] | List[InterleavedContentItem], - text_truncation: Optional[TextTruncation] = TextTruncation.none, - output_dimension: Optional[int] = None, - task_type: Optional[EmbeddingTaskType] = None, + contents: list[str] | list[InterleavedContentItem], + text_truncation: TextTruncation | None = TextTruncation.none, + output_dimension: int | None = None, + task_type: EmbeddingTaskType | None = None, ) -> EmbeddingsResponse: raise NotImplementedError() @@ -186,7 +186,7 @@ class SambaNovaInferenceAdapter( return params - async def convert_to_sambanova_messages(self, messages: List[Message]) -> List[dict]: + async def convert_to_sambanova_messages(self, messages: list[Message]) -> list[dict]: conversation = [] for message in messages: content = {} @@ -244,7 +244,7 @@ class SambaNovaInferenceAdapter( return content - def convert_to_sambanova_tool(self, tools: List[ToolDefinition]) -> List[dict]: + def convert_to_sambanova_tool(self, tools: list[ToolDefinition]) -> list[dict]: if tools is None: return tools @@ -292,7 +292,7 @@ class SambaNovaInferenceAdapter( def convert_to_sambanova_tool_calls( self, tool_calls, - ) -> List[ToolCall]: + ) -> list[ToolCall]: if not tool_calls: return [] diff --git a/llama_stack/providers/remote/inference/sambanova_openai_compat/config.py b/llama_stack/providers/remote/inference/sambanova_openai_compat/config.py index b792cb6e7..072fa85d1 100644 --- a/llama_stack/providers/remote/inference/sambanova_openai_compat/config.py +++ b/llama_stack/providers/remote/inference/sambanova_openai_compat/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, Field @@ -12,7 +12,7 @@ from llama_stack.schema_utils import json_schema_type class SambaNovaProviderDataValidator(BaseModel): - sambanova_api_key: Optional[str] = Field( + sambanova_api_key: str | None = Field( default=None, description="API key for SambaNova models", ) @@ -20,7 +20,7 @@ class SambaNovaProviderDataValidator(BaseModel): @json_schema_type class SambaNovaCompatConfig(BaseModel): - api_key: Optional[str] = Field( + api_key: str | None = Field( default=None, description="The SambaNova API key", ) @@ -31,7 +31,7 @@ class SambaNovaCompatConfig(BaseModel): ) @classmethod - def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> Dict[str, Any]: + def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]: return { "openai_compat_api_base": "https://api.sambanova.ai/v1", "api_key": api_key, diff --git a/llama_stack/providers/remote/inference/tgi/__init__.py b/llama_stack/providers/remote/inference/tgi/__init__.py index 834e51324..51614f1a6 100644 --- a/llama_stack/providers/remote/inference/tgi/__init__.py +++ b/llama_stack/providers/remote/inference/tgi/__init__.py @@ -4,13 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Union - from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig async def get_adapter_impl( - config: Union[InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig], + config: InferenceAPIImplConfig | InferenceEndpointImplConfig | TGIImplConfig, _deps, ): from .tgi import InferenceAPIAdapter, InferenceEndpointAdapter, TGIAdapter diff --git a/llama_stack/providers/remote/inference/tgi/config.py b/llama_stack/providers/remote/inference/tgi/config.py index 6ad663662..3d632c9d8 100644 --- a/llama_stack/providers/remote/inference/tgi/config.py +++ b/llama_stack/providers/remote/inference/tgi/config.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Optional from pydantic import BaseModel, Field, SecretStr @@ -29,7 +28,7 @@ class InferenceEndpointImplConfig(BaseModel): endpoint_name: str = Field( description="The name of the Hugging Face Inference Endpoint in the format of '{namespace}/{endpoint_name}' (e.g. 'my-cool-org/meta-llama-3-1-8b-instruct-rce'). Namespace is optional and will default to the user account if not provided.", ) - api_token: Optional[SecretStr] = Field( + api_token: SecretStr | None = Field( default=None, description="Your Hugging Face user access token (will default to locally saved token if not provided)", ) @@ -52,7 +51,7 @@ class InferenceAPIImplConfig(BaseModel): huggingface_repo: str = Field( description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')", ) - api_token: Optional[SecretStr] = Field( + api_token: SecretStr | None = Field( default=None, description="Your Hugging Face user access token (will default to locally saved token if not provided)", ) diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 4ee386a15..8f6666462 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -6,7 +6,7 @@ import logging -from typing import AsyncGenerator, List, Optional +from collections.abc import AsyncGenerator from huggingface_hub import AsyncInferenceClient, HfApi @@ -105,10 +105,10 @@ class _HfAdapter( self, model_id: str, content: InterleavedContent, - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, ) -> AsyncGenerator: if sampling_params is None: sampling_params = SamplingParams() @@ -134,7 +134,7 @@ class _HfAdapter( def _build_options( self, - sampling_params: Optional[SamplingParams] = None, + sampling_params: SamplingParams | None = None, fmt: ResponseFormat = None, ): options = get_sampling_options(sampling_params) @@ -209,15 +209,15 @@ class _HfAdapter( async def chat_completion( self, model_id: str, - messages: List[Message], - sampling_params: Optional[SamplingParams] = None, - tools: Optional[List[ToolDefinition]] = None, - tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = None, - response_format: Optional[ResponseFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, - tool_config: Optional[ToolConfig] = None, + messages: list[Message], + sampling_params: SamplingParams | None = None, + tools: list[ToolDefinition] | None = None, + tool_choice: ToolChoice | None = ToolChoice.auto, + tool_prompt_format: ToolPromptFormat | None = None, + response_format: ResponseFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, + tool_config: ToolConfig | None = None, ) -> AsyncGenerator: if sampling_params is None: sampling_params = SamplingParams() @@ -284,10 +284,10 @@ class _HfAdapter( async def embeddings( self, model_id: str, - contents: List[str] | List[InterleavedContentItem], - text_truncation: Optional[TextTruncation] = TextTruncation.none, - output_dimension: Optional[int] = None, - task_type: Optional[EmbeddingTaskType] = None, + contents: list[str] | list[InterleavedContentItem], + text_truncation: TextTruncation | None = TextTruncation.none, + output_dimension: int | None = None, + task_type: EmbeddingTaskType | None = None, ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/together/config.py b/llama_stack/providers/remote/inference/together/config.py index fa7c45c9f..5c7f60519 100644 --- a/llama_stack/providers/remote/inference/together/config.py +++ b/llama_stack/providers/remote/inference/together/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, Field, SecretStr @@ -17,13 +17,13 @@ class TogetherImplConfig(BaseModel): default="https://api.together.xyz/v1", description="The URL for the Together AI server", ) - api_key: Optional[SecretStr] = Field( + api_key: SecretStr | None = Field( default=None, description="The Together AI API Key", ) @classmethod - def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + def sample_run_config(cls, **kwargs) -> dict[str, Any]: return { "url": "https://api.together.xyz/v1", "api_key": "${env.TOGETHER_API_KEY:}", diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 48e41f5b0..562e6e0ff 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -4,7 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union +from collections.abc import AsyncGenerator, AsyncIterator +from typing import Any from openai import AsyncOpenAI from together import AsyncTogether @@ -86,10 +87,10 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi self, model_id: str, content: InterleavedContent, - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, ) -> AsyncGenerator: if sampling_params is None: sampling_params = SamplingParams() @@ -147,8 +148,8 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi def _build_options( self, - sampling_params: Optional[SamplingParams], - logprobs: Optional[LogProbConfig], + sampling_params: SamplingParams | None, + logprobs: LogProbConfig | None, fmt: ResponseFormat, ) -> dict: options = get_sampling_options(sampling_params) @@ -175,15 +176,15 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi async def chat_completion( self, model_id: str, - messages: List[Message], - sampling_params: Optional[SamplingParams] = None, - tools: Optional[List[ToolDefinition]] = None, - tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = None, - response_format: Optional[ResponseFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, - tool_config: Optional[ToolConfig] = None, + messages: list[Message], + sampling_params: SamplingParams | None = None, + tools: list[ToolDefinition] | None = None, + tool_choice: ToolChoice | None = ToolChoice.auto, + tool_prompt_format: ToolPromptFormat | None = None, + response_format: ResponseFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, + tool_config: ToolConfig | None = None, ) -> AsyncGenerator: if sampling_params is None: sampling_params = SamplingParams() @@ -224,7 +225,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi async for chunk in process_chat_completion_stream_response(stream, request): yield chunk - async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: + async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict: input_dict = {} media_present = request_has_media(request) llama_model = self.get_llama_model(request.model) @@ -249,10 +250,10 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi async def embeddings( self, model_id: str, - contents: List[str] | List[InterleavedContentItem], - text_truncation: Optional[TextTruncation] = TextTruncation.none, - output_dimension: Optional[int] = None, - task_type: Optional[EmbeddingTaskType] = None, + contents: list[str] | list[InterleavedContentItem], + text_truncation: TextTruncation | None = TextTruncation.none, + output_dimension: int | None = None, + task_type: EmbeddingTaskType | None = None, ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) assert all(not content_has_media(content) for content in contents), ( @@ -269,24 +270,24 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi async def openai_completion( self, model: str, - prompt: Union[str, List[str], List[int], List[List[int]]], - best_of: Optional[int] = None, - echo: Optional[bool] = None, - frequency_penalty: Optional[float] = None, - logit_bias: Optional[Dict[str, float]] = None, - logprobs: Optional[bool] = None, - max_tokens: Optional[int] = None, - n: Optional[int] = None, - presence_penalty: Optional[float] = None, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - stream: Optional[bool] = None, - stream_options: Optional[Dict[str, Any]] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - user: Optional[str] = None, - guided_choice: Optional[List[str]] = None, - prompt_logprobs: Optional[int] = None, + prompt: str | list[str] | list[int] | list[list[int]], + best_of: int | None = None, + echo: bool | None = None, + frequency_penalty: float | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_tokens: int | None = None, + n: int | None = None, + presence_penalty: float | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + top_p: float | None = None, + user: str | None = None, + guided_choice: list[str] | None = None, + prompt_logprobs: int | None = None, ) -> OpenAICompletion: model_obj = await self.model_store.get_model(model) params = await prepare_openai_completion_params( @@ -313,29 +314,29 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi async def openai_chat_completion( self, model: str, - messages: List[OpenAIMessageParam], - frequency_penalty: Optional[float] = None, - function_call: Optional[Union[str, Dict[str, Any]]] = None, - functions: Optional[List[Dict[str, Any]]] = None, - logit_bias: Optional[Dict[str, float]] = None, - logprobs: Optional[bool] = None, - max_completion_tokens: Optional[int] = None, - max_tokens: Optional[int] = None, - n: Optional[int] = None, - parallel_tool_calls: Optional[bool] = None, - presence_penalty: Optional[float] = None, - response_format: Optional[OpenAIResponseFormatParam] = None, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - stream: Optional[bool] = None, - stream_options: Optional[Dict[str, Any]] = None, - temperature: Optional[float] = None, - tool_choice: Optional[Union[str, Dict[str, Any]]] = None, - tools: Optional[List[Dict[str, Any]]] = None, - top_logprobs: Optional[int] = None, - top_p: Optional[float] = None, - user: Optional[str] = None, - ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: + messages: list[OpenAIMessageParam], + frequency_penalty: float | None = None, + function_call: str | dict[str, Any] | None = None, + functions: list[dict[str, Any]] | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_completion_tokens: int | None = None, + max_tokens: int | None = None, + n: int | None = None, + parallel_tool_calls: bool | None = None, + presence_penalty: float | None = None, + response_format: OpenAIResponseFormatParam | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + tool_choice: str | dict[str, Any] | None = None, + tools: list[dict[str, Any]] | None = None, + top_logprobs: int | None = None, + top_p: float | None = None, + user: str | None = None, + ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: model_obj = await self.model_store.get_model(model) params = await prepare_openai_completion_params( model=model_obj.provider_resource_id, diff --git a/llama_stack/providers/remote/inference/together_openai_compat/config.py b/llama_stack/providers/remote/inference/together_openai_compat/config.py index 120adbed9..0c6d4f748 100644 --- a/llama_stack/providers/remote/inference/together_openai_compat/config.py +++ b/llama_stack/providers/remote/inference/together_openai_compat/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, Field @@ -12,7 +12,7 @@ from llama_stack.schema_utils import json_schema_type class TogetherProviderDataValidator(BaseModel): - together_api_key: Optional[str] = Field( + together_api_key: str | None = Field( default=None, description="API key for Together models", ) @@ -20,7 +20,7 @@ class TogetherProviderDataValidator(BaseModel): @json_schema_type class TogetherCompatConfig(BaseModel): - api_key: Optional[str] = Field( + api_key: str | None = Field( default=None, description="The Together API key", ) @@ -31,7 +31,7 @@ class TogetherCompatConfig(BaseModel): ) @classmethod - def sample_run_config(cls, api_key: str = "${env.TOGETHER_API_KEY}", **kwargs) -> Dict[str, Any]: + def sample_run_config(cls, api_key: str = "${env.TOGETHER_API_KEY}", **kwargs) -> dict[str, Any]: return { "openai_compat_api_base": "https://api.together.xyz/v1", "api_key": api_key, diff --git a/llama_stack/providers/remote/inference/vllm/config.py b/llama_stack/providers/remote/inference/vllm/config.py index 762cffde3..8530594b6 100644 --- a/llama_stack/providers/remote/inference/vllm/config.py +++ b/llama_stack/providers/remote/inference/vllm/config.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Optional from pydantic import BaseModel, Field @@ -13,7 +12,7 @@ from llama_stack.schema_utils import json_schema_type @json_schema_type class VLLMInferenceAdapterConfig(BaseModel): - url: Optional[str] = Field( + url: str | None = Field( default=None, description="The URL for the vLLM model serving endpoint", ) @@ -21,7 +20,7 @@ class VLLMInferenceAdapterConfig(BaseModel): default=4096, description="Maximum number of tokens to generate.", ) - api_token: Optional[str] = Field( + api_token: str | None = Field( default="fake", description="The API token", ) diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index ac268c86c..addf2d35b 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -5,7 +5,8 @@ # the root directory of this source tree. import json import logging -from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union +from collections.abc import AsyncGenerator, AsyncIterator +from typing import Any import httpx from openai import AsyncOpenAI @@ -94,7 +95,7 @@ def build_hf_repo_model_entries(): def _convert_to_vllm_tool_calls_in_response( tool_calls, -) -> List[ToolCall]: +) -> list[ToolCall]: if not tool_calls: return [] @@ -109,7 +110,7 @@ def _convert_to_vllm_tool_calls_in_response( ] -def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]: +def _convert_to_vllm_tools_in_request(tools: list[ToolDefinition]) -> list[dict]: compat_tools = [] for tool in tools: @@ -262,10 +263,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): self, model_id: str, content: InterleavedContent, - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, ) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]: self._lazy_initialize_client() if sampling_params is None: @@ -287,15 +288,15 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): async def chat_completion( self, model_id: str, - messages: List[Message], - sampling_params: Optional[SamplingParams] = None, - tools: Optional[List[ToolDefinition]] = None, - tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = None, - response_format: Optional[ResponseFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, - tool_config: Optional[ToolConfig] = None, + messages: list[Message], + sampling_params: SamplingParams | None = None, + tools: list[ToolDefinition] | None = None, + tool_choice: ToolChoice | None = ToolChoice.auto, + tool_prompt_format: ToolPromptFormat | None = None, + response_format: ResponseFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, + tool_config: ToolConfig | None = None, ) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]: self._lazy_initialize_client() if sampling_params is None: @@ -385,7 +386,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): ) return model - async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: + async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict: options = get_sampling_options(request.sampling_params) if "max_tokens" not in options: options["max_tokens"] = self.config.max_tokens @@ -422,10 +423,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): async def embeddings( self, model_id: str, - contents: List[str] | List[InterleavedContentItem], - text_truncation: Optional[TextTruncation] = TextTruncation.none, - output_dimension: Optional[int] = None, - task_type: Optional[EmbeddingTaskType] = None, + contents: list[str] | list[InterleavedContentItem], + text_truncation: TextTruncation | None = TextTruncation.none, + output_dimension: int | None = None, + task_type: EmbeddingTaskType | None = None, ) -> EmbeddingsResponse: self._lazy_initialize_client() assert self.client is not None @@ -448,29 +449,29 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): async def openai_completion( self, model: str, - prompt: Union[str, List[str], List[int], List[List[int]]], - best_of: Optional[int] = None, - echo: Optional[bool] = None, - frequency_penalty: Optional[float] = None, - logit_bias: Optional[Dict[str, float]] = None, - logprobs: Optional[bool] = None, - max_tokens: Optional[int] = None, - n: Optional[int] = None, - presence_penalty: Optional[float] = None, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - stream: Optional[bool] = None, - stream_options: Optional[Dict[str, Any]] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - user: Optional[str] = None, - guided_choice: Optional[List[str]] = None, - prompt_logprobs: Optional[int] = None, + prompt: str | list[str] | list[int] | list[list[int]], + best_of: int | None = None, + echo: bool | None = None, + frequency_penalty: float | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_tokens: int | None = None, + n: int | None = None, + presence_penalty: float | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + top_p: float | None = None, + user: str | None = None, + guided_choice: list[str] | None = None, + prompt_logprobs: int | None = None, ) -> OpenAICompletion: self._lazy_initialize_client() model_obj = await self._get_model(model) - extra_body: Dict[str, Any] = {} + extra_body: dict[str, Any] = {} if prompt_logprobs is not None and prompt_logprobs >= 0: extra_body["prompt_logprobs"] = prompt_logprobs if guided_choice: @@ -501,29 +502,29 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): async def openai_chat_completion( self, model: str, - messages: List[OpenAIMessageParam], - frequency_penalty: Optional[float] = None, - function_call: Optional[Union[str, Dict[str, Any]]] = None, - functions: Optional[List[Dict[str, Any]]] = None, - logit_bias: Optional[Dict[str, float]] = None, - logprobs: Optional[bool] = None, - max_completion_tokens: Optional[int] = None, - max_tokens: Optional[int] = None, - n: Optional[int] = None, - parallel_tool_calls: Optional[bool] = None, - presence_penalty: Optional[float] = None, - response_format: Optional[OpenAIResponseFormatParam] = None, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - stream: Optional[bool] = None, - stream_options: Optional[Dict[str, Any]] = None, - temperature: Optional[float] = None, - tool_choice: Optional[Union[str, Dict[str, Any]]] = None, - tools: Optional[List[Dict[str, Any]]] = None, - top_logprobs: Optional[int] = None, - top_p: Optional[float] = None, - user: Optional[str] = None, - ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: + messages: list[OpenAIMessageParam], + frequency_penalty: float | None = None, + function_call: str | dict[str, Any] | None = None, + functions: list[dict[str, Any]] | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_completion_tokens: int | None = None, + max_tokens: int | None = None, + n: int | None = None, + parallel_tool_calls: bool | None = None, + presence_penalty: float | None = None, + response_format: OpenAIResponseFormatParam | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + tool_choice: str | dict[str, Any] | None = None, + tools: list[dict[str, Any]] | None = None, + top_logprobs: int | None = None, + top_p: float | None = None, + user: str | None = None, + ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: self._lazy_initialize_client() model_obj = await self._get_model(model) params = await prepare_openai_completion_params( @@ -556,21 +557,21 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): async def batch_completion( self, model_id: str, - content_batch: List[InterleavedContent], - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - logprobs: Optional[LogProbConfig] = None, + content_batch: list[InterleavedContent], + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + logprobs: LogProbConfig | None = None, ): raise NotImplementedError("Batch completion is not supported for Ollama") async def batch_chat_completion( self, model_id: str, - messages_batch: List[List[Message]], - sampling_params: Optional[SamplingParams] = None, - tools: Optional[List[ToolDefinition]] = None, - tool_config: Optional[ToolConfig] = None, - response_format: Optional[ResponseFormat] = None, - logprobs: Optional[LogProbConfig] = None, + messages_batch: list[list[Message]], + sampling_params: SamplingParams | None = None, + tools: list[ToolDefinition] | None = None, + tool_config: ToolConfig | None = None, + response_format: ResponseFormat | None = None, + logprobs: LogProbConfig | None = None, ): raise NotImplementedError("Batch chat completion is not supported for Ollama") diff --git a/llama_stack/providers/remote/inference/watsonx/config.py b/llama_stack/providers/remote/inference/watsonx/config.py index 7ee99b7e0..5eda9c5c0 100644 --- a/llama_stack/providers/remote/inference/watsonx/config.py +++ b/llama_stack/providers/remote/inference/watsonx/config.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import os -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, Field, SecretStr @@ -24,11 +24,11 @@ class WatsonXConfig(BaseModel): default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"), description="A base url for accessing the watsonx.ai", ) - api_key: Optional[SecretStr] = Field( + api_key: SecretStr | None = Field( default_factory=lambda: os.getenv("WATSONX_API_KEY"), description="The watsonx API key, only needed of using the hosted service", ) - project_id: Optional[str] = Field( + project_id: str | None = Field( default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"), description="The Project ID key, only needed of using the hosted service", ) @@ -38,7 +38,7 @@ class WatsonXConfig(BaseModel): ) @classmethod - def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + def sample_run_config(cls, **kwargs) -> dict[str, Any]: return { "url": "${env.WATSONX_BASE_URL:https://us-south.ml.cloud.ibm.com}", "api_key": "${env.WATSONX_API_KEY:}", diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index fa9cc4391..c1299e11f 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -4,7 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union +from collections.abc import AsyncGenerator, AsyncIterator +from typing import Any from ibm_watson_machine_learning.foundation_models import Model from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams @@ -78,10 +79,10 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): self, model_id: str, content: InterleavedContent, - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, ) -> AsyncGenerator: if sampling_params is None: sampling_params = SamplingParams() @@ -152,15 +153,15 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): async def chat_completion( self, model_id: str, - messages: List[Message], - sampling_params: Optional[SamplingParams] = None, - tools: Optional[List[ToolDefinition]] = None, - tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = None, - response_format: Optional[ResponseFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, - tool_config: Optional[ToolConfig] = None, + messages: list[Message], + sampling_params: SamplingParams | None = None, + tools: list[ToolDefinition] | None = None, + tool_choice: ToolChoice | None = ToolChoice.auto, + tool_prompt_format: ToolPromptFormat | None = None, + response_format: ResponseFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, + tool_config: ToolConfig | None = None, ) -> AsyncGenerator: if sampling_params is None: sampling_params = SamplingParams() @@ -217,7 +218,7 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): async for chunk in process_chat_completion_stream_response(stream, request): yield chunk - async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: + async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict: input_dict = {"params": {}} media_present = request_has_media(request) llama_model = self.get_llama_model(request.model) @@ -252,34 +253,34 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): async def embeddings( self, model_id: str, - contents: List[str] | List[InterleavedContentItem], - text_truncation: Optional[TextTruncation] = TextTruncation.none, - output_dimension: Optional[int] = None, - task_type: Optional[EmbeddingTaskType] = None, + contents: list[str] | list[InterleavedContentItem], + text_truncation: TextTruncation | None = TextTruncation.none, + output_dimension: int | None = None, + task_type: EmbeddingTaskType | None = None, ) -> EmbeddingsResponse: raise NotImplementedError("embedding is not supported for watsonx") async def openai_completion( self, model: str, - prompt: Union[str, List[str], List[int], List[List[int]]], - best_of: Optional[int] = None, - echo: Optional[bool] = None, - frequency_penalty: Optional[float] = None, - logit_bias: Optional[Dict[str, float]] = None, - logprobs: Optional[bool] = None, - max_tokens: Optional[int] = None, - n: Optional[int] = None, - presence_penalty: Optional[float] = None, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - stream: Optional[bool] = None, - stream_options: Optional[Dict[str, Any]] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - user: Optional[str] = None, - guided_choice: Optional[List[str]] = None, - prompt_logprobs: Optional[int] = None, + prompt: str | list[str] | list[int] | list[list[int]], + best_of: int | None = None, + echo: bool | None = None, + frequency_penalty: float | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_tokens: int | None = None, + n: int | None = None, + presence_penalty: float | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + top_p: float | None = None, + user: str | None = None, + guided_choice: list[str] | None = None, + prompt_logprobs: int | None = None, ) -> OpenAICompletion: model_obj = await self.model_store.get_model(model) params = await prepare_openai_completion_params( @@ -306,29 +307,29 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): async def openai_chat_completion( self, model: str, - messages: List[OpenAIMessageParam], - frequency_penalty: Optional[float] = None, - function_call: Optional[Union[str, Dict[str, Any]]] = None, - functions: Optional[List[Dict[str, Any]]] = None, - logit_bias: Optional[Dict[str, float]] = None, - logprobs: Optional[bool] = None, - max_completion_tokens: Optional[int] = None, - max_tokens: Optional[int] = None, - n: Optional[int] = None, - parallel_tool_calls: Optional[bool] = None, - presence_penalty: Optional[float] = None, - response_format: Optional[OpenAIResponseFormatParam] = None, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - stream: Optional[bool] = None, - stream_options: Optional[Dict[str, Any]] = None, - temperature: Optional[float] = None, - tool_choice: Optional[Union[str, Dict[str, Any]]] = None, - tools: Optional[List[Dict[str, Any]]] = None, - top_logprobs: Optional[int] = None, - top_p: Optional[float] = None, - user: Optional[str] = None, - ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: + messages: list[OpenAIMessageParam], + frequency_penalty: float | None = None, + function_call: str | dict[str, Any] | None = None, + functions: list[dict[str, Any]] | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_completion_tokens: int | None = None, + max_tokens: int | None = None, + n: int | None = None, + parallel_tool_calls: bool | None = None, + presence_penalty: float | None = None, + response_format: OpenAIResponseFormatParam | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + tool_choice: str | dict[str, Any] | None = None, + tools: list[dict[str, Any]] | None = None, + top_logprobs: int | None = None, + top_p: float | None = None, + user: str | None = None, + ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: model_obj = await self.model_store.get_model(model) params = await prepare_openai_completion_params( model=model_obj.provider_resource_id, diff --git a/llama_stack/providers/remote/post_training/nvidia/config.py b/llama_stack/providers/remote/post_training/nvidia/config.py index 7b42c8bb0..fa08b6e3f 100644 --- a/llama_stack/providers/remote/post_training/nvidia/config.py +++ b/llama_stack/providers/remote/post_training/nvidia/config.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import os -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, Field @@ -15,23 +15,23 @@ from pydantic import BaseModel, Field class NvidiaPostTrainingConfig(BaseModel): """Configuration for NVIDIA Post Training implementation.""" - api_key: Optional[str] = Field( + api_key: str | None = Field( default_factory=lambda: os.getenv("NVIDIA_API_KEY"), description="The NVIDIA API key.", ) - dataset_namespace: Optional[str] = Field( + dataset_namespace: str | None = Field( default_factory=lambda: os.getenv("NVIDIA_DATASET_NAMESPACE", "default"), description="The NVIDIA dataset namespace.", ) - project_id: Optional[str] = Field( + project_id: str | None = Field( default_factory=lambda: os.getenv("NVIDIA_PROJECT_ID", "test-example-model@v1"), description="The NVIDIA project ID.", ) # ToDO: validate this, add default value - customizer_url: Optional[str] = Field( + customizer_url: str | None = Field( default_factory=lambda: os.getenv("NVIDIA_CUSTOMIZER_URL"), description="Base URL for the NeMo Customizer API", ) @@ -53,7 +53,7 @@ class NvidiaPostTrainingConfig(BaseModel): ) @classmethod - def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + def sample_run_config(cls, **kwargs) -> dict[str, Any]: return { "api_key": "${env.NVIDIA_API_KEY:}", "dataset_namespace": "${env.NVIDIA_DATASET_NAMESPACE:default}", @@ -71,27 +71,27 @@ class SFTLoRADefaultConfig(BaseModel): n_epochs: int = 50 # NeMo customizer specific parameters - log_every_n_steps: Optional[int] = None + log_every_n_steps: int | None = None val_check_interval: float = 0.25 sequence_packing_enabled: bool = False weight_decay: float = 0.01 lr: float = 0.0001 # SFT specific parameters - hidden_dropout: Optional[float] = None - attention_dropout: Optional[float] = None - ffn_dropout: Optional[float] = None + hidden_dropout: float | None = None + attention_dropout: float | None = None + ffn_dropout: float | None = None # LoRA default parameters lora_adapter_dim: int = 8 - lora_adapter_dropout: Optional[float] = None + lora_adapter_dropout: float | None = None lora_alpha: int = 16 # Data config batch_size: int = 8 @classmethod - def sample_config(cls) -> Dict[str, Any]: + def sample_config(cls) -> dict[str, Any]: """Return a sample configuration for NVIDIA training.""" return { "n_epochs": 50, diff --git a/llama_stack/providers/remote/post_training/nvidia/models.py b/llama_stack/providers/remote/post_training/nvidia/models.py index 1b31b4dbe..6a28f8af8 100644 --- a/llama_stack/providers/remote/post_training/nvidia/models.py +++ b/llama_stack/providers/remote/post_training/nvidia/models.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List from llama_stack.models.llama.sku_types import CoreModelId from llama_stack.providers.utils.inference.model_registry import ( @@ -24,5 +23,5 @@ _MODEL_ENTRIES = [ ] -def get_model_entries() -> List[ProviderModelEntry]: +def get_model_entries() -> list[ProviderModelEntry]: return _MODEL_ENTRIES diff --git a/llama_stack/providers/remote/post_training/nvidia/post_training.py b/llama_stack/providers/remote/post_training/nvidia/post_training.py index c74fb2a24..409818cb3 100644 --- a/llama_stack/providers/remote/post_training/nvidia/post_training.py +++ b/llama_stack/providers/remote/post_training/nvidia/post_training.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import warnings from datetime import datetime -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Literal import aiohttp from pydantic import BaseModel, ConfigDict @@ -50,7 +50,7 @@ class NvidiaPostTrainingJob(PostTrainingJob): class ListNvidiaPostTrainingJobs(BaseModel): - data: List[NvidiaPostTrainingJob] + data: list[NvidiaPostTrainingJob] class NvidiaPostTrainingJobStatusResponse(PostTrainingJobStatusResponse): @@ -83,11 +83,11 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): self, method: str, path: str, - headers: Optional[Dict[str, Any]] = None, - params: Optional[Dict[str, Any]] = None, - json: Optional[Dict[str, Any]] = None, + headers: dict[str, Any] | None = None, + params: dict[str, Any] | None = None, + json: dict[str, Any] | None = None, **kwargs, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Helper method to make HTTP requests to the Customizer API.""" url = f"{self.customizer_url}{path}" request_headers = self.headers.copy() @@ -109,9 +109,9 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): async def get_training_jobs( self, - page: Optional[int] = 1, - page_size: Optional[int] = 10, - sort: Optional[Literal["created_at", "-created_at"]] = "created_at", + page: int | None = 1, + page_size: int | None = 10, + sort: Literal["created_at", "-created_at"] | None = "created_at", ) -> ListNvidiaPostTrainingJobs: """Get all customization jobs. Updated the base class return type from ListPostTrainingJobsResponse to ListNvidiaPostTrainingJobs. @@ -207,12 +207,12 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): async def supervised_fine_tune( self, job_uuid: str, - training_config: Dict[str, Any], - hyperparam_search_config: Dict[str, Any], - logger_config: Dict[str, Any], + training_config: dict[str, Any], + hyperparam_search_config: dict[str, Any], + logger_config: dict[str, Any], model: str, - checkpoint_dir: Optional[str], - algorithm_config: Optional[AlgorithmConfig] = None, + checkpoint_dir: str | None, + algorithm_config: AlgorithmConfig | None = None, ) -> NvidiaPostTrainingJob: """ Fine-tunes a model on a dataset. @@ -423,8 +423,8 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): finetuned_model: str, algorithm_config: DPOAlignmentConfig, training_config: TrainingConfig, - hyperparam_search_config: Dict[str, Any], - logger_config: Dict[str, Any], + hyperparam_search_config: dict[str, Any], + logger_config: dict[str, Any], ) -> PostTrainingJob: """Optimize a model based on preference data.""" raise NotImplementedError("Preference optimization is not implemented yet") diff --git a/llama_stack/providers/remote/post_training/nvidia/utils.py b/llama_stack/providers/remote/post_training/nvidia/utils.py index ac47966af..d6e1016b2 100644 --- a/llama_stack/providers/remote/post_training/nvidia/utils.py +++ b/llama_stack/providers/remote/post_training/nvidia/utils.py @@ -6,7 +6,7 @@ import logging import warnings -from typing import Any, Dict, Set, Tuple +from typing import Any from pydantic import BaseModel @@ -18,7 +18,7 @@ from .config import NvidiaPostTrainingConfig logger = logging.getLogger(__name__) -def warn_unsupported_params(config_dict: Any, supported_keys: Set[str], config_name: str) -> None: +def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None: keys = set(config_dict.__annotations__.keys()) if isinstance(config_dict, BaseModel) else config_dict.keys() unsupported_params = [k for k in keys if k not in supported_keys] if unsupported_params: @@ -28,7 +28,7 @@ def warn_unsupported_params(config_dict: Any, supported_keys: Set[str], config_n def validate_training_params( - training_config: Dict[str, Any], supported_keys: Set[str], config_name: str = "TrainingConfig" + training_config: dict[str, Any], supported_keys: set[str], config_name: str = "TrainingConfig" ) -> None: """ Validates training parameters against supported keys. @@ -57,7 +57,7 @@ def validate_training_params( # ToDo: implement post health checks for customizer are enabled -async def _get_health(url: str) -> Tuple[bool, bool]: ... +async def _get_health(url: str) -> tuple[bool, bool]: ... async def check_health(config: NvidiaPostTrainingConfig) -> None: ... diff --git a/llama_stack/providers/remote/safety/bedrock/bedrock.py b/llama_stack/providers/remote/safety/bedrock/bedrock.py index 2f960eead..c43b51073 100644 --- a/llama_stack/providers/remote/safety/bedrock/bedrock.py +++ b/llama_stack/providers/remote/safety/bedrock/bedrock.py @@ -6,7 +6,7 @@ import json import logging -from typing import Any, Dict, List +from typing import Any from llama_stack.apis.inference import Message from llama_stack.apis.safety import ( @@ -53,7 +53,7 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): ) async def run_shield( - self, shield_id: str, messages: List[Message], params: Dict[str, Any] = None + self, shield_id: str, messages: list[Message], params: dict[str, Any] = None ) -> RunShieldResponse: shield = await self.shield_store.get_shield(shield_id) if not shield: diff --git a/llama_stack/providers/remote/safety/nvidia/config.py b/llama_stack/providers/remote/safety/nvidia/config.py index 3df80ed4f..4ca703a4d 100644 --- a/llama_stack/providers/remote/safety/nvidia/config.py +++ b/llama_stack/providers/remote/safety/nvidia/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import os -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, Field @@ -27,10 +27,10 @@ class NVIDIASafetyConfig(BaseModel): default_factory=lambda: os.getenv("GUARDRAILS_SERVICE_URL", "http://0.0.0.0:7331"), description="The url for accessing the guardrails service", ) - config_id: Optional[str] = Field(default="self-check", description="Config ID to use from the config store") + config_id: str | None = Field(default="self-check", description="Config ID to use from the config store") @classmethod - def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + def sample_run_config(cls, **kwargs) -> dict[str, Any]: return { "guardrails_service_url": "${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}", "config_id": "self-check", diff --git a/llama_stack/providers/remote/safety/nvidia/nvidia.py b/llama_stack/providers/remote/safety/nvidia/nvidia.py index 13bc212a1..411badb1c 100644 --- a/llama_stack/providers/remote/safety/nvidia/nvidia.py +++ b/llama_stack/providers/remote/safety/nvidia/nvidia.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import logging -from typing import Any, List, Optional +from typing import Any import requests @@ -41,7 +41,7 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate): raise ValueError("Shield model not provided.") async def run_shield( - self, shield_id: str, messages: List[Message], params: Optional[dict[str, Any]] = None + self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None ) -> RunShieldResponse: """ Run a safety shield check against the provided messages. @@ -112,7 +112,7 @@ class NeMoGuardrails: response.raise_for_status() return response.json() - async def run(self, messages: List[Message]) -> RunShieldResponse: + async def run(self, messages: list[Message]) -> RunShieldResponse: """ Queries the /v1/guardrails/checks endpoint of the NeMo guardrails deployed API. diff --git a/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py b/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py index b34c9fd9d..18bec463f 100644 --- a/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py +++ b/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import json -from typing import Any, Dict, Optional +from typing import Any import httpx @@ -50,7 +50,7 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP return provider_data.bing_search_api_key async def list_runtime_tools( - self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None ) -> ListToolDefsResponse: return ListToolDefsResponse( data=[ @@ -68,7 +68,7 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP ] ) - async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult: + async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult: api_key = self._get_api_key() headers = { "Ocp-Apim-Subscription-Key": api_key, diff --git a/llama_stack/providers/remote/tool_runtime/bing_search/config.py b/llama_stack/providers/remote/tool_runtime/bing_search/config.py index 4f089439f..30269dbc1 100644 --- a/llama_stack/providers/remote/tool_runtime/bing_search/config.py +++ b/llama_stack/providers/remote/tool_runtime/bing_search/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel @@ -12,11 +12,11 @@ from pydantic import BaseModel class BingSearchToolConfig(BaseModel): """Configuration for Bing Search Tool Runtime""" - api_key: Optional[str] = None + api_key: str | None = None top_k: int = 3 @classmethod - def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]: + def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: return { "api_key": "${env.BING_API_KEY:}", } diff --git a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py index 41f3ce823..355cb98b6 100644 --- a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py +++ b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Any import httpx @@ -49,7 +49,7 @@ class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequest return provider_data.brave_search_api_key async def list_runtime_tools( - self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None ) -> ListToolDefsResponse: return ListToolDefsResponse( data=[ @@ -68,7 +68,7 @@ class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequest ] ) - async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult: + async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult: api_key = self._get_api_key() url = "https://api.search.brave.com/res/v1/web/search" headers = { diff --git a/llama_stack/providers/remote/tool_runtime/brave_search/config.py b/llama_stack/providers/remote/tool_runtime/brave_search/config.py index ab6053609..37ba21304 100644 --- a/llama_stack/providers/remote/tool_runtime/brave_search/config.py +++ b/llama_stack/providers/remote/tool_runtime/brave_search/config.py @@ -4,13 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, Field class BraveSearchToolConfig(BaseModel): - api_key: Optional[str] = Field( + api_key: str | None = Field( default=None, description="The Brave Search API Key", ) @@ -20,7 +20,7 @@ class BraveSearchToolConfig(BaseModel): ) @classmethod - def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]: + def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]: return { "api_key": "${env.BRAVE_SEARCH_API_KEY:}", "max_results": 3, diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py index 30ac407bc..d509074fc 100644 --- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py @@ -4,12 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from pydantic import BaseModel class ModelContextProtocolConfig(BaseModel): @classmethod - def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]: + def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: return {} diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py index 676917225..142730e89 100644 --- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Any from urllib.parse import urlparse from mcp import ClientSession @@ -31,7 +31,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): pass async def list_runtime_tools( - self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None ) -> ListToolDefsResponse: if mcp_endpoint is None: raise ValueError("mcp_endpoint is required") @@ -63,7 +63,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): ) return ListToolDefsResponse(data=tools) - async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult: + async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult: tool = await self.tool_store.get_tool(tool_name) if tool.metadata is None or tool.metadata.get("endpoint") is None: raise ValueError(f"Tool {tool_name} does not have metadata") diff --git a/llama_stack/providers/remote/tool_runtime/tavily_search/config.py b/llama_stack/providers/remote/tool_runtime/tavily_search/config.py index 945430bb1..c9b18d30d 100644 --- a/llama_stack/providers/remote/tool_runtime/tavily_search/config.py +++ b/llama_stack/providers/remote/tool_runtime/tavily_search/config.py @@ -4,13 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, Field class TavilySearchToolConfig(BaseModel): - api_key: Optional[str] = Field( + api_key: str | None = Field( default=None, description="The Tavily Search API Key", ) @@ -20,7 +20,7 @@ class TavilySearchToolConfig(BaseModel): ) @classmethod - def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]: + def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]: return { "api_key": "${env.TAVILY_SEARCH_API_KEY:}", "max_results": 3, diff --git a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py index 719d6be14..9d6fcd951 100644 --- a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py +++ b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import json -from typing import Any, Dict, Optional +from typing import Any import httpx @@ -49,7 +49,7 @@ class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques return provider_data.tavily_search_api_key async def list_runtime_tools( - self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None ) -> ListToolDefsResponse: return ListToolDefsResponse( data=[ @@ -67,7 +67,7 @@ class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques ] ) - async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult: + async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult: api_key = self._get_api_key() async with httpx.AsyncClient() as client: response = await client.post( diff --git a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/config.py b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/config.py index 8ea49c7b5..aefc86bd6 100644 --- a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/config.py +++ b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel @@ -12,10 +12,10 @@ from pydantic import BaseModel class WolframAlphaToolConfig(BaseModel): """Configuration for WolframAlpha Tool Runtime""" - api_key: Optional[str] = None + api_key: str | None = None @classmethod - def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]: + def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: return { "api_key": "${env.WOLFRAM_ALPHA_API_KEY:}", } diff --git a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py index b3e0e120c..a3724e4b4 100644 --- a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py +++ b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import json -from typing import Any, Dict, Optional +from typing import Any import httpx @@ -50,7 +50,7 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques return provider_data.wolfram_alpha_api_key async def list_runtime_tools( - self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None ) -> ListToolDefsResponse: return ListToolDefsResponse( data=[ @@ -68,7 +68,7 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques ] ) - async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult: + async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult: api_key = self._get_api_key() params = { "input": kwargs["query"], diff --git a/llama_stack/providers/remote/vector_io/chroma/__init__.py b/llama_stack/providers/remote/vector_io/chroma/__init__.py index 8646b04d6..ebbc62b1c 100644 --- a/llama_stack/providers/remote/vector_io/chroma/__init__.py +++ b/llama_stack/providers/remote/vector_io/chroma/__init__.py @@ -4,14 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Dict - from llama_stack.providers.datatypes import Api, ProviderSpec from .config import ChromaVectorIOConfig -async def get_adapter_impl(config: ChromaVectorIOConfig, deps: Dict[Api, ProviderSpec]): +async def get_adapter_impl(config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec]): from .chroma import ChromaVectorIOAdapter impl = ChromaVectorIOAdapter(config, deps[Api.inference]) diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index 3bf3a7740..5381a48ef 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -6,7 +6,7 @@ import asyncio import json import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any from urllib.parse import urlparse import chromadb @@ -27,7 +27,7 @@ from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig log = logging.getLogger(__name__) -ChromaClientType = Union[chromadb.AsyncHttpClient, chromadb.PersistentClient] +ChromaClientType = chromadb.AsyncHttpClient | chromadb.PersistentClient # this is a helper to allow us to use async and non-async chroma clients interchangeably @@ -42,7 +42,7 @@ class ChromaIndex(EmbeddingIndex): self.client = client self.collection = collection - async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): + async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): assert len(chunks) == len(embeddings), ( f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" ) @@ -89,7 +89,7 @@ class ChromaIndex(EmbeddingIndex): class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): def __init__( self, - config: Union[RemoteChromaVectorIOConfig, InlineChromaVectorIOConfig], + config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig, inference_api: Api.inference, ) -> None: log.info(f"Initializing ChromaVectorIOAdapter with url: {config}") @@ -137,8 +137,8 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): async def insert_chunks( self, vector_db_id: str, - chunks: List[Chunk], - ttl_seconds: Optional[int] = None, + chunks: list[Chunk], + ttl_seconds: int | None = None, ) -> None: index = await self._get_and_cache_vector_db_index(vector_db_id) @@ -148,7 +148,7 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): self, vector_db_id: str, query: InterleavedContent, - params: Optional[Dict[str, Any]] = None, + params: dict[str, Any] | None = None, ) -> QueryChunksResponse: index = await self._get_and_cache_vector_db_index(vector_db_id) diff --git a/llama_stack/providers/remote/vector_io/chroma/config.py b/llama_stack/providers/remote/vector_io/chroma/config.py index 3e2463252..4e893fab4 100644 --- a/llama_stack/providers/remote/vector_io/chroma/config.py +++ b/llama_stack/providers/remote/vector_io/chroma/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from pydantic import BaseModel @@ -13,5 +13,5 @@ class ChromaVectorIOConfig(BaseModel): url: str @classmethod - def sample_run_config(cls, url: str = "${env.CHROMADB_URL}", **kwargs: Any) -> Dict[str, Any]: + def sample_run_config(cls, url: str = "${env.CHROMADB_URL}", **kwargs: Any) -> dict[str, Any]: return {"url": url} diff --git a/llama_stack/providers/remote/vector_io/milvus/__init__.py b/llama_stack/providers/remote/vector_io/milvus/__init__.py index 84cb1d748..92dbfda2e 100644 --- a/llama_stack/providers/remote/vector_io/milvus/__init__.py +++ b/llama_stack/providers/remote/vector_io/milvus/__init__.py @@ -4,14 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Dict - from llama_stack.providers.datatypes import Api, ProviderSpec from .config import MilvusVectorIOConfig -async def get_adapter_impl(config: MilvusVectorIOConfig, deps: Dict[Api, ProviderSpec]): +async def get_adapter_impl(config: MilvusVectorIOConfig, deps: dict[Api, ProviderSpec]): from .milvus import MilvusVectorIOAdapter assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}" diff --git a/llama_stack/providers/remote/vector_io/milvus/config.py b/llama_stack/providers/remote/vector_io/milvus/config.py index 17da6b23d..3d25e9c49 100644 --- a/llama_stack/providers/remote/vector_io/milvus/config.py +++ b/llama_stack/providers/remote/vector_io/milvus/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel @@ -14,9 +14,9 @@ from llama_stack.schema_utils import json_schema_type @json_schema_type class MilvusVectorIOConfig(BaseModel): uri: str - token: Optional[str] = None + token: str | None = None consistency_level: str = "Strong" @classmethod - def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]: + def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: return {"uri": "${env.MILVUS_ENDPOINT}", "token": "${env.MILVUS_TOKEN}"} diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index 1949d293d..c98417b56 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -9,7 +9,7 @@ import hashlib import logging import os import uuid -from typing import Any, Dict, List, Optional, Union +from typing import Any from numpy.typing import NDArray from pymilvus import MilvusClient @@ -39,7 +39,7 @@ class MilvusIndex(EmbeddingIndex): if await asyncio.to_thread(self.client.has_collection, self.collection_name): await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name) - async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): + async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): assert len(chunks) == len(embeddings), ( f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" ) @@ -89,7 +89,7 @@ class MilvusIndex(EmbeddingIndex): class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): def __init__( - self, config: Union[RemoteMilvusVectorIOConfig, InlineMilvusVectorIOConfig], inference_api: Api.inference + self, config: RemoteMilvusVectorIOConfig | InlineMilvusVectorIOConfig, inference_api: Api.inference ) -> None: self.config = config self.cache = {} @@ -124,7 +124,7 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): self.cache[vector_db.identifier] = index - async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> Optional[VectorDBWithIndex]: + async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None: if vector_db_id in self.cache: return self.cache[vector_db_id] @@ -148,8 +148,8 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): async def insert_chunks( self, vector_db_id: str, - chunks: List[Chunk], - ttl_seconds: Optional[int] = None, + chunks: list[Chunk], + ttl_seconds: int | None = None, ) -> None: index = await self._get_and_cache_vector_db_index(vector_db_id) if not index: @@ -161,7 +161,7 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): self, vector_db_id: str, query: InterleavedContent, - params: Optional[Dict[str, Any]] = None, + params: dict[str, Any] | None = None, ) -> QueryChunksResponse: index = await self._get_and_cache_vector_db_index(vector_db_id) if not index: @@ -172,7 +172,7 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): def generate_chunk_id(document_id: str, chunk_text: str) -> str: """Generate a unique chunk ID using a hash of document ID and chunk text.""" - hash_input = f"{document_id}:{chunk_text}".encode("utf-8") + hash_input = f"{document_id}:{chunk_text}".encode() return str(uuid.UUID(hashlib.md5(hash_input).hexdigest())) diff --git a/llama_stack/providers/remote/vector_io/pgvector/__init__.py b/llama_stack/providers/remote/vector_io/pgvector/__init__.py index 089d890b7..9f528db74 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/__init__.py +++ b/llama_stack/providers/remote/vector_io/pgvector/__init__.py @@ -4,14 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Dict - from llama_stack.providers.datatypes import Api, ProviderSpec from .config import PGVectorVectorIOConfig -async def get_adapter_impl(config: PGVectorVectorIOConfig, deps: Dict[Api, ProviderSpec]): +async def get_adapter_impl(config: PGVectorVectorIOConfig, deps: dict[Api, ProviderSpec]): from .pgvector import PGVectorVectorIOAdapter impl = PGVectorVectorIOAdapter(config, deps[Api.inference]) diff --git a/llama_stack/providers/remote/vector_io/pgvector/config.py b/llama_stack/providers/remote/vector_io/pgvector/config.py index e9eb0f12d..04b92a2e4 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/config.py +++ b/llama_stack/providers/remote/vector_io/pgvector/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from pydantic import BaseModel, Field @@ -28,5 +28,5 @@ class PGVectorVectorIOConfig(BaseModel): user: str = "${env.PGVECTOR_USER}", password: str = "${env.PGVECTOR_PASSWORD}", **kwargs: Any, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: return {"host": host, "port": port, "db": db, "user": user, "password": password} diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index 7c683e126..94546c6cf 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import logging -from typing import Any, Dict, List, Optional, Tuple +from typing import Any import psycopg2 from numpy.typing import NDArray @@ -33,7 +33,7 @@ def check_extension_version(cur): return result[0] if result else None -def upsert_models(conn, keys_models: List[Tuple[str, BaseModel]]): +def upsert_models(conn, keys_models: list[tuple[str, BaseModel]]): with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: query = sql.SQL( """ @@ -74,7 +74,7 @@ class PGVectorIndex(EmbeddingIndex): """ ) - async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): + async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): assert len(chunks) == len(embeddings), ( f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" ) @@ -180,8 +180,8 @@ class PGVectorVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): async def insert_chunks( self, vector_db_id: str, - chunks: List[Chunk], - ttl_seconds: Optional[int] = None, + chunks: list[Chunk], + ttl_seconds: int | None = None, ) -> None: index = await self._get_and_cache_vector_db_index(vector_db_id) await index.insert_chunks(chunks) @@ -190,7 +190,7 @@ class PGVectorVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): self, vector_db_id: str, query: InterleavedContent, - params: Optional[Dict[str, Any]] = None, + params: dict[str, Any] | None = None, ) -> QueryChunksResponse: index = await self._get_and_cache_vector_db_index(vector_db_id) return await index.query_chunks(query, params) diff --git a/llama_stack/providers/remote/vector_io/qdrant/__init__.py b/llama_stack/providers/remote/vector_io/qdrant/__init__.py index f5bb7f84c..029de285f 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/__init__.py +++ b/llama_stack/providers/remote/vector_io/qdrant/__init__.py @@ -4,14 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Dict - from llama_stack.providers.datatypes import Api, ProviderSpec from .config import QdrantVectorIOConfig -async def get_adapter_impl(config: QdrantVectorIOConfig, deps: Dict[Api, ProviderSpec]): +async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]): from .qdrant import QdrantVectorIOAdapter impl = QdrantVectorIOAdapter(config, deps[Api.inference]) diff --git a/llama_stack/providers/remote/vector_io/qdrant/config.py b/llama_stack/providers/remote/vector_io/qdrant/config.py index 6d7eebe23..314d3f5f1 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/config.py +++ b/llama_stack/providers/remote/vector_io/qdrant/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel @@ -13,19 +13,19 @@ from llama_stack.schema_utils import json_schema_type @json_schema_type class QdrantVectorIOConfig(BaseModel): - location: Optional[str] = None - url: Optional[str] = None - port: Optional[int] = 6333 + location: str | None = None + url: str | None = None + port: int | None = 6333 grpc_port: int = 6334 prefer_grpc: bool = False - https: Optional[bool] = None - api_key: Optional[str] = None - prefix: Optional[str] = None - timeout: Optional[int] = None - host: Optional[str] = None + https: bool | None = None + api_key: str | None = None + prefix: str | None = None + timeout: int | None = None + host: str | None = None @classmethod - def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]: + def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]: return { "api_key": "${env.QDRANT_API_KEY}", } diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index 9e7788dc0..514a6c70d 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -6,7 +6,7 @@ import logging import uuid -from typing import Any, Dict, List, Optional, Union +from typing import Any from numpy.typing import NDArray from qdrant_client import AsyncQdrantClient, models @@ -44,7 +44,7 @@ class QdrantIndex(EmbeddingIndex): self.client = client self.collection_name = collection_name - async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): + async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): assert len(chunks) == len(embeddings), ( f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" ) @@ -101,7 +101,7 @@ class QdrantIndex(EmbeddingIndex): class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): def __init__( - self, config: Union[RemoteQdrantVectorIOConfig, InlineQdrantVectorIOConfig], inference_api: Api.inference + self, config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig, inference_api: Api.inference ) -> None: self.config = config self.client: AsyncQdrantClient = None @@ -131,7 +131,7 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): await self.cache[vector_db_id].index.delete() del self.cache[vector_db_id] - async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> Optional[VectorDBWithIndex]: + async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None: if vector_db_id in self.cache: return self.cache[vector_db_id] @@ -150,8 +150,8 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): async def insert_chunks( self, vector_db_id: str, - chunks: List[Chunk], - ttl_seconds: Optional[int] = None, + chunks: list[Chunk], + ttl_seconds: int | None = None, ) -> None: index = await self._get_and_cache_vector_db_index(vector_db_id) if not index: @@ -163,7 +163,7 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): self, vector_db_id: str, query: InterleavedContent, - params: Optional[Dict[str, Any]] = None, + params: dict[str, Any] | None = None, ) -> QueryChunksResponse: index = await self._get_and_cache_vector_db_index(vector_db_id) if not index: diff --git a/llama_stack/providers/remote/vector_io/weaviate/__init__.py b/llama_stack/providers/remote/vector_io/weaviate/__init__.py index c93c628d8..22e116c22 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/__init__.py +++ b/llama_stack/providers/remote/vector_io/weaviate/__init__.py @@ -4,14 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Dict - from llama_stack.providers.datatypes import Api, ProviderSpec -from .config import WeaviateRequestProviderData, WeaviateVectorIOConfig # noqa: F401 +from .config import WeaviateVectorIOConfig -async def get_adapter_impl(config: WeaviateVectorIOConfig, deps: Dict[Api, ProviderSpec]): +async def get_adapter_impl(config: WeaviateVectorIOConfig, deps: dict[Api, ProviderSpec]): from .weaviate import WeaviateVectorIOAdapter impl = WeaviateVectorIOAdapter(config, deps[Api.inference]) diff --git a/llama_stack/providers/remote/vector_io/weaviate/config.py b/llama_stack/providers/remote/vector_io/weaviate/config.py index cc587f252..a8c6e3e2c 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/config.py +++ b/llama_stack/providers/remote/vector_io/weaviate/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from pydantic import BaseModel @@ -16,5 +16,5 @@ class WeaviateRequestProviderData(BaseModel): class WeaviateVectorIOConfig(BaseModel): @classmethod - def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]: + def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]: return {} diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index 52aa2f3a3..308d2eb3d 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import json import logging -from typing import Any, Dict, List, Optional +from typing import Any import weaviate import weaviate.classes as wvc @@ -33,7 +33,7 @@ class WeaviateIndex(EmbeddingIndex): self.client = client self.collection_name = collection_name - async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): + async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): assert len(chunks) == len(embeddings), ( f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" ) @@ -80,7 +80,7 @@ class WeaviateIndex(EmbeddingIndex): return QueryChunksResponse(chunks=chunks, scores=scores) - async def delete(self, chunk_ids: List[str]) -> None: + async def delete(self, chunk_ids: list[str]) -> None: collection = self.client.collections.get(self.collection_name) collection.data.delete_many(where=Filter.by_property("id").contains_any(chunk_ids)) @@ -144,7 +144,7 @@ class WeaviateVectorIOAdapter( self.inference_api, ) - async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> Optional[VectorDBWithIndex]: + async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None: if vector_db_id in self.cache: return self.cache[vector_db_id] @@ -167,8 +167,8 @@ class WeaviateVectorIOAdapter( async def insert_chunks( self, vector_db_id: str, - chunks: List[Chunk], - ttl_seconds: Optional[int] = None, + chunks: list[Chunk], + ttl_seconds: int | None = None, ) -> None: index = await self._get_and_cache_vector_db_index(vector_db_id) if not index: @@ -180,7 +180,7 @@ class WeaviateVectorIOAdapter( self, vector_db_id: str, query: InterleavedContent, - params: Optional[Dict[str, Any]] = None, + params: dict[str, Any] | None = None, ) -> QueryChunksResponse: index = await self._get_and_cache_vector_db_index(vector_db_id) if not index: diff --git a/llama_stack/providers/tests/conftest.py b/llama_stack/providers/tests/conftest.py index d3e715b7e..cd86af0d6 100644 --- a/llama_stack/providers/tests/conftest.py +++ b/llama_stack/providers/tests/conftest.py @@ -7,7 +7,7 @@ import os from collections import defaultdict from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any import pytest import yaml @@ -23,36 +23,36 @@ from .report import Report class ProviderFixture(BaseModel): - providers: List[Provider] - provider_data: Optional[Dict[str, Any]] = None + providers: list[Provider] + provider_data: dict[str, Any] | None = None class TestScenario(BaseModel): # provider fixtures can be either a mark or a dictionary of api -> providers - provider_fixtures: Dict[str, str] = Field(default_factory=dict) - fixture_combo_id: Optional[str] = None + provider_fixtures: dict[str, str] = Field(default_factory=dict) + fixture_combo_id: str | None = None class APITestConfig(BaseModel): - scenarios: List[TestScenario] = Field(default_factory=list) - inference_models: List[str] = Field(default_factory=list) + scenarios: list[TestScenario] = Field(default_factory=list) + inference_models: list[str] = Field(default_factory=list) # test name format should be :: - tests: List[str] = Field(default_factory=list) + tests: list[str] = Field(default_factory=list) class MemoryApiTestConfig(APITestConfig): - embedding_model: Optional[str] = Field(default_factory=None) + embedding_model: str | None = Field(default_factory=None) class AgentsApiTestConfig(APITestConfig): - safety_shield: Optional[str] = Field(default_factory=None) + safety_shield: str | None = Field(default_factory=None) class TestConfig(BaseModel): - inference: Optional[APITestConfig] = None - agents: Optional[AgentsApiTestConfig] = None - memory: Optional[MemoryApiTestConfig] = None + inference: APITestConfig | None = None + agents: AgentsApiTestConfig | None = None + memory: MemoryApiTestConfig | None = None def get_test_config_from_config_file(metafunc_config): @@ -65,7 +65,7 @@ def get_test_config_from_config_file(metafunc_config): raise ValueError( f"Test config {config_file} was specified but not found. Please make sure it exists in the llama_stack/providers/tests directory." ) - with open(config_file_path, "r") as config_file: + with open(config_file_path) as config_file: config = yaml.safe_load(config_file) return TestConfig(**config) @@ -188,18 +188,18 @@ def pytest_addoption(parser): ) -def make_provider_id(providers: Dict[str, str]) -> str: +def make_provider_id(providers: dict[str, str]) -> str: return ":".join(f"{api}={provider}" for api, provider in sorted(providers.items())) -def get_provider_marks(providers: Dict[str, str]) -> List[Any]: +def get_provider_marks(providers: dict[str, str]) -> list[Any]: marks = [] for provider in providers.values(): marks.append(getattr(pytest.mark, provider)) return marks -def get_provider_fixture_overrides(config, available_fixtures: Dict[str, List[str]]) -> Optional[List[pytest.param]]: +def get_provider_fixture_overrides(config, available_fixtures: dict[str, list[str]]) -> list[pytest.param] | None: provider_str = config.getoption("--providers") if not provider_str: return None @@ -214,7 +214,7 @@ def get_provider_fixture_overrides(config, available_fixtures: Dict[str, List[st ] -def parse_fixture_string(provider_str: str, available_fixtures: Dict[str, List[str]]) -> Dict[str, str]: +def parse_fixture_string(provider_str: str, available_fixtures: dict[str, list[str]]) -> dict[str, str]: """Parse provider string of format 'api1=provider1,api2=provider2'""" if not provider_str: return {} diff --git a/llama_stack/providers/utils/bedrock/config.py b/llama_stack/providers/utils/bedrock/config.py index 95019666b..b25617d76 100644 --- a/llama_stack/providers/utils/bedrock/config.py +++ b/llama_stack/providers/utils/bedrock/config.py @@ -3,54 +3,53 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Optional from pydantic import BaseModel, Field class BedrockBaseConfig(BaseModel): - aws_access_key_id: Optional[str] = Field( + aws_access_key_id: str | None = Field( default=None, description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID", ) - aws_secret_access_key: Optional[str] = Field( + aws_secret_access_key: str | None = Field( default=None, description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY", ) - aws_session_token: Optional[str] = Field( + aws_session_token: str | None = Field( default=None, description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN", ) - region_name: Optional[str] = Field( + region_name: str | None = Field( default=None, description="The default AWS Region to use, for example, us-west-1 or us-west-2." "Default use environment variable: AWS_DEFAULT_REGION", ) - profile_name: Optional[str] = Field( + profile_name: str | None = Field( default=None, description="The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE", ) - total_max_attempts: Optional[int] = Field( + total_max_attempts: int | None = Field( default=None, description="An integer representing the maximum number of attempts that will be made for a single request, " "including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS", ) - retry_mode: Optional[str] = Field( + retry_mode: str | None = Field( default=None, description="A string representing the type of retries Boto3 will perform." "Default use environment variable: AWS_RETRY_MODE", ) - connect_timeout: Optional[float] = Field( + connect_timeout: float | None = Field( default=60, description="The time in seconds till a timeout exception is thrown when attempting to make a connection. " "The default is 60 seconds.", ) - read_timeout: Optional[float] = Field( + read_timeout: float | None = Field( default=60, description="The time in seconds till a timeout exception is thrown when attempting to read from a connection." "The default is 60 seconds.", ) - session_ttl: Optional[int] = Field( + session_ttl: int | None = Field( default=3600, description="The time in seconds till a session expires. The default is 3600 seconds (1 hour).", ) diff --git a/llama_stack/providers/utils/common/data_schema_validator.py b/llama_stack/providers/utils/common/data_schema_validator.py index eb9d9dd60..28a243863 100644 --- a/llama_stack/providers/utils/common/data_schema_validator.py +++ b/llama_stack/providers/utils/common/data_schema_validator.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from enum import Enum -from typing import Any, Dict, List +from typing import Any from llama_stack.apis.common.type_system import ( ChatCompletionInputType, @@ -85,16 +85,16 @@ def get_valid_schemas(api_str: str): def validate_dataset_schema( - dataset_schema: Dict[str, Any], - expected_schemas: List[Dict[str, Any]], + dataset_schema: dict[str, Any], + expected_schemas: list[dict[str, Any]], ): if dataset_schema not in expected_schemas: raise ValueError(f"Dataset {dataset_schema} does not have a correct input schema in {expected_schemas}") def validate_row_schema( - input_row: Dict[str, Any], - expected_schemas: List[Dict[str, Any]], + input_row: dict[str, Any], + expected_schemas: list[dict[str, Any]], ): for schema in expected_schemas: if all(key in input_row for key in schema): diff --git a/llama_stack/providers/utils/datasetio/pagination.py b/llama_stack/providers/utils/datasetio/pagination.py index 1b693f8f5..033022491 100644 --- a/llama_stack/providers/utils/datasetio/pagination.py +++ b/llama_stack/providers/utils/datasetio/pagination.py @@ -4,13 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List +from typing import Any from llama_stack.apis.common.responses import PaginatedResponse def paginate_records( - records: List[Dict[str, Any]], + records: list[dict[str, Any]], start_index: int | None = None, limit: int | None = None, ) -> PaginatedResponse: diff --git a/llama_stack/providers/utils/inference/__init__.py b/llama_stack/providers/utils/inference/__init__.py index e36be9404..66269d173 100644 --- a/llama_stack/providers/utils/inference/__init__.py +++ b/llama_stack/providers/utils/inference/__init__.py @@ -4,8 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List - from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.models.llama.sku_types import * # noqa: F403 @@ -22,7 +20,7 @@ def is_supported_safety_model(model: Model) -> bool: ] -def supported_inference_models() -> List[Model]: +def supported_inference_models() -> list[Model]: return [ m for m in all_registered_models() diff --git a/llama_stack/providers/utils/inference/embedding_mixin.py b/llama_stack/providers/utils/inference/embedding_mixin.py index 8b14c7502..7c8144c62 100644 --- a/llama_stack/providers/utils/inference/embedding_mixin.py +++ b/llama_stack/providers/utils/inference/embedding_mixin.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import logging -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING if TYPE_CHECKING: from sentence_transformers import SentenceTransformer @@ -31,10 +31,10 @@ class SentenceTransformerEmbeddingMixin: async def embeddings( self, model_id: str, - contents: List[str] | List[InterleavedContentItem], - text_truncation: Optional[TextTruncation] = TextTruncation.none, - output_dimension: Optional[int] = None, - task_type: Optional[EmbeddingTaskType] = None, + contents: list[str] | list[InterleavedContentItem], + text_truncation: TextTruncation | None = TextTruncation.none, + output_dimension: int | None = None, + task_type: EmbeddingTaskType | None = None, ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) embedding_model = self._load_sentence_transformer_model(model.provider_resource_id) diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 098891e37..c3c2ab61f 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -4,7 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union +from collections.abc import AsyncGenerator, AsyncIterator +from typing import Any import litellm @@ -64,7 +65,7 @@ class LiteLLMOpenAIMixin( def __init__( self, model_entries, - api_key_from_config: Optional[str], + api_key_from_config: str | None, provider_data_api_key_field: str, openai_compat_api_base: str | None = None, ): @@ -97,26 +98,26 @@ class LiteLLMOpenAIMixin( self, model_id: str, content: InterleavedContent, - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, ) -> AsyncGenerator: raise NotImplementedError("LiteLLM does not support completion requests") async def chat_completion( self, model_id: str, - messages: List[Message], - sampling_params: Optional[SamplingParams] = None, - tools: Optional[List[ToolDefinition]] = None, - tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = None, - response_format: Optional[ResponseFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, - tool_config: Optional[ToolConfig] = None, - ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: + messages: list[Message], + sampling_params: SamplingParams | None = None, + tools: list[ToolDefinition] | None = None, + tool_choice: ToolChoice | None = ToolChoice.auto, + tool_prompt_format: ToolPromptFormat | None = None, + response_format: ResponseFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, + tool_config: ToolConfig | None = None, + ) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]: if sampling_params is None: sampling_params = SamplingParams() @@ -243,10 +244,10 @@ class LiteLLMOpenAIMixin( async def embeddings( self, model_id: str, - contents: List[str] | List[InterleavedContentItem], - text_truncation: Optional[TextTruncation] = TextTruncation.none, - output_dimension: Optional[int] = None, - task_type: Optional[EmbeddingTaskType] = None, + contents: list[str] | list[InterleavedContentItem], + text_truncation: TextTruncation | None = TextTruncation.none, + output_dimension: int | None = None, + task_type: EmbeddingTaskType | None = None, ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) @@ -261,24 +262,24 @@ class LiteLLMOpenAIMixin( async def openai_completion( self, model: str, - prompt: Union[str, List[str], List[int], List[List[int]]], - best_of: Optional[int] = None, - echo: Optional[bool] = None, - frequency_penalty: Optional[float] = None, - logit_bias: Optional[Dict[str, float]] = None, - logprobs: Optional[bool] = None, - max_tokens: Optional[int] = None, - n: Optional[int] = None, - presence_penalty: Optional[float] = None, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - stream: Optional[bool] = None, - stream_options: Optional[Dict[str, Any]] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - user: Optional[str] = None, - guided_choice: Optional[List[str]] = None, - prompt_logprobs: Optional[int] = None, + prompt: str | list[str] | list[int] | list[list[int]], + best_of: int | None = None, + echo: bool | None = None, + frequency_penalty: float | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_tokens: int | None = None, + n: int | None = None, + presence_penalty: float | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + top_p: float | None = None, + user: str | None = None, + guided_choice: list[str] | None = None, + prompt_logprobs: int | None = None, ) -> OpenAICompletion: model_obj = await self.model_store.get_model(model) params = await prepare_openai_completion_params( @@ -309,29 +310,29 @@ class LiteLLMOpenAIMixin( async def openai_chat_completion( self, model: str, - messages: List[OpenAIMessageParam], - frequency_penalty: Optional[float] = None, - function_call: Optional[Union[str, Dict[str, Any]]] = None, - functions: Optional[List[Dict[str, Any]]] = None, - logit_bias: Optional[Dict[str, float]] = None, - logprobs: Optional[bool] = None, - max_completion_tokens: Optional[int] = None, - max_tokens: Optional[int] = None, - n: Optional[int] = None, - parallel_tool_calls: Optional[bool] = None, - presence_penalty: Optional[float] = None, - response_format: Optional[OpenAIResponseFormatParam] = None, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - stream: Optional[bool] = None, - stream_options: Optional[Dict[str, Any]] = None, - temperature: Optional[float] = None, - tool_choice: Optional[Union[str, Dict[str, Any]]] = None, - tools: Optional[List[Dict[str, Any]]] = None, - top_logprobs: Optional[int] = None, - top_p: Optional[float] = None, - user: Optional[str] = None, - ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: + messages: list[OpenAIMessageParam], + frequency_penalty: float | None = None, + function_call: str | dict[str, Any] | None = None, + functions: list[dict[str, Any]] | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_completion_tokens: int | None = None, + max_tokens: int | None = None, + n: int | None = None, + parallel_tool_calls: bool | None = None, + presence_penalty: float | None = None, + response_format: OpenAIResponseFormatParam | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + tool_choice: str | dict[str, Any] | None = None, + tools: list[dict[str, Any]] | None = None, + top_logprobs: int | None = None, + top_p: float | None = None, + user: str | None = None, + ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: model_obj = await self.model_store.get_model(model) params = await prepare_openai_completion_params( model=self.get_litellm_model_name(model_obj.provider_resource_id), @@ -365,21 +366,21 @@ class LiteLLMOpenAIMixin( async def batch_completion( self, model_id: str, - content_batch: List[InterleavedContent], - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - logprobs: Optional[LogProbConfig] = None, + content_batch: list[InterleavedContent], + sampling_params: SamplingParams | None = None, + response_format: ResponseFormat | None = None, + logprobs: LogProbConfig | None = None, ): raise NotImplementedError("Batch completion is not supported for OpenAI Compat") async def batch_chat_completion( self, model_id: str, - messages_batch: List[List[Message]], - sampling_params: Optional[SamplingParams] = None, - tools: Optional[List[ToolDefinition]] = None, - tool_config: Optional[ToolConfig] = None, - response_format: Optional[ResponseFormat] = None, - logprobs: Optional[LogProbConfig] = None, + messages_batch: list[list[Message]], + sampling_params: SamplingParams | None = None, + tools: list[ToolDefinition] | None = None, + tool_config: ToolConfig | None = None, + response_format: ResponseFormat | None = None, + logprobs: LogProbConfig | None = None, ): raise NotImplementedError("Batch chat completion is not supported for OpenAI Compat") diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index c5199b0a8..d707e36c2 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Optional +from typing import Any from pydantic import BaseModel, Field @@ -20,13 +20,13 @@ from llama_stack.providers.utils.inference import ( # more closer to the Model class. class ProviderModelEntry(BaseModel): provider_model_id: str - aliases: List[str] = Field(default_factory=list) - llama_model: Optional[str] = None + aliases: list[str] = Field(default_factory=list) + llama_model: str | None = None model_type: ModelType = ModelType.llm - metadata: Dict[str, Any] = Field(default_factory=dict) + metadata: dict[str, Any] = Field(default_factory=dict) -def get_huggingface_repo(model_descriptor: str) -> Optional[str]: +def get_huggingface_repo(model_descriptor: str) -> str | None: for model in all_registered_models(): if model.descriptor() == model_descriptor: return model.huggingface_repo @@ -34,7 +34,7 @@ def get_huggingface_repo(model_descriptor: str) -> Optional[str]: def build_hf_repo_model_entry( - provider_model_id: str, model_descriptor: str, additional_aliases: Optional[List[str]] = None + provider_model_id: str, model_descriptor: str, additional_aliases: list[str] | None = None ) -> ProviderModelEntry: aliases = [ get_huggingface_repo(model_descriptor), @@ -58,7 +58,7 @@ def build_model_entry(provider_model_id: str, model_descriptor: str) -> Provider class ModelRegistryHelper(ModelsProtocolPrivate): - def __init__(self, model_entries: List[ProviderModelEntry]): + def __init__(self, model_entries: list[ProviderModelEntry]): self.alias_to_provider_id_map = {} self.provider_id_to_llama_model_map = {} for entry in model_entries: @@ -72,11 +72,11 @@ class ModelRegistryHelper(ModelsProtocolPrivate): self.alias_to_provider_id_map[entry.llama_model] = entry.provider_model_id self.provider_id_to_llama_model_map[entry.provider_model_id] = entry.llama_model - def get_provider_model_id(self, identifier: str) -> Optional[str]: + def get_provider_model_id(self, identifier: str) -> str | None: return self.alias_to_provider_id_map.get(identifier, None) # TODO: why keep a separate llama model mapping? - def get_llama_model(self, provider_model_id: str) -> Optional[str]: + def get_llama_model(self, provider_model_id: str) -> str | None: return self.provider_id_to_llama_model_map.get(provider_model_id, None) async def register_model(self, model: Model) -> Model: diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index d9dfba110..f90245c08 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -8,16 +8,9 @@ import logging import time import uuid import warnings +from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Iterable from typing import ( Any, - AsyncGenerator, - AsyncIterator, - Awaitable, - Dict, - Iterable, - List, - Optional, - Union, ) from openai import AsyncStream @@ -141,24 +134,24 @@ class OpenAICompatCompletionChoiceDelta(BaseModel): class OpenAICompatLogprobs(BaseModel): - text_offset: Optional[List[int]] = None + text_offset: list[int] | None = None - token_logprobs: Optional[List[float]] = None + token_logprobs: list[float] | None = None - tokens: Optional[List[str]] = None + tokens: list[str] | None = None - top_logprobs: Optional[List[Dict[str, float]]] = None + top_logprobs: list[dict[str, float]] | None = None class OpenAICompatCompletionChoice(BaseModel): - finish_reason: Optional[str] = None - text: Optional[str] = None - delta: Optional[OpenAICompatCompletionChoiceDelta] = None - logprobs: Optional[OpenAICompatLogprobs] = None + finish_reason: str | None = None + text: str | None = None + delta: OpenAICompatCompletionChoiceDelta | None = None + logprobs: OpenAICompatLogprobs | None = None class OpenAICompatCompletionResponse(BaseModel): - choices: List[OpenAICompatCompletionChoice] + choices: list[OpenAICompatCompletionChoice] def get_sampling_strategy_options(params: SamplingParams) -> dict: @@ -217,8 +210,8 @@ def get_stop_reason(finish_reason: str) -> StopReason: def convert_openai_completion_logprobs( - logprobs: Optional[OpenAICompatLogprobs], -) -> Optional[List[TokenLogProbs]]: + logprobs: OpenAICompatLogprobs | None, +) -> list[TokenLogProbs] | None: if not logprobs: return None if hasattr(logprobs, "top_logprobs"): @@ -235,7 +228,7 @@ def convert_openai_completion_logprobs( return None -def convert_openai_completion_logprobs_stream(text: str, logprobs: Optional[Union[float, OpenAICompatLogprobs]]): +def convert_openai_completion_logprobs_stream(text: str, logprobs: float | OpenAICompatLogprobs | None): if logprobs is None: return None if isinstance(logprobs, float): @@ -562,7 +555,7 @@ class UnparseableToolCall(BaseModel): async def convert_message_to_openai_dict_new( - message: Message | Dict, + message: Message | dict, ) -> OpenAIChatCompletionMessage: """ Convert a Message to an OpenAI API-compatible dictionary. @@ -591,14 +584,10 @@ async def convert_message_to_openai_dict_new( # List[...] -> List[...] async def _convert_message_content( content: InterleavedContent, - ) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]: + ) -> str | Iterable[OpenAIChatCompletionContentPartParam]: async def impl( content_: InterleavedContent, - ) -> Union[ - str, - OpenAIChatCompletionContentPartParam, - List[OpenAIChatCompletionContentPartParam], - ]: + ) -> str | OpenAIChatCompletionContentPartParam | list[OpenAIChatCompletionContentPartParam]: # Llama Stack and OpenAI spec match for str and text input if isinstance(content_, str): return content_ @@ -670,7 +659,7 @@ async def convert_message_to_openai_dict_new( def convert_tool_call( tool_call: ChatCompletionMessageToolCall, -) -> Union[ToolCall, UnparseableToolCall]: +) -> ToolCall | UnparseableToolCall: """ Convert a ChatCompletionMessageToolCall tool call to either a ToolCall or UnparseableToolCall. Returns an UnparseableToolCall @@ -846,7 +835,7 @@ def _convert_openai_finish_reason(finish_reason: str) -> StopReason: }.get(finish_reason, StopReason.end_of_turn) -def _convert_openai_request_tool_config(tool_choice: Optional[Union[str, Dict[str, Any]]] = None) -> ToolConfig: +def _convert_openai_request_tool_config(tool_choice: str | dict[str, Any] | None = None) -> ToolConfig: tool_config = ToolConfig() if tool_choice: try: @@ -857,7 +846,7 @@ def _convert_openai_request_tool_config(tool_choice: Optional[Union[str, Dict[st return tool_config -def _convert_openai_request_tools(tools: Optional[List[Dict[str, Any]]] = None) -> List[ToolDefinition]: +def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) -> list[ToolDefinition]: lls_tools = [] if not tools: return lls_tools @@ -903,8 +892,8 @@ def _convert_openai_request_response_format( def _convert_openai_tool_calls( - tool_calls: List[OpenAIChatCompletionMessageToolCall], -) -> List[ToolCall]: + tool_calls: list[OpenAIChatCompletionMessageToolCall], +) -> list[ToolCall]: """ Convert an OpenAI ChatCompletionMessageToolCall list into a list of ToolCall. @@ -940,7 +929,7 @@ def _convert_openai_tool_calls( def _convert_openai_logprobs( logprobs: OpenAIChoiceLogprobs, -) -> Optional[List[TokenLogProbs]]: +) -> list[TokenLogProbs] | None: """ Convert an OpenAI ChoiceLogprobs into a list of TokenLogProbs. @@ -973,9 +962,9 @@ def _convert_openai_logprobs( def _convert_openai_sampling_params( - max_tokens: Optional[int] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, + max_tokens: int | None = None, + temperature: float | None = None, + top_p: float | None = None, ) -> SamplingParams: sampling_params = SamplingParams() @@ -998,8 +987,8 @@ def _convert_openai_sampling_params( def openai_messages_to_messages( - messages: List[OpenAIChatCompletionMessage], -) -> List[Message]: + messages: list[OpenAIChatCompletionMessage], +) -> list[Message]: """ Convert a list of OpenAIChatCompletionMessage into a list of Message. """ @@ -1027,7 +1016,7 @@ def openai_messages_to_messages( return converted_messages -def openai_content_to_content(content: Union[str, Iterable[OpenAIChatCompletionContentPartParam]]): +def openai_content_to_content(content: str | Iterable[OpenAIChatCompletionContentPartParam]): if isinstance(content, str): return content elif isinstance(content, list): @@ -1273,24 +1262,24 @@ class OpenAICompletionToLlamaStackMixin: async def openai_completion( self, model: str, - prompt: Union[str, List[str], List[int], List[List[int]]], - best_of: Optional[int] = None, - echo: Optional[bool] = None, - frequency_penalty: Optional[float] = None, - logit_bias: Optional[Dict[str, float]] = None, - logprobs: Optional[bool] = None, - max_tokens: Optional[int] = None, - n: Optional[int] = None, - presence_penalty: Optional[float] = None, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - stream: Optional[bool] = None, - stream_options: Optional[Dict[str, Any]] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - user: Optional[str] = None, - guided_choice: Optional[List[str]] = None, - prompt_logprobs: Optional[int] = None, + prompt: str | list[str] | list[int] | list[list[int]], + best_of: int | None = None, + echo: bool | None = None, + frequency_penalty: float | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_tokens: int | None = None, + n: int | None = None, + presence_penalty: float | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + top_p: float | None = None, + user: str | None = None, + guided_choice: list[str] | None = None, + prompt_logprobs: int | None = None, ) -> OpenAICompletion: if stream: raise ValueError(f"{self.__class__.__name__} doesn't support streaming openai completions") @@ -1342,29 +1331,29 @@ class OpenAIChatCompletionToLlamaStackMixin: async def openai_chat_completion( self, model: str, - messages: List[OpenAIChatCompletionMessage], - frequency_penalty: Optional[float] = None, - function_call: Optional[Union[str, Dict[str, Any]]] = None, - functions: Optional[List[Dict[str, Any]]] = None, - logit_bias: Optional[Dict[str, float]] = None, - logprobs: Optional[bool] = None, - max_completion_tokens: Optional[int] = None, - max_tokens: Optional[int] = None, - n: Optional[int] = None, - parallel_tool_calls: Optional[bool] = None, - presence_penalty: Optional[float] = None, - response_format: Optional[OpenAIResponseFormatParam] = None, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - stream: Optional[bool] = None, - stream_options: Optional[Dict[str, Any]] = None, - temperature: Optional[float] = None, - tool_choice: Optional[Union[str, Dict[str, Any]]] = None, - tools: Optional[List[Dict[str, Any]]] = None, - top_logprobs: Optional[int] = None, - top_p: Optional[float] = None, - user: Optional[str] = None, - ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: + messages: list[OpenAIChatCompletionMessage], + frequency_penalty: float | None = None, + function_call: str | dict[str, Any] | None = None, + functions: list[dict[str, Any]] | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_completion_tokens: int | None = None, + max_tokens: int | None = None, + n: int | None = None, + parallel_tool_calls: bool | None = None, + presence_penalty: float | None = None, + response_format: OpenAIResponseFormatParam | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + tool_choice: str | dict[str, Any] | None = None, + tools: list[dict[str, Any]] | None = None, + top_logprobs: int | None = None, + top_p: float | None = None, + user: str | None = None, + ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: messages = openai_messages_to_messages(messages) response_format = _convert_openai_request_response_format(response_format) sampling_params = _convert_openai_sampling_params( @@ -1403,7 +1392,7 @@ class OpenAIChatCompletionToLlamaStackMixin: async def _process_stream_response( self, model: str, - outstanding_responses: List[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]], + outstanding_responses: list[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]], ): id = f"chatcmpl-{uuid.uuid4()}" for outstanding_response in outstanding_responses: @@ -1466,7 +1455,7 @@ class OpenAIChatCompletionToLlamaStackMixin: i = i + 1 async def _process_non_stream_response( - self, model: str, outstanding_responses: List[Awaitable[ChatCompletionResponse]] + self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]] ) -> OpenAIChatCompletion: choices = [] for outstanding_response in outstanding_responses: diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 657dc4b86..d53b51537 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -9,7 +9,6 @@ import base64 import io import json import re -from typing import List, Optional, Tuple, Union import httpx from PIL import Image as PIL_Image @@ -63,7 +62,7 @@ log = get_logger(name=__name__, category="inference") class ChatCompletionRequestWithRawContent(ChatCompletionRequest): - messages: List[RawMessage] + messages: list[RawMessage] class CompletionRequestWithRawContent(CompletionRequest): @@ -93,8 +92,8 @@ def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> s async def convert_request_to_raw( - request: Union[ChatCompletionRequest, CompletionRequest], -) -> Union[ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent]: + request: ChatCompletionRequest | CompletionRequest, +) -> ChatCompletionRequestWithRawContent | CompletionRequestWithRawContent: if isinstance(request, ChatCompletionRequest): messages = [] for m in request.messages: @@ -170,18 +169,18 @@ def content_has_media(content: InterleavedContent): return _has_media_content(content) -def messages_have_media(messages: List[Message]): +def messages_have_media(messages: list[Message]): return any(content_has_media(m.content) for m in messages) -def request_has_media(request: Union[ChatCompletionRequest, CompletionRequest]): +def request_has_media(request: ChatCompletionRequest | CompletionRequest): if isinstance(request, ChatCompletionRequest): return messages_have_media(request.messages) else: return content_has_media(request.content) -async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]: +async def localize_image_content(media: ImageContentItem) -> tuple[bytes, str]: image = media.image if image.url and image.url.uri.startswith("http"): async with httpx.AsyncClient() as client: @@ -228,7 +227,7 @@ async def completion_request_to_prompt(request: CompletionRequest) -> str: async def completion_request_to_prompt_model_input_info( request: CompletionRequest, -) -> Tuple[str, int]: +) -> tuple[str, int]: content = augment_content_with_response_format_prompt(request.response_format, request.content) request.content = content request = await convert_request_to_raw(request) @@ -265,7 +264,7 @@ async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llam async def chat_completion_request_to_model_input_info( request: ChatCompletionRequest, llama_model: str -) -> Tuple[str, int]: +) -> tuple[str, int]: messages = chat_completion_request_to_messages(request, llama_model) request.messages = messages request = await convert_request_to_raw(request) @@ -284,7 +283,7 @@ async def chat_completion_request_to_model_input_info( def chat_completion_request_to_messages( request: ChatCompletionRequest, llama_model: str, -) -> List[Message]: +) -> list[Message]: """Reads chat completion request and augments the messages to handle tools. For eg. for llama_3_1, add system message with the appropriate tools or add user messsage for custom tools, etc. @@ -323,7 +322,7 @@ def chat_completion_request_to_messages( return messages -def response_format_prompt(fmt: Optional[ResponseFormat]): +def response_format_prompt(fmt: ResponseFormat | None): if not fmt: return None @@ -337,7 +336,7 @@ def response_format_prompt(fmt: Optional[ResponseFormat]): def augment_messages_for_tools_llama_3_1( request: ChatCompletionRequest, -) -> List[Message]: +) -> list[Message]: existing_messages = request.messages existing_system_message = None if existing_messages[0].role == Role.system.value: @@ -406,7 +405,7 @@ def augment_messages_for_tools_llama_3_1( def augment_messages_for_tools_llama( request: ChatCompletionRequest, custom_tool_prompt_generator, -) -> List[Message]: +) -> list[Message]: existing_messages = request.messages existing_system_message = None if existing_messages[0].role == Role.system.value: @@ -457,7 +456,7 @@ def augment_messages_for_tools_llama( return messages -def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefinition]) -> str: +def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: list[ToolDefinition]) -> str: if tool_choice == ToolChoice.auto: return "" elif tool_choice == ToolChoice.required: diff --git a/llama_stack/providers/utils/kvstore/api.py b/llama_stack/providers/utils/kvstore/api.py index 84b1730e1..8efde4ea9 100644 --- a/llama_stack/providers/utils/kvstore/api.py +++ b/llama_stack/providers/utils/kvstore/api.py @@ -5,15 +5,15 @@ # the root directory of this source tree. from datetime import datetime -from typing import List, Optional, Protocol +from typing import Protocol class KVStore(Protocol): # TODO: make the value type bytes instead of str - async def set(self, key: str, value: str, expiration: Optional[datetime] = None) -> None: ... + async def set(self, key: str, value: str, expiration: datetime | None = None) -> None: ... - async def get(self, key: str) -> Optional[str]: ... + async def get(self, key: str) -> str | None: ... async def delete(self, key: str) -> None: ... - async def range(self, start_key: str, end_key: str) -> List[str]: ... + async def range(self, start_key: str, end_key: str) -> list[str]: ... diff --git a/llama_stack/providers/utils/kvstore/config.py b/llama_stack/providers/utils/kvstore/config.py index 4f85982be..e9aac6e8c 100644 --- a/llama_stack/providers/utils/kvstore/config.py +++ b/llama_stack/providers/utils/kvstore/config.py @@ -6,10 +6,9 @@ import re from enum import Enum -from typing import Literal, Optional, Union +from typing import Annotated, Literal from pydantic import BaseModel, Field, field_validator -from typing_extensions import Annotated from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR @@ -22,7 +21,7 @@ class KVStoreType(Enum): class CommonConfig(BaseModel): - namespace: Optional[str] = Field( + namespace: str | None = Field( default=None, description="All keys will be prefixed with this namespace", ) @@ -69,7 +68,7 @@ class PostgresKVStoreConfig(CommonConfig): port: int = 5432 db: str = "llamastack" user: str - password: Optional[str] = None + password: str | None = None table_name: str = "llamastack_kvstore" @classmethod @@ -108,7 +107,7 @@ class MongoDBKVStoreConfig(CommonConfig): port: int = 27017 db: str = "llamastack" user: str = None - password: Optional[str] = None + password: str | None = None collection_name: str = "llamastack_kvstore" @classmethod @@ -126,6 +125,6 @@ class MongoDBKVStoreConfig(CommonConfig): KVStoreConfig = Annotated[ - Union[RedisKVStoreConfig, SqliteKVStoreConfig, PostgresKVStoreConfig, MongoDBKVStoreConfig], + RedisKVStoreConfig | SqliteKVStoreConfig | PostgresKVStoreConfig | MongoDBKVStoreConfig, Field(discriminator="type", default=KVStoreType.sqlite.value), ] diff --git a/llama_stack/providers/utils/kvstore/kvstore.py b/llama_stack/providers/utils/kvstore/kvstore.py index 6bc175260..0eb969b65 100644 --- a/llama_stack/providers/utils/kvstore/kvstore.py +++ b/llama_stack/providers/utils/kvstore/kvstore.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List, Optional from .api import KVStore from .config import KVStoreConfig, KVStoreType @@ -21,13 +20,13 @@ class InmemoryKVStoreImpl(KVStore): async def initialize(self) -> None: pass - async def get(self, key: str) -> Optional[str]: + async def get(self, key: str) -> str | None: return self._store.get(key) async def set(self, key: str, value: str) -> None: self._store[key] = value - async def range(self, start_key: str, end_key: str) -> List[str]: + async def range(self, start_key: str, end_key: str) -> list[str]: return [self._store[key] for key in self._store.keys() if key >= start_key and key < end_key] diff --git a/llama_stack/providers/utils/kvstore/mongodb/mongodb.py b/llama_stack/providers/utils/kvstore/mongodb/mongodb.py index c1581dc8d..330a079bf 100644 --- a/llama_stack/providers/utils/kvstore/mongodb/mongodb.py +++ b/llama_stack/providers/utils/kvstore/mongodb/mongodb.py @@ -6,7 +6,6 @@ import logging from datetime import datetime -from typing import List, Optional from pymongo import AsyncMongoClient @@ -43,12 +42,12 @@ class MongoDBKVStoreImpl(KVStore): return key return f"{self.config.namespace}:{key}" - async def set(self, key: str, value: str, expiration: Optional[datetime] = None) -> None: + async def set(self, key: str, value: str, expiration: datetime | None = None) -> None: key = self._namespaced_key(key) update_query = {"$set": {"value": value, "expiration": expiration}} await self.collection.update_one({"key": key}, update_query, upsert=True) - async def get(self, key: str) -> Optional[str]: + async def get(self, key: str) -> str | None: key = self._namespaced_key(key) query = {"key": key} result = await self.collection.find_one(query, {"value": 1, "_id": 0}) @@ -58,7 +57,7 @@ class MongoDBKVStoreImpl(KVStore): key = self._namespaced_key(key) await self.collection.delete_one({"key": key}) - async def range(self, start_key: str, end_key: str) -> List[str]: + async def range(self, start_key: str, end_key: str) -> list[str]: start_key = self._namespaced_key(start_key) end_key = self._namespaced_key(end_key) query = { diff --git a/llama_stack/providers/utils/kvstore/postgres/postgres.py b/llama_stack/providers/utils/kvstore/postgres/postgres.py index 097d36066..6bfbc4f81 100644 --- a/llama_stack/providers/utils/kvstore/postgres/postgres.py +++ b/llama_stack/providers/utils/kvstore/postgres/postgres.py @@ -6,7 +6,6 @@ import logging from datetime import datetime -from typing import List, Optional import psycopg2 from psycopg2.extras import DictCursor @@ -54,7 +53,7 @@ class PostgresKVStoreImpl(KVStore): return key return f"{self.config.namespace}:{key}" - async def set(self, key: str, value: str, expiration: Optional[datetime] = None) -> None: + async def set(self, key: str, value: str, expiration: datetime | None = None) -> None: key = self._namespaced_key(key) self.cursor.execute( f""" @@ -66,7 +65,7 @@ class PostgresKVStoreImpl(KVStore): (key, value, expiration), ) - async def get(self, key: str) -> Optional[str]: + async def get(self, key: str) -> str | None: key = self._namespaced_key(key) self.cursor.execute( f""" @@ -86,7 +85,7 @@ class PostgresKVStoreImpl(KVStore): (key,), ) - async def range(self, start_key: str, end_key: str) -> List[str]: + async def range(self, start_key: str, end_key: str) -> list[str]: start_key = self._namespaced_key(start_key) end_key = self._namespaced_key(end_key) diff --git a/llama_stack/providers/utils/kvstore/redis/redis.py b/llama_stack/providers/utils/kvstore/redis/redis.py index a390ea866..d95de0291 100644 --- a/llama_stack/providers/utils/kvstore/redis/redis.py +++ b/llama_stack/providers/utils/kvstore/redis/redis.py @@ -5,7 +5,6 @@ # the root directory of this source tree. from datetime import datetime -from typing import List, Optional from redis.asyncio import Redis @@ -25,13 +24,13 @@ class RedisKVStoreImpl(KVStore): return key return f"{self.config.namespace}:{key}" - async def set(self, key: str, value: str, expiration: Optional[datetime] = None) -> None: + async def set(self, key: str, value: str, expiration: datetime | None = None) -> None: key = self._namespaced_key(key) await self.redis.set(key, value) if expiration: await self.redis.expireat(key, expiration) - async def get(self, key: str) -> Optional[str]: + async def get(self, key: str) -> str | None: key = self._namespaced_key(key) value = await self.redis.get(key) if value is None: @@ -43,7 +42,7 @@ class RedisKVStoreImpl(KVStore): key = self._namespaced_key(key) await self.redis.delete(key) - async def range(self, start_key: str, end_key: str) -> List[str]: + async def range(self, start_key: str, end_key: str) -> list[str]: start_key = self._namespaced_key(start_key) end_key = self._namespaced_key(end_key) cursor = 0 diff --git a/llama_stack/providers/utils/kvstore/sqlite/sqlite.py b/llama_stack/providers/utils/kvstore/sqlite/sqlite.py index bc0488aac..faca77887 100644 --- a/llama_stack/providers/utils/kvstore/sqlite/sqlite.py +++ b/llama_stack/providers/utils/kvstore/sqlite/sqlite.py @@ -6,7 +6,6 @@ import os from datetime import datetime -from typing import List, Optional import aiosqlite @@ -33,7 +32,7 @@ class SqliteKVStoreImpl(KVStore): ) await db.commit() - async def set(self, key: str, value: str, expiration: Optional[datetime] = None) -> None: + async def set(self, key: str, value: str, expiration: datetime | None = None) -> None: async with aiosqlite.connect(self.db_path) as db: await db.execute( f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)", @@ -41,7 +40,7 @@ class SqliteKVStoreImpl(KVStore): ) await db.commit() - async def get(self, key: str) -> Optional[str]: + async def get(self, key: str) -> str | None: async with aiosqlite.connect(self.db_path) as db: async with db.execute(f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)) as cursor: row = await cursor.fetchone() @@ -55,7 +54,7 @@ class SqliteKVStoreImpl(KVStore): await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,)) await db.commit() - async def range(self, start_key: str, end_key: str) -> List[str]: + async def range(self, start_key: str, end_key: str) -> list[str]: async with aiosqlite.connect(self.db_path) as db: async with db.execute( f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?", diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index ba4403ea1..f4834969a 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -9,7 +9,7 @@ import logging import re from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Any from urllib.parse import unquote import httpx @@ -94,7 +94,7 @@ def content_from_data(data_url: str) -> str: return "" -def concat_interleaved_content(content: List[InterleavedContent]) -> InterleavedContent: +def concat_interleaved_content(content: list[InterleavedContent]) -> InterleavedContent: """concatenate interleaved content into a single list. ensure that 'str's are converted to TextContentItem when in a list""" ret = [] @@ -141,7 +141,7 @@ async def content_from_doc(doc: RAGDocument) -> str: return interleaved_content_as_str(doc.content) -def make_overlapped_chunks(document_id: str, text: str, window_len: int, overlap_len: int) -> List[Chunk]: +def make_overlapped_chunks(document_id: str, text: str, window_len: int, overlap_len: int) -> list[Chunk]: tokenizer = Tokenizer.get_instance() tokens = tokenizer.encode(text, bos=False, eos=False) @@ -165,7 +165,7 @@ def make_overlapped_chunks(document_id: str, text: str, window_len: int, overlap class EmbeddingIndex(ABC): @abstractmethod - async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): + async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): raise NotImplementedError() @abstractmethod @@ -185,7 +185,7 @@ class VectorDBWithIndex: async def insert_chunks( self, - chunks: List[Chunk], + chunks: list[Chunk], ) -> None: embeddings_response = await self.inference_api.embeddings( self.vector_db.embedding_model, [x.content for x in chunks] @@ -197,7 +197,7 @@ class VectorDBWithIndex: async def query_chunks( self, query: InterleavedContent, - params: Optional[Dict[str, Any]] = None, + params: dict[str, Any] | None = None, ) -> QueryChunksResponse: if params is None: params = {} diff --git a/llama_stack/providers/utils/scheduler.py b/llama_stack/providers/utils/scheduler.py index d4cffe605..845ab1f02 100644 --- a/llama_stack/providers/utils/scheduler.py +++ b/llama_stack/providers/utils/scheduler.py @@ -8,9 +8,10 @@ import abc import asyncio import functools import threading +from collections.abc import Callable, Coroutine, Iterable from datetime import datetime, timezone from enum import Enum -from typing import Any, Callable, Coroutine, Dict, Iterable, Tuple, TypeAlias +from typing import Any, TypeAlias from pydantic import BaseModel @@ -38,7 +39,7 @@ class JobArtifact(BaseModel): name: str # TODO: uri should be a reference to /files API; revisit when /files is implemented uri: str | None = None - metadata: Dict[str, Any] + metadata: dict[str, Any] JobHandler = Callable[ @@ -46,7 +47,7 @@ JobHandler = Callable[ ] -LogMessage: TypeAlias = Tuple[datetime, str] +LogMessage: TypeAlias = tuple[datetime, str] _COMPLETED_STATUSES = {JobStatus.completed, JobStatus.failed} @@ -60,7 +61,7 @@ class Job: self._handler = handler self._artifacts: list[JobArtifact] = [] self._logs: list[LogMessage] = [] - self._state_transitions: list[Tuple[datetime, JobStatus]] = [(datetime.now(timezone.utc), JobStatus.new)] + self._state_transitions: list[tuple[datetime, JobStatus]] = [(datetime.now(timezone.utc), JobStatus.new)] @property def handler(self) -> JobHandler: diff --git a/llama_stack/providers/utils/scoring/aggregation_utils.py b/llama_stack/providers/utils/scoring/aggregation_utils.py index 7254c9433..cff9a112f 100644 --- a/llama_stack/providers/utils/scoring/aggregation_utils.py +++ b/llama_stack/providers/utils/scoring/aggregation_utils.py @@ -4,13 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import statistics -from typing import Any, Dict, List +from typing import Any from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring_functions import AggregationFunctionType -def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: +def aggregate_accuracy(scoring_results: list[ScoringResultRow]) -> dict[str, Any]: num_correct = sum(result["score"] for result in scoring_results) avg_score = num_correct / len(scoring_results) @@ -21,14 +21,14 @@ def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any } -def aggregate_average(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: +def aggregate_average(scoring_results: list[ScoringResultRow]) -> dict[str, Any]: return { "average": sum(result["score"] for result in scoring_results if result["score"] is not None) / len([_ for _ in scoring_results if _["score"] is not None]), } -def aggregate_weighted_average(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: +def aggregate_weighted_average(scoring_results: list[ScoringResultRow]) -> dict[str, Any]: return { "weighted_average": sum( result["score"] * result["weight"] @@ -40,14 +40,14 @@ def aggregate_weighted_average(scoring_results: List[ScoringResultRow]) -> Dict[ def aggregate_categorical_count( - scoring_results: List[ScoringResultRow], -) -> Dict[str, Any]: + scoring_results: list[ScoringResultRow], +) -> dict[str, Any]: scores = [str(r["score"]) for r in scoring_results] unique_scores = sorted(set(scores)) return {"categorical_count": {s: scores.count(s) for s in unique_scores}} -def aggregate_median(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: +def aggregate_median(scoring_results: list[ScoringResultRow]) -> dict[str, Any]: scores = [r["score"] for r in scoring_results if r["score"] is not None] median = statistics.median(scores) if scores else None return {"median": median} @@ -64,8 +64,8 @@ AGGREGATION_FUNCTIONS = { def aggregate_metrics( - scoring_results: List[ScoringResultRow], metrics: List[AggregationFunctionType] -) -> Dict[str, Any]: + scoring_results: list[ScoringResultRow], metrics: list[AggregationFunctionType] +) -> dict[str, Any]: agg_results = {} for metric in metrics: if metric not in AGGREGATION_FUNCTIONS: diff --git a/llama_stack/providers/utils/scoring/base_scoring_fn.py b/llama_stack/providers/utils/scoring/base_scoring_fn.py index 834deb7e1..2fae177b7 100644 --- a/llama_stack/providers/utils/scoring/base_scoring_fn.py +++ b/llama_stack/providers/utils/scoring/base_scoring_fn.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any from llama_stack.apis.scoring import ScoringFnParams, ScoringResultRow from llama_stack.apis.scoring_functions import ScoringFn @@ -28,28 +28,28 @@ class BaseScoringFn(ABC): @abstractmethod async def score_row( self, - input_row: Dict[str, Any], - scoring_fn_identifier: Optional[str] = None, - scoring_params: Optional[ScoringFnParams] = None, + input_row: dict[str, Any], + scoring_fn_identifier: str | None = None, + scoring_params: ScoringFnParams | None = None, ) -> ScoringResultRow: raise NotImplementedError() @abstractmethod async def aggregate( self, - scoring_results: List[ScoringResultRow], - scoring_fn_identifier: Optional[str] = None, - scoring_params: Optional[ScoringFnParams] = None, - ) -> Dict[str, Any]: + scoring_results: list[ScoringResultRow], + scoring_fn_identifier: str | None = None, + scoring_params: ScoringFnParams | None = None, + ) -> dict[str, Any]: raise NotImplementedError() @abstractmethod async def score( self, - input_rows: List[Dict[str, Any]], - scoring_fn_identifier: Optional[str] = None, - scoring_params: Optional[ScoringFnParams] = None, - ) -> List[ScoringResultRow]: + input_rows: list[dict[str, Any]], + scoring_fn_identifier: str | None = None, + scoring_params: ScoringFnParams | None = None, + ) -> list[ScoringResultRow]: raise NotImplementedError() @@ -65,7 +65,7 @@ class RegisteredBaseScoringFn(BaseScoringFn): def __str__(self) -> str: return self.__class__.__name__ - def get_supported_scoring_fn_defs(self) -> List[ScoringFn]: + def get_supported_scoring_fn_defs(self) -> list[ScoringFn]: return list(self.supported_fn_defs_registry.values()) def register_scoring_fn_def(self, scoring_fn: ScoringFn) -> None: @@ -81,18 +81,18 @@ class RegisteredBaseScoringFn(BaseScoringFn): @abstractmethod async def score_row( self, - input_row: Dict[str, Any], - scoring_fn_identifier: Optional[str] = None, - scoring_params: Optional[ScoringFnParams] = None, + input_row: dict[str, Any], + scoring_fn_identifier: str | None = None, + scoring_params: ScoringFnParams | None = None, ) -> ScoringResultRow: raise NotImplementedError() async def aggregate( self, - scoring_results: List[ScoringResultRow], - scoring_fn_identifier: Optional[str] = None, - scoring_params: Optional[ScoringFnParams] = None, - ) -> Dict[str, Any]: + scoring_results: list[ScoringResultRow], + scoring_fn_identifier: str | None = None, + scoring_params: ScoringFnParams | None = None, + ) -> dict[str, Any]: params = self.supported_fn_defs_registry[scoring_fn_identifier].params if scoring_params is not None: if params is None: @@ -107,8 +107,8 @@ class RegisteredBaseScoringFn(BaseScoringFn): async def score( self, - input_rows: List[Dict[str, Any]], - scoring_fn_identifier: Optional[str] = None, - scoring_params: Optional[ScoringFnParams] = None, - ) -> List[ScoringResultRow]: + input_rows: list[dict[str, Any]], + scoring_fn_identifier: str | None = None, + scoring_params: ScoringFnParams | None = None, + ) -> list[ScoringResultRow]: return [await self.score_row(input_row, scoring_fn_identifier, scoring_params) for input_row in input_rows] diff --git a/llama_stack/providers/utils/scoring/basic_scoring_utils.py b/llama_stack/providers/utils/scoring/basic_scoring_utils.py index 91abfdb2e..7372a521c 100644 --- a/llama_stack/providers/utils/scoring/basic_scoring_utils.py +++ b/llama_stack/providers/utils/scoring/basic_scoring_utils.py @@ -5,8 +5,8 @@ # the root directory of this source tree. import contextlib import signal +from collections.abc import Iterator from types import FrameType -from typing import Iterator, Optional class TimeoutError(Exception): @@ -15,7 +15,7 @@ class TimeoutError(Exception): @contextlib.contextmanager def time_limit(seconds: float) -> Iterator[None]: - def signal_handler(signum: int, frame: Optional[FrameType]) -> None: + def signal_handler(signum: int, frame: FrameType | None) -> None: raise TimeoutError("Timed out!") signal.setitimer(signal.ITIMER_REAL, seconds) diff --git a/llama_stack/providers/utils/telemetry/dataset_mixin.py b/llama_stack/providers/utils/telemetry/dataset_mixin.py index 34c612133..fe729a244 100644 --- a/llama_stack/providers/utils/telemetry/dataset_mixin.py +++ b/llama_stack/providers/utils/telemetry/dataset_mixin.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List, Optional from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.telemetry import QueryCondition, QuerySpansResponse, Span @@ -17,10 +16,10 @@ class TelemetryDatasetMixin: async def save_spans_to_dataset( self, - attribute_filters: List[QueryCondition], - attributes_to_save: List[str], + attribute_filters: list[QueryCondition], + attributes_to_save: list[str], dataset_id: str, - max_depth: Optional[int] = None, + max_depth: int | None = None, ) -> None: if self.datasetio_api is None: raise RuntimeError("DatasetIO API not available") @@ -48,9 +47,9 @@ class TelemetryDatasetMixin: async def query_spans( self, - attribute_filters: List[QueryCondition], - attributes_to_return: List[str], - max_depth: Optional[int] = None, + attribute_filters: list[QueryCondition], + attributes_to_return: list[str], + max_depth: int | None = None, ) -> QuerySpansResponse: traces = await self.query_traces(attribute_filters=attribute_filters) spans = [] diff --git a/llama_stack/providers/utils/telemetry/sqlite_trace_store.py b/llama_stack/providers/utils/telemetry/sqlite_trace_store.py index 3248f3fa7..af1145fe7 100644 --- a/llama_stack/providers/utils/telemetry/sqlite_trace_store.py +++ b/llama_stack/providers/utils/telemetry/sqlite_trace_store.py @@ -6,7 +6,7 @@ import json from datetime import datetime -from typing import Dict, List, Optional, Protocol +from typing import Protocol import aiosqlite @@ -16,18 +16,18 @@ from llama_stack.apis.telemetry import QueryCondition, Span, SpanWithStatus, Tra class TraceStore(Protocol): async def query_traces( self, - attribute_filters: Optional[List[QueryCondition]] = None, - limit: Optional[int] = 100, - offset: Optional[int] = 0, - order_by: Optional[List[str]] = None, - ) -> List[Trace]: ... + attribute_filters: list[QueryCondition] | None = None, + limit: int | None = 100, + offset: int | None = 0, + order_by: list[str] | None = None, + ) -> list[Trace]: ... async def get_span_tree( self, span_id: str, - attributes_to_return: Optional[List[str]] = None, - max_depth: Optional[int] = None, - ) -> Dict[str, SpanWithStatus]: ... + attributes_to_return: list[str] | None = None, + max_depth: int | None = None, + ) -> dict[str, SpanWithStatus]: ... class SQLiteTraceStore(TraceStore): @@ -36,11 +36,11 @@ class SQLiteTraceStore(TraceStore): async def query_traces( self, - attribute_filters: Optional[List[QueryCondition]] = None, - limit: Optional[int] = 100, - offset: Optional[int] = 0, - order_by: Optional[List[str]] = None, - ) -> List[Trace]: + attribute_filters: list[QueryCondition] | None = None, + limit: int | None = 100, + offset: int | None = 0, + order_by: list[str] | None = None, + ) -> list[Trace]: def build_where_clause() -> tuple[str, list]: if not attribute_filters: return "", [] @@ -112,9 +112,9 @@ class SQLiteTraceStore(TraceStore): async def get_span_tree( self, span_id: str, - attributes_to_return: Optional[List[str]] = None, - max_depth: Optional[int] = None, - ) -> Dict[str, SpanWithStatus]: + attributes_to_return: list[str] | None = None, + max_depth: int | None = None, + ) -> dict[str, SpanWithStatus]: # Build the attributes selection attributes_select = "s.attributes" if attributes_to_return: diff --git a/llama_stack/providers/utils/telemetry/trace_protocol.py b/llama_stack/providers/utils/telemetry/trace_protocol.py index 525ade74d..eb6d8b331 100644 --- a/llama_stack/providers/utils/telemetry/trace_protocol.py +++ b/llama_stack/providers/utils/telemetry/trace_protocol.py @@ -7,8 +7,9 @@ import asyncio import inspect import json +from collections.abc import AsyncGenerator, Callable from functools import wraps -from typing import Any, AsyncGenerator, Callable, Type, TypeVar +from typing import Any, TypeVar from pydantic import BaseModel @@ -25,13 +26,13 @@ def _prepare_for_json(value: Any) -> str: """Serialize a single value into JSON-compatible format.""" if value is None: return "" - elif isinstance(value, (str, int, float, bool)): + elif isinstance(value, str | int | float | bool): return value elif hasattr(value, "_name_"): return value._name_ elif isinstance(value, BaseModel): return json.loads(value.model_dump_json()) - elif isinstance(value, (list, tuple, set)): + elif isinstance(value, list | tuple | set): return [_prepare_for_json(item) for item in value] elif isinstance(value, dict): return {str(k): _prepare_for_json(v) for k, v in value.items()} @@ -43,7 +44,7 @@ def _prepare_for_json(value: Any) -> str: return str(value) -def trace_protocol(cls: Type[T]) -> Type[T]: +def trace_protocol(cls: type[T]) -> type[T]: """ A class decorator that automatically traces all methods in a protocol/base class and its inheriting classes. diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py index 3d5c717d6..0f4fdd0d8 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -10,9 +10,10 @@ import logging import queue import random import threading +from collections.abc import Callable from datetime import datetime, timezone from functools import wraps -from typing import Any, Callable, Dict, List, Optional +from typing import Any from llama_stack.apis.telemetry import ( LogSeverity, @@ -106,13 +107,13 @@ class BackgroundLogger: class TraceContext: - spans: List[Span] = [] + spans: list[Span] = [] def __init__(self, logger: BackgroundLogger, trace_id: str): self.logger = logger self.trace_id = trace_id - def push_span(self, name: str, attributes: Dict[str, Any] = None) -> Span: + def push_span(self, name: str, attributes: dict[str, Any] = None) -> Span: current_span = self.get_current_span() span = Span( span_id=generate_span_id(), @@ -168,7 +169,7 @@ def setup_logger(api: Telemetry, level: int = logging.INFO): root_logger.addHandler(TelemetryHandler()) -async def start_trace(name: str, attributes: Dict[str, Any] = None) -> TraceContext: +async def start_trace(name: str, attributes: dict[str, Any] = None) -> TraceContext: global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER if BACKGROUND_LOGGER is None: @@ -246,7 +247,7 @@ class TelemetryHandler(logging.Handler): class SpanContextManager: - def __init__(self, name: str, attributes: Dict[str, Any] = None): + def __init__(self, name: str, attributes: dict[str, Any] = None): self.name = name self.attributes = attributes self.span = None @@ -316,11 +317,11 @@ class SpanContextManager: return wrapper -def span(name: str, attributes: Dict[str, Any] = None): +def span(name: str, attributes: dict[str, Any] = None): return SpanContextManager(name, attributes) -def get_current_span() -> Optional[Span]: +def get_current_span() -> Span | None: global CURRENT_TRACE_CONTEXT if CURRENT_TRACE_CONTEXT is None: logger.debug("No trace context to get current span") diff --git a/llama_stack/schema_utils.py b/llama_stack/schema_utils.py index 8143f1224..694de333e 100644 --- a/llama_stack/schema_utils.py +++ b/llama_stack/schema_utils.py @@ -4,37 +4,38 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Callable, List, Optional, TypeVar +from typing import Any, TypeVar from .strong_typing.schema import json_schema_type, register_schema # noqa: F401 @dataclass class WebMethod: - route: Optional[str] = None + route: str | None = None public: bool = False - request_examples: Optional[List[Any]] = None - response_examples: Optional[List[Any]] = None - method: Optional[str] = None - raw_bytes_request_body: Optional[bool] = False + request_examples: list[Any] | None = None + response_examples: list[Any] | None = None + method: str | None = None + raw_bytes_request_body: bool | None = False # A descriptive name of the corresponding span created by tracing - descriptive_name: Optional[str] = None - experimental: Optional[bool] = False + descriptive_name: str | None = None + experimental: bool | None = False T = TypeVar("T", bound=Callable[..., Any]) def webmethod( - route: Optional[str] = None, - method: Optional[str] = None, - public: Optional[bool] = False, - request_examples: Optional[List[Any]] = None, - response_examples: Optional[List[Any]] = None, - raw_bytes_request_body: Optional[bool] = False, - descriptive_name: Optional[str] = None, - experimental: Optional[bool] = False, + route: str | None = None, + method: str | None = None, + public: bool | None = False, + request_examples: list[Any] | None = None, + response_examples: list[Any] | None = None, + raw_bytes_request_body: bool | None = False, + descriptive_name: str | None = None, + experimental: bool | None = False, ) -> Callable[[T], T]: """ Decorator that supplies additional metadata to an endpoint operation function. diff --git a/llama_stack/strong_typing/docstring.py b/llama_stack/strong_typing/docstring.py index b038d1024..497c9ea82 100644 --- a/llama_stack/strong_typing/docstring.py +++ b/llama_stack/strong_typing/docstring.py @@ -11,6 +11,7 @@ Type-safe data interchange for Python data classes. """ import builtins +import collections.abc import dataclasses import inspect import re @@ -171,6 +172,13 @@ class SupportsDoc(Protocol): __doc__: Optional[str] +def _maybe_unwrap_async_iterator(t): + origin_type = typing.get_origin(t) + if origin_type is collections.abc.AsyncIterator: + return typing.get_args(t)[0] + return t + + def parse_type(typ: SupportsDoc) -> Docstring: """ Parse the docstring of a type into its components. @@ -178,6 +186,8 @@ def parse_type(typ: SupportsDoc) -> Docstring: :param typ: The type whose documentation string to parse. :returns: Components of the documentation string. """ + # Use docstring from the iterator origin type for streaming apis + typ = _maybe_unwrap_async_iterator(typ) doc = get_docstring(typ) if doc is None: diff --git a/llama_stack/strong_typing/schema.py b/llama_stack/strong_typing/schema.py index 1427c22e6..82baddc86 100644 --- a/llama_stack/strong_typing/schema.py +++ b/llama_stack/strong_typing/schema.py @@ -10,6 +10,7 @@ Type-safe data interchange for Python data classes. :see: https://github.com/hunyadi/strong_typing """ +import collections.abc import dataclasses import datetime import decimal @@ -487,6 +488,9 @@ class JsonSchemaGenerator: elif origin_type is type: (concrete_type,) = typing.get_args(typ) # unpack single tuple element return {"const": self.type_to_schema(concrete_type, force_expand=True)} + elif origin_type is collections.abc.AsyncIterator: + (concrete_type,) = typing.get_args(typ) + return self.type_to_schema(concrete_type) # dictionary of class attributes members = dict(inspect.getmembers(typ, lambda a: not inspect.isroutine(a))) diff --git a/llama_stack/templates/dev/dev.py b/llama_stack/templates/dev/dev.py index 69924acbe..af636d891 100644 --- a/llama_stack/templates/dev/dev.py +++ b/llama_stack/templates/dev/dev.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List, Tuple from llama_stack.apis.models.models import ModelType from llama_stack.distribution.datatypes import ( @@ -50,7 +49,7 @@ from llama_stack.templates.template import ( ) -def get_inference_providers() -> Tuple[List[Provider], List[ModelInput]]: +def get_inference_providers() -> tuple[list[Provider], list[ModelInput]]: # in this template, we allow each API key to be optional providers = [ ( diff --git a/llama_stack/templates/llama_api/llama_api.py b/llama_stack/templates/llama_api/llama_api.py index 5f55a5b68..20ee6d370 100644 --- a/llama_stack/templates/llama_api/llama_api.py +++ b/llama_stack/templates/llama_api/llama_api.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List, Tuple from llama_stack.apis.models.models import ModelType from llama_stack.distribution.datatypes import ( @@ -36,7 +35,7 @@ from llama_stack.templates.template import ( ) -def get_inference_providers() -> Tuple[List[Provider], List[ModelInput]]: +def get_inference_providers() -> tuple[list[Provider], list[ModelInput]]: # in this template, we allow each API key to be optional providers = [ ( diff --git a/llama_stack/templates/open-benchmark/open_benchmark.py b/llama_stack/templates/open-benchmark/open_benchmark.py index a6a906c6f..9f4943558 100644 --- a/llama_stack/templates/open-benchmark/open_benchmark.py +++ b/llama_stack/templates/open-benchmark/open_benchmark.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Dict, List, Tuple from llama_stack.apis.datasets import DatasetPurpose, URIDataSource from llama_stack.apis.models.models import ModelType @@ -36,7 +35,7 @@ from llama_stack.templates.template import ( ) -def get_inference_providers() -> Tuple[List[Provider], Dict[str, List[ProviderModelEntry]]]: +def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderModelEntry]]]: # in this template, we allow each API key to be optional providers = [ ( diff --git a/llama_stack/templates/template.py b/llama_stack/templates/template.py index 92b1b534d..e4d28d904 100644 --- a/llama_stack/templates/template.py +++ b/llama_stack/templates/template.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from pathlib import Path -from typing import Dict, List, Literal, Optional, Tuple +from typing import Literal import jinja2 import yaml @@ -32,8 +32,8 @@ from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig def get_model_registry( - available_models: Dict[str, List[ProviderModelEntry]], -) -> List[ModelInput]: + available_models: dict[str, list[ProviderModelEntry]], +) -> list[ModelInput]: models = [] for provider_id, entries in available_models.items(): for entry in entries: @@ -57,18 +57,18 @@ class DefaultModel(BaseModel): class RunConfigSettings(BaseModel): - provider_overrides: Dict[str, List[Provider]] = Field(default_factory=dict) - default_models: Optional[List[ModelInput]] = None - default_shields: Optional[List[ShieldInput]] = None - default_tool_groups: Optional[List[ToolGroupInput]] = None - default_datasets: Optional[List[DatasetInput]] = None - default_benchmarks: Optional[List[BenchmarkInput]] = None + provider_overrides: dict[str, list[Provider]] = Field(default_factory=dict) + default_models: list[ModelInput] | None = None + default_shields: list[ShieldInput] | None = None + default_tool_groups: list[ToolGroupInput] | None = None + default_datasets: list[DatasetInput] | None = None + default_benchmarks: list[BenchmarkInput] | None = None def run_config( self, name: str, - providers: Dict[str, List[str]], - container_image: Optional[str] = None, + providers: dict[str, list[str]], + container_image: str | None = None, ) -> StackRunConfig: provider_registry = get_provider_registry() @@ -135,15 +135,15 @@ class DistributionTemplate(BaseModel): description: str distro_type: Literal["self_hosted", "remote_hosted", "ondevice"] - providers: Dict[str, List[str]] - run_configs: Dict[str, RunConfigSettings] - template_path: Optional[Path] = None + providers: dict[str, list[str]] + run_configs: dict[str, RunConfigSettings] + template_path: Path | None = None # Optional configuration - run_config_env_vars: Optional[Dict[str, Tuple[str, str]]] = None - container_image: Optional[str] = None + run_config_env_vars: dict[str, tuple[str, str]] | None = None + container_image: str | None = None - available_models_by_provider: Optional[Dict[str, List[ProviderModelEntry]]] = None + available_models_by_provider: dict[str, list[ProviderModelEntry]] | None = None def build_config(self) -> BuildConfig: return BuildConfig( diff --git a/llama_stack/templates/verification/verification.py b/llama_stack/templates/verification/verification.py index e6f74aad8..ca9210e85 100644 --- a/llama_stack/templates/verification/verification.py +++ b/llama_stack/templates/verification/verification.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Dict, List, Tuple from llama_stack.apis.models.models import ModelType from llama_stack.distribution.datatypes import ( @@ -51,7 +50,7 @@ from llama_stack.templates.template import ( ) -def get_inference_providers() -> Tuple[List[Provider], Dict[str, List[ProviderModelEntry]]]: +def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderModelEntry]]]: # in this template, we allow each API key to be optional providers = [ ( diff --git a/pyproject.toml b/pyproject.toml index f1f65be90..05fe86ec2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -144,6 +144,7 @@ exclude = [ [tool.ruff.lint] select = [ + "UP", # pyupgrade "B", # flake8-bugbear "B9", # flake8-bugbear subset "C", # comprehensions diff --git a/scripts/distro_codegen.py b/scripts/distro_codegen.py index a65e2c80d..30f533883 100755 --- a/scripts/distro_codegen.py +++ b/scripts/distro_codegen.py @@ -10,9 +10,9 @@ import importlib import json import subprocess import sys +from collections.abc import Iterable from functools import partial from pathlib import Path -from typing import Iterable from rich.progress import Progress, SpinnerColumn, TextColumn diff --git a/tests/integration/agents/test_agents.py b/tests/integration/agents/test_agents.py index f884d440d..bd307ab19 100644 --- a/tests/integration/agents/test_agents.py +++ b/tests/integration/agents/test_agents.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from uuid import uuid4 import pytest @@ -37,7 +37,7 @@ def get_boiling_point(liquid_name: str, celcius: bool = True) -> int: return -1 -def get_boiling_point_with_metadata(liquid_name: str, celcius: bool = True) -> Dict[str, Any]: +def get_boiling_point_with_metadata(liquid_name: str, celcius: bool = True) -> dict[str, Any]: """ Returns the boiling point of a liquid in Celcius or Fahrenheit diff --git a/tests/integration/fixtures/recordable_mock.py b/tests/integration/fixtures/recordable_mock.py index 632d5b3ef..140831cfe 100644 --- a/tests/integration/fixtures/recordable_mock.py +++ b/tests/integration/fixtures/recordable_mock.py @@ -24,7 +24,7 @@ class RecordableMock: # Load existing cache if available and not recording if self.json_path.exists(): try: - with open(self.json_path, "r") as f: + with open(self.json_path) as f: self.cache = json.load(f) except Exception as e: print(f"Error loading cache from {self.json_path}: {e}") diff --git a/tests/integration/post_training/test_post_training.py b/tests/integration/post_training/test_post_training.py index 3e22bc5a7..648ace9d6 100644 --- a/tests/integration/post_training/test_post_training.py +++ b/tests/integration/post_training/test_post_training.py @@ -3,7 +3,6 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List import pytest @@ -77,7 +76,7 @@ class TestPostTraining: async def test_get_training_jobs(self, post_training_stack): post_training_impl = post_training_stack jobs_list = await post_training_impl.get_training_jobs() - assert isinstance(jobs_list, List) + assert isinstance(jobs_list, list) assert jobs_list[0].job_uuid == "1234" @pytest.mark.asyncio diff --git a/tests/integration/test_cases/test_case.py b/tests/integration/test_cases/test_case.py index 2a3c73310..fc3bf97c8 100644 --- a/tests/integration/test_cases/test_case.py +++ b/tests/integration/test_cases/test_case.py @@ -20,7 +20,7 @@ class TestCase: # loading all test cases if self._jsonblob == {}: for api in self._apis: - with open(pathlib.Path(__file__).parent / f"{api}.json", "r") as f: + with open(pathlib.Path(__file__).parent / f"{api}.json") as f: coloned = api.replace("/", ":") try: loaded = json.load(f) diff --git a/tests/unit/cli/test_stack_config.py b/tests/unit/cli/test_stack_config.py index 312f58c09..d2b6f4b08 100644 --- a/tests/unit/cli/test_stack_config.py +++ b/tests/unit/cli/test_stack_config.py @@ -18,11 +18,11 @@ from llama_stack.distribution.configure import ( @pytest.fixture def up_to_date_config(): return yaml.safe_load( - """ - version: {version} + f""" + version: {LLAMA_STACK_RUN_CONFIG_VERSION} image_name: foo apis_to_serve: [] - built_at: {built_at} + built_at: {datetime.now().isoformat()} providers: inference: - provider_id: provider1 @@ -42,16 +42,16 @@ def up_to_date_config(): - provider_id: provider1 provider_type: inline::meta-reference config: {{}} - """.format(version=LLAMA_STACK_RUN_CONFIG_VERSION, built_at=datetime.now().isoformat()) + """ ) @pytest.fixture def old_config(): return yaml.safe_load( - """ + f""" image_name: foo - built_at: {built_at} + built_at: {datetime.now().isoformat()} apis_to_serve: [] routing_table: inference: @@ -82,7 +82,7 @@ def old_config(): telemetry: provider_type: noop config: {{}} - """.format(built_at=datetime.now().isoformat()) + """ ) diff --git a/tests/unit/distribution/test_distribution.py b/tests/unit/distribution/test_distribution.py index a4daffb82..ae24602d7 100644 --- a/tests/unit/distribution/test_distribution.py +++ b/tests/unit/distribution/test_distribution.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict +from typing import Any from unittest.mock import patch import pytest @@ -23,7 +23,7 @@ class SampleConfig(BaseModel): ) @classmethod - def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]: + def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]: return { "foo": "baz", } diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index b3172cad4..a2e3b64c2 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -10,7 +10,7 @@ import logging import threading import time from http.server import BaseHTTPRequestHandler, HTTPServer -from typing import Any, Dict +from typing import Any from unittest.mock import AsyncMock, patch import pytest @@ -55,7 +55,7 @@ from llama_stack.providers.remote.inference.vllm.vllm import ( class MockInferenceAdapterWithSleep: - def __init__(self, sleep_time: int, response: Dict[str, Any]): + def __init__(self, sleep_time: int, response: dict[str, Any]): self.httpd = None class DelayedRequestHandler(BaseHTTPRequestHandler): diff --git a/tests/unit/server/test_access_control.py b/tests/unit/server/test_access_control.py index ab0feb1a9..7d92a5cf5 100644 --- a/tests/unit/server/test_access_control.py +++ b/tests/unit/server/test_access_control.py @@ -22,7 +22,7 @@ from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl class AsyncMock(MagicMock): async def __call__(self, *args, **kwargs): - return super(AsyncMock, self).__call__(*args, **kwargs) + return super().__call__(*args, **kwargs) def _return_model(model): diff --git a/tests/unit/server/test_resolver.py b/tests/unit/server/test_resolver.py index fcf0b3945..3af9535a0 100644 --- a/tests/unit/server/test_resolver.py +++ b/tests/unit/server/test_resolver.py @@ -6,7 +6,7 @@ import inspect import sys -from typing import Any, Dict, Protocol +from typing import Any, Protocol from unittest.mock import AsyncMock, MagicMock import pytest @@ -48,14 +48,14 @@ class SampleConfig(BaseModel): ) @classmethod - def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]: + def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]: return { "foo": "baz", } class SampleImpl: - def __init__(self, config: SampleConfig, deps: Dict[Api, Any], provider_spec: ProviderSpec = None): + def __init__(self, config: SampleConfig, deps: dict[Api, Any], provider_spec: ProviderSpec = None): self.__provider_id__ = "test_provider" self.__provider_spec__ = provider_spec self.__provider_config__ = config diff --git a/tests/verifications/generate_report.py b/tests/verifications/generate_report.py index 54d46ef41..67ef14e90 100755 --- a/tests/verifications/generate_report.py +++ b/tests/verifications/generate_report.py @@ -50,7 +50,7 @@ import subprocess import time from collections import defaultdict from pathlib import Path -from typing import Any, DefaultDict, Dict, Set, Tuple +from typing import Any from tests.verifications.openai_api.fixtures.fixtures import _load_all_verification_configs @@ -106,7 +106,7 @@ def run_tests(provider, keyword=None): # Check if the JSON file was created if temp_json_file.exists(): - with open(temp_json_file, "r") as f: + with open(temp_json_file) as f: test_results = json.load(f) test_results["run_timestamp"] = timestamp @@ -141,7 +141,7 @@ def run_multiple_tests(providers_to_run: list[str], keyword: str | None): def parse_results( result_file, -) -> Tuple[DefaultDict[str, DefaultDict[str, Dict[str, bool]]], DefaultDict[str, Set[str]], Set[str], str]: +) -> tuple[defaultdict[str, defaultdict[str, dict[str, bool]]], defaultdict[str, set[str]], set[str], str]: """Parse a single test results file. Returns: @@ -156,13 +156,13 @@ def parse_results( # Return empty defaultdicts/set matching the type hint return defaultdict(lambda: defaultdict(dict)), defaultdict(set), set(), "" - with open(result_file, "r") as f: + with open(result_file) as f: results = json.load(f) # Initialize results dictionary with specific types - parsed_results: DefaultDict[str, DefaultDict[str, Dict[str, bool]]] = defaultdict(lambda: defaultdict(dict)) - providers_in_file: DefaultDict[str, Set[str]] = defaultdict(set) - tests_in_file: Set[str] = set() + parsed_results: defaultdict[str, defaultdict[str, dict[str, bool]]] = defaultdict(lambda: defaultdict(dict)) + providers_in_file: defaultdict[str, set[str]] = defaultdict(set) + tests_in_file: set[str] = set() # Extract provider from filename (e.g., "openai.json" -> "openai") provider: str = result_file.stem @@ -248,10 +248,10 @@ def parse_results( def generate_report( - results_dict: Dict[str, Any], - providers: Dict[str, Set[str]], - all_tests: Set[str], - provider_timestamps: Dict[str, str], + results_dict: dict[str, Any], + providers: dict[str, set[str]], + all_tests: set[str], + provider_timestamps: dict[str, str], output_file=None, ): """Generate the markdown report. @@ -277,8 +277,8 @@ def generate_report( sorted_tests = sorted(all_tests) # Calculate counts for each base test name - base_test_case_counts: DefaultDict[str, int] = defaultdict(int) - base_test_name_map: Dict[str, str] = {} + base_test_case_counts: defaultdict[str, int] = defaultdict(int) + base_test_name_map: dict[str, str] = {} for test_name in sorted_tests: match = re.match(r"^(.*?)( \([^)]+\))?$", test_name) if match: diff --git a/tests/verifications/openai_api/conftest.py b/tests/verifications/openai_api/conftest.py index 7b4c92f1c..e4f7f27a0 100644 --- a/tests/verifications/openai_api/conftest.py +++ b/tests/verifications/openai_api/conftest.py @@ -18,7 +18,7 @@ def pytest_generate_tests(metafunc): try: config_data = _load_all_verification_configs() - except (FileNotFoundError, IOError) as e: + except (OSError, FileNotFoundError) as e: print(f"ERROR loading verification configs: {e}") config_data = {"providers": {}} diff --git a/tests/verifications/openai_api/fixtures/fixtures.py b/tests/verifications/openai_api/fixtures/fixtures.py index 2ea73cf26..a7328e5f6 100644 --- a/tests/verifications/openai_api/fixtures/fixtures.py +++ b/tests/verifications/openai_api/fixtures/fixtures.py @@ -33,7 +33,7 @@ def _load_all_verification_configs(): for config_path in yaml_files: provider_name = config_path.stem try: - with open(config_path, "r") as f: + with open(config_path) as f: provider_config = yaml.safe_load(f) if provider_config: all_provider_configs[provider_name] = provider_config @@ -41,7 +41,7 @@ def _load_all_verification_configs(): # Log warning if possible, or just skip empty files silently print(f"Warning: Config file {config_path} is empty or invalid.") except Exception as e: - raise IOError(f"Error loading config file {config_path}: {e}") from e + raise OSError(f"Error loading config file {config_path}: {e}") from e return {"providers": all_provider_configs} @@ -49,7 +49,7 @@ def _load_all_verification_configs(): def case_id_generator(case): """Generate a test ID from the case's 'case_id' field, or use a default.""" case_id = case.get("case_id") - if isinstance(case_id, (str, int)): + if isinstance(case_id, str | int): return re.sub(r"\\W|^(?=\\d)", "_", str(case_id)) return None @@ -77,7 +77,7 @@ def verification_config(): """Pytest fixture to provide the loaded verification config.""" try: return _load_all_verification_configs() - except (FileNotFoundError, IOError) as e: + except (OSError, FileNotFoundError) as e: pytest.fail(str(e)) # Fail test collection if config loading fails diff --git a/tests/verifications/openai_api/fixtures/load.py b/tests/verifications/openai_api/fixtures/load.py index 98580b2a1..0184ee146 100644 --- a/tests/verifications/openai_api/fixtures/load.py +++ b/tests/verifications/openai_api/fixtures/load.py @@ -12,5 +12,5 @@ import yaml def load_test_cases(name: str): fixture_dir = Path(__file__).parent / "test_cases" yaml_path = fixture_dir / f"{name}.yaml" - with open(yaml_path, "r") as f: + with open(yaml_path) as f: return yaml.safe_load(f)