chore: enable pyupgrade fixes (#1806)

# What does this PR do?

The goal of this PR is code base modernization.

Schema reflection code needed a minor adjustment to handle UnionTypes
and collections.abc.AsyncIterator. (Both are preferred for latest Python
releases.)

Note to reviewers: almost all changes here are automatically generated
by pyupgrade. Some additional unused imports were cleaned up. The only
change worth of note can be found under `docs/openapi_generator` and
`llama_stack/strong_typing/schema.py` where reflection code was updated
to deal with "newer" types.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-05-01 17:23:50 -04:00 committed by GitHub
parent ffe3d0b2cd
commit 9e6561a1ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
319 changed files with 2843 additions and 3033 deletions

View file

@ -6,6 +6,7 @@
import hashlib import hashlib
import ipaddress import ipaddress
import types
import typing import typing
from dataclasses import make_dataclass from dataclasses import make_dataclass
from typing import Any, Dict, Set, Union from typing import Any, Dict, Set, Union
@ -189,7 +190,7 @@ class ContentBuilder:
else: else:
return "application/json" return "application/json"
if typing.get_origin(payload_type) is typing.Union: if typing.get_origin(payload_type) in (typing.Union, types.UnionType):
media_types = [] media_types = []
item_types = [] item_types = []
for x in typing.get_args(payload_type): for x in typing.get_args(payload_type):

View file

@ -4,20 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import AsyncIterator
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import ( from typing import Annotated, Any, Literal, Protocol, runtime_checkable
Annotated,
Any,
AsyncIterator,
Dict,
List,
Literal,
Optional,
Protocol,
Union,
runtime_checkable,
)
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
@ -79,8 +69,8 @@ class StepCommon(BaseModel):
turn_id: str turn_id: str
step_id: str step_id: str
started_at: Optional[datetime] = None started_at: datetime | None = None
completed_at: Optional[datetime] = None completed_at: datetime | None = None
class StepType(Enum): class StepType(Enum):
@ -120,8 +110,8 @@ class ToolExecutionStep(StepCommon):
""" """
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
tool_calls: List[ToolCall] tool_calls: list[ToolCall]
tool_responses: List[ToolResponse] tool_responses: list[ToolResponse]
@json_schema_type @json_schema_type
@ -132,7 +122,7 @@ class ShieldCallStep(StepCommon):
""" """
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
violation: Optional[SafetyViolation] violation: SafetyViolation | None
@json_schema_type @json_schema_type
@ -150,12 +140,7 @@ class MemoryRetrievalStep(StepCommon):
Step = Annotated[ Step = Annotated[
Union[ InferenceStep | ToolExecutionStep | ShieldCallStep | MemoryRetrievalStep,
InferenceStep,
ToolExecutionStep,
ShieldCallStep,
MemoryRetrievalStep,
],
Field(discriminator="step_type"), Field(discriminator="step_type"),
] ]
@ -166,18 +151,13 @@ class Turn(BaseModel):
turn_id: str turn_id: str
session_id: str session_id: str
input_messages: List[ input_messages: list[UserMessage | ToolResponseMessage]
Union[ steps: list[Step]
UserMessage,
ToolResponseMessage,
]
]
steps: List[Step]
output_message: CompletionMessage output_message: CompletionMessage
output_attachments: Optional[List[Attachment]] = Field(default_factory=list) output_attachments: list[Attachment] | None = Field(default_factory=list)
started_at: datetime started_at: datetime
completed_at: Optional[datetime] = None completed_at: datetime | None = None
@json_schema_type @json_schema_type
@ -186,34 +166,31 @@ class Session(BaseModel):
session_id: str session_id: str
session_name: str session_name: str
turns: List[Turn] turns: list[Turn]
started_at: datetime started_at: datetime
class AgentToolGroupWithArgs(BaseModel): class AgentToolGroupWithArgs(BaseModel):
name: str name: str
args: Dict[str, Any] args: dict[str, Any]
AgentToolGroup = Union[ AgentToolGroup = str | AgentToolGroupWithArgs
str,
AgentToolGroupWithArgs,
]
register_schema(AgentToolGroup, name="AgentTool") register_schema(AgentToolGroup, name="AgentTool")
class AgentConfigCommon(BaseModel): class AgentConfigCommon(BaseModel):
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams) sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
input_shields: Optional[List[str]] = Field(default_factory=list) input_shields: list[str] | None = Field(default_factory=list)
output_shields: Optional[List[str]] = Field(default_factory=list) output_shields: list[str] | None = Field(default_factory=list)
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list) toolgroups: list[AgentToolGroup] | None = Field(default_factory=list)
client_tools: Optional[List[ToolDef]] = Field(default_factory=list) client_tools: list[ToolDef] | None = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=None, deprecated="use tool_config instead") tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None, deprecated="use tool_config instead") tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead")
tool_config: Optional[ToolConfig] = Field(default=None) tool_config: ToolConfig | None = Field(default=None)
max_infer_iters: Optional[int] = 10 max_infer_iters: int | None = 10
def model_post_init(self, __context): def model_post_init(self, __context):
if self.tool_config: if self.tool_config:
@ -243,9 +220,9 @@ class AgentConfig(AgentConfigCommon):
model: str model: str
instructions: str instructions: str
name: Optional[str] = None name: str | None = None
enable_session_persistence: Optional[bool] = False enable_session_persistence: bool | None = False
response_format: Optional[ResponseFormat] = None response_format: ResponseFormat | None = None
@json_schema_type @json_schema_type
@ -257,16 +234,16 @@ class Agent(BaseModel):
@json_schema_type @json_schema_type
class ListAgentsResponse(BaseModel): class ListAgentsResponse(BaseModel):
data: List[Agent] data: list[Agent]
@json_schema_type @json_schema_type
class ListAgentSessionsResponse(BaseModel): class ListAgentSessionsResponse(BaseModel):
data: List[Session] data: list[Session]
class AgentConfigOverridablePerTurn(AgentConfigCommon): class AgentConfigOverridablePerTurn(AgentConfigCommon):
instructions: Optional[str] = None instructions: str | None = None
class AgentTurnResponseEventType(Enum): class AgentTurnResponseEventType(Enum):
@ -284,7 +261,7 @@ class AgentTurnResponseStepStartPayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.step_start.value] = AgentTurnResponseEventType.step_start.value event_type: Literal[AgentTurnResponseEventType.step_start.value] = AgentTurnResponseEventType.step_start.value
step_type: StepType step_type: StepType
step_id: str step_id: str
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict) metadata: dict[str, Any] | None = Field(default_factory=dict)
@json_schema_type @json_schema_type
@ -327,14 +304,12 @@ class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
AgentTurnResponseEventPayload = Annotated[ AgentTurnResponseEventPayload = Annotated[
Union[ AgentTurnResponseStepStartPayload
AgentTurnResponseStepStartPayload, | AgentTurnResponseStepProgressPayload
AgentTurnResponseStepProgressPayload, | AgentTurnResponseStepCompletePayload
AgentTurnResponseStepCompletePayload, | AgentTurnResponseTurnStartPayload
AgentTurnResponseTurnStartPayload, | AgentTurnResponseTurnCompletePayload
AgentTurnResponseTurnCompletePayload, | AgentTurnResponseTurnAwaitingInputPayload,
AgentTurnResponseTurnAwaitingInputPayload,
],
Field(discriminator="event_type"), Field(discriminator="event_type"),
] ]
register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload") register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
@ -363,18 +338,13 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
# TODO: figure out how we can simplify this and make why # TODO: figure out how we can simplify this and make why
# ToolResponseMessage needs to be here (it is function call # ToolResponseMessage needs to be here (it is function call
# execution from outside the system) # execution from outside the system)
messages: List[ messages: list[UserMessage | ToolResponseMessage]
Union[
UserMessage,
ToolResponseMessage,
]
]
documents: Optional[List[Document]] = None documents: list[Document] | None = None
toolgroups: Optional[List[AgentToolGroup]] = None toolgroups: list[AgentToolGroup] | None = None
stream: Optional[bool] = False stream: bool | None = False
tool_config: Optional[ToolConfig] = None tool_config: ToolConfig | None = None
@json_schema_type @json_schema_type
@ -382,8 +352,8 @@ class AgentTurnResumeRequest(BaseModel):
agent_id: str agent_id: str
session_id: str session_id: str
turn_id: str turn_id: str
tool_responses: List[ToolResponse] tool_responses: list[ToolResponse]
stream: Optional[bool] = False stream: bool | None = False
@json_schema_type @json_schema_type
@ -429,17 +399,12 @@ class Agents(Protocol):
self, self,
agent_id: str, agent_id: str,
session_id: str, session_id: str,
messages: List[ messages: list[UserMessage | ToolResponseMessage],
Union[ stream: bool | None = False,
UserMessage, documents: list[Document] | None = None,
ToolResponseMessage, toolgroups: list[AgentToolGroup] | None = None,
] tool_config: ToolConfig | None = None,
], ) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
stream: Optional[bool] = False,
documents: Optional[List[Document]] = None,
toolgroups: Optional[List[AgentToolGroup]] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
"""Create a new turn for an agent. """Create a new turn for an agent.
:param agent_id: The ID of the agent to create the turn for. :param agent_id: The ID of the agent to create the turn for.
@ -463,9 +428,9 @@ class Agents(Protocol):
agent_id: str, agent_id: str,
session_id: str, session_id: str,
turn_id: str, turn_id: str,
tool_responses: List[ToolResponse], tool_responses: list[ToolResponse],
stream: Optional[bool] = False, stream: bool | None = False,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
"""Resume an agent turn with executed tool call responses. """Resume an agent turn with executed tool call responses.
When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready. When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.
@ -538,7 +503,7 @@ class Agents(Protocol):
self, self,
session_id: str, session_id: str,
agent_id: str, agent_id: str,
turn_ids: Optional[List[str]] = None, turn_ids: list[str] | None = None,
) -> Session: ) -> Session:
"""Retrieve an agent session by its ID. """Retrieve an agent session by its ID.
@ -623,14 +588,14 @@ class Agents(Protocol):
@webmethod(route="/openai/v1/responses", method="POST") @webmethod(route="/openai/v1/responses", method="POST")
async def create_openai_response( async def create_openai_response(
self, self,
input: Union[str, List[OpenAIResponseInputMessage]], input: str | list[OpenAIResponseInputMessage],
model: str, model: str,
previous_response_id: Optional[str] = None, previous_response_id: str | None = None,
store: Optional[bool] = True, store: bool | None = True,
stream: Optional[bool] = False, stream: bool | None = False,
temperature: Optional[float] = None, temperature: float | None = None,
tools: Optional[List[OpenAIResponseInputTool]] = None, tools: list[OpenAIResponseInputTool] | None = None,
) -> Union[OpenAIResponseObject, AsyncIterator[OpenAIResponseObjectStream]]: ) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
"""Create a new OpenAI response. """Create a new OpenAI response.
:param input: Input message(s) to create the response. :param input: Input message(s) to create the response.

View file

@ -4,10 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import List, Literal, Optional, Union from typing import Annotated, Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.schema_utils import json_schema_type, register_schema from llama_stack.schema_utils import json_schema_type, register_schema
@ -25,7 +24,7 @@ class OpenAIResponseOutputMessageContentOutputText(BaseModel):
OpenAIResponseOutputMessageContent = Annotated[ OpenAIResponseOutputMessageContent = Annotated[
Union[OpenAIResponseOutputMessageContentOutputText,], OpenAIResponseOutputMessageContentOutputText,
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMessageContent") register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMessageContent")
@ -34,7 +33,7 @@ register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMe
@json_schema_type @json_schema_type
class OpenAIResponseOutputMessage(BaseModel): class OpenAIResponseOutputMessage(BaseModel):
id: str id: str
content: List[OpenAIResponseOutputMessageContent] content: list[OpenAIResponseOutputMessageContent]
role: Literal["assistant"] = "assistant" role: Literal["assistant"] = "assistant"
status: str status: str
type: Literal["message"] = "message" type: Literal["message"] = "message"
@ -48,10 +47,7 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
OpenAIResponseOutput = Annotated[ OpenAIResponseOutput = Annotated[
Union[ OpenAIResponseOutputMessage | OpenAIResponseOutputMessageWebSearchToolCall,
OpenAIResponseOutputMessage,
OpenAIResponseOutputMessageWebSearchToolCall,
],
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput") register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
@ -60,18 +56,18 @@ register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
@json_schema_type @json_schema_type
class OpenAIResponseObject(BaseModel): class OpenAIResponseObject(BaseModel):
created_at: int created_at: int
error: Optional[OpenAIResponseError] = None error: OpenAIResponseError | None = None
id: str id: str
model: str model: str
object: Literal["response"] = "response" object: Literal["response"] = "response"
output: List[OpenAIResponseOutput] output: list[OpenAIResponseOutput]
parallel_tool_calls: bool = False parallel_tool_calls: bool = False
previous_response_id: Optional[str] = None previous_response_id: str | None = None
status: str status: str
temperature: Optional[float] = None temperature: float | None = None
top_p: Optional[float] = None top_p: float | None = None
truncation: Optional[str] = None truncation: str | None = None
user: Optional[str] = None user: str | None = None
@json_schema_type @json_schema_type
@ -87,10 +83,7 @@ class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
OpenAIResponseObjectStream = Annotated[ OpenAIResponseObjectStream = Annotated[
Union[ OpenAIResponseObjectStreamResponseCreated | OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseObjectStreamResponseCreated,
OpenAIResponseObjectStreamResponseCompleted,
],
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream") register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream")
@ -107,12 +100,12 @@ class OpenAIResponseInputMessageContentImage(BaseModel):
detail: Literal["low"] | Literal["high"] | Literal["auto"] = "auto" detail: Literal["low"] | Literal["high"] | Literal["auto"] = "auto"
type: Literal["input_image"] = "input_image" type: Literal["input_image"] = "input_image"
# TODO: handle file_id # TODO: handle file_id
image_url: Optional[str] = None image_url: str | None = None
# TODO: handle file content types # TODO: handle file content types
OpenAIResponseInputMessageContent = Annotated[ OpenAIResponseInputMessageContent = Annotated[
Union[OpenAIResponseInputMessageContentText, OpenAIResponseInputMessageContentImage], OpenAIResponseInputMessageContentText | OpenAIResponseInputMessageContentImage,
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent") register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent")
@ -120,21 +113,21 @@ register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMess
@json_schema_type @json_schema_type
class OpenAIResponseInputMessage(BaseModel): class OpenAIResponseInputMessage(BaseModel):
content: Union[str, List[OpenAIResponseInputMessageContent]] content: str | list[OpenAIResponseInputMessageContent]
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"] role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
type: Optional[Literal["message"]] = "message" type: Literal["message"] | None = "message"
@json_schema_type @json_schema_type
class OpenAIResponseInputToolWebSearch(BaseModel): class OpenAIResponseInputToolWebSearch(BaseModel):
type: Literal["web_search"] | Literal["web_search_preview_2025_03_11"] = "web_search" type: Literal["web_search"] | Literal["web_search_preview_2025_03_11"] = "web_search"
# TODO: actually use search_context_size somewhere... # TODO: actually use search_context_size somewhere...
search_context_size: Optional[str] = Field(default="medium", pattern="^low|medium|high$") search_context_size: str | None = Field(default="medium", pattern="^low|medium|high$")
# TODO: add user_location # TODO: add user_location
OpenAIResponseInputTool = Annotated[ OpenAIResponseInputTool = Annotated[
Union[OpenAIResponseInputToolWebSearch,], OpenAIResponseInputToolWebSearch,
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool") register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import List, Optional, Protocol, runtime_checkable from typing import Protocol, runtime_checkable
from llama_stack.apis.common.job_types import Job from llama_stack.apis.common.job_types import Job
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
@ -34,22 +34,22 @@ class BatchInference(Protocol):
async def completion( async def completion(
self, self,
model: str, model: str,
content_batch: List[InterleavedContent], content_batch: list[InterleavedContent],
sampling_params: Optional[SamplingParams] = None, sampling_params: SamplingParams | None = None,
response_format: Optional[ResponseFormat] = None, response_format: ResponseFormat | None = None,
logprobs: Optional[LogProbConfig] = None, logprobs: LogProbConfig | None = None,
) -> Job: ... ) -> Job: ...
@webmethod(route="/batch-inference/chat-completion", method="POST") @webmethod(route="/batch-inference/chat-completion", method="POST")
async def chat_completion( async def chat_completion(
self, self,
model: str, model: str,
messages_batch: List[List[Message]], messages_batch: list[list[Message]],
sampling_params: Optional[SamplingParams] = None, sampling_params: SamplingParams | None = None,
# zero-shot tool definitions as input to the model # zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = None, tools: list[ToolDefinition] | None = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: ToolPromptFormat | None = None,
response_format: Optional[ResponseFormat] = None, response_format: ResponseFormat | None = None,
logprobs: Optional[LogProbConfig] = None, logprobs: LogProbConfig | None = None,
) -> Job: ... ) -> Job: ...

View file

@ -3,7 +3,7 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from typing import Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -13,8 +13,8 @@ from llama_stack.schema_utils import json_schema_type, webmethod
class CommonBenchmarkFields(BaseModel): class CommonBenchmarkFields(BaseModel):
dataset_id: str dataset_id: str
scoring_functions: List[str] scoring_functions: list[str]
metadata: Dict[str, Any] = Field( metadata: dict[str, Any] = Field(
default_factory=dict, default_factory=dict,
description="Metadata for this evaluation task", description="Metadata for this evaluation task",
) )
@ -35,12 +35,12 @@ class Benchmark(CommonBenchmarkFields, Resource):
class BenchmarkInput(CommonBenchmarkFields, BaseModel): class BenchmarkInput(CommonBenchmarkFields, BaseModel):
benchmark_id: str benchmark_id: str
provider_id: Optional[str] = None provider_id: str | None = None
provider_benchmark_id: Optional[str] = None provider_benchmark_id: str | None = None
class ListBenchmarksResponse(BaseModel): class ListBenchmarksResponse(BaseModel):
data: List[Benchmark] data: list[Benchmark]
@runtime_checkable @runtime_checkable
@ -59,8 +59,8 @@ class Benchmarks(Protocol):
self, self,
benchmark_id: str, benchmark_id: str,
dataset_id: str, dataset_id: str,
scoring_functions: List[str], scoring_functions: list[str],
provider_benchmark_id: Optional[str] = None, provider_benchmark_id: str | None = None,
provider_id: Optional[str] = None, provider_id: str | None = None,
metadata: Optional[Dict[str, Any]] = None, metadata: dict[str, Any] | None = None,
) -> None: ... ) -> None: ...

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Annotated, List, Literal, Optional, Union from typing import Annotated, Literal
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
@ -26,9 +26,9 @@ class _URLOrData(BaseModel):
:param data: base64 encoded image data as string :param data: base64 encoded image data as string
""" """
url: Optional[URL] = None url: URL | None = None
# data is a base64 encoded string, hint with contentEncoding=base64 # data is a base64 encoded string, hint with contentEncoding=base64
data: Optional[str] = Field(contentEncoding="base64", default=None) data: str | None = Field(contentEncoding="base64", default=None)
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
@ -64,13 +64,13 @@ class TextContentItem(BaseModel):
# other modalities can be added here # other modalities can be added here
InterleavedContentItem = Annotated[ InterleavedContentItem = Annotated[
Union[ImageContentItem, TextContentItem], ImageContentItem | TextContentItem,
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(InterleavedContentItem, name="InterleavedContentItem") register_schema(InterleavedContentItem, name="InterleavedContentItem")
# accept a single "str" as a special case since it is common # accept a single "str" as a special case since it is common
InterleavedContent = Union[str, InterleavedContentItem, List[InterleavedContentItem]] InterleavedContent = str | InterleavedContentItem | list[InterleavedContentItem]
register_schema(InterleavedContent, name="InterleavedContent") register_schema(InterleavedContent, name="InterleavedContent")
@ -100,13 +100,13 @@ class ToolCallDelta(BaseModel):
# you either send an in-progress tool call so the client can stream a long # you either send an in-progress tool call so the client can stream a long
# code generation or you send the final parsed tool call at the end of the # code generation or you send the final parsed tool call at the end of the
# stream # stream
tool_call: Union[str, ToolCall] tool_call: str | ToolCall
parse_status: ToolCallParseStatus parse_status: ToolCallParseStatus
# streaming completions send a stream of ContentDeltas # streaming completions send a stream of ContentDeltas
ContentDelta = Annotated[ ContentDelta = Annotated[
Union[TextDelta, ImageDelta, ToolCallDelta], TextDelta | ImageDelta | ToolCallDelta,
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(ContentDelta, name="ContentDelta") register_schema(ContentDelta, name="ContentDelta")

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Any, Dict, Optional from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
@ -25,6 +25,6 @@ class RestAPIMethod(Enum):
class RestAPIExecutionConfig(BaseModel): class RestAPIExecutionConfig(BaseModel):
url: URL url: URL
method: RestAPIMethod method: RestAPIMethod
params: Optional[Dict[str, Any]] = None params: dict[str, Any] | None = None
headers: Optional[Dict[str, Any]] = None headers: dict[str, Any] | None = None
body: Optional[Dict[str, Any]] = None body: dict[str, Any] | None = None

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
@ -19,5 +19,5 @@ class PaginatedResponse(BaseModel):
:param has_more: Whether there are more items available after this set :param has_more: Whether there are more items available after this set
""" """
data: List[Dict[str, Any]] data: list[dict[str, Any]]
has_more: bool has_more: bool

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
from datetime import datetime from datetime import datetime
from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
@ -27,4 +26,4 @@ class Checkpoint(BaseModel):
epoch: int epoch: int
post_training_job_id: str post_training_job_id: str
path: str path: str
training_metrics: Optional[PostTrainingMetric] = None training_metrics: PostTrainingMetric | None = None

View file

@ -4,10 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Literal, Union from typing import Annotated, Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.schema_utils import json_schema_type, register_schema from llama_stack.schema_utils import json_schema_type, register_schema
@ -73,18 +72,16 @@ class DialogType(BaseModel):
ParamType = Annotated[ ParamType = Annotated[
Union[ StringType
StringType, | NumberType
NumberType, | BooleanType
BooleanType, | ArrayType
ArrayType, | ObjectType
ObjectType, | JsonType
JsonType, | UnionType
UnionType, | ChatCompletionInputType
ChatCompletionInputType, | CompletionInputType
CompletionInputType, | AgentTurnInputType,
AgentTurnInputType,
],
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(ParamType, name="ParamType") register_schema(ParamType, name="ParamType")

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable from typing import Any, Protocol, runtime_checkable
from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasets import Dataset from llama_stack.apis.datasets import Dataset
@ -24,8 +24,8 @@ class DatasetIO(Protocol):
async def iterrows( async def iterrows(
self, self,
dataset_id: str, dataset_id: str,
start_index: Optional[int] = None, start_index: int | None = None,
limit: Optional[int] = None, limit: int | None = None,
) -> PaginatedResponse: ) -> PaginatedResponse:
"""Get a paginated list of rows from a dataset. """Get a paginated list of rows from a dataset.
@ -44,4 +44,4 @@ class DatasetIO(Protocol):
... ...
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST") @webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ... async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: ...

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Annotated, Any, Dict, List, Literal, Optional, Protocol, Union from typing import Annotated, Any, Literal, Protocol
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -81,11 +81,11 @@ class RowsDataSource(BaseModel):
""" """
type: Literal["rows"] = "rows" type: Literal["rows"] = "rows"
rows: List[Dict[str, Any]] rows: list[dict[str, Any]]
DataSource = Annotated[ DataSource = Annotated[
Union[URIDataSource, RowsDataSource], URIDataSource | RowsDataSource,
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(DataSource, name="DataSource") register_schema(DataSource, name="DataSource")
@ -98,7 +98,7 @@ class CommonDatasetFields(BaseModel):
purpose: DatasetPurpose purpose: DatasetPurpose
source: DataSource source: DataSource
metadata: Dict[str, Any] = Field( metadata: dict[str, Any] = Field(
default_factory=dict, default_factory=dict,
description="Any additional metadata for this dataset", description="Any additional metadata for this dataset",
) )
@ -122,7 +122,7 @@ class DatasetInput(CommonDatasetFields, BaseModel):
class ListDatasetsResponse(BaseModel): class ListDatasetsResponse(BaseModel):
data: List[Dataset] data: list[Dataset]
class Datasets(Protocol): class Datasets(Protocol):
@ -131,8 +131,8 @@ class Datasets(Protocol):
self, self,
purpose: DatasetPurpose, purpose: DatasetPurpose,
source: DataSource, source: DataSource,
metadata: Optional[Dict[str, Any]] = None, metadata: dict[str, Any] | None = None,
dataset_id: Optional[str] = None, dataset_id: str | None = None,
) -> Dataset: ) -> Dataset:
""" """
Register a new dataset. Register a new dataset.

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
@ -54,4 +53,4 @@ class Error(BaseModel):
status: int status: int
title: str title: str
detail: str detail: str
instance: Optional[str] = None instance: str | None = None

View file

@ -4,10 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List, Literal, Optional, Protocol, Union from typing import Annotated, Any, Literal, Protocol
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.agents import AgentConfig from llama_stack.apis.agents import AgentConfig
from llama_stack.apis.common.job_types import Job from llama_stack.apis.common.job_types import Job
@ -29,7 +28,7 @@ class ModelCandidate(BaseModel):
type: Literal["model"] = "model" type: Literal["model"] = "model"
model: str model: str
sampling_params: SamplingParams sampling_params: SamplingParams
system_message: Optional[SystemMessage] = None system_message: SystemMessage | None = None
@json_schema_type @json_schema_type
@ -43,7 +42,7 @@ class AgentCandidate(BaseModel):
config: AgentConfig config: AgentConfig
EvalCandidate = Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")] EvalCandidate = Annotated[ModelCandidate | AgentCandidate, Field(discriminator="type")]
register_schema(EvalCandidate, name="EvalCandidate") register_schema(EvalCandidate, name="EvalCandidate")
@ -57,11 +56,11 @@ class BenchmarkConfig(BaseModel):
""" """
eval_candidate: EvalCandidate eval_candidate: EvalCandidate
scoring_params: Dict[str, ScoringFnParams] = Field( scoring_params: dict[str, ScoringFnParams] = Field(
description="Map between scoring function id and parameters for each scoring function you want to run", description="Map between scoring function id and parameters for each scoring function you want to run",
default_factory=dict, default_factory=dict,
) )
num_examples: Optional[int] = Field( num_examples: int | None = Field(
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated", description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
default=None, default=None,
) )
@ -76,9 +75,9 @@ class EvaluateResponse(BaseModel):
:param scores: The scores from the evaluation. :param scores: The scores from the evaluation.
""" """
generations: List[Dict[str, Any]] generations: list[dict[str, Any]]
# each key in the dict is a scoring function name # each key in the dict is a scoring function name
scores: Dict[str, ScoringResult] scores: dict[str, ScoringResult]
class Eval(Protocol): class Eval(Protocol):
@ -101,8 +100,8 @@ class Eval(Protocol):
async def evaluate_rows( async def evaluate_rows(
self, self,
benchmark_id: str, benchmark_id: str,
input_rows: List[Dict[str, Any]], input_rows: list[dict[str, Any]],
scoring_functions: List[str], scoring_functions: list[str],
benchmark_config: BenchmarkConfig, benchmark_config: BenchmarkConfig,
) -> EvaluateResponse: ) -> EvaluateResponse:
"""Evaluate a list of rows on a benchmark. """Evaluate a list of rows on a benchmark.

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import List, Optional, Protocol, runtime_checkable from typing import Protocol, runtime_checkable
from pydantic import BaseModel from pydantic import BaseModel
@ -42,7 +42,7 @@ class ListBucketResponse(BaseModel):
:param data: List of FileResponse entries :param data: List of FileResponse entries
""" """
data: List[BucketResponse] data: list[BucketResponse]
@json_schema_type @json_schema_type
@ -74,7 +74,7 @@ class ListFileResponse(BaseModel):
:param data: List of FileResponse entries :param data: List of FileResponse entries
""" """
data: List[FileResponse] data: list[FileResponse]
@runtime_checkable @runtime_checkable
@ -102,7 +102,7 @@ class Files(Protocol):
async def upload_content_to_session( async def upload_content_to_session(
self, self,
upload_id: str, upload_id: str,
) -> Optional[FileResponse]: ) -> FileResponse | None:
""" """
Upload file content to an existing upload session. Upload file content to an existing upload session.
On the server, request body will have the raw bytes that are uploaded. On the server, request body will have the raw bytes that are uploaded.

View file

@ -4,21 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import AsyncIterator
from enum import Enum from enum import Enum
from typing import ( from typing import (
Annotated,
Any, Any,
AsyncIterator,
Dict,
List,
Literal, Literal,
Optional,
Protocol, Protocol,
Union,
runtime_checkable, runtime_checkable,
) )
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated, TypedDict from typing_extensions import TypedDict
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
@ -47,8 +44,8 @@ class GreedySamplingStrategy(BaseModel):
@json_schema_type @json_schema_type
class TopPSamplingStrategy(BaseModel): class TopPSamplingStrategy(BaseModel):
type: Literal["top_p"] = "top_p" type: Literal["top_p"] = "top_p"
temperature: Optional[float] = Field(..., gt=0.0) temperature: float | None = Field(..., gt=0.0)
top_p: Optional[float] = 0.95 top_p: float | None = 0.95
@json_schema_type @json_schema_type
@ -58,7 +55,7 @@ class TopKSamplingStrategy(BaseModel):
SamplingStrategy = Annotated[ SamplingStrategy = Annotated[
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy], GreedySamplingStrategy | TopPSamplingStrategy | TopKSamplingStrategy,
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(SamplingStrategy, name="SamplingStrategy") register_schema(SamplingStrategy, name="SamplingStrategy")
@ -79,9 +76,9 @@ class SamplingParams(BaseModel):
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy) strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
max_tokens: Optional[int] = 0 max_tokens: int | None = 0
repetition_penalty: Optional[float] = 1.0 repetition_penalty: float | None = 1.0
stop: Optional[List[str]] = None stop: list[str] | None = None
class LogProbConfig(BaseModel): class LogProbConfig(BaseModel):
@ -90,7 +87,7 @@ class LogProbConfig(BaseModel):
:param top_k: How many tokens (for each position) to return log probabilities for. :param top_k: How many tokens (for each position) to return log probabilities for.
""" """
top_k: Optional[int] = 0 top_k: int | None = 0
class QuantizationType(Enum): class QuantizationType(Enum):
@ -125,11 +122,11 @@ class Int4QuantizationConfig(BaseModel):
""" """
type: Literal["int4_mixed"] = "int4_mixed" type: Literal["int4_mixed"] = "int4_mixed"
scheme: Optional[str] = "int4_weight_int8_dynamic_activation" scheme: str | None = "int4_weight_int8_dynamic_activation"
QuantizationConfig = Annotated[ QuantizationConfig = Annotated[
Union[Bf16QuantizationConfig, Fp8QuantizationConfig, Int4QuantizationConfig], Bf16QuantizationConfig | Fp8QuantizationConfig | Int4QuantizationConfig,
Field(discriminator="type"), Field(discriminator="type"),
] ]
@ -145,7 +142,7 @@ class UserMessage(BaseModel):
role: Literal["user"] = "user" role: Literal["user"] = "user"
content: InterleavedContent content: InterleavedContent
context: Optional[InterleavedContent] = None context: InterleavedContent | None = None
@json_schema_type @json_schema_type
@ -190,16 +187,11 @@ class CompletionMessage(BaseModel):
role: Literal["assistant"] = "assistant" role: Literal["assistant"] = "assistant"
content: InterleavedContent content: InterleavedContent
stop_reason: StopReason stop_reason: StopReason
tool_calls: Optional[List[ToolCall]] = Field(default_factory=list) tool_calls: list[ToolCall] | None = Field(default_factory=list)
Message = Annotated[ Message = Annotated[
Union[ UserMessage | SystemMessage | ToolResponseMessage | CompletionMessage,
UserMessage,
SystemMessage,
ToolResponseMessage,
CompletionMessage,
],
Field(discriminator="role"), Field(discriminator="role"),
] ]
register_schema(Message, name="Message") register_schema(Message, name="Message")
@ -208,9 +200,9 @@ register_schema(Message, name="Message")
@json_schema_type @json_schema_type
class ToolResponse(BaseModel): class ToolResponse(BaseModel):
call_id: str call_id: str
tool_name: Union[BuiltinTool, str] tool_name: BuiltinTool | str
content: InterleavedContent content: InterleavedContent
metadata: Optional[Dict[str, Any]] = None metadata: dict[str, Any] | None = None
@field_validator("tool_name", mode="before") @field_validator("tool_name", mode="before")
@classmethod @classmethod
@ -243,7 +235,7 @@ class TokenLogProbs(BaseModel):
:param logprobs_by_token: Dictionary mapping tokens to their log probabilities :param logprobs_by_token: Dictionary mapping tokens to their log probabilities
""" """
logprobs_by_token: Dict[str, float] logprobs_by_token: dict[str, float]
class ChatCompletionResponseEventType(Enum): class ChatCompletionResponseEventType(Enum):
@ -271,8 +263,8 @@ class ChatCompletionResponseEvent(BaseModel):
event_type: ChatCompletionResponseEventType event_type: ChatCompletionResponseEventType
delta: ContentDelta delta: ContentDelta
logprobs: Optional[List[TokenLogProbs]] = None logprobs: list[TokenLogProbs] | None = None
stop_reason: Optional[StopReason] = None stop_reason: StopReason | None = None
class ResponseFormatType(Enum): class ResponseFormatType(Enum):
@ -295,7 +287,7 @@ class JsonSchemaResponseFormat(BaseModel):
""" """
type: Literal[ResponseFormatType.json_schema.value] = ResponseFormatType.json_schema.value type: Literal[ResponseFormatType.json_schema.value] = ResponseFormatType.json_schema.value
json_schema: Dict[str, Any] json_schema: dict[str, Any]
@json_schema_type @json_schema_type
@ -307,11 +299,11 @@ class GrammarResponseFormat(BaseModel):
""" """
type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value
bnf: Dict[str, Any] bnf: dict[str, Any]
ResponseFormat = Annotated[ ResponseFormat = Annotated[
Union[JsonSchemaResponseFormat, GrammarResponseFormat], JsonSchemaResponseFormat | GrammarResponseFormat,
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(ResponseFormat, name="ResponseFormat") register_schema(ResponseFormat, name="ResponseFormat")
@ -321,10 +313,10 @@ register_schema(ResponseFormat, name="ResponseFormat")
class CompletionRequest(BaseModel): class CompletionRequest(BaseModel):
model: str model: str
content: InterleavedContent content: InterleavedContent
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams) sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
response_format: Optional[ResponseFormat] = None response_format: ResponseFormat | None = None
stream: Optional[bool] = False stream: bool | None = False
logprobs: Optional[LogProbConfig] = None logprobs: LogProbConfig | None = None
@json_schema_type @json_schema_type
@ -338,7 +330,7 @@ class CompletionResponse(MetricResponseMixin):
content: str content: str
stop_reason: StopReason stop_reason: StopReason
logprobs: Optional[List[TokenLogProbs]] = None logprobs: list[TokenLogProbs] | None = None
@json_schema_type @json_schema_type
@ -351,8 +343,8 @@ class CompletionResponseStreamChunk(MetricResponseMixin):
""" """
delta: str delta: str
stop_reason: Optional[StopReason] = None stop_reason: StopReason | None = None
logprobs: Optional[List[TokenLogProbs]] = None logprobs: list[TokenLogProbs] | None = None
class SystemMessageBehavior(Enum): class SystemMessageBehavior(Enum):
@ -383,9 +375,9 @@ class ToolConfig(BaseModel):
'{{function_definitions}}' to indicate where the function definitions should be inserted. '{{function_definitions}}' to indicate where the function definitions should be inserted.
""" """
tool_choice: Optional[ToolChoice | str] = Field(default=ToolChoice.auto) tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None) tool_prompt_format: ToolPromptFormat | None = Field(default=None)
system_message_behavior: Optional[SystemMessageBehavior] = Field(default=SystemMessageBehavior.append) system_message_behavior: SystemMessageBehavior | None = Field(default=SystemMessageBehavior.append)
def model_post_init(self, __context: Any) -> None: def model_post_init(self, __context: Any) -> None:
if isinstance(self.tool_choice, str): if isinstance(self.tool_choice, str):
@ -399,15 +391,15 @@ class ToolConfig(BaseModel):
@json_schema_type @json_schema_type
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
model: str model: str
messages: List[Message] messages: list[Message]
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams) sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
tools: Optional[List[ToolDefinition]] = Field(default_factory=list) tools: list[ToolDefinition] | None = Field(default_factory=list)
tool_config: Optional[ToolConfig] = Field(default_factory=ToolConfig) tool_config: ToolConfig | None = Field(default_factory=ToolConfig)
response_format: Optional[ResponseFormat] = None response_format: ResponseFormat | None = None
stream: Optional[bool] = False stream: bool | None = False
logprobs: Optional[LogProbConfig] = None logprobs: LogProbConfig | None = None
@json_schema_type @json_schema_type
@ -429,7 +421,7 @@ class ChatCompletionResponse(MetricResponseMixin):
""" """
completion_message: CompletionMessage completion_message: CompletionMessage
logprobs: Optional[List[TokenLogProbs]] = None logprobs: list[TokenLogProbs] | None = None
@json_schema_type @json_schema_type
@ -439,7 +431,7 @@ class EmbeddingsResponse(BaseModel):
:param embeddings: List of embedding vectors, one per input content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id} :param embeddings: List of embedding vectors, one per input content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
""" """
embeddings: List[List[float]] embeddings: list[list[float]]
@json_schema_type @json_schema_type
@ -451,7 +443,7 @@ class OpenAIChatCompletionContentPartTextParam(BaseModel):
@json_schema_type @json_schema_type
class OpenAIImageURL(BaseModel): class OpenAIImageURL(BaseModel):
url: str url: str
detail: Optional[str] = None detail: str | None = None
@json_schema_type @json_schema_type
@ -461,16 +453,13 @@ class OpenAIChatCompletionContentPartImageParam(BaseModel):
OpenAIChatCompletionContentPartParam = Annotated[ OpenAIChatCompletionContentPartParam = Annotated[
Union[ OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionContentPartImageParam,
],
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam") register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam")
OpenAIChatCompletionMessageContent = Union[str, List[OpenAIChatCompletionContentPartParam]] OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam]
@json_schema_type @json_schema_type
@ -484,7 +473,7 @@ class OpenAIUserMessageParam(BaseModel):
role: Literal["user"] = "user" role: Literal["user"] = "user"
content: OpenAIChatCompletionMessageContent content: OpenAIChatCompletionMessageContent
name: Optional[str] = None name: str | None = None
@json_schema_type @json_schema_type
@ -498,21 +487,21 @@ class OpenAISystemMessageParam(BaseModel):
role: Literal["system"] = "system" role: Literal["system"] = "system"
content: OpenAIChatCompletionMessageContent content: OpenAIChatCompletionMessageContent
name: Optional[str] = None name: str | None = None
@json_schema_type @json_schema_type
class OpenAIChatCompletionToolCallFunction(BaseModel): class OpenAIChatCompletionToolCallFunction(BaseModel):
name: Optional[str] = None name: str | None = None
arguments: Optional[str] = None arguments: str | None = None
@json_schema_type @json_schema_type
class OpenAIChatCompletionToolCall(BaseModel): class OpenAIChatCompletionToolCall(BaseModel):
index: Optional[int] = None index: int | None = None
id: Optional[str] = None id: str | None = None
type: Literal["function"] = "function" type: Literal["function"] = "function"
function: Optional[OpenAIChatCompletionToolCallFunction] = None function: OpenAIChatCompletionToolCallFunction | None = None
@json_schema_type @json_schema_type
@ -526,9 +515,9 @@ class OpenAIAssistantMessageParam(BaseModel):
""" """
role: Literal["assistant"] = "assistant" role: Literal["assistant"] = "assistant"
content: Optional[OpenAIChatCompletionMessageContent] = None content: OpenAIChatCompletionMessageContent | None = None
name: Optional[str] = None name: str | None = None
tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = None tool_calls: list[OpenAIChatCompletionToolCall] | None = None
@json_schema_type @json_schema_type
@ -556,17 +545,15 @@ class OpenAIDeveloperMessageParam(BaseModel):
role: Literal["developer"] = "developer" role: Literal["developer"] = "developer"
content: OpenAIChatCompletionMessageContent content: OpenAIChatCompletionMessageContent
name: Optional[str] = None name: str | None = None
OpenAIMessageParam = Annotated[ OpenAIMessageParam = Annotated[
Union[ OpenAIUserMessageParam
OpenAIUserMessageParam, | OpenAISystemMessageParam
OpenAISystemMessageParam, | OpenAIAssistantMessageParam
OpenAIAssistantMessageParam, | OpenAIToolMessageParam
OpenAIToolMessageParam, | OpenAIDeveloperMessageParam,
OpenAIDeveloperMessageParam,
],
Field(discriminator="role"), Field(discriminator="role"),
] ]
register_schema(OpenAIMessageParam, name="OpenAIMessageParam") register_schema(OpenAIMessageParam, name="OpenAIMessageParam")
@ -580,14 +567,14 @@ class OpenAIResponseFormatText(BaseModel):
@json_schema_type @json_schema_type
class OpenAIJSONSchema(TypedDict, total=False): class OpenAIJSONSchema(TypedDict, total=False):
name: str name: str
description: Optional[str] = None description: str | None = None
strict: Optional[bool] = None strict: bool | None = None
# Pydantic BaseModel cannot be used with a schema param, since it already # Pydantic BaseModel cannot be used with a schema param, since it already
# has one. And, we don't want to alias here because then have to handle # has one. And, we don't want to alias here because then have to handle
# that alias when converting to OpenAI params. So, to support schema, # that alias when converting to OpenAI params. So, to support schema,
# we use a TypedDict. # we use a TypedDict.
schema: Optional[Dict[str, Any]] = None schema: dict[str, Any] | None = None
@json_schema_type @json_schema_type
@ -602,11 +589,7 @@ class OpenAIResponseFormatJSONObject(BaseModel):
OpenAIResponseFormatParam = Annotated[ OpenAIResponseFormatParam = Annotated[
Union[ OpenAIResponseFormatText | OpenAIResponseFormatJSONSchema | OpenAIResponseFormatJSONObject,
OpenAIResponseFormatText,
OpenAIResponseFormatJSONSchema,
OpenAIResponseFormatJSONObject,
],
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam") register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam")
@ -622,7 +605,7 @@ class OpenAITopLogProb(BaseModel):
""" """
token: str token: str
bytes: Optional[List[int]] = None bytes: list[int] | None = None
logprob: float logprob: float
@ -637,9 +620,9 @@ class OpenAITokenLogProb(BaseModel):
""" """
token: str token: str
bytes: Optional[List[int]] = None bytes: list[int] | None = None
logprob: float logprob: float
top_logprobs: List[OpenAITopLogProb] top_logprobs: list[OpenAITopLogProb]
@json_schema_type @json_schema_type
@ -650,8 +633,8 @@ class OpenAIChoiceLogprobs(BaseModel):
:param refusal: (Optional) The log probabilities for the tokens in the message :param refusal: (Optional) The log probabilities for the tokens in the message
""" """
content: Optional[List[OpenAITokenLogProb]] = None content: list[OpenAITokenLogProb] | None = None
refusal: Optional[List[OpenAITokenLogProb]] = None refusal: list[OpenAITokenLogProb] | None = None
@json_schema_type @json_schema_type
@ -664,10 +647,10 @@ class OpenAIChoiceDelta(BaseModel):
:param tool_calls: (Optional) The tool calls of the delta :param tool_calls: (Optional) The tool calls of the delta
""" """
content: Optional[str] = None content: str | None = None
refusal: Optional[str] = None refusal: str | None = None
role: Optional[str] = None role: str | None = None
tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = None tool_calls: list[OpenAIChatCompletionToolCall] | None = None
@json_schema_type @json_schema_type
@ -683,7 +666,7 @@ class OpenAIChunkChoice(BaseModel):
delta: OpenAIChoiceDelta delta: OpenAIChoiceDelta
finish_reason: str finish_reason: str
index: int index: int
logprobs: Optional[OpenAIChoiceLogprobs] = None logprobs: OpenAIChoiceLogprobs | None = None
@json_schema_type @json_schema_type
@ -699,7 +682,7 @@ class OpenAIChoice(BaseModel):
message: OpenAIMessageParam message: OpenAIMessageParam
finish_reason: str finish_reason: str
index: int index: int
logprobs: Optional[OpenAIChoiceLogprobs] = None logprobs: OpenAIChoiceLogprobs | None = None
@json_schema_type @json_schema_type
@ -714,7 +697,7 @@ class OpenAIChatCompletion(BaseModel):
""" """
id: str id: str
choices: List[OpenAIChoice] choices: list[OpenAIChoice]
object: Literal["chat.completion"] = "chat.completion" object: Literal["chat.completion"] = "chat.completion"
created: int created: int
model: str model: str
@ -732,7 +715,7 @@ class OpenAIChatCompletionChunk(BaseModel):
""" """
id: str id: str
choices: List[OpenAIChunkChoice] choices: list[OpenAIChunkChoice]
object: Literal["chat.completion.chunk"] = "chat.completion.chunk" object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int created: int
model: str model: str
@ -748,10 +731,10 @@ class OpenAICompletionLogprobs(BaseModel):
:top_logprobs: (Optional) The top log probabilities for the tokens :top_logprobs: (Optional) The top log probabilities for the tokens
""" """
text_offset: Optional[List[int]] = None text_offset: list[int] | None = None
token_logprobs: Optional[List[float]] = None token_logprobs: list[float] | None = None
tokens: Optional[List[str]] = None tokens: list[str] | None = None
top_logprobs: Optional[List[Dict[str, float]]] = None top_logprobs: list[dict[str, float]] | None = None
@json_schema_type @json_schema_type
@ -767,7 +750,7 @@ class OpenAICompletionChoice(BaseModel):
finish_reason: str finish_reason: str
text: str text: str
index: int index: int
logprobs: Optional[OpenAIChoiceLogprobs] = None logprobs: OpenAIChoiceLogprobs | None = None
@json_schema_type @json_schema_type
@ -782,7 +765,7 @@ class OpenAICompletion(BaseModel):
""" """
id: str id: str
choices: List[OpenAICompletionChoice] choices: list[OpenAICompletionChoice]
created: int created: int
model: str model: str
object: Literal["text_completion"] = "text_completion" object: Literal["text_completion"] = "text_completion"
@ -818,12 +801,12 @@ class EmbeddingTaskType(Enum):
@json_schema_type @json_schema_type
class BatchCompletionResponse(BaseModel): class BatchCompletionResponse(BaseModel):
batch: List[CompletionResponse] batch: list[CompletionResponse]
@json_schema_type @json_schema_type
class BatchChatCompletionResponse(BaseModel): class BatchChatCompletionResponse(BaseModel):
batch: List[ChatCompletionResponse] batch: list[ChatCompletionResponse]
@runtime_checkable @runtime_checkable
@ -843,11 +826,11 @@ class Inference(Protocol):
self, self,
model_id: str, model_id: str,
content: InterleavedContent, content: InterleavedContent,
sampling_params: Optional[SamplingParams] = None, sampling_params: SamplingParams | None = None,
response_format: Optional[ResponseFormat] = None, response_format: ResponseFormat | None = None,
stream: Optional[bool] = False, stream: bool | None = False,
logprobs: Optional[LogProbConfig] = None, logprobs: LogProbConfig | None = None,
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
"""Generate a completion for the given content using the specified model. """Generate a completion for the given content using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint. :param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
@ -865,10 +848,10 @@ class Inference(Protocol):
async def batch_completion( async def batch_completion(
self, self,
model_id: str, model_id: str,
content_batch: List[InterleavedContent], content_batch: list[InterleavedContent],
sampling_params: Optional[SamplingParams] = None, sampling_params: SamplingParams | None = None,
response_format: Optional[ResponseFormat] = None, response_format: ResponseFormat | None = None,
logprobs: Optional[LogProbConfig] = None, logprobs: LogProbConfig | None = None,
) -> BatchCompletionResponse: ) -> BatchCompletionResponse:
raise NotImplementedError("Batch completion is not implemented") raise NotImplementedError("Batch completion is not implemented")
@ -876,16 +859,16 @@ class Inference(Protocol):
async def chat_completion( async def chat_completion(
self, self,
model_id: str, model_id: str,
messages: List[Message], messages: list[Message],
sampling_params: Optional[SamplingParams] = None, sampling_params: SamplingParams | None = None,
tools: Optional[List[ToolDefinition]] = None, tools: list[ToolDefinition] | None = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: ToolPromptFormat | None = None,
response_format: Optional[ResponseFormat] = None, response_format: ResponseFormat | None = None,
stream: Optional[bool] = False, stream: bool | None = False,
logprobs: Optional[LogProbConfig] = None, logprobs: LogProbConfig | None = None,
tool_config: Optional[ToolConfig] = None, tool_config: ToolConfig | None = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: ) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
"""Generate a chat completion for the given messages using the specified model. """Generate a chat completion for the given messages using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint. :param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
@ -916,12 +899,12 @@ class Inference(Protocol):
async def batch_chat_completion( async def batch_chat_completion(
self, self,
model_id: str, model_id: str,
messages_batch: List[List[Message]], messages_batch: list[list[Message]],
sampling_params: Optional[SamplingParams] = None, sampling_params: SamplingParams | None = None,
tools: Optional[List[ToolDefinition]] = None, tools: list[ToolDefinition] | None = None,
tool_config: Optional[ToolConfig] = None, tool_config: ToolConfig | None = None,
response_format: Optional[ResponseFormat] = None, response_format: ResponseFormat | None = None,
logprobs: Optional[LogProbConfig] = None, logprobs: LogProbConfig | None = None,
) -> BatchChatCompletionResponse: ) -> BatchChatCompletionResponse:
raise NotImplementedError("Batch chat completion is not implemented") raise NotImplementedError("Batch chat completion is not implemented")
@ -929,10 +912,10 @@ class Inference(Protocol):
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[str] | List[InterleavedContentItem], contents: list[str] | list[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none, text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: Optional[int] = None, output_dimension: int | None = None,
task_type: Optional[EmbeddingTaskType] = None, task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
"""Generate embeddings for content pieces using the specified model. """Generate embeddings for content pieces using the specified model.
@ -950,25 +933,25 @@ class Inference(Protocol):
self, self,
# Standard OpenAI completion parameters # Standard OpenAI completion parameters
model: str, model: str,
prompt: Union[str, List[str], List[int], List[List[int]]], prompt: str | list[str] | list[int] | list[list[int]],
best_of: Optional[int] = None, best_of: int | None = None,
echo: Optional[bool] = None, echo: bool | None = None,
frequency_penalty: Optional[float] = None, frequency_penalty: float | None = None,
logit_bias: Optional[Dict[str, float]] = None, logit_bias: dict[str, float] | None = None,
logprobs: Optional[bool] = None, logprobs: bool | None = None,
max_tokens: Optional[int] = None, max_tokens: int | None = None,
n: Optional[int] = None, n: int | None = None,
presence_penalty: Optional[float] = None, presence_penalty: float | None = None,
seed: Optional[int] = None, seed: int | None = None,
stop: Optional[Union[str, List[str]]] = None, stop: str | list[str] | None = None,
stream: Optional[bool] = None, stream: bool | None = None,
stream_options: Optional[Dict[str, Any]] = None, stream_options: dict[str, Any] | None = None,
temperature: Optional[float] = None, temperature: float | None = None,
top_p: Optional[float] = None, top_p: float | None = None,
user: Optional[str] = None, user: str | None = None,
# vLLM-specific parameters # vLLM-specific parameters
guided_choice: Optional[List[str]] = None, guided_choice: list[str] | None = None,
prompt_logprobs: Optional[int] = None, prompt_logprobs: int | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
"""Generate an OpenAI-compatible completion for the given prompt using the specified model. """Generate an OpenAI-compatible completion for the given prompt using the specified model.
@ -996,29 +979,29 @@ class Inference(Protocol):
async def openai_chat_completion( async def openai_chat_completion(
self, self,
model: str, model: str,
messages: List[OpenAIMessageParam], messages: list[OpenAIMessageParam],
frequency_penalty: Optional[float] = None, frequency_penalty: float | None = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None, function_call: str | dict[str, Any] | None = None,
functions: Optional[List[Dict[str, Any]]] = None, functions: list[dict[str, Any]] | None = None,
logit_bias: Optional[Dict[str, float]] = None, logit_bias: dict[str, float] | None = None,
logprobs: Optional[bool] = None, logprobs: bool | None = None,
max_completion_tokens: Optional[int] = None, max_completion_tokens: int | None = None,
max_tokens: Optional[int] = None, max_tokens: int | None = None,
n: Optional[int] = None, n: int | None = None,
parallel_tool_calls: Optional[bool] = None, parallel_tool_calls: bool | None = None,
presence_penalty: Optional[float] = None, presence_penalty: float | None = None,
response_format: Optional[OpenAIResponseFormatParam] = None, response_format: OpenAIResponseFormatParam | None = None,
seed: Optional[int] = None, seed: int | None = None,
stop: Optional[Union[str, List[str]]] = None, stop: str | list[str] | None = None,
stream: Optional[bool] = None, stream: bool | None = None,
stream_options: Optional[Dict[str, Any]] = None, stream_options: dict[str, Any] | None = None,
temperature: Optional[float] = None, temperature: float | None = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None, tool_choice: str | dict[str, Any] | None = None,
tools: Optional[List[Dict[str, Any]]] = None, tools: list[dict[str, Any]] | None = None,
top_logprobs: Optional[int] = None, top_logprobs: int | None = None,
top_p: Optional[float] = None, top_p: float | None = None,
user: Optional[str] = None, user: str | None = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model. """Generate an OpenAI-compatible chat completion for the given messages using the specified model.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint. :param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import List, Protocol, runtime_checkable from typing import Protocol, runtime_checkable
from pydantic import BaseModel from pydantic import BaseModel
@ -16,7 +16,7 @@ from llama_stack.schema_utils import json_schema_type, webmethod
class RouteInfo(BaseModel): class RouteInfo(BaseModel):
route: str route: str
method: str method: str
provider_types: List[str] provider_types: list[str]
@json_schema_type @json_schema_type
@ -30,7 +30,7 @@ class VersionInfo(BaseModel):
class ListRoutesResponse(BaseModel): class ListRoutesResponse(BaseModel):
data: List[RouteInfo] data: list[RouteInfo]
@runtime_checkable @runtime_checkable

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from typing import Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
@ -15,7 +15,7 @@ from llama_stack.schema_utils import json_schema_type, webmethod
class CommonModelFields(BaseModel): class CommonModelFields(BaseModel):
metadata: Dict[str, Any] = Field( metadata: dict[str, Any] = Field(
default_factory=dict, default_factory=dict,
description="Any additional metadata for this model", description="Any additional metadata for this model",
) )
@ -46,14 +46,14 @@ class Model(CommonModelFields, Resource):
class ModelInput(CommonModelFields): class ModelInput(CommonModelFields):
model_id: str model_id: str
provider_id: Optional[str] = None provider_id: str | None = None
provider_model_id: Optional[str] = None provider_model_id: str | None = None
model_type: Optional[ModelType] = ModelType.llm model_type: ModelType | None = ModelType.llm
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
class ListModelsResponse(BaseModel): class ListModelsResponse(BaseModel):
data: List[Model] data: list[Model]
@json_schema_type @json_schema_type
@ -73,7 +73,7 @@ class OpenAIModel(BaseModel):
class OpenAIListModelsResponse(BaseModel): class OpenAIListModelsResponse(BaseModel):
data: List[OpenAIModel] data: list[OpenAIModel]
@runtime_checkable @runtime_checkable
@ -95,10 +95,10 @@ class Models(Protocol):
async def register_model( async def register_model(
self, self,
model_id: str, model_id: str,
provider_model_id: Optional[str] = None, provider_model_id: str | None = None,
provider_id: Optional[str] = None, provider_id: str | None = None,
metadata: Optional[Dict[str, Any]] = None, metadata: dict[str, Any] | None = None,
model_type: Optional[ModelType] = None, model_type: ModelType | None = None,
) -> Model: ... ) -> Model: ...
@webmethod(route="/models/{model_id:path}", method="DELETE") @webmethod(route="/models/{model_id:path}", method="DELETE")

View file

@ -6,10 +6,9 @@
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, Union from typing import Annotated, Any, Literal, Protocol
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.job_types import JobStatus from llama_stack.apis.common.job_types import JobStatus
@ -36,9 +35,9 @@ class DataConfig(BaseModel):
batch_size: int batch_size: int
shuffle: bool shuffle: bool
data_format: DatasetFormat data_format: DatasetFormat
validation_dataset_id: Optional[str] = None validation_dataset_id: str | None = None
packed: Optional[bool] = False packed: bool | None = False
train_on_input: Optional[bool] = False train_on_input: bool | None = False
@json_schema_type @json_schema_type
@ -51,10 +50,10 @@ class OptimizerConfig(BaseModel):
@json_schema_type @json_schema_type
class EfficiencyConfig(BaseModel): class EfficiencyConfig(BaseModel):
enable_activation_checkpointing: Optional[bool] = False enable_activation_checkpointing: bool | None = False
enable_activation_offloading: Optional[bool] = False enable_activation_offloading: bool | None = False
memory_efficient_fsdp_wrap: Optional[bool] = False memory_efficient_fsdp_wrap: bool | None = False
fsdp_cpu_offload: Optional[bool] = False fsdp_cpu_offload: bool | None = False
@json_schema_type @json_schema_type
@ -62,23 +61,23 @@ class TrainingConfig(BaseModel):
n_epochs: int n_epochs: int
max_steps_per_epoch: int = 1 max_steps_per_epoch: int = 1
gradient_accumulation_steps: int = 1 gradient_accumulation_steps: int = 1
max_validation_steps: Optional[int] = 1 max_validation_steps: int | None = 1
data_config: Optional[DataConfig] = None data_config: DataConfig | None = None
optimizer_config: Optional[OptimizerConfig] = None optimizer_config: OptimizerConfig | None = None
efficiency_config: Optional[EfficiencyConfig] = None efficiency_config: EfficiencyConfig | None = None
dtype: Optional[str] = "bf16" dtype: str | None = "bf16"
@json_schema_type @json_schema_type
class LoraFinetuningConfig(BaseModel): class LoraFinetuningConfig(BaseModel):
type: Literal["LoRA"] = "LoRA" type: Literal["LoRA"] = "LoRA"
lora_attn_modules: List[str] lora_attn_modules: list[str]
apply_lora_to_mlp: bool apply_lora_to_mlp: bool
apply_lora_to_output: bool apply_lora_to_output: bool
rank: int rank: int
alpha: int alpha: int
use_dora: Optional[bool] = False use_dora: bool | None = False
quantize_base: Optional[bool] = False quantize_base: bool | None = False
@json_schema_type @json_schema_type
@ -88,7 +87,7 @@ class QATFinetuningConfig(BaseModel):
group_size: int group_size: int
AlgorithmConfig = Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")] AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")]
register_schema(AlgorithmConfig, name="AlgorithmConfig") register_schema(AlgorithmConfig, name="AlgorithmConfig")
@ -97,7 +96,7 @@ class PostTrainingJobLogStream(BaseModel):
"""Stream of logs from a finetuning job.""" """Stream of logs from a finetuning job."""
job_uuid: str job_uuid: str
log_lines: List[str] log_lines: list[str]
@json_schema_type @json_schema_type
@ -131,8 +130,8 @@ class PostTrainingRLHFRequest(BaseModel):
training_config: TrainingConfig training_config: TrainingConfig
# TODO: define these # TODO: define these
hyperparam_search_config: Dict[str, Any] hyperparam_search_config: dict[str, Any]
logger_config: Dict[str, Any] logger_config: dict[str, Any]
class PostTrainingJob(BaseModel): class PostTrainingJob(BaseModel):
@ -146,17 +145,17 @@ class PostTrainingJobStatusResponse(BaseModel):
job_uuid: str job_uuid: str
status: JobStatus status: JobStatus
scheduled_at: Optional[datetime] = None scheduled_at: datetime | None = None
started_at: Optional[datetime] = None started_at: datetime | None = None
completed_at: Optional[datetime] = None completed_at: datetime | None = None
resources_allocated: Optional[Dict[str, Any]] = None resources_allocated: dict[str, Any] | None = None
checkpoints: List[Checkpoint] = Field(default_factory=list) checkpoints: list[Checkpoint] = Field(default_factory=list)
class ListPostTrainingJobsResponse(BaseModel): class ListPostTrainingJobsResponse(BaseModel):
data: List[PostTrainingJob] data: list[PostTrainingJob]
@json_schema_type @json_schema_type
@ -164,7 +163,7 @@ class PostTrainingJobArtifactsResponse(BaseModel):
"""Artifacts of a finetuning job.""" """Artifacts of a finetuning job."""
job_uuid: str job_uuid: str
checkpoints: List[Checkpoint] = Field(default_factory=list) checkpoints: list[Checkpoint] = Field(default_factory=list)
# TODO(ashwin): metrics, evals # TODO(ashwin): metrics, evals
@ -175,14 +174,14 @@ class PostTraining(Protocol):
self, self,
job_uuid: str, job_uuid: str,
training_config: TrainingConfig, training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any], hyperparam_search_config: dict[str, Any],
logger_config: Dict[str, Any], logger_config: dict[str, Any],
model: Optional[str] = Field( model: str | None = Field(
default=None, default=None,
description="Model descriptor for training if not in provider config`", description="Model descriptor for training if not in provider config`",
), ),
checkpoint_dir: Optional[str] = None, checkpoint_dir: str | None = None,
algorithm_config: Optional[AlgorithmConfig] = None, algorithm_config: AlgorithmConfig | None = None,
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize", method="POST") @webmethod(route="/post-training/preference-optimize", method="POST")
@ -192,8 +191,8 @@ class PostTraining(Protocol):
finetuned_model: str, finetuned_model: str,
algorithm_config: DPOAlignmentConfig, algorithm_config: DPOAlignmentConfig,
training_config: TrainingConfig, training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any], hyperparam_search_config: dict[str, Any],
logger_config: Dict[str, Any], logger_config: dict[str, Any],
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
@webmethod(route="/post-training/jobs", method="GET") @webmethod(route="/post-training/jobs", method="GET")

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List, Protocol, runtime_checkable from typing import Any, Protocol, runtime_checkable
from pydantic import BaseModel from pydantic import BaseModel
@ -17,12 +17,12 @@ class ProviderInfo(BaseModel):
api: str api: str
provider_id: str provider_id: str
provider_type: str provider_type: str
config: Dict[str, Any] config: dict[str, Any]
health: HealthResponse health: HealthResponse
class ListProvidersResponse(BaseModel): class ListProvidersResponse(BaseModel):
data: List[ProviderInfo] data: list[ProviderInfo]
@runtime_checkable @runtime_checkable

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable from typing import Any, Protocol, runtime_checkable
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -27,16 +27,16 @@ class SafetyViolation(BaseModel):
violation_level: ViolationLevel violation_level: ViolationLevel
# what message should you convey to the user # what message should you convey to the user
user_message: Optional[str] = None user_message: str | None = None
# additional metadata (including specific violation codes) more for # additional metadata (including specific violation codes) more for
# debugging, telemetry # debugging, telemetry
metadata: Dict[str, Any] = Field(default_factory=dict) metadata: dict[str, Any] = Field(default_factory=dict)
@json_schema_type @json_schema_type
class RunShieldResponse(BaseModel): class RunShieldResponse(BaseModel):
violation: Optional[SafetyViolation] = None violation: SafetyViolation | None = None
class ShieldStore(Protocol): class ShieldStore(Protocol):
@ -52,6 +52,6 @@ class Safety(Protocol):
async def run_shield( async def run_shield(
self, self,
shield_id: str, shield_id: str,
messages: List[Message], messages: list[Message],
params: Dict[str, Any] = None, params: dict[str, Any] = None,
) -> RunShieldResponse: ... ) -> RunShieldResponse: ...

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable from typing import Any, Protocol, runtime_checkable
from pydantic import BaseModel from pydantic import BaseModel
@ -12,7 +12,7 @@ from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.schema_utils import json_schema_type, webmethod from llama_stack.schema_utils import json_schema_type, webmethod
# mapping of metric to value # mapping of metric to value
ScoringResultRow = Dict[str, Any] ScoringResultRow = dict[str, Any]
@json_schema_type @json_schema_type
@ -24,15 +24,15 @@ class ScoringResult(BaseModel):
:param aggregated_results: Map of metric name to aggregated value :param aggregated_results: Map of metric name to aggregated value
""" """
score_rows: List[ScoringResultRow] score_rows: list[ScoringResultRow]
# aggregated metrics to value # aggregated metrics to value
aggregated_results: Dict[str, Any] aggregated_results: dict[str, Any]
@json_schema_type @json_schema_type
class ScoreBatchResponse(BaseModel): class ScoreBatchResponse(BaseModel):
dataset_id: Optional[str] = None dataset_id: str | None = None
results: Dict[str, ScoringResult] results: dict[str, ScoringResult]
@json_schema_type @json_schema_type
@ -44,7 +44,7 @@ class ScoreResponse(BaseModel):
""" """
# each key in the dict is a scoring function name # each key in the dict is a scoring function name
results: Dict[str, ScoringResult] results: dict[str, ScoringResult]
class ScoringFunctionStore(Protocol): class ScoringFunctionStore(Protocol):
@ -59,15 +59,15 @@ class Scoring(Protocol):
async def score_batch( async def score_batch(
self, self,
dataset_id: str, dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]], scoring_functions: dict[str, ScoringFnParams | None],
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ... ) -> ScoreBatchResponse: ...
@webmethod(route="/scoring/score", method="POST") @webmethod(route="/scoring/score", method="POST")
async def score( async def score(
self, self,
input_rows: List[Dict[str, Any]], input_rows: list[dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]], scoring_functions: dict[str, ScoringFnParams | None],
) -> ScoreResponse: ) -> ScoreResponse:
"""Score a list of rows. """Score a list of rows.

View file

@ -6,18 +6,14 @@
from enum import Enum from enum import Enum
from typing import ( from typing import (
Annotated,
Any, Any,
Dict,
List,
Literal, Literal,
Optional,
Protocol, Protocol,
Union,
runtime_checkable, runtime_checkable,
) )
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.common.type_system import ParamType from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.resource import Resource, ResourceType
@ -46,12 +42,12 @@ class AggregationFunctionType(Enum):
class LLMAsJudgeScoringFnParams(BaseModel): class LLMAsJudgeScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.llm_as_judge.value] = ScoringFnParamsType.llm_as_judge.value type: Literal[ScoringFnParamsType.llm_as_judge.value] = ScoringFnParamsType.llm_as_judge.value
judge_model: str judge_model: str
prompt_template: Optional[str] = None prompt_template: str | None = None
judge_score_regexes: Optional[List[str]] = Field( judge_score_regexes: list[str] | None = Field(
description="Regexes to extract the answer from generated response", description="Regexes to extract the answer from generated response",
default_factory=list, default_factory=list,
) )
aggregation_functions: Optional[List[AggregationFunctionType]] = Field( aggregation_functions: list[AggregationFunctionType] | None = Field(
description="Aggregation functions to apply to the scores of each row", description="Aggregation functions to apply to the scores of each row",
default_factory=list, default_factory=list,
) )
@ -60,11 +56,11 @@ class LLMAsJudgeScoringFnParams(BaseModel):
@json_schema_type @json_schema_type
class RegexParserScoringFnParams(BaseModel): class RegexParserScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.regex_parser.value] = ScoringFnParamsType.regex_parser.value type: Literal[ScoringFnParamsType.regex_parser.value] = ScoringFnParamsType.regex_parser.value
parsing_regexes: Optional[List[str]] = Field( parsing_regexes: list[str] | None = Field(
description="Regex to extract the answer from generated response", description="Regex to extract the answer from generated response",
default_factory=list, default_factory=list,
) )
aggregation_functions: Optional[List[AggregationFunctionType]] = Field( aggregation_functions: list[AggregationFunctionType] | None = Field(
description="Aggregation functions to apply to the scores of each row", description="Aggregation functions to apply to the scores of each row",
default_factory=list, default_factory=list,
) )
@ -73,33 +69,29 @@ class RegexParserScoringFnParams(BaseModel):
@json_schema_type @json_schema_type
class BasicScoringFnParams(BaseModel): class BasicScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.basic.value] = ScoringFnParamsType.basic.value type: Literal[ScoringFnParamsType.basic.value] = ScoringFnParamsType.basic.value
aggregation_functions: Optional[List[AggregationFunctionType]] = Field( aggregation_functions: list[AggregationFunctionType] | None = Field(
description="Aggregation functions to apply to the scores of each row", description="Aggregation functions to apply to the scores of each row",
default_factory=list, default_factory=list,
) )
ScoringFnParams = Annotated[ ScoringFnParams = Annotated[
Union[ LLMAsJudgeScoringFnParams | RegexParserScoringFnParams | BasicScoringFnParams,
LLMAsJudgeScoringFnParams,
RegexParserScoringFnParams,
BasicScoringFnParams,
],
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(ScoringFnParams, name="ScoringFnParams") register_schema(ScoringFnParams, name="ScoringFnParams")
class CommonScoringFnFields(BaseModel): class CommonScoringFnFields(BaseModel):
description: Optional[str] = None description: str | None = None
metadata: Dict[str, Any] = Field( metadata: dict[str, Any] = Field(
default_factory=dict, default_factory=dict,
description="Any additional metadata for this definition", description="Any additional metadata for this definition",
) )
return_type: ParamType = Field( return_type: ParamType = Field(
description="The return type of the deterministic function", description="The return type of the deterministic function",
) )
params: Optional[ScoringFnParams] = Field( params: ScoringFnParams | None = Field(
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval", description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
default=None, default=None,
) )
@ -120,12 +112,12 @@ class ScoringFn(CommonScoringFnFields, Resource):
class ScoringFnInput(CommonScoringFnFields, BaseModel): class ScoringFnInput(CommonScoringFnFields, BaseModel):
scoring_fn_id: str scoring_fn_id: str
provider_id: Optional[str] = None provider_id: str | None = None
provider_scoring_fn_id: Optional[str] = None provider_scoring_fn_id: str | None = None
class ListScoringFunctionsResponse(BaseModel): class ListScoringFunctionsResponse(BaseModel):
data: List[ScoringFn] data: list[ScoringFn]
@runtime_checkable @runtime_checkable
@ -142,7 +134,7 @@ class ScoringFunctions(Protocol):
scoring_fn_id: str, scoring_fn_id: str,
description: str, description: str,
return_type: ParamType, return_type: ParamType,
provider_scoring_fn_id: Optional[str] = None, provider_scoring_fn_id: str | None = None,
provider_id: Optional[str] = None, provider_id: str | None = None,
params: Optional[ScoringFnParams] = None, params: ScoringFnParams | None = None,
) -> None: ... ) -> None: ...

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from typing import Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel from pydantic import BaseModel
@ -14,7 +14,7 @@ from llama_stack.schema_utils import json_schema_type, webmethod
class CommonShieldFields(BaseModel): class CommonShieldFields(BaseModel):
params: Optional[Dict[str, Any]] = None params: dict[str, Any] | None = None
@json_schema_type @json_schema_type
@ -34,12 +34,12 @@ class Shield(CommonShieldFields, Resource):
class ShieldInput(CommonShieldFields): class ShieldInput(CommonShieldFields):
shield_id: str shield_id: str
provider_id: Optional[str] = None provider_id: str | None = None
provider_shield_id: Optional[str] = None provider_shield_id: str | None = None
class ListShieldsResponse(BaseModel): class ListShieldsResponse(BaseModel):
data: List[Shield] data: list[Shield]
@runtime_checkable @runtime_checkable
@ -55,7 +55,7 @@ class Shields(Protocol):
async def register_shield( async def register_shield(
self, self,
shield_id: str, shield_id: str,
provider_shield_id: Optional[str] = None, provider_shield_id: str | None = None,
provider_id: Optional[str] = None, provider_id: str | None = None,
params: Optional[Dict[str, Any]] = None, params: dict[str, Any] | None = None,
) -> Shield: ... ) -> Shield: ...

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, Union from typing import Any, Protocol
from pydantic import BaseModel from pydantic import BaseModel
@ -28,24 +28,24 @@ class FilteringFunction(Enum):
class SyntheticDataGenerationRequest(BaseModel): class SyntheticDataGenerationRequest(BaseModel):
"""Request to generate synthetic data. A small batch of prompts and a filtering function""" """Request to generate synthetic data. A small batch of prompts and a filtering function"""
dialogs: List[Message] dialogs: list[Message]
filtering_function: FilteringFunction = FilteringFunction.none filtering_function: FilteringFunction = FilteringFunction.none
model: Optional[str] = None model: str | None = None
@json_schema_type @json_schema_type
class SyntheticDataGenerationResponse(BaseModel): class SyntheticDataGenerationResponse(BaseModel):
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold.""" """Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
synthetic_data: List[Dict[str, Any]] synthetic_data: list[dict[str, Any]]
statistics: Optional[Dict[str, Any]] = None statistics: dict[str, Any] | None = None
class SyntheticDataGeneration(Protocol): class SyntheticDataGeneration(Protocol):
@webmethod(route="/synthetic-data-generation/generate") @webmethod(route="/synthetic-data-generation/generate")
def synthetic_data_generate( def synthetic_data_generate(
self, self,
dialogs: List[Message], dialogs: list[Message],
filtering_function: FilteringFunction = FilteringFunction.none, filtering_function: FilteringFunction = FilteringFunction.none,
model: Optional[str] = None, model: str | None = None,
) -> Union[SyntheticDataGenerationResponse]: ... ) -> SyntheticDataGenerationResponse: ...

View file

@ -7,18 +7,14 @@
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import ( from typing import (
Annotated,
Any, Any,
Dict,
List,
Literal, Literal,
Optional,
Protocol, Protocol,
Union,
runtime_checkable, runtime_checkable,
) )
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.models.llama.datatypes import Primitive from llama_stack.models.llama.datatypes import Primitive
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@ -37,11 +33,11 @@ class SpanStatus(Enum):
class Span(BaseModel): class Span(BaseModel):
span_id: str span_id: str
trace_id: str trace_id: str
parent_span_id: Optional[str] = None parent_span_id: str | None = None
name: str name: str
start_time: datetime start_time: datetime
end_time: Optional[datetime] = None end_time: datetime | None = None
attributes: Optional[Dict[str, Any]] = Field(default_factory=dict) attributes: dict[str, Any] | None = Field(default_factory=dict)
def set_attribute(self, key: str, value: Any): def set_attribute(self, key: str, value: Any):
if self.attributes is None: if self.attributes is None:
@ -54,7 +50,7 @@ class Trace(BaseModel):
trace_id: str trace_id: str
root_span_id: str root_span_id: str
start_time: datetime start_time: datetime
end_time: Optional[datetime] = None end_time: datetime | None = None
@json_schema_type @json_schema_type
@ -78,7 +74,7 @@ class EventCommon(BaseModel):
trace_id: str trace_id: str
span_id: str span_id: str
timestamp: datetime timestamp: datetime
attributes: Optional[Dict[str, Primitive]] = Field(default_factory=dict) attributes: dict[str, Primitive] | None = Field(default_factory=dict)
@json_schema_type @json_schema_type
@ -92,15 +88,15 @@ class UnstructuredLogEvent(EventCommon):
class MetricEvent(EventCommon): class MetricEvent(EventCommon):
type: Literal[EventType.METRIC.value] = EventType.METRIC.value type: Literal[EventType.METRIC.value] = EventType.METRIC.value
metric: str # this would be an enum metric: str # this would be an enum
value: Union[int, float] value: int | float
unit: str unit: str
@json_schema_type @json_schema_type
class MetricInResponse(BaseModel): class MetricInResponse(BaseModel):
metric: str metric: str
value: Union[int, float] value: int | float
unit: Optional[str] = None unit: str | None = None
# This is a short term solution to allow inference API to return metrics # This is a short term solution to allow inference API to return metrics
@ -124,7 +120,7 @@ class MetricInResponse(BaseModel):
class MetricResponseMixin(BaseModel): class MetricResponseMixin(BaseModel):
metrics: Optional[List[MetricInResponse]] = None metrics: list[MetricInResponse] | None = None
@json_schema_type @json_schema_type
@ -137,7 +133,7 @@ class StructuredLogType(Enum):
class SpanStartPayload(BaseModel): class SpanStartPayload(BaseModel):
type: Literal[StructuredLogType.SPAN_START.value] = StructuredLogType.SPAN_START.value type: Literal[StructuredLogType.SPAN_START.value] = StructuredLogType.SPAN_START.value
name: str name: str
parent_span_id: Optional[str] = None parent_span_id: str | None = None
@json_schema_type @json_schema_type
@ -147,10 +143,7 @@ class SpanEndPayload(BaseModel):
StructuredLogPayload = Annotated[ StructuredLogPayload = Annotated[
Union[ SpanStartPayload | SpanEndPayload,
SpanStartPayload,
SpanEndPayload,
],
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(StructuredLogPayload, name="StructuredLogPayload") register_schema(StructuredLogPayload, name="StructuredLogPayload")
@ -163,11 +156,7 @@ class StructuredLogEvent(EventCommon):
Event = Annotated[ Event = Annotated[
Union[ UnstructuredLogEvent | MetricEvent | StructuredLogEvent,
UnstructuredLogEvent,
MetricEvent,
StructuredLogEvent,
],
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(Event, name="Event") register_schema(Event, name="Event")
@ -184,7 +173,7 @@ class EvalTrace(BaseModel):
@json_schema_type @json_schema_type
class SpanWithStatus(Span): class SpanWithStatus(Span):
status: Optional[SpanStatus] = None status: SpanStatus | None = None
@json_schema_type @json_schema_type
@ -203,15 +192,15 @@ class QueryCondition(BaseModel):
class QueryTracesResponse(BaseModel): class QueryTracesResponse(BaseModel):
data: List[Trace] data: list[Trace]
class QuerySpansResponse(BaseModel): class QuerySpansResponse(BaseModel):
data: List[Span] data: list[Span]
class QuerySpanTreeResponse(BaseModel): class QuerySpanTreeResponse(BaseModel):
data: Dict[str, SpanWithStatus] data: dict[str, SpanWithStatus]
@runtime_checkable @runtime_checkable
@ -222,10 +211,10 @@ class Telemetry(Protocol):
@webmethod(route="/telemetry/traces", method="POST") @webmethod(route="/telemetry/traces", method="POST")
async def query_traces( async def query_traces(
self, self,
attribute_filters: Optional[List[QueryCondition]] = None, attribute_filters: list[QueryCondition] | None = None,
limit: Optional[int] = 100, limit: int | None = 100,
offset: Optional[int] = 0, offset: int | None = 0,
order_by: Optional[List[str]] = None, order_by: list[str] | None = None,
) -> QueryTracesResponse: ... ) -> QueryTracesResponse: ...
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET") @webmethod(route="/telemetry/traces/{trace_id:path}", method="GET")
@ -238,23 +227,23 @@ class Telemetry(Protocol):
async def get_span_tree( async def get_span_tree(
self, self,
span_id: str, span_id: str,
attributes_to_return: Optional[List[str]] = None, attributes_to_return: list[str] | None = None,
max_depth: Optional[int] = None, max_depth: int | None = None,
) -> QuerySpanTreeResponse: ... ) -> QuerySpanTreeResponse: ...
@webmethod(route="/telemetry/spans", method="POST") @webmethod(route="/telemetry/spans", method="POST")
async def query_spans( async def query_spans(
self, self,
attribute_filters: List[QueryCondition], attribute_filters: list[QueryCondition],
attributes_to_return: List[str], attributes_to_return: list[str],
max_depth: Optional[int] = None, max_depth: int | None = None,
) -> QuerySpansResponse: ... ) -> QuerySpansResponse: ...
@webmethod(route="/telemetry/spans/export", method="POST") @webmethod(route="/telemetry/spans/export", method="POST")
async def save_spans_to_dataset( async def save_spans_to_dataset(
self, self,
attribute_filters: List[QueryCondition], attribute_filters: list[QueryCondition],
attributes_to_save: List[str], attributes_to_save: list[str],
dataset_id: str, dataset_id: str,
max_depth: Optional[int] = None, max_depth: int | None = None,
) -> None: ... ) -> None: ...

View file

@ -5,10 +5,10 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Union from typing import Annotated, Any, Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated, Protocol, runtime_checkable from typing_extensions import Protocol, runtime_checkable
from llama_stack.apis.common.content_types import URL, InterleavedContent from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@ -29,13 +29,13 @@ class RAGDocument(BaseModel):
document_id: str document_id: str
content: InterleavedContent | URL content: InterleavedContent | URL
mime_type: str | None = None mime_type: str | None = None
metadata: Dict[str, Any] = Field(default_factory=dict) metadata: dict[str, Any] = Field(default_factory=dict)
@json_schema_type @json_schema_type
class RAGQueryResult(BaseModel): class RAGQueryResult(BaseModel):
content: Optional[InterleavedContent] = None content: InterleavedContent | None = None
metadata: Dict[str, Any] = Field(default_factory=dict) metadata: dict[str, Any] = Field(default_factory=dict)
@json_schema_type @json_schema_type
@ -59,10 +59,7 @@ class LLMRAGQueryGeneratorConfig(BaseModel):
RAGQueryGeneratorConfig = Annotated[ RAGQueryGeneratorConfig = Annotated[
Union[ DefaultRAGQueryGeneratorConfig | LLMRAGQueryGeneratorConfig,
DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig,
],
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig") register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
@ -83,7 +80,7 @@ class RAGToolRuntime(Protocol):
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST") @webmethod(route="/tool-runtime/rag-tool/insert", method="POST")
async def insert( async def insert(
self, self,
documents: List[RAGDocument], documents: list[RAGDocument],
vector_db_id: str, vector_db_id: str,
chunk_size_in_tokens: int = 512, chunk_size_in_tokens: int = 512,
) -> None: ) -> None:
@ -94,8 +91,8 @@ class RAGToolRuntime(Protocol):
async def query( async def query(
self, self,
content: InterleavedContent, content: InterleavedContent,
vector_db_ids: List[str], vector_db_ids: list[str],
query_config: Optional[RAGQueryConfig] = None, query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult: ) -> RAGQueryResult:
"""Query the RAG system for context; typically invoked by the agent""" """Query the RAG system for context; typically invoked by the agent"""
... ...

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Literal, Optional from typing import Any, Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Protocol, runtime_checkable from typing_extensions import Protocol, runtime_checkable
@ -24,7 +24,7 @@ class ToolParameter(BaseModel):
parameter_type: str parameter_type: str
description: str description: str
required: bool = Field(default=True) required: bool = Field(default=True)
default: Optional[Any] = None default: Any | None = None
@json_schema_type @json_schema_type
@ -40,39 +40,39 @@ class Tool(Resource):
toolgroup_id: str toolgroup_id: str
tool_host: ToolHost tool_host: ToolHost
description: str description: str
parameters: List[ToolParameter] parameters: list[ToolParameter]
metadata: Optional[Dict[str, Any]] = None metadata: dict[str, Any] | None = None
@json_schema_type @json_schema_type
class ToolDef(BaseModel): class ToolDef(BaseModel):
name: str name: str
description: Optional[str] = None description: str | None = None
parameters: Optional[List[ToolParameter]] = None parameters: list[ToolParameter] | None = None
metadata: Optional[Dict[str, Any]] = None metadata: dict[str, Any] | None = None
@json_schema_type @json_schema_type
class ToolGroupInput(BaseModel): class ToolGroupInput(BaseModel):
toolgroup_id: str toolgroup_id: str
provider_id: str provider_id: str
args: Optional[Dict[str, Any]] = None args: dict[str, Any] | None = None
mcp_endpoint: Optional[URL] = None mcp_endpoint: URL | None = None
@json_schema_type @json_schema_type
class ToolGroup(Resource): class ToolGroup(Resource):
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
mcp_endpoint: Optional[URL] = None mcp_endpoint: URL | None = None
args: Optional[Dict[str, Any]] = None args: dict[str, Any] | None = None
@json_schema_type @json_schema_type
class ToolInvocationResult(BaseModel): class ToolInvocationResult(BaseModel):
content: Optional[InterleavedContent] = None content: InterleavedContent | None = None
error_message: Optional[str] = None error_message: str | None = None
error_code: Optional[int] = None error_code: int | None = None
metadata: Optional[Dict[str, Any]] = None metadata: dict[str, Any] | None = None
class ToolStore(Protocol): class ToolStore(Protocol):
@ -81,11 +81,11 @@ class ToolStore(Protocol):
class ListToolGroupsResponse(BaseModel): class ListToolGroupsResponse(BaseModel):
data: List[ToolGroup] data: list[ToolGroup]
class ListToolsResponse(BaseModel): class ListToolsResponse(BaseModel):
data: List[Tool] data: list[Tool]
class ListToolDefsResponse(BaseModel): class ListToolDefsResponse(BaseModel):
@ -100,8 +100,8 @@ class ToolGroups(Protocol):
self, self,
toolgroup_id: str, toolgroup_id: str,
provider_id: str, provider_id: str,
mcp_endpoint: Optional[URL] = None, mcp_endpoint: URL | None = None,
args: Optional[Dict[str, Any]] = None, args: dict[str, Any] | None = None,
) -> None: ) -> None:
"""Register a tool group""" """Register a tool group"""
... ...
@ -118,7 +118,7 @@ class ToolGroups(Protocol):
... ...
@webmethod(route="/tools", method="GET") @webmethod(route="/tools", method="GET")
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse: async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
"""List tools with optional tool group""" """List tools with optional tool group"""
... ...
@ -151,10 +151,10 @@ class ToolRuntime(Protocol):
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed. # TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
@webmethod(route="/tool-runtime/list-tools", method="GET") @webmethod(route="/tool-runtime/list-tools", method="GET")
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse: ... ) -> ListToolDefsResponse: ...
@webmethod(route="/tool-runtime/invoke", method="POST") @webmethod(route="/tool-runtime/invoke", method="POST")
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult: async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
"""Run a tool with the given arguments""" """Run a tool with the given arguments"""
... ...

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import List, Literal, Optional, Protocol, runtime_checkable from typing import Literal, Protocol, runtime_checkable
from pydantic import BaseModel from pydantic import BaseModel
@ -33,11 +33,11 @@ class VectorDBInput(BaseModel):
vector_db_id: str vector_db_id: str
embedding_model: str embedding_model: str
embedding_dimension: int embedding_dimension: int
provider_vector_db_id: Optional[str] = None provider_vector_db_id: str | None = None
class ListVectorDBsResponse(BaseModel): class ListVectorDBsResponse(BaseModel):
data: List[VectorDB] data: list[VectorDB]
@runtime_checkable @runtime_checkable
@ -57,9 +57,9 @@ class VectorDBs(Protocol):
self, self,
vector_db_id: str, vector_db_id: str,
embedding_model: str, embedding_model: str,
embedding_dimension: Optional[int] = 384, embedding_dimension: int | None = 384,
provider_id: Optional[str] = None, provider_id: str | None = None,
provider_vector_db_id: Optional[str] = None, provider_vector_db_id: str | None = None,
) -> VectorDB: ... ) -> VectorDB: ...
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE") @webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE")

View file

@ -8,7 +8,7 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable from typing import Any, Protocol, runtime_checkable
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -20,17 +20,17 @@ from llama_stack.schema_utils import json_schema_type, webmethod
class Chunk(BaseModel): class Chunk(BaseModel):
content: InterleavedContent content: InterleavedContent
metadata: Dict[str, Any] = Field(default_factory=dict) metadata: dict[str, Any] = Field(default_factory=dict)
@json_schema_type @json_schema_type
class QueryChunksResponse(BaseModel): class QueryChunksResponse(BaseModel):
chunks: List[Chunk] chunks: list[Chunk]
scores: List[float] scores: list[float]
class VectorDBStore(Protocol): class VectorDBStore(Protocol):
def get_vector_db(self, vector_db_id: str) -> Optional[VectorDB]: ... def get_vector_db(self, vector_db_id: str) -> VectorDB | None: ...
@runtime_checkable @runtime_checkable
@ -44,8 +44,8 @@ class VectorIO(Protocol):
async def insert_chunks( async def insert_chunks(
self, self,
vector_db_id: str, vector_db_id: str,
chunks: List[Chunk], chunks: list[Chunk],
ttl_seconds: Optional[int] = None, ttl_seconds: int | None = None,
) -> None: ... ) -> None: ...
@webmethod(route="/vector-io/query", method="POST") @webmethod(route="/vector-io/query", method="POST")
@ -53,5 +53,5 @@ class VectorIO(Protocol):
self, self,
vector_db_id: str, vector_db_id: str,
query: InterleavedContent, query: InterleavedContent,
params: Optional[Dict[str, Any]] = None, params: dict[str, Any] | None = None,
) -> QueryChunksResponse: ... ) -> QueryChunksResponse: ...

View file

@ -13,7 +13,6 @@ from dataclasses import dataclass
from datetime import datetime, timezone from datetime import datetime, timezone
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional
import httpx import httpx
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
@ -102,7 +101,7 @@ class DownloadTask:
output_file: str output_file: str
total_size: int = 0 total_size: int = 0
downloaded_size: int = 0 downloaded_size: int = 0
task_id: Optional[int] = None task_id: int | None = None
retries: int = 0 retries: int = 0
max_retries: int = 3 max_retries: int = 3
@ -262,7 +261,7 @@ class ParallelDownloader:
self.progress.update(task.task_id, description=f"[red]Failed: {task.output_file}[/red]") self.progress.update(task.task_id, description=f"[red]Failed: {task.output_file}[/red]")
raise DownloadError(f"Download failed for {task.output_file}: {str(e)}") from e raise DownloadError(f"Download failed for {task.output_file}: {str(e)}") from e
def has_disk_space(self, tasks: List[DownloadTask]) -> bool: def has_disk_space(self, tasks: list[DownloadTask]) -> bool:
try: try:
total_remaining_size = sum(task.total_size - task.downloaded_size for task in tasks) total_remaining_size = sum(task.total_size - task.downloaded_size for task in tasks)
dir_path = os.path.dirname(os.path.abspath(tasks[0].output_file)) dir_path = os.path.dirname(os.path.abspath(tasks[0].output_file))
@ -282,7 +281,7 @@ class ParallelDownloader:
except Exception as e: except Exception as e:
raise DownloadError(f"Failed to check disk space: {str(e)}") from e raise DownloadError(f"Failed to check disk space: {str(e)}") from e
async def download_all(self, tasks: List[DownloadTask]) -> None: async def download_all(self, tasks: list[DownloadTask]) -> None:
if not tasks: if not tasks:
raise ValueError("No download tasks provided") raise ValueError("No download tasks provided")
@ -391,20 +390,20 @@ def _meta_download(
class ModelEntry(BaseModel): class ModelEntry(BaseModel):
model_id: str model_id: str
files: Dict[str, str] files: dict[str, str]
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
class Manifest(BaseModel): class Manifest(BaseModel):
models: List[ModelEntry] models: list[ModelEntry]
expires_on: datetime expires_on: datetime
def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int): def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
with open(manifest_file, "r") as f: with open(manifest_file) as f:
d = json.load(f) d = json.load(f)
manifest = Manifest(**d) manifest = Manifest(**d)

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict from typing import Any
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
@ -22,7 +22,7 @@ class PromptGuardModel(BaseModel):
max_seq_length: int = 512 max_seq_length: int = 512
is_instruct_model: bool = False is_instruct_model: bool = False
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16 quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
arch_args: Dict[str, Any] = Field(default_factory=dict) arch_args: dict[str, Any] = Field(default_factory=dict)
def descriptor(self) -> str: def descriptor(self) -> str:
return self.model_id return self.model_id
@ -44,11 +44,11 @@ def prompt_guard_model_skus():
] ]
def prompt_guard_model_sku_map() -> Dict[str, Any]: def prompt_guard_model_sku_map() -> dict[str, Any]:
return {model.model_id: model for model in prompt_guard_model_skus()} return {model.model_id: model for model in prompt_guard_model_skus()}
def prompt_guard_download_info_map() -> Dict[str, LlamaDownloadInfo]: def prompt_guard_download_info_map() -> dict[str, LlamaDownloadInfo]:
return { return {
model.model_id: LlamaDownloadInfo( model.model_id: LlamaDownloadInfo(
folder="Prompt-Guard" if model.model_id == "Prompt-Guard-86M" else model.model_id, folder="Prompt-Guard" if model.model_id == "Prompt-Guard-86M" else model.model_id,

View file

@ -13,7 +13,6 @@ import sys
import textwrap import textwrap
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import Dict, Optional
import yaml import yaml
from prompt_toolkit import prompt from prompt_toolkit import prompt
@ -46,14 +45,14 @@ from llama_stack.providers.datatypes import Api
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates" TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
@lru_cache() @lru_cache
def available_templates_specs() -> Dict[str, BuildConfig]: def available_templates_specs() -> dict[str, BuildConfig]:
import yaml import yaml
template_specs = {} template_specs = {}
for p in TEMPLATES_PATH.rglob("*build.yaml"): for p in TEMPLATES_PATH.rglob("*build.yaml"):
template_name = p.parent.name template_name = p.parent.name
with open(p, "r") as f: with open(p) as f:
build_config = BuildConfig(**yaml.safe_load(f)) build_config = BuildConfig(**yaml.safe_load(f))
template_specs[template_name] = build_config template_specs[template_name] = build_config
return template_specs return template_specs
@ -178,7 +177,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
if not available_providers: if not available_providers:
continue continue
api_provider = prompt( api_provider = prompt(
"> Enter provider for API {}: ".format(api.value), f"> Enter provider for API {api.value}: ",
completer=WordCompleter(available_providers), completer=WordCompleter(available_providers),
complete_while_typing=True, complete_while_typing=True,
validator=Validator.from_callable( validator=Validator.from_callable(
@ -201,7 +200,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
build_config = BuildConfig(image_type=image_type, distribution_spec=distribution_spec) build_config = BuildConfig(image_type=image_type, distribution_spec=distribution_spec)
else: else:
with open(args.config, "r") as f: with open(args.config) as f:
try: try:
build_config = BuildConfig(**yaml.safe_load(f)) build_config = BuildConfig(**yaml.safe_load(f))
except Exception as e: except Exception as e:
@ -332,9 +331,9 @@ def _generate_run_config(
def _run_stack_build_command_from_build_config( def _run_stack_build_command_from_build_config(
build_config: BuildConfig, build_config: BuildConfig,
image_name: Optional[str] = None, image_name: str | None = None,
template_name: Optional[str] = None, template_name: str | None = None,
config_path: Optional[str] = None, config_path: str | None = None,
) -> str: ) -> str:
image_name = image_name or build_config.image_name image_name = image_name or build_config.image_name
if build_config.image_type == LlamaStackImageType.CONTAINER.value: if build_config.image_type == LlamaStackImageType.CONTAINER.value:

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Iterable from collections.abc import Iterable
from rich.console import Console from rich.console import Console
from rich.table import Table from rich.table import Table

View file

@ -9,7 +9,6 @@ import hashlib
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional
from rich.console import Console from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn from rich.progress import Progress, SpinnerColumn, TextColumn
@ -21,7 +20,7 @@ from llama_stack.cli.subcommand import Subcommand
class VerificationResult: class VerificationResult:
filename: str filename: str
expected_hash: str expected_hash: str
actual_hash: Optional[str] actual_hash: str | None
exists: bool exists: bool
matches: bool matches: bool
@ -60,9 +59,9 @@ def calculate_md5(filepath: Path, chunk_size: int = 8192) -> str:
return md5_hash.hexdigest() return md5_hash.hexdigest()
def load_checksums(checklist_path: Path) -> Dict[str, str]: def load_checksums(checklist_path: Path) -> dict[str, str]:
checksums = {} checksums = {}
with open(checklist_path, "r") as f: with open(checklist_path) as f:
for line in f: for line in f:
if line.strip(): if line.strip():
md5sum, filepath = line.strip().split(" ", 1) md5sum, filepath = line.strip().split(" ", 1)
@ -72,7 +71,7 @@ def load_checksums(checklist_path: Path) -> Dict[str, str]:
return checksums return checksums
def verify_files(model_dir: Path, checksums: Dict[str, str], console: Console) -> List[VerificationResult]: def verify_files(model_dir: Path, checksums: dict[str, str], console: Console) -> list[VerificationResult]:
results = [] results = []
with Progress( with Progress(

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, Optional from typing import Any
from llama_stack.distribution.datatypes import AccessAttributes from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -14,8 +14,8 @@ logger = get_logger(__name__, category="core")
def check_access( def check_access(
obj_identifier: str, obj_identifier: str,
obj_attributes: Optional[AccessAttributes], obj_attributes: AccessAttributes | None,
user_attributes: Optional[Dict[str, Any]] = None, user_attributes: dict[str, Any] | None = None,
) -> bool: ) -> bool:
"""Check if the current user has access to the given object, based on access attributes. """Check if the current user has access to the given object, based on access attributes.

View file

@ -8,7 +8,7 @@ import inspect
import json import json
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from enum import Enum from enum import Enum
from typing import Any, Type, Union, get_args, get_origin from typing import Any, Union, get_args, get_origin
import httpx import httpx
from pydantic import BaseModel, parse_obj_as from pydantic import BaseModel, parse_obj_as
@ -27,7 +27,7 @@ async def get_client_impl(protocol, config: RemoteProviderConfig, _deps: Any):
return impl return impl
def create_api_client_class(protocol) -> Type: def create_api_client_class(protocol) -> type:
if protocol in _CLIENT_CLASSES: if protocol in _CLIENT_CLASSES:
return _CLIENT_CLASSES[protocol] return _CLIENT_CLASSES[protocol]

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import logging import logging
import textwrap import textwrap
from typing import Any, Dict from typing import Any
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
LLAMA_STACK_RUN_CONFIG_VERSION, LLAMA_STACK_RUN_CONFIG_VERSION,
@ -24,7 +24,7 @@ from llama_stack.providers.datatypes import Api, ProviderSpec
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def configure_single_provider(registry: Dict[str, ProviderSpec], provider: Provider) -> Provider: def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provider) -> Provider:
provider_spec = registry[provider.provider_type] provider_spec = registry[provider.provider_type]
config_type = instantiate_class_type(provider_spec.config_class) config_type = instantiate_class_type(provider_spec.config_class)
try: try:
@ -120,8 +120,8 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
def upgrade_from_routing_table( def upgrade_from_routing_table(
config_dict: Dict[str, Any], config_dict: dict[str, Any],
) -> Dict[str, Any]: ) -> dict[str, Any]:
def get_providers(entries): def get_providers(entries):
return [ return [
Provider( Provider(
@ -163,7 +163,7 @@ def upgrade_from_routing_table(
return config_dict return config_dict
def parse_and_maybe_upgrade_config(config_dict: Dict[str, Any]) -> StackRunConfig: def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfig:
version = config_dict.get("version", None) version = config_dict.get("version", None)
if version == LLAMA_STACK_RUN_CONFIG_VERSION: if version == LLAMA_STACK_RUN_CONFIG_VERSION:
return StackRunConfig(**config_dict) return StackRunConfig(**config_dict)

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Annotated, Any, Dict, List, Optional, Union from typing import Annotated, Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -30,7 +30,7 @@ LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
LLAMA_STACK_RUN_CONFIG_VERSION = "2" LLAMA_STACK_RUN_CONFIG_VERSION = "2"
RoutingKey = Union[str, List[str]] RoutingKey = str | list[str]
class AccessAttributes(BaseModel): class AccessAttributes(BaseModel):
@ -47,17 +47,17 @@ class AccessAttributes(BaseModel):
""" """
# Standard attribute categories - the minimal set we need now # Standard attribute categories - the minimal set we need now
roles: Optional[List[str]] = Field( roles: list[str] | None = Field(
default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')" default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')"
) )
teams: Optional[List[str]] = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')") teams: list[str] | None = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
projects: Optional[List[str]] = Field( projects: list[str] | None = Field(
default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')" default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')"
) )
namespaces: Optional[List[str]] = Field( namespaces: list[str] | None = Field(
default=None, description="Namespace-based access control for resource isolation" default=None, description="Namespace-based access control for resource isolation"
) )
@ -106,7 +106,7 @@ class ResourceWithACL(Resource):
# ^ User must have access to the customer-insights project AND have confidential namespace # ^ User must have access to the customer-insights project AND have confidential namespace
""" """
access_attributes: Optional[AccessAttributes] = None access_attributes: AccessAttributes | None = None
# Use the extended Resource for all routable objects # Use the extended Resource for all routable objects
@ -142,41 +142,21 @@ class ToolGroupWithACL(ToolGroup, ResourceWithACL):
pass pass
RoutableObject = Union[ RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | Tool | ToolGroup
Model,
Shield,
VectorDB,
Dataset,
ScoringFn,
Benchmark,
Tool,
ToolGroup,
]
RoutableObjectWithProvider = Annotated[ RoutableObjectWithProvider = Annotated[
Union[ ModelWithACL
ModelWithACL, | ShieldWithACL
ShieldWithACL, | VectorDBWithACL
VectorDBWithACL, | DatasetWithACL
DatasetWithACL, | ScoringFnWithACL
ScoringFnWithACL, | BenchmarkWithACL
BenchmarkWithACL, | ToolWithACL
ToolWithACL, | ToolGroupWithACL,
ToolGroupWithACL,
],
Field(discriminator="type"), Field(discriminator="type"),
] ]
RoutedProtocol = Union[ RoutedProtocol = Inference | Safety | VectorIO | DatasetIO | Scoring | Eval | ToolRuntime
Inference,
Safety,
VectorIO,
DatasetIO,
Scoring,
Eval,
ToolRuntime,
]
# Example: /inference, /safety # Example: /inference, /safety
@ -184,15 +164,15 @@ class AutoRoutedProviderSpec(ProviderSpec):
provider_type: str = "router" provider_type: str = "router"
config_class: str = "" config_class: str = ""
container_image: Optional[str] = None container_image: str | None = None
routing_table_api: Api routing_table_api: Api
module: str module: str
provider_data_validator: Optional[str] = Field( provider_data_validator: str | None = Field(
default=None, default=None,
) )
@property @property
def pip_packages(self) -> List[str]: def pip_packages(self) -> list[str]:
raise AssertionError("Should not be called on AutoRoutedProviderSpec") raise AssertionError("Should not be called on AutoRoutedProviderSpec")
@ -200,20 +180,20 @@ class AutoRoutedProviderSpec(ProviderSpec):
class RoutingTableProviderSpec(ProviderSpec): class RoutingTableProviderSpec(ProviderSpec):
provider_type: str = "routing_table" provider_type: str = "routing_table"
config_class: str = "" config_class: str = ""
container_image: Optional[str] = None container_image: str | None = None
router_api: Api router_api: Api
module: str module: str
pip_packages: List[str] = Field(default_factory=list) pip_packages: list[str] = Field(default_factory=list)
class DistributionSpec(BaseModel): class DistributionSpec(BaseModel):
description: Optional[str] = Field( description: str | None = Field(
default="", default="",
description="Description of the distribution", description="Description of the distribution",
) )
container_image: Optional[str] = None container_image: str | None = None
providers: Dict[str, Union[str, List[str]]] = Field( providers: dict[str, str | list[str]] = Field(
default_factory=dict, default_factory=dict,
description=""" description="""
Provider Types for each of the APIs provided by this distribution. If you Provider Types for each of the APIs provided by this distribution. If you
@ -225,12 +205,12 @@ in the runtime configuration to help route to the correct provider.""",
class Provider(BaseModel): class Provider(BaseModel):
provider_id: str provider_id: str
provider_type: str provider_type: str
config: Dict[str, Any] config: dict[str, Any]
class LoggingConfig(BaseModel): class LoggingConfig(BaseModel):
category_levels: Dict[str, str] = Field( category_levels: dict[str, str] = Field(
default_factory=Dict, default_factory=dict,
description=""" description="""
Dictionary of different logging configurations for different portions (ex: core, server) of llama stack""", Dictionary of different logging configurations for different portions (ex: core, server) of llama stack""",
) )
@ -248,7 +228,7 @@ class AuthenticationConfig(BaseModel):
..., ...,
description="Type of authentication provider (e.g., 'kubernetes', 'custom')", description="Type of authentication provider (e.g., 'kubernetes', 'custom')",
) )
config: Dict[str, str] = Field( config: dict[str, str] = Field(
..., ...,
description="Provider-specific configuration", description="Provider-specific configuration",
) )
@ -261,15 +241,15 @@ class ServerConfig(BaseModel):
ge=1024, ge=1024,
le=65535, le=65535,
) )
tls_certfile: Optional[str] = Field( tls_certfile: str | None = Field(
default=None, default=None,
description="Path to TLS certificate file for HTTPS", description="Path to TLS certificate file for HTTPS",
) )
tls_keyfile: Optional[str] = Field( tls_keyfile: str | None = Field(
default=None, default=None,
description="Path to TLS key file for HTTPS", description="Path to TLS key file for HTTPS",
) )
auth: Optional[AuthenticationConfig] = Field( auth: AuthenticationConfig | None = Field(
default=None, default=None,
description="Authentication configuration for the server", description="Authentication configuration for the server",
) )
@ -285,23 +265,23 @@ Reference to the distribution this package refers to. For unregistered (adhoc) p
this could be just a hash this could be just a hash
""", """,
) )
container_image: Optional[str] = Field( container_image: str | None = Field(
default=None, default=None,
description="Reference to the container image if this package refers to a container", description="Reference to the container image if this package refers to a container",
) )
apis: List[str] = Field( apis: list[str] = Field(
default_factory=list, default_factory=list,
description=""" description="""
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""", The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
) )
providers: Dict[str, List[Provider]] = Field( providers: dict[str, list[Provider]] = Field(
description=""" description="""
One or more providers to use for each API. The same provider_type (e.g., meta-reference) One or more providers to use for each API. The same provider_type (e.g., meta-reference)
can be instantiated multiple times (with different configs) if necessary. can be instantiated multiple times (with different configs) if necessary.
""", """,
) )
metadata_store: Optional[KVStoreConfig] = Field( metadata_store: KVStoreConfig | None = Field(
default=None, default=None,
description=""" description="""
Configuration for the persistence store used by the distribution registry. If not specified, Configuration for the persistence store used by the distribution registry. If not specified,
@ -309,22 +289,22 @@ a default SQLite store will be used.""",
) )
# registry of "resources" in the distribution # registry of "resources" in the distribution
models: List[ModelInput] = Field(default_factory=list) models: list[ModelInput] = Field(default_factory=list)
shields: List[ShieldInput] = Field(default_factory=list) shields: list[ShieldInput] = Field(default_factory=list)
vector_dbs: List[VectorDBInput] = Field(default_factory=list) vector_dbs: list[VectorDBInput] = Field(default_factory=list)
datasets: List[DatasetInput] = Field(default_factory=list) datasets: list[DatasetInput] = Field(default_factory=list)
scoring_fns: List[ScoringFnInput] = Field(default_factory=list) scoring_fns: list[ScoringFnInput] = Field(default_factory=list)
benchmarks: List[BenchmarkInput] = Field(default_factory=list) benchmarks: list[BenchmarkInput] = Field(default_factory=list)
tool_groups: List[ToolGroupInput] = Field(default_factory=list) tool_groups: list[ToolGroupInput] = Field(default_factory=list)
logging: Optional[LoggingConfig] = Field(default=None, description="Configuration for Llama Stack Logging") logging: LoggingConfig | None = Field(default=None, description="Configuration for Llama Stack Logging")
server: ServerConfig = Field( server: ServerConfig = Field(
default_factory=ServerConfig, default_factory=ServerConfig,
description="Configuration for the HTTP(S) server", description="Configuration for the HTTP(S) server",
) )
external_providers_dir: Optional[str] = Field( external_providers_dir: str | None = Field(
default=None, default=None,
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.", description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
) )
@ -338,11 +318,11 @@ class BuildConfig(BaseModel):
default="conda", default="conda",
description="Type of package to build (conda | container | venv)", description="Type of package to build (conda | container | venv)",
) )
image_name: Optional[str] = Field( image_name: str | None = Field(
default=None, default=None,
description="Name of the distribution to build", description="Name of the distribution to build",
) )
external_providers_dir: Optional[str] = Field( external_providers_dir: str | None = Field(
default=None, default=None,
description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. " description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. "
"pip_packages MUST contain the provider package name.", "pip_packages MUST contain the provider package name.",

View file

@ -7,7 +7,7 @@
import glob import glob
import importlib import importlib
import os import os
from typing import Any, Dict, List from typing import Any
import yaml import yaml
from pydantic import BaseModel from pydantic import BaseModel
@ -24,7 +24,7 @@ from llama_stack.providers.datatypes import (
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core")
def stack_apis() -> List[Api]: def stack_apis() -> list[Api]:
return list(Api) return list(Api)
@ -33,7 +33,7 @@ class AutoRoutedApiInfo(BaseModel):
router_api: Api router_api: Api
def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]: def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]:
return [ return [
AutoRoutedApiInfo( AutoRoutedApiInfo(
routing_table_api=Api.models, routing_table_api=Api.models,
@ -66,12 +66,12 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
] ]
def providable_apis() -> List[Api]: def providable_apis() -> list[Api]:
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()} routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers] return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]
def _load_remote_provider_spec(spec_data: Dict[str, Any], api: Api) -> ProviderSpec: def _load_remote_provider_spec(spec_data: dict[str, Any], api: Api) -> ProviderSpec:
adapter = AdapterSpec(**spec_data["adapter"]) adapter = AdapterSpec(**spec_data["adapter"])
spec = remote_provider_spec( spec = remote_provider_spec(
api=api, api=api,
@ -81,7 +81,7 @@ def _load_remote_provider_spec(spec_data: Dict[str, Any], api: Api) -> ProviderS
return spec return spec
def _load_inline_provider_spec(spec_data: Dict[str, Any], api: Api, provider_name: str) -> ProviderSpec: def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
spec = InlineProviderSpec( spec = InlineProviderSpec(
api=api, api=api,
provider_type=f"inline::{provider_name}", provider_type=f"inline::{provider_name}",
@ -98,7 +98,7 @@ def _load_inline_provider_spec(spec_data: Dict[str, Any], api: Api, provider_nam
def get_provider_registry( def get_provider_registry(
config=None, config=None,
) -> Dict[Api, Dict[str, ProviderSpec]]: ) -> dict[Api, dict[str, ProviderSpec]]:
"""Get the provider registry, optionally including external providers. """Get the provider registry, optionally including external providers.
This function loads both built-in providers and external providers from YAML files. This function loads both built-in providers and external providers from YAML files.
@ -133,7 +133,7 @@ def get_provider_registry(
ValueError: If any provider spec is invalid ValueError: If any provider spec is invalid
""" """
ret: Dict[Api, Dict[str, ProviderSpec]] = {} ret: dict[Api, dict[str, ProviderSpec]] = {}
for api in providable_apis(): for api in providable_apis():
name = api.name.lower() name = api.name.lower()
logger.debug(f"Importing module {name}") logger.debug(f"Importing module {name}")

View file

@ -12,7 +12,7 @@ import os
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any, Optional, TypeVar, Union, get_args, get_origin from typing import Any, TypeVar, Union, get_args, get_origin
import httpx import httpx
import yaml import yaml
@ -119,8 +119,8 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
self, self,
config_path_or_template_name: str, config_path_or_template_name: str,
skip_logger_removal: bool = False, skip_logger_removal: bool = False,
custom_provider_registry: Optional[ProviderRegistry] = None, custom_provider_registry: ProviderRegistry | None = None,
provider_data: Optional[dict[str, Any]] = None, provider_data: dict[str, Any] | None = None,
): ):
super().__init__() super().__init__()
self.async_client = AsyncLlamaStackAsLibraryClient( self.async_client = AsyncLlamaStackAsLibraryClient(
@ -181,8 +181,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
def __init__( def __init__(
self, self,
config_path_or_template_name: str, config_path_or_template_name: str,
custom_provider_registry: Optional[ProviderRegistry] = None, custom_provider_registry: ProviderRegistry | None = None,
provider_data: Optional[dict[str, Any]] = None, provider_data: dict[str, Any] | None = None,
): ):
super().__init__() super().__init__()
# when using the library client, we should not log to console since many # when using the library client, we should not log to console since many
@ -371,7 +371,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
) )
return await response.parse() return await response.parse()
def _convert_body(self, path: str, method: str, body: Optional[dict] = None) -> dict: def _convert_body(self, path: str, method: str, body: dict | None = None) -> dict:
if not body: if not body:
return {} return {}

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
from typing import Any, Dict from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
@ -73,14 +73,14 @@ class ProviderImpl(Providers):
raise ValueError(f"Provider {provider_id} not found") raise ValueError(f"Provider {provider_id} not found")
async def get_providers_health(self) -> Dict[str, Dict[str, HealthResponse]]: async def get_providers_health(self) -> dict[str, dict[str, HealthResponse]]:
"""Get health status for all providers. """Get health status for all providers.
Returns: Returns:
Dict[str, Dict[str, HealthResponse]]: A dictionary mapping API names to provider health statuses. Dict[str, Dict[str, HealthResponse]]: A dictionary mapping API names to provider health statuses.
Each API maps to a dictionary of provider IDs to their health responses. Each API maps to a dictionary of provider IDs to their health responses.
""" """
providers_health: Dict[str, Dict[str, HealthResponse]] = {} providers_health: dict[str, dict[str, HealthResponse]] = {}
timeout = 1.0 timeout = 1.0
async def check_provider_health(impl: Any) -> tuple[str, HealthResponse] | None: async def check_provider_health(impl: Any) -> tuple[str, HealthResponse] | None:

View file

@ -7,7 +7,8 @@
import contextvars import contextvars
import json import json
import logging import logging
from typing import Any, ContextManager, Dict, List, Optional from contextlib import AbstractContextManager
from typing import Any
from .utils.dynamic import instantiate_class_type from .utils.dynamic import instantiate_class_type
@ -17,11 +18,11 @@ log = logging.getLogger(__name__)
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None) PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
class RequestProviderDataContext(ContextManager): class RequestProviderDataContext(AbstractContextManager):
"""Context manager for request provider data""" """Context manager for request provider data"""
def __init__( def __init__(
self, provider_data: Optional[Dict[str, Any]] = None, auth_attributes: Optional[Dict[str, List[str]]] = None self, provider_data: dict[str, Any] | None = None, auth_attributes: dict[str, list[str]] | None = None
): ):
self.provider_data = provider_data or {} self.provider_data = provider_data or {}
if auth_attributes: if auth_attributes:
@ -63,7 +64,7 @@ class NeedsRequestProviderData:
return None return None
def parse_request_provider_data(headers: Dict[str, str]) -> Optional[Dict[str, Any]]: def parse_request_provider_data(headers: dict[str, str]) -> dict[str, Any] | None:
"""Parse provider data from request headers""" """Parse provider data from request headers"""
keys = [ keys = [
"X-LlamaStack-Provider-Data", "X-LlamaStack-Provider-Data",
@ -86,14 +87,14 @@ def parse_request_provider_data(headers: Dict[str, str]) -> Optional[Dict[str, A
def request_provider_data_context( def request_provider_data_context(
headers: Dict[str, str], auth_attributes: Optional[Dict[str, List[str]]] = None headers: dict[str, str], auth_attributes: dict[str, list[str]] | None = None
) -> ContextManager: ) -> AbstractContextManager:
"""Context manager that sets request provider data from headers and auth attributes for the duration of the context""" """Context manager that sets request provider data from headers and auth attributes for the duration of the context"""
provider_data = parse_request_provider_data(headers) provider_data = parse_request_provider_data(headers)
return RequestProviderDataContext(provider_data, auth_attributes) return RequestProviderDataContext(provider_data, auth_attributes)
def get_auth_attributes() -> Optional[Dict[str, List[str]]]: def get_auth_attributes() -> dict[str, list[str]] | None:
"""Helper to retrieve auth attributes from the provider data context""" """Helper to retrieve auth attributes from the provider data context"""
provider_data = PROVIDER_DATA_VAR.get() provider_data = PROVIDER_DATA_VAR.get()
if not provider_data: if not provider_data:

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import importlib import importlib
import inspect import inspect
from typing import Any, Dict, List, Set, Tuple from typing import Any
from llama_stack.apis.agents import Agents from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.benchmarks import Benchmarks
@ -58,7 +58,7 @@ class InvalidProviderError(Exception):
pass pass
def api_protocol_map() -> Dict[Api, Any]: def api_protocol_map() -> dict[Api, Any]:
return { return {
Api.providers: ProvidersAPI, Api.providers: ProvidersAPI,
Api.agents: Agents, Api.agents: Agents,
@ -83,7 +83,7 @@ def api_protocol_map() -> Dict[Api, Any]:
} }
def additional_protocols_map() -> Dict[Api, Any]: def additional_protocols_map() -> dict[Api, Any]:
return { return {
Api.inference: (ModelsProtocolPrivate, Models, Api.models), Api.inference: (ModelsProtocolPrivate, Models, Api.models),
Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups), Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups),
@ -104,14 +104,14 @@ class ProviderWithSpec(Provider):
spec: ProviderSpec spec: ProviderSpec
ProviderRegistry = Dict[Api, Dict[str, ProviderSpec]] ProviderRegistry = dict[Api, dict[str, ProviderSpec]]
async def resolve_impls( async def resolve_impls(
run_config: StackRunConfig, run_config: StackRunConfig,
provider_registry: ProviderRegistry, provider_registry: ProviderRegistry,
dist_registry: DistributionRegistry, dist_registry: DistributionRegistry,
) -> Dict[Api, Any]: ) -> dict[Api, Any]:
""" """
Resolves provider implementations by: Resolves provider implementations by:
1. Validating and organizing providers. 1. Validating and organizing providers.
@ -136,7 +136,7 @@ async def resolve_impls(
return await instantiate_providers(sorted_providers, router_apis, dist_registry) return await instantiate_providers(sorted_providers, router_apis, dist_registry)
def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str, Dict[str, ProviderWithSpec]]: def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
"""Generates specifications for automatically routed APIs.""" """Generates specifications for automatically routed APIs."""
specs = {} specs = {}
for info in builtin_automatically_routed_apis(): for info in builtin_automatically_routed_apis():
@ -178,10 +178,10 @@ def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str,
def validate_and_prepare_providers( def validate_and_prepare_providers(
run_config: StackRunConfig, provider_registry: ProviderRegistry, routing_table_apis: Set[Api], router_apis: Set[Api] run_config: StackRunConfig, provider_registry: ProviderRegistry, routing_table_apis: set[Api], router_apis: set[Api]
) -> Dict[str, Dict[str, ProviderWithSpec]]: ) -> dict[str, dict[str, ProviderWithSpec]]:
"""Validates providers, handles deprecations, and organizes them into a spec dictionary.""" """Validates providers, handles deprecations, and organizes them into a spec dictionary."""
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]] = {} providers_with_specs: dict[str, dict[str, ProviderWithSpec]] = {}
for api_str, providers in run_config.providers.items(): for api_str, providers in run_config.providers.items():
api = Api(api_str) api = Api(api_str)
@ -222,10 +222,10 @@ def validate_provider(provider: Provider, api: Api, provider_registry: ProviderR
def sort_providers_by_deps( def sort_providers_by_deps(
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]], run_config: StackRunConfig providers_with_specs: dict[str, dict[str, ProviderWithSpec]], run_config: StackRunConfig
) -> List[Tuple[str, ProviderWithSpec]]: ) -> list[tuple[str, ProviderWithSpec]]:
"""Sorts providers based on their dependencies.""" """Sorts providers based on their dependencies."""
sorted_providers: List[Tuple[str, ProviderWithSpec]] = topological_sort( sorted_providers: list[tuple[str, ProviderWithSpec]] = topological_sort(
{k: list(v.values()) for k, v in providers_with_specs.items()} {k: list(v.values()) for k, v in providers_with_specs.items()}
) )
@ -236,11 +236,11 @@ def sort_providers_by_deps(
async def instantiate_providers( async def instantiate_providers(
sorted_providers: List[Tuple[str, ProviderWithSpec]], router_apis: Set[Api], dist_registry: DistributionRegistry sorted_providers: list[tuple[str, ProviderWithSpec]], router_apis: set[Api], dist_registry: DistributionRegistry
) -> Dict: ) -> dict:
"""Instantiates providers asynchronously while managing dependencies.""" """Instantiates providers asynchronously while managing dependencies."""
impls: Dict[Api, Any] = {} impls: dict[Api, Any] = {}
inner_impls_by_provider_id: Dict[str, Dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis} inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
for api_str, provider in sorted_providers: for api_str, provider in sorted_providers:
deps = {a: impls[a] for a in provider.spec.api_dependencies} deps = {a: impls[a] for a in provider.spec.api_dependencies}
for a in provider.spec.optional_api_dependencies: for a in provider.spec.optional_api_dependencies:
@ -263,9 +263,9 @@ async def instantiate_providers(
def topological_sort( def topological_sort(
providers_with_specs: Dict[str, List[ProviderWithSpec]], providers_with_specs: dict[str, list[ProviderWithSpec]],
) -> List[Tuple[str, ProviderWithSpec]]: ) -> list[tuple[str, ProviderWithSpec]]:
def dfs(kv, visited: Set[str], stack: List[str]): def dfs(kv, visited: set[str], stack: list[str]):
api_str, providers = kv api_str, providers = kv
visited.add(api_str) visited.add(api_str)
@ -280,8 +280,8 @@ def topological_sort(
stack.append(api_str) stack.append(api_str)
visited: Set[str] = set() visited: set[str] = set()
stack: List[str] = [] stack: list[str] = []
for api_str, providers in providers_with_specs.items(): for api_str, providers in providers_with_specs.items():
if api_str not in visited: if api_str not in visited:
@ -298,8 +298,8 @@ def topological_sort(
# returns a class implementing the protocol corresponding to the Api # returns a class implementing the protocol corresponding to the Api
async def instantiate_provider( async def instantiate_provider(
provider: ProviderWithSpec, provider: ProviderWithSpec,
deps: Dict[Api, Any], deps: dict[Api, Any],
inner_impls: Dict[str, Any], inner_impls: dict[str, Any],
dist_registry: DistributionRegistry, dist_registry: DistributionRegistry,
): ):
protocols = api_protocol_map() protocols = api_protocol_map()
@ -391,8 +391,8 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
async def resolve_remote_stack_impls( async def resolve_remote_stack_impls(
config: RemoteProviderConfig, config: RemoteProviderConfig,
apis: List[str], apis: list[str],
) -> Dict[Api, Any]: ) -> dict[Api, Any]:
protocols = api_protocol_map() protocols = api_protocol_map()
additional_protocols = additional_protocols_map() additional_protocols = additional_protocols_map()

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict from typing import Any
from llama_stack.distribution.datatypes import RoutedProtocol from llama_stack.distribution.datatypes import RoutedProtocol
from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.store import DistributionRegistry
@ -23,7 +23,7 @@ from .routing_tables import (
async def get_routing_table_impl( async def get_routing_table_impl(
api: Api, api: Api,
impls_by_provider_id: Dict[str, RoutedProtocol], impls_by_provider_id: dict[str, RoutedProtocol],
_deps, _deps,
dist_registry: DistributionRegistry, dist_registry: DistributionRegistry,
) -> Any: ) -> Any:
@ -45,7 +45,7 @@ async def get_routing_table_impl(
return impl return impl
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any: async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: dict[str, Any]) -> Any:
from .routers import ( from .routers import (
DatasetIORouter, DatasetIORouter,
EvalRouter, EvalRouter,

View file

@ -6,12 +6,12 @@
import asyncio import asyncio
import time import time
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from collections.abc import AsyncGenerator, AsyncIterator
from typing import Annotated, Any
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
from pydantic import Field, TypeAdapter from pydantic import Field, TypeAdapter
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
URL, URL,
@ -100,9 +100,9 @@ class VectorIORouter(VectorIO):
self, self,
vector_db_id: str, vector_db_id: str,
embedding_model: str, embedding_model: str,
embedding_dimension: Optional[int] = 384, embedding_dimension: int | None = 384,
provider_id: Optional[str] = None, provider_id: str | None = None,
provider_vector_db_id: Optional[str] = None, provider_vector_db_id: str | None = None,
) -> None: ) -> None:
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}") logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
await self.routing_table.register_vector_db( await self.routing_table.register_vector_db(
@ -116,8 +116,8 @@ class VectorIORouter(VectorIO):
async def insert_chunks( async def insert_chunks(
self, self,
vector_db_id: str, vector_db_id: str,
chunks: List[Chunk], chunks: list[Chunk],
ttl_seconds: Optional[int] = None, ttl_seconds: int | None = None,
) -> None: ) -> None:
logger.debug( logger.debug(
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
@ -128,7 +128,7 @@ class VectorIORouter(VectorIO):
self, self,
vector_db_id: str, vector_db_id: str,
query: InterleavedContent, query: InterleavedContent,
params: Optional[Dict[str, Any]] = None, params: dict[str, Any] | None = None,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}") logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params) return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
@ -140,7 +140,7 @@ class InferenceRouter(Inference):
def __init__( def __init__(
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
telemetry: Optional[Telemetry] = None, telemetry: Telemetry | None = None,
) -> None: ) -> None:
logger.debug("Initializing InferenceRouter") logger.debug("Initializing InferenceRouter")
self.routing_table = routing_table self.routing_table = routing_table
@ -160,10 +160,10 @@ class InferenceRouter(Inference):
async def register_model( async def register_model(
self, self,
model_id: str, model_id: str,
provider_model_id: Optional[str] = None, provider_model_id: str | None = None,
provider_id: Optional[str] = None, provider_id: str | None = None,
metadata: Optional[Dict[str, Any]] = None, metadata: dict[str, Any] | None = None,
model_type: Optional[ModelType] = None, model_type: ModelType | None = None,
) -> None: ) -> None:
logger.debug( logger.debug(
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}", f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
@ -176,7 +176,7 @@ class InferenceRouter(Inference):
completion_tokens: int, completion_tokens: int,
total_tokens: int, total_tokens: int,
model: Model, model: Model,
) -> List[MetricEvent]: ) -> list[MetricEvent]:
"""Constructs a list of MetricEvent objects containing token usage metrics. """Constructs a list of MetricEvent objects containing token usage metrics.
Args: Args:
@ -221,7 +221,7 @@ class InferenceRouter(Inference):
completion_tokens: int, completion_tokens: int,
total_tokens: int, total_tokens: int,
model: Model, model: Model,
) -> List[MetricInResponse]: ) -> list[MetricInResponse]:
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model) metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
if self.telemetry: if self.telemetry:
for metric in metrics: for metric in metrics:
@ -230,9 +230,9 @@ class InferenceRouter(Inference):
async def _count_tokens( async def _count_tokens(
self, self,
messages: List[Message] | InterleavedContent, messages: list[Message] | InterleavedContent,
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: ToolPromptFormat | None = None,
) -> Optional[int]: ) -> int | None:
if isinstance(messages, list): if isinstance(messages, list):
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format) encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
else: else:
@ -242,16 +242,16 @@ class InferenceRouter(Inference):
async def chat_completion( async def chat_completion(
self, self,
model_id: str, model_id: str,
messages: List[Message], messages: list[Message],
sampling_params: Optional[SamplingParams] = None, sampling_params: SamplingParams | None = None,
response_format: Optional[ResponseFormat] = None, response_format: ResponseFormat | None = None,
tools: Optional[List[ToolDefinition]] = None, tools: list[ToolDefinition] | None = None,
tool_choice: Optional[ToolChoice] = None, tool_choice: ToolChoice | None = None,
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: ToolPromptFormat | None = None,
stream: Optional[bool] = False, stream: bool | None = False,
logprobs: Optional[LogProbConfig] = None, logprobs: LogProbConfig | None = None,
tool_config: Optional[ToolConfig] = None, tool_config: ToolConfig | None = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: ) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
logger.debug( logger.debug(
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
) )
@ -351,12 +351,12 @@ class InferenceRouter(Inference):
async def batch_chat_completion( async def batch_chat_completion(
self, self,
model_id: str, model_id: str,
messages_batch: List[List[Message]], messages_batch: list[list[Message]],
tools: Optional[List[ToolDefinition]] = None, tools: list[ToolDefinition] | None = None,
tool_config: Optional[ToolConfig] = None, tool_config: ToolConfig | None = None,
sampling_params: Optional[SamplingParams] = None, sampling_params: SamplingParams | None = None,
response_format: Optional[ResponseFormat] = None, response_format: ResponseFormat | None = None,
logprobs: Optional[LogProbConfig] = None, logprobs: LogProbConfig | None = None,
) -> BatchChatCompletionResponse: ) -> BatchChatCompletionResponse:
logger.debug( logger.debug(
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}", f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
@ -376,10 +376,10 @@ class InferenceRouter(Inference):
self, self,
model_id: str, model_id: str,
content: InterleavedContent, content: InterleavedContent,
sampling_params: Optional[SamplingParams] = None, sampling_params: SamplingParams | None = None,
response_format: Optional[ResponseFormat] = None, response_format: ResponseFormat | None = None,
stream: Optional[bool] = False, stream: bool | None = False,
logprobs: Optional[LogProbConfig] = None, logprobs: LogProbConfig | None = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
if sampling_params is None: if sampling_params is None:
sampling_params = SamplingParams() sampling_params = SamplingParams()
@ -439,10 +439,10 @@ class InferenceRouter(Inference):
async def batch_completion( async def batch_completion(
self, self,
model_id: str, model_id: str,
content_batch: List[InterleavedContent], content_batch: list[InterleavedContent],
sampling_params: Optional[SamplingParams] = None, sampling_params: SamplingParams | None = None,
response_format: Optional[ResponseFormat] = None, response_format: ResponseFormat | None = None,
logprobs: Optional[LogProbConfig] = None, logprobs: LogProbConfig | None = None,
) -> BatchCompletionResponse: ) -> BatchCompletionResponse:
logger.debug( logger.debug(
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}", f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
@ -453,10 +453,10 @@ class InferenceRouter(Inference):
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[str] | List[InterleavedContentItem], contents: list[str] | list[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none, text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: Optional[int] = None, output_dimension: int | None = None,
task_type: Optional[EmbeddingTaskType] = None, task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
logger.debug(f"InferenceRouter.embeddings: {model_id}") logger.debug(f"InferenceRouter.embeddings: {model_id}")
model = await self.routing_table.get_model(model_id) model = await self.routing_table.get_model(model_id)
@ -475,24 +475,24 @@ class InferenceRouter(Inference):
async def openai_completion( async def openai_completion(
self, self,
model: str, model: str,
prompt: Union[str, List[str], List[int], List[List[int]]], prompt: str | list[str] | list[int] | list[list[int]],
best_of: Optional[int] = None, best_of: int | None = None,
echo: Optional[bool] = None, echo: bool | None = None,
frequency_penalty: Optional[float] = None, frequency_penalty: float | None = None,
logit_bias: Optional[Dict[str, float]] = None, logit_bias: dict[str, float] | None = None,
logprobs: Optional[bool] = None, logprobs: bool | None = None,
max_tokens: Optional[int] = None, max_tokens: int | None = None,
n: Optional[int] = None, n: int | None = None,
presence_penalty: Optional[float] = None, presence_penalty: float | None = None,
seed: Optional[int] = None, seed: int | None = None,
stop: Optional[Union[str, List[str]]] = None, stop: str | list[str] | None = None,
stream: Optional[bool] = None, stream: bool | None = None,
stream_options: Optional[Dict[str, Any]] = None, stream_options: dict[str, Any] | None = None,
temperature: Optional[float] = None, temperature: float | None = None,
top_p: Optional[float] = None, top_p: float | None = None,
user: Optional[str] = None, user: str | None = None,
guided_choice: Optional[List[str]] = None, guided_choice: list[str] | None = None,
prompt_logprobs: Optional[int] = None, prompt_logprobs: int | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
logger.debug( logger.debug(
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}", f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
@ -531,29 +531,29 @@ class InferenceRouter(Inference):
async def openai_chat_completion( async def openai_chat_completion(
self, self,
model: str, model: str,
messages: Annotated[List[OpenAIMessageParam], Field(..., min_length=1)], messages: Annotated[list[OpenAIMessageParam], Field(..., min_length=1)],
frequency_penalty: Optional[float] = None, frequency_penalty: float | None = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None, function_call: str | dict[str, Any] | None = None,
functions: Optional[List[Dict[str, Any]]] = None, functions: list[dict[str, Any]] | None = None,
logit_bias: Optional[Dict[str, float]] = None, logit_bias: dict[str, float] | None = None,
logprobs: Optional[bool] = None, logprobs: bool | None = None,
max_completion_tokens: Optional[int] = None, max_completion_tokens: int | None = None,
max_tokens: Optional[int] = None, max_tokens: int | None = None,
n: Optional[int] = None, n: int | None = None,
parallel_tool_calls: Optional[bool] = None, parallel_tool_calls: bool | None = None,
presence_penalty: Optional[float] = None, presence_penalty: float | None = None,
response_format: Optional[OpenAIResponseFormatParam] = None, response_format: OpenAIResponseFormatParam | None = None,
seed: Optional[int] = None, seed: int | None = None,
stop: Optional[Union[str, List[str]]] = None, stop: str | list[str] | None = None,
stream: Optional[bool] = None, stream: bool | None = None,
stream_options: Optional[Dict[str, Any]] = None, stream_options: dict[str, Any] | None = None,
temperature: Optional[float] = None, temperature: float | None = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None, tool_choice: str | dict[str, Any] | None = None,
tools: Optional[List[Dict[str, Any]]] = None, tools: list[dict[str, Any]] | None = None,
top_logprobs: Optional[int] = None, top_logprobs: int | None = None,
top_p: Optional[float] = None, top_p: float | None = None,
user: Optional[str] = None, user: str | None = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
logger.debug( logger.debug(
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}", f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
) )
@ -602,7 +602,7 @@ class InferenceRouter(Inference):
provider = self.routing_table.get_provider_impl(model_obj.identifier) provider = self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.openai_chat_completion(**params) return await provider.openai_chat_completion(**params)
async def health(self) -> Dict[str, HealthResponse]: async def health(self) -> dict[str, HealthResponse]:
health_statuses = {} health_statuses = {}
timeout = 0.5 timeout = 0.5
for provider_id, impl in self.routing_table.impls_by_provider_id.items(): for provider_id, impl in self.routing_table.impls_by_provider_id.items():
@ -645,9 +645,9 @@ class SafetyRouter(Safety):
async def register_shield( async def register_shield(
self, self,
shield_id: str, shield_id: str,
provider_shield_id: Optional[str] = None, provider_shield_id: str | None = None,
provider_id: Optional[str] = None, provider_id: str | None = None,
params: Optional[Dict[str, Any]] = None, params: dict[str, Any] | None = None,
) -> Shield: ) -> Shield:
logger.debug(f"SafetyRouter.register_shield: {shield_id}") logger.debug(f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
@ -655,8 +655,8 @@ class SafetyRouter(Safety):
async def run_shield( async def run_shield(
self, self,
shield_id: str, shield_id: str,
messages: List[Message], messages: list[Message],
params: Dict[str, Any] = None, params: dict[str, Any] = None,
) -> RunShieldResponse: ) -> RunShieldResponse:
logger.debug(f"SafetyRouter.run_shield: {shield_id}") logger.debug(f"SafetyRouter.run_shield: {shield_id}")
return await self.routing_table.get_provider_impl(shield_id).run_shield( return await self.routing_table.get_provider_impl(shield_id).run_shield(
@ -686,8 +686,8 @@ class DatasetIORouter(DatasetIO):
self, self,
purpose: DatasetPurpose, purpose: DatasetPurpose,
source: DataSource, source: DataSource,
metadata: Optional[Dict[str, Any]] = None, metadata: dict[str, Any] | None = None,
dataset_id: Optional[str] = None, dataset_id: str | None = None,
) -> None: ) -> None:
logger.debug( logger.debug(
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}", f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
@ -702,8 +702,8 @@ class DatasetIORouter(DatasetIO):
async def iterrows( async def iterrows(
self, self,
dataset_id: str, dataset_id: str,
start_index: Optional[int] = None, start_index: int | None = None,
limit: Optional[int] = None, limit: int | None = None,
) -> PaginatedResponse: ) -> PaginatedResponse:
logger.debug( logger.debug(
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}", f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
@ -714,7 +714,7 @@ class DatasetIORouter(DatasetIO):
limit=limit, limit=limit,
) )
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows") logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
return await self.routing_table.get_provider_impl(dataset_id).append_rows( return await self.routing_table.get_provider_impl(dataset_id).append_rows(
dataset_id=dataset_id, dataset_id=dataset_id,
@ -741,7 +741,7 @@ class ScoringRouter(Scoring):
async def score_batch( async def score_batch(
self, self,
dataset_id: str, dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, scoring_functions: dict[str, ScoringFnParams | None] = None,
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
logger.debug(f"ScoringRouter.score_batch: {dataset_id}") logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
@ -762,8 +762,8 @@ class ScoringRouter(Scoring):
async def score( async def score(
self, self,
input_rows: List[Dict[str, Any]], input_rows: list[dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, scoring_functions: dict[str, ScoringFnParams | None] = None,
) -> ScoreResponse: ) -> ScoreResponse:
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions") logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
res = {} res = {}
@ -808,8 +808,8 @@ class EvalRouter(Eval):
async def evaluate_rows( async def evaluate_rows(
self, self,
benchmark_id: str, benchmark_id: str,
input_rows: List[Dict[str, Any]], input_rows: list[dict[str, Any]],
scoring_functions: List[str], scoring_functions: list[str],
benchmark_config: BenchmarkConfig, benchmark_config: BenchmarkConfig,
) -> EvaluateResponse: ) -> EvaluateResponse:
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows") logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
@ -863,8 +863,8 @@ class ToolRuntimeRouter(ToolRuntime):
async def query( async def query(
self, self,
content: InterleavedContent, content: InterleavedContent,
vector_db_ids: List[str], vector_db_ids: list[str],
query_config: Optional[RAGQueryConfig] = None, query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult: ) -> RAGQueryResult:
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}") logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
return await self.routing_table.get_provider_impl("knowledge_search").query( return await self.routing_table.get_provider_impl("knowledge_search").query(
@ -873,7 +873,7 @@ class ToolRuntimeRouter(ToolRuntime):
async def insert( async def insert(
self, self,
documents: List[RAGDocument], documents: list[RAGDocument],
vector_db_id: str, vector_db_id: str,
chunk_size_in_tokens: int = 512, chunk_size_in_tokens: int = 512,
) -> None: ) -> None:
@ -904,7 +904,7 @@ class ToolRuntimeRouter(ToolRuntime):
logger.debug("ToolRuntimeRouter.shutdown") logger.debug("ToolRuntimeRouter.shutdown")
pass pass
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any: async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}") logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
return await self.routing_table.get_provider_impl(tool_name).invoke_tool( return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
tool_name=tool_name, tool_name=tool_name,
@ -912,7 +912,7 @@ class ToolRuntimeRouter(ToolRuntime):
) )
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse: ) -> ListToolDefsResponse:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint) return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)

View file

@ -7,7 +7,7 @@
import logging import logging
import time import time
import uuid import uuid
from typing import Any, Dict, List, Optional from typing import Any
from pydantic import TypeAdapter from pydantic import TypeAdapter
@ -106,20 +106,20 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
raise ValueError(f"Unregister not supported for {api}") raise ValueError(f"Unregister not supported for {api}")
Registry = Dict[str, List[RoutableObjectWithProvider]] Registry = dict[str, list[RoutableObjectWithProvider]]
class CommonRoutingTableImpl(RoutingTable): class CommonRoutingTableImpl(RoutingTable):
def __init__( def __init__(
self, self,
impls_by_provider_id: Dict[str, RoutedProtocol], impls_by_provider_id: dict[str, RoutedProtocol],
dist_registry: DistributionRegistry, dist_registry: DistributionRegistry,
) -> None: ) -> None:
self.impls_by_provider_id = impls_by_provider_id self.impls_by_provider_id = impls_by_provider_id
self.dist_registry = dist_registry self.dist_registry = dist_registry
async def initialize(self) -> None: async def initialize(self) -> None:
async def add_objects(objs: List[RoutableObjectWithProvider], provider_id: str, cls) -> None: async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None:
for obj in objs: for obj in objs:
if cls is None: if cls is None:
obj.provider_id = provider_id obj.provider_id = provider_id
@ -154,7 +154,7 @@ class CommonRoutingTableImpl(RoutingTable):
for p in self.impls_by_provider_id.values(): for p in self.impls_by_provider_id.values():
await p.shutdown() await p.shutdown()
def get_provider_impl(self, routing_key: str, provider_id: Optional[str] = None) -> Any: def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
def apiname_object(): def apiname_object():
if isinstance(self, ModelsRoutingTable): if isinstance(self, ModelsRoutingTable):
return ("Inference", "model") return ("Inference", "model")
@ -192,7 +192,7 @@ class CommonRoutingTableImpl(RoutingTable):
raise ValueError(f"Provider not found for `{routing_key}`") raise ValueError(f"Provider not found for `{routing_key}`")
async def get_object_by_identifier(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]: async def get_object_by_identifier(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
# Get from disk registry # Get from disk registry
obj = await self.dist_registry.get(type, identifier) obj = await self.dist_registry.get(type, identifier)
if not obj: if not obj:
@ -236,7 +236,7 @@ class CommonRoutingTableImpl(RoutingTable):
await self.dist_registry.register(obj) await self.dist_registry.register(obj)
return obj return obj
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]: async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
objs = await self.dist_registry.get_all() objs = await self.dist_registry.get_all()
filtered_objs = [obj for obj in objs if obj.type == type] filtered_objs = [obj for obj in objs if obj.type == type]
@ -277,10 +277,10 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def register_model( async def register_model(
self, self,
model_id: str, model_id: str,
provider_model_id: Optional[str] = None, provider_model_id: str | None = None,
provider_id: Optional[str] = None, provider_id: str | None = None,
metadata: Optional[Dict[str, Any]] = None, metadata: dict[str, Any] | None = None,
model_type: Optional[ModelType] = None, model_type: ModelType | None = None,
) -> Model: ) -> Model:
if provider_model_id is None: if provider_model_id is None:
provider_model_id = model_id provider_model_id = model_id
@ -328,9 +328,9 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def register_shield( async def register_shield(
self, self,
shield_id: str, shield_id: str,
provider_shield_id: Optional[str] = None, provider_shield_id: str | None = None,
provider_id: Optional[str] = None, provider_id: str | None = None,
params: Optional[Dict[str, Any]] = None, params: dict[str, Any] | None = None,
) -> Shield: ) -> Shield:
if provider_shield_id is None: if provider_shield_id is None:
provider_shield_id = shield_id provider_shield_id = shield_id
@ -368,9 +368,9 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
self, self,
vector_db_id: str, vector_db_id: str,
embedding_model: str, embedding_model: str,
embedding_dimension: Optional[int] = 384, embedding_dimension: int | None = 384,
provider_id: Optional[str] = None, provider_id: str | None = None,
provider_vector_db_id: Optional[str] = None, provider_vector_db_id: str | None = None,
) -> VectorDB: ) -> VectorDB:
if provider_vector_db_id is None: if provider_vector_db_id is None:
provider_vector_db_id = vector_db_id provider_vector_db_id = vector_db_id
@ -423,8 +423,8 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
self, self,
purpose: DatasetPurpose, purpose: DatasetPurpose,
source: DataSource, source: DataSource,
metadata: Optional[Dict[str, Any]] = None, metadata: dict[str, Any] | None = None,
dataset_id: Optional[str] = None, dataset_id: str | None = None,
) -> Dataset: ) -> Dataset:
if isinstance(source, dict): if isinstance(source, dict):
if source["type"] == "uri": if source["type"] == "uri":
@ -489,9 +489,9 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
scoring_fn_id: str, scoring_fn_id: str,
description: str, description: str,
return_type: ParamType, return_type: ParamType,
provider_scoring_fn_id: Optional[str] = None, provider_scoring_fn_id: str | None = None,
provider_id: Optional[str] = None, provider_id: str | None = None,
params: Optional[ScoringFnParams] = None, params: ScoringFnParams | None = None,
) -> None: ) -> None:
if provider_scoring_fn_id is None: if provider_scoring_fn_id is None:
provider_scoring_fn_id = scoring_fn_id provider_scoring_fn_id = scoring_fn_id
@ -528,10 +528,10 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
self, self,
benchmark_id: str, benchmark_id: str,
dataset_id: str, dataset_id: str,
scoring_functions: List[str], scoring_functions: list[str],
metadata: Optional[Dict[str, Any]] = None, metadata: dict[str, Any] | None = None,
provider_benchmark_id: Optional[str] = None, provider_benchmark_id: str | None = None,
provider_id: Optional[str] = None, provider_id: str | None = None,
) -> None: ) -> None:
if metadata is None: if metadata is None:
metadata = {} metadata = {}
@ -556,7 +556,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse: async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
tools = await self.get_all_with_type("tool") tools = await self.get_all_with_type("tool")
if toolgroup_id: if toolgroup_id:
tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id] tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id]
@ -578,8 +578,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
self, self,
toolgroup_id: str, toolgroup_id: str,
provider_id: str, provider_id: str,
mcp_endpoint: Optional[URL] = None, mcp_endpoint: URL | None = None,
args: Optional[Dict[str, Any]] = None, args: dict[str, Any] | None = None,
) -> None: ) -> None:
tools = [] tools = []
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint) tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)

View file

@ -7,7 +7,6 @@
import json import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from typing import Dict, List, Optional
from urllib.parse import parse_qs from urllib.parse import parse_qs
import httpx import httpx
@ -22,7 +21,7 @@ logger = get_logger(name=__name__, category="auth")
class AuthResponse(BaseModel): class AuthResponse(BaseModel):
"""The format of the authentication response from the auth endpoint.""" """The format of the authentication response from the auth endpoint."""
access_attributes: Optional[AccessAttributes] = Field( access_attributes: AccessAttributes | None = Field(
default=None, default=None,
description=""" description="""
Structured user attributes for attribute-based access control. Structured user attributes for attribute-based access control.
@ -44,7 +43,7 @@ class AuthResponse(BaseModel):
""", """,
) )
message: Optional[str] = Field( message: str | None = Field(
default=None, description="Optional message providing additional context about the authentication result." default=None, description="Optional message providing additional context about the authentication result."
) )
@ -52,9 +51,9 @@ class AuthResponse(BaseModel):
class AuthRequestContext(BaseModel): class AuthRequestContext(BaseModel):
path: str = Field(description="The path of the request being authenticated") path: str = Field(description="The path of the request being authenticated")
headers: Dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)") headers: dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
params: Dict[str, List[str]] = Field( params: dict[str, list[str]] = Field(
description="Query parameters from the original request, parsed as dictionary of lists" description="Query parameters from the original request, parsed as dictionary of lists"
) )
@ -76,14 +75,14 @@ class AuthProviderConfig(BaseModel):
"""Base configuration for authentication providers.""" """Base configuration for authentication providers."""
provider_type: AuthProviderType = Field(..., description="Type of authentication provider") provider_type: AuthProviderType = Field(..., description="Type of authentication provider")
config: Dict[str, str] = Field(..., description="Provider-specific configuration") config: dict[str, str] = Field(..., description="Provider-specific configuration")
class AuthProvider(ABC): class AuthProvider(ABC):
"""Abstract base class for authentication providers.""" """Abstract base class for authentication providers."""
@abstractmethod @abstractmethod
async def validate_token(self, token: str, scope: Optional[Dict] = None) -> Optional[AccessAttributes]: async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
"""Validate a token and return access attributes.""" """Validate a token and return access attributes."""
pass pass
@ -96,7 +95,7 @@ class AuthProvider(ABC):
class KubernetesAuthProvider(AuthProvider): class KubernetesAuthProvider(AuthProvider):
"""Kubernetes authentication provider that validates tokens against the Kubernetes API server.""" """Kubernetes authentication provider that validates tokens against the Kubernetes API server."""
def __init__(self, config: Dict[str, str]): def __init__(self, config: dict[str, str]):
self.api_server_url = config["api_server_url"] self.api_server_url = config["api_server_url"]
self.ca_cert_path = config.get("ca_cert_path") self.ca_cert_path = config.get("ca_cert_path")
self._client = None self._client = None
@ -120,7 +119,7 @@ class KubernetesAuthProvider(AuthProvider):
self._client = ApiClient(configuration) self._client = ApiClient(configuration)
return self._client return self._client
async def validate_token(self, token: str, scope: Optional[Dict] = None) -> Optional[AccessAttributes]: async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
"""Validate a Kubernetes token and return access attributes.""" """Validate a Kubernetes token and return access attributes."""
try: try:
client = await self._get_client() client = await self._get_client()
@ -166,11 +165,11 @@ class KubernetesAuthProvider(AuthProvider):
class CustomAuthProvider(AuthProvider): class CustomAuthProvider(AuthProvider):
"""Custom authentication provider that uses an external endpoint.""" """Custom authentication provider that uses an external endpoint."""
def __init__(self, config: Dict[str, str]): def __init__(self, config: dict[str, str]):
self.endpoint = config["endpoint"] self.endpoint = config["endpoint"]
self._client = None self._client = None
async def validate_token(self, token: str, scope: Optional[Dict] = None) -> Optional[AccessAttributes]: async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
"""Validate a token using the custom authentication endpoint.""" """Validate a token using the custom authentication endpoint."""
if not self.endpoint: if not self.endpoint:
raise ValueError("Authentication endpoint not configured") raise ValueError("Authentication endpoint not configured")

View file

@ -6,7 +6,6 @@
import inspect import inspect
import re import re
from typing import Dict, List
from pydantic import BaseModel from pydantic import BaseModel
@ -29,7 +28,7 @@ def toolgroup_protocol_map():
} }
def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]: def get_all_api_endpoints() -> dict[Api, list[ApiEndpoint]]:
apis = {} apis = {}
protocols = api_protocol_map() protocols = api_protocol_map()

View file

@ -15,7 +15,7 @@ import warnings
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from importlib.metadata import version as parse_version from importlib.metadata import version as parse_version
from pathlib import Path from pathlib import Path
from typing import Any, List, Optional, Union from typing import Annotated, Any
import yaml import yaml
from fastapi import Body, FastAPI, HTTPException, Request from fastapi import Body, FastAPI, HTTPException, Request
@ -24,7 +24,6 @@ from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from openai import BadRequestError from openai import BadRequestError
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from typing_extensions import Annotated
from llama_stack.distribution.datatypes import LoggingConfig, StackRunConfig from llama_stack.distribution.datatypes import LoggingConfig, StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.distribution import builtin_automatically_routed_apis
@ -91,7 +90,7 @@ async def global_exception_handler(request: Request, exc: Exception):
return JSONResponse(status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}}) return JSONResponse(status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}})
def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidationError]: def translate_exception(exc: Exception) -> HTTPException | RequestValidationError:
if isinstance(exc, ValidationError): if isinstance(exc, ValidationError):
exc = RequestValidationError(exc.errors()) exc = RequestValidationError(exc.errors())
@ -315,7 +314,7 @@ class ClientVersionMiddleware:
return await self.app(scope, receive, send) return await self.app(scope, receive, send)
def main(args: Optional[argparse.Namespace] = None): def main(args: argparse.Namespace | None = None):
"""Start the LlamaStack server.""" """Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.") parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
parser.add_argument( parser.add_argument(
@ -385,7 +384,7 @@ def main(args: Optional[argparse.Namespace] = None):
raise ValueError("Either --yaml-config or --template must be provided") raise ValueError("Either --yaml-config or --template must be provided")
logger_config = None logger_config = None
with open(config_file, "r") as fp: with open(config_file) as fp:
config_contents = yaml.safe_load(fp) config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")): if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg) logger_config = LoggingConfig(**cfg)
@ -517,7 +516,7 @@ def main(args: Optional[argparse.Namespace] = None):
uvicorn.run(**uvicorn_config) uvicorn.run(**uvicorn_config)
def extract_path_params(route: str) -> List[str]: def extract_path_params(route: str) -> list[str]:
segments = route.split("/") segments = route.split("/")
params = [seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")] params = [seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")]
# to handle path params like {param:path} # to handle path params like {param:path}

View file

@ -8,7 +8,7 @@ import importlib.resources
import os import os
import re import re
import tempfile import tempfile
from typing import Any, Dict, Optional from typing import Any
import yaml import yaml
@ -90,7 +90,7 @@ RESOURCES = [
] ]
async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]): async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
for rsrc, api, register_method, list_method in RESOURCES: for rsrc, api, register_method, list_method in RESOURCES:
objects = getattr(run_config, rsrc) objects = getattr(run_config, rsrc)
if api not in impls: if api not in impls:
@ -197,7 +197,7 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
) from e ) from e
def add_internal_implementations(impls: Dict[Api, Any], run_config: StackRunConfig) -> None: def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConfig) -> None:
"""Add internal implementations (inspect and providers) to the implementations dictionary. """Add internal implementations (inspect and providers) to the implementations dictionary.
Args: Args:
@ -220,8 +220,8 @@ def add_internal_implementations(impls: Dict[Api, Any], run_config: StackRunConf
# Produces a stack of providers for the given run config. Not all APIs may be # Produces a stack of providers for the given run config. Not all APIs may be
# asked for in the run config. # asked for in the run config.
async def construct_stack( async def construct_stack(
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
) -> Dict[Api, Any]: ) -> dict[Api, Any]:
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name) dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry) impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
@ -244,7 +244,7 @@ def get_stack_run_config_from_template(template: str) -> StackRunConfig:
def run_config_from_adhoc_config_spec( def run_config_from_adhoc_config_spec(
adhoc_config_spec: str, provider_registry: Optional[ProviderRegistry] = None adhoc_config_spec: str, provider_registry: ProviderRegistry | None = None
) -> StackRunConfig: ) -> StackRunConfig:
""" """
Create an adhoc distribution from a list of API providers. Create an adhoc distribution from a list of API providers.

View file

@ -6,7 +6,7 @@
import asyncio import asyncio
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Dict, List, Optional, Protocol, Tuple from typing import Protocol
import pydantic import pydantic
@ -20,13 +20,13 @@ logger = get_logger(__name__, category="core")
class DistributionRegistry(Protocol): class DistributionRegistry(Protocol):
async def get_all(self) -> List[RoutableObjectWithProvider]: ... async def get_all(self) -> list[RoutableObjectWithProvider]: ...
async def initialize(self) -> None: ... async def initialize(self) -> None: ...
async def get(self, identifier: str) -> Optional[RoutableObjectWithProvider]: ... async def get(self, identifier: str) -> RoutableObjectWithProvider | None: ...
def get_cached(self, identifier: str) -> Optional[RoutableObjectWithProvider]: ... def get_cached(self, identifier: str) -> RoutableObjectWithProvider | None: ...
async def update(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: ... async def update(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: ...
@ -40,13 +40,13 @@ KEY_VERSION = "v8"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}" KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
def _get_registry_key_range() -> Tuple[str, str]: def _get_registry_key_range() -> tuple[str, str]:
"""Returns the start and end keys for the registry range query.""" """Returns the start and end keys for the registry range query."""
start_key = f"{REGISTER_PREFIX}:{KEY_VERSION}" start_key = f"{REGISTER_PREFIX}:{KEY_VERSION}"
return start_key, f"{start_key}\xff" return start_key, f"{start_key}\xff"
def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider]: def _parse_registry_values(values: list[str]) -> list[RoutableObjectWithProvider]:
"""Utility function to parse registry values into RoutableObjectWithProvider objects.""" """Utility function to parse registry values into RoutableObjectWithProvider objects."""
all_objects = [] all_objects = []
for value in values: for value in values:
@ -67,16 +67,16 @@ class DiskDistributionRegistry(DistributionRegistry):
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
def get_cached(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]: def get_cached(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
# Disk registry does not have a cache # Disk registry does not have a cache
raise NotImplementedError("Disk registry does not have a cache") raise NotImplementedError("Disk registry does not have a cache")
async def get_all(self) -> List[RoutableObjectWithProvider]: async def get_all(self) -> list[RoutableObjectWithProvider]:
start_key, end_key = _get_registry_key_range() start_key, end_key = _get_registry_key_range()
values = await self.kvstore.range(start_key, end_key) values = await self.kvstore.range(start_key, end_key)
return _parse_registry_values(values) return _parse_registry_values(values)
async def get(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]: async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
json_str = await self.kvstore.get(KEY_FORMAT.format(type=type, identifier=identifier)) json_str = await self.kvstore.get(KEY_FORMAT.format(type=type, identifier=identifier))
if not json_str: if not json_str:
return None return None
@ -113,7 +113,7 @@ class DiskDistributionRegistry(DistributionRegistry):
class CachedDiskDistributionRegistry(DiskDistributionRegistry): class CachedDiskDistributionRegistry(DiskDistributionRegistry):
def __init__(self, kvstore: KVStore): def __init__(self, kvstore: KVStore):
super().__init__(kvstore) super().__init__(kvstore)
self.cache: Dict[Tuple[str, str], RoutableObjectWithProvider] = {} self.cache: dict[tuple[str, str], RoutableObjectWithProvider] = {}
self._initialized = False self._initialized = False
self._initialize_lock = asyncio.Lock() self._initialize_lock = asyncio.Lock()
self._cache_lock = asyncio.Lock() self._cache_lock = asyncio.Lock()
@ -147,15 +147,15 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
async def initialize(self) -> None: async def initialize(self) -> None:
await self._ensure_initialized() await self._ensure_initialized()
def get_cached(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]: def get_cached(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
return self.cache.get((type, identifier), None) return self.cache.get((type, identifier), None)
async def get_all(self) -> List[RoutableObjectWithProvider]: async def get_all(self) -> list[RoutableObjectWithProvider]:
await self._ensure_initialized() await self._ensure_initialized()
async with self._locked_cache() as cache: async with self._locked_cache() as cache:
return list(cache.values()) return list(cache.values())
async def get(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]: async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
await self._ensure_initialized() await self._ensure_initialized()
cache_key = (type, identifier) cache_key = (type, identifier)
@ -189,7 +189,7 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
async def create_dist_registry( async def create_dist_registry(
metadata_store: Optional[KVStoreConfig], metadata_store: KVStoreConfig | None,
image_name: str, image_name: str,
) -> tuple[CachedDiskDistributionRegistry, KVStore]: ) -> tuple[CachedDiskDistributionRegistry, KVStore]:
# instantiate kvstore for storing and retrieving distribution metadata # instantiate kvstore for storing and retrieving distribution metadata

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import os import os
from typing import Optional
from llama_stack_client import LlamaStackClient from llama_stack_client import LlamaStackClient
@ -23,7 +22,7 @@ class LlamaStackApi:
}, },
) )
def run_scoring(self, row, scoring_function_ids: list[str], scoring_params: Optional[dict]): def run_scoring(self, row, scoring_function_ids: list[str], scoring_params: dict | None):
"""Run scoring on a single row""" """Run scoring on a single row"""
if not scoring_params: if not scoring_params:
scoring_params = {fn_id: None for fn_id in scoring_function_ids} scoring_params = {fn_id: None for fn_id in scoring_function_ids}

View file

@ -4,10 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict from typing import Any
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]: def redact_sensitive_fields(data: dict[str, Any]) -> dict[str, Any]:
"""Redact sensitive information from config before printing.""" """Redact sensitive information from config before printing."""
sensitive_patterns = ["api_key", "api_token", "password", "secret"] sensitive_patterns = ["api_key", "api_token", "password", "secret"]
@ -18,7 +18,7 @@ def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
return [_redact_value(i) for i in v] return [_redact_value(i) for i in v]
return v return v
def _redact_dict(d: Dict[str, Any]) -> Dict[str, Any]: def _redact_dict(d: dict[str, Any]) -> dict[str, Any]:
result = {} result = {}
for k, v in d.items(): for k, v in d.items():
if any(pattern in k.lower() for pattern in sensitive_patterns): if any(pattern in k.lower() for pattern in sensitive_patterns):

View file

@ -4,14 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import AsyncGenerator
from contextvars import ContextVar from contextvars import ContextVar
from typing import AsyncGenerator, List, TypeVar from typing import TypeVar
T = TypeVar("T") T = TypeVar("T")
def preserve_contexts_async_generator( def preserve_contexts_async_generator(
gen: AsyncGenerator[T, None], context_vars: List[ContextVar] gen: AsyncGenerator[T, None], context_vars: list[ContextVar]
) -> AsyncGenerator[T, None]: ) -> AsyncGenerator[T, None]:
""" """
Wraps an async generator to preserve context variables across iterations. Wraps an async generator to preserve context variables across iterations.

View file

@ -8,12 +8,11 @@ import inspect
import json import json
import logging import logging
from enum import Enum from enum import Enum
from typing import Any, List, Literal, Optional, Type, Union, get_args, get_origin from typing import Annotated, Any, Literal, Union, get_args, get_origin
from pydantic import BaseModel from pydantic import BaseModel
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefinedType from pydantic_core import PydanticUndefinedType
from typing_extensions import Annotated
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -21,7 +20,7 @@ log = logging.getLogger(__name__)
def is_list_of_primitives(field_type): def is_list_of_primitives(field_type):
"""Check if a field type is a List of primitive types.""" """Check if a field type is a List of primitive types."""
origin = get_origin(field_type) origin = get_origin(field_type)
if origin is List or origin is list: if origin is list or origin is list:
args = get_args(field_type) args = get_args(field_type)
if len(args) == 1 and args[0] in (int, float, str, bool): if len(args) == 1 and args[0] in (int, float, str, bool):
return True return True
@ -53,7 +52,7 @@ def get_non_none_type(field_type):
return next(arg for arg in get_args(field_type) if arg is not type(None)) return next(arg for arg in get_args(field_type) if arg is not type(None))
def manually_validate_field(model: Type[BaseModel], field_name: str, value: Any): def manually_validate_field(model: type[BaseModel], field_name: str, value: Any):
validators = model.__pydantic_decorators__.field_validators validators = model.__pydantic_decorators__.field_validators
for _name, validator in validators.items(): for _name, validator in validators.items():
if field_name in validator.info.fields: if field_name in validator.info.fields:
@ -126,7 +125,7 @@ def prompt_for_discriminated_union(
# #
# doesn't support List[nested_class] yet or Dicts of any kind. needs a bunch of # doesn't support List[nested_class] yet or Dicts of any kind. needs a bunch of
# unit tests for coverage. # unit tests for coverage.
def prompt_for_config(config_type: type[BaseModel], existing_config: Optional[BaseModel] = None) -> BaseModel: def prompt_for_config(config_type: type[BaseModel], existing_config: BaseModel | None = None) -> BaseModel:
""" """
Recursively prompt the user for configuration values based on a Pydantic BaseModel. Recursively prompt the user for configuration values based on a Pydantic BaseModel.

View file

@ -7,7 +7,6 @@
import logging import logging
import os import os
from logging.config import dictConfig from logging.config import dictConfig
from typing import Dict, Optional
from rich.console import Console from rich.console import Console
from rich.errors import MarkupError from rich.errors import MarkupError
@ -33,7 +32,7 @@ CATEGORIES = [
] ]
# Initialize category levels with default level # Initialize category levels with default level
_category_levels: Dict[str, int] = {category: DEFAULT_LOG_LEVEL for category in CATEGORIES} _category_levels: dict[str, int] = {category: DEFAULT_LOG_LEVEL for category in CATEGORIES}
def config_to_category_levels(category: str, level: str): def config_to_category_levels(category: str, level: str):
@ -49,7 +48,7 @@ def config_to_category_levels(category: str, level: str):
Dict[str, int]: A dictionary mapping categories to their log levels. Dict[str, int]: A dictionary mapping categories to their log levels.
""" """
category_levels: Dict[str, int] = {} category_levels: dict[str, int] = {}
level_value = logging._nameToLevel.get(str(level).upper()) level_value = logging._nameToLevel.get(str(level).upper())
if level_value is None: if level_value is None:
logging.warning(f"Unknown log level '{level}' for category '{category}'. Falling back to default 'INFO'.") logging.warning(f"Unknown log level '{level}' for category '{category}'. Falling back to default 'INFO'.")
@ -69,7 +68,7 @@ def config_to_category_levels(category: str, level: str):
return category_levels return category_levels
def parse_yaml_config(yaml_config: LoggingConfig) -> Dict[str, int]: def parse_yaml_config(yaml_config: LoggingConfig) -> dict[str, int]:
""" """
Helper function to parse a yaml logging configuration found in the run.yaml Helper function to parse a yaml logging configuration found in the run.yaml
@ -86,7 +85,7 @@ def parse_yaml_config(yaml_config: LoggingConfig) -> Dict[str, int]:
return category_levels return category_levels
def parse_environment_config(env_config: str) -> Dict[str, int]: def parse_environment_config(env_config: str) -> dict[str, int]:
""" """
Parse the LLAMA_STACK_LOGGING environment variable and return a dictionary of category log levels. Parse the LLAMA_STACK_LOGGING environment variable and return a dictionary of category log levels.
@ -131,7 +130,7 @@ class CustomRichHandler(RichHandler):
self.markup = original_markup self.markup = original_markup
def setup_logging(category_levels: Dict[str, int], log_file: str | None) -> None: def setup_logging(category_levels: dict[str, int], log_file: str | None) -> None:
""" """
Configure logging based on the provided category log levels and an optional log file. Configure logging based on the provided category log levels and an optional log file.
@ -211,7 +210,7 @@ def setup_logging(category_levels: Dict[str, int], log_file: str | None) -> None
def get_logger( def get_logger(
name: str, category: str = "uncategorized", config: Optional[LoggingConfig] | None = None name: str, category: str = "uncategorized", config: LoggingConfig | None | None = None
) -> logging.LoggerAdapter: ) -> logging.LoggerAdapter:
""" """
Returns a logger with the specified name and category. Returns a logger with the specified name and category.

View file

@ -7,14 +7,14 @@
import concurrent.futures import concurrent.futures
import re import re
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Union from typing import Any
import numpy as np import numpy as np
import torch import torch
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank, get_model_parallel_world_size from fairscale.nn.model_parallel.initialize import get_model_parallel_rank, get_model_parallel_world_size
def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> List[int]: def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> list[int]:
"""Map a new MP rank to a list of old MP ranks given a change in MP size.""" """Map a new MP rank to a list of old MP ranks given a change in MP size."""
if new_mp_size % old_mp_size == 0: if new_mp_size % old_mp_size == 0:
# Read old MP shard and split it into smaller ones # Read old MP shard and split it into smaller ones
@ -31,12 +31,12 @@ def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> List[in
def maybe_reshard_state_dict( def maybe_reshard_state_dict(
ckpt_paths: List[Path], ckpt_paths: list[Path],
n_kv_heads: int, n_kv_heads: int,
moe_num_experts: Optional[int] = None, moe_num_experts: int | None = None,
map_location: Union[str, torch.device] = "cpu", map_location: str | torch.device = "cpu",
mmap: bool = True, mmap: bool = True,
) -> Dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
if str(map_location) == "cpu": if str(map_location) == "cpu":
torch.set_default_tensor_type(torch.BFloat16Tensor) torch.set_default_tensor_type(torch.BFloat16Tensor)
else: else:
@ -97,18 +97,18 @@ _MOE_WEIGHT_COLUMN_KEY = {"feed_forward.experts.moe_w_out_eF_D"}
def reshard_mp( def reshard_mp(
state_dicts: List[Dict[str, torch.Tensor]], state_dicts: list[dict[str, torch.Tensor]],
size: int, size: int,
rank: int, rank: int,
repeat_qk_qv: int = 1, repeat_qk_qv: int = 1,
) -> Dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
""" """
Reshard a list of state dicts into a single state dict given a change in MP size. Reshard a list of state dicts into a single state dict given a change in MP size.
If the list has more than one state dict, we concatenate the values of the same If the list has more than one state dict, we concatenate the values of the same
key across all state dicts. Otherwise, we just slice it for the current MP rank. key across all state dicts. Otherwise, we just slice it for the current MP rank.
""" """
def concat_or_chunk(tensors: List[torch.Tensor], dim: int) -> torch.Tensor: def concat_or_chunk(tensors: list[torch.Tensor], dim: int) -> torch.Tensor:
if len(tensors) > 1: if len(tensors) > 1:
return torch.cat(tensors, dim=dim) return torch.cat(tensors, dim=dim)
return tensors[0].chunk(size, dim=dim)[rank].clone() return tensors[0].chunk(size, dim=dim)[rank].clone()
@ -144,7 +144,7 @@ def reshard_mp(
column_regex = re.compile("|".join(column_keys)) column_regex = re.compile("|".join(column_keys))
row_regex = re.compile("|".join(row_keys)) row_regex = re.compile("|".join(row_keys))
output: Dict[str, torch.Tensor] = {} output: dict[str, torch.Tensor] = {}
with concurrent.futures.ThreadPoolExecutor() as executor: with concurrent.futures.ThreadPoolExecutor() as executor:
# Note: only processes keys in the first state dict. # Note: only processes keys in the first state dict.
# Assumes keys are the same across all state dicts. # Assumes keys are the same across all state dicts.
@ -154,7 +154,7 @@ def reshard_mp(
return output return output
def convert_moe_weights(state_dict: Dict[str, Any], num_experts: int) -> Dict[str, Any]: def convert_moe_weights(state_dict: dict[str, Any], num_experts: int) -> dict[str, Any]:
routed_keys = _MOE_WEIGHT_ROW_KEY | _MOE_WEIGHT_COLUMN_KEY routed_keys = _MOE_WEIGHT_ROW_KEY | _MOE_WEIGHT_COLUMN_KEY
routed_regex = re.compile("|".join(routed_keys)) routed_regex = re.compile("|".join(routed_keys))
keys = list(state_dict.keys()) keys = list(state_dict.keys())

View file

@ -7,10 +7,9 @@
import base64 import base64
from enum import Enum from enum import Enum
from io import BytesIO from io import BytesIO
from typing import Any, Dict, List, Literal, Optional, Union from typing import Annotated, Any, Literal
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
from typing_extensions import Annotated
# The goal is that these set of types are relevant for all Llama models. # The goal is that these set of types are relevant for all Llama models.
# That isn't the current state yet -- e.g., BuiltinTool is somewhat specific to # That isn't the current state yet -- e.g., BuiltinTool is somewhat specific to
@ -31,21 +30,21 @@ class BuiltinTool(Enum):
code_interpreter = "code_interpreter" code_interpreter = "code_interpreter"
Primitive = Union[str, int, float, bool, None] Primitive = str | int | float | bool | None
RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]] RecursiveType = Primitive | list[Primitive] | dict[str, Primitive]
class ToolCall(BaseModel): class ToolCall(BaseModel):
call_id: str call_id: str
tool_name: Union[BuiltinTool, str] tool_name: BuiltinTool | str
# Plan is to deprecate the Dict in favor of a JSON string # Plan is to deprecate the Dict in favor of a JSON string
# that is parsed on the client side instead of trying to manage # that is parsed on the client side instead of trying to manage
# the recursive type here. # the recursive type here.
# Making this a union so that client side can start prepping for this change. # Making this a union so that client side can start prepping for this change.
# Eventually, we will remove both the Dict and arguments_json field, # Eventually, we will remove both the Dict and arguments_json field,
# and arguments will just be a str # and arguments will just be a str
arguments: Union[str, Dict[str, RecursiveType]] arguments: str | dict[str, RecursiveType]
arguments_json: Optional[str] = None arguments_json: str | None = None
@field_validator("tool_name", mode="before") @field_validator("tool_name", mode="before")
@classmethod @classmethod
@ -91,15 +90,15 @@ class StopReason(Enum):
class ToolParamDefinition(BaseModel): class ToolParamDefinition(BaseModel):
param_type: str param_type: str
description: Optional[str] = None description: str | None = None
required: Optional[bool] = True required: bool | None = True
default: Optional[Any] = None default: Any | None = None
class ToolDefinition(BaseModel): class ToolDefinition(BaseModel):
tool_name: Union[BuiltinTool, str] tool_name: BuiltinTool | str
description: Optional[str] = None description: str | None = None
parameters: Optional[Dict[str, ToolParamDefinition]] = None parameters: dict[str, ToolParamDefinition] | None = None
@field_validator("tool_name", mode="before") @field_validator("tool_name", mode="before")
@classmethod @classmethod
@ -119,7 +118,7 @@ class RawMediaItem(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
@field_serializer("data") @field_serializer("data")
def serialize_data(self, data: Optional[bytes], _info): def serialize_data(self, data: bytes | None, _info):
if data is None: if data is None:
return None return None
return base64.b64encode(data).decode("utf-8") return base64.b64encode(data).decode("utf-8")
@ -137,9 +136,9 @@ class RawTextItem(BaseModel):
text: str text: str
RawContentItem = Annotated[Union[RawTextItem, RawMediaItem], Field(discriminator="type")] RawContentItem = Annotated[RawTextItem | RawMediaItem, Field(discriminator="type")]
RawContent = str | RawContentItem | List[RawContentItem] RawContent = str | RawContentItem | list[RawContentItem]
class RawMessage(BaseModel): class RawMessage(BaseModel):
@ -147,17 +146,17 @@ class RawMessage(BaseModel):
content: RawContent content: RawContent
# This is for RAG but likely should be absorbed into content # This is for RAG but likely should be absorbed into content
context: Optional[RawContent] = None context: RawContent | None = None
# These are for the output message coming from the assistant # These are for the output message coming from the assistant
stop_reason: Optional[StopReason] = None stop_reason: StopReason | None = None
tool_calls: List[ToolCall] = Field(default_factory=list) tool_calls: list[ToolCall] = Field(default_factory=list)
class GenerationResult(BaseModel): class GenerationResult(BaseModel):
token: int token: int
text: str text: str
logprobs: Optional[List[float]] = None logprobs: list[float] | None = None
source: Literal["input"] | Literal["output"] source: Literal["input"] | Literal["output"]

View file

@ -6,7 +6,6 @@
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Optional
class QuantizationScheme(Enum): class QuantizationScheme(Enum):
@ -15,8 +14,8 @@ class QuantizationScheme(Enum):
@dataclass @dataclass
class QuantizationArgs: class QuantizationArgs:
scheme: Optional[QuantizationScheme] = None scheme: QuantizationScheme | None = None
group_size: Optional[int] = None group_size: int | None = None
spinquant: bool = False spinquant: bool = False
def __init__(self, **kwargs): def __init__(self, **kwargs):
@ -39,10 +38,10 @@ class ModelArgs:
dim: int = 4096 dim: int = 4096
n_layers: int = 32 n_layers: int = 32
n_heads: int = 32 n_heads: int = 32
n_kv_heads: Optional[int] = None n_kv_heads: int | None = None
vocab_size: int = -1 vocab_size: int = -1
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
ffn_dim_multiplier: Optional[float] = None ffn_dim_multiplier: float | None = None
norm_eps: float = 1e-5 norm_eps: float = 1e-5
rope_theta: float = 500000 rope_theta: float = 500000
use_scaled_rope: bool = False use_scaled_rope: bool = False
@ -55,8 +54,8 @@ class ModelArgs:
vision_max_num_chunks: int = 4 vision_max_num_chunks: int = 4
vision_num_cross_attention_layers: int = -1 vision_num_cross_attention_layers: int = -1
quantization_args: Optional[QuantizationArgs] = None quantization_args: QuantizationArgs | None = None
lora_args: Optional[LoRAArgs] = None lora_args: LoRAArgs | None = None
def __init__(self, **kwargs): def __init__(self, **kwargs):
for k, v in kwargs.items(): for k, v in kwargs.items():

View file

@ -8,7 +8,6 @@ import io
import json import json
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from PIL import Image as PIL_Image from PIL import Image as PIL_Image
@ -29,14 +28,14 @@ from .tool_utils import ToolUtils
@dataclass @dataclass
class VisionInput: class VisionInput:
mask: List[List[int]] mask: list[list[int]]
images: List[PIL_Image.Image] images: list[PIL_Image.Image]
@dataclass @dataclass
class LLMInput: class LLMInput:
tokens: List[int] tokens: list[int]
vision: Optional[VisionInput] = None vision: VisionInput | None = None
def role_str(role: Role) -> str: def role_str(role: Role) -> str:
@ -50,7 +49,7 @@ def role_str(role: Role) -> str:
class ChatFormat: class ChatFormat:
possible_headers: Dict[Role, str] possible_headers: dict[Role, str]
def __init__(self, tokenizer: Tokenizer): def __init__(self, tokenizer: Tokenizer):
self.tokenizer = tokenizer self.tokenizer = tokenizer
@ -58,7 +57,7 @@ class ChatFormat:
self.possible_headers = {role: f"<|start_header_id|>{role_str(role)}<|end_header_id|>\n\n" for role in Role} self.possible_headers = {role: f"<|start_header_id|>{role_str(role)}<|end_header_id|>\n\n" for role in Role}
self.vision_token = self.tokenizer.special_tokens["<|image|>"] self.vision_token = self.tokenizer.special_tokens["<|image|>"]
def _encode_header(self, role: str) -> List[int]: def _encode_header(self, role: str) -> list[int]:
tokens = [] tokens = []
tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"]) tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
tokens.extend(self.tokenizer.encode("ipython" if role == "tool" else role, bos=False, eos=False)) tokens.extend(self.tokenizer.encode("ipython" if role == "tool" else role, bos=False, eos=False))
@ -70,7 +69,7 @@ class ChatFormat:
tokens, images = self._encode_content(content, bos=True) tokens, images = self._encode_content(content, bos=True)
return self._model_input_from_tokens_images(tokens, images) return self._model_input_from_tokens_images(tokens, images)
def _encode_content(self, content: RawContent, bos: bool = False) -> Tuple[List[int], List[PIL_Image.Image]]: def _encode_content(self, content: RawContent, bos: bool = False) -> tuple[list[int], list[PIL_Image.Image]]:
tokens = [] tokens = []
images = [] images = []
@ -107,7 +106,7 @@ class ChatFormat:
def encode_message( def encode_message(
self, message: RawMessage, tool_prompt_format: ToolPromptFormat self, message: RawMessage, tool_prompt_format: ToolPromptFormat
) -> Tuple[List[int], List[PIL_Image.Image]]: ) -> tuple[list[int], list[PIL_Image.Image]]:
tokens = self._encode_header(message.role) tokens = self._encode_header(message.role)
images = [] images = []
@ -145,8 +144,8 @@ class ChatFormat:
def encode_dialog_prompt( def encode_dialog_prompt(
self, self,
messages: List[RawMessage], messages: list[RawMessage],
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: ToolPromptFormat | None = None,
) -> LLMInput: ) -> LLMInput:
tool_prompt_format = tool_prompt_format or ToolPromptFormat.json tool_prompt_format = tool_prompt_format or ToolPromptFormat.json
tokens = [] tokens = []
@ -163,7 +162,7 @@ class ChatFormat:
return self._model_input_from_tokens_images(tokens, images) return self._model_input_from_tokens_images(tokens, images)
# TODO(this should be generic, not only for assistant messages) # TODO(this should be generic, not only for assistant messages)
def decode_assistant_message(self, tokens: List[int], stop_reason: StopReason) -> RawMessage: def decode_assistant_message(self, tokens: list[int], stop_reason: StopReason) -> RawMessage:
content = self.tokenizer.decode(tokens) content = self.tokenizer.decode(tokens)
return self.decode_assistant_message_from_content(content, stop_reason) return self.decode_assistant_message_from_content(content, stop_reason)
@ -234,7 +233,7 @@ class ChatFormat:
tool_calls=tool_calls, tool_calls=tool_calls,
) )
def _model_input_from_tokens_images(self, tokens: List[int], images: List[PIL_Image.Image]) -> LLMInput: def _model_input_from_tokens_images(self, tokens: list[int], images: list[PIL_Image.Image]) -> LLMInput:
vision_input = None vision_input = None
if len(images) > 0: if len(images) > 0:
vision_input = VisionInput( vision_input = VisionInput(
@ -249,9 +248,9 @@ class ChatFormat:
def create_vision_mask( def create_vision_mask(
tokens: List[int], tokens: list[int],
vision_token: int, vision_token: int,
) -> List[List[int]]: ) -> list[list[int]]:
vision_token_locations = [i for i, token in enumerate(tokens) if token == vision_token] vision_token_locations = [i for i, token in enumerate(tokens) if token == vision_token]
if len(vision_token_locations) == 0: if len(vision_token_locations) == 0:
return [] return []

View file

@ -15,8 +15,8 @@ import json
import os import os
import sys import sys
import time import time
from collections.abc import Callable, Generator
from pathlib import Path from pathlib import Path
from typing import Callable, Generator, List, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -41,8 +41,8 @@ class Llama3:
ckpt_dir: str, ckpt_dir: str,
max_seq_len: int, max_seq_len: int,
max_batch_size: int, max_batch_size: int,
world_size: Optional[int] = None, world_size: int | None = None,
quantization_mode: Optional[QuantizationMode] = None, quantization_mode: QuantizationMode | None = None,
seed: int = 1, seed: int = 1,
device: str = "cuda", device: str = "cuda",
): ):
@ -82,7 +82,7 @@ class Llama3:
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth")) ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}" assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})") print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
with open(Path(ckpt_dir) / "params.json", "r") as f: with open(Path(ckpt_dir) / "params.json") as f:
params = json.loads(f.read()) params = json.loads(f.read())
model_args: ModelArgs = ModelArgs( model_args: ModelArgs = ModelArgs(
@ -154,15 +154,15 @@ class Llama3:
@torch.inference_mode() @torch.inference_mode()
def generate( def generate(
self, self,
llm_inputs: List[LLMInput], llm_inputs: list[LLMInput],
temperature: float = 0.6, temperature: float = 0.6,
top_p: float = 0.9, top_p: float = 0.9,
max_gen_len: Optional[int] = None, max_gen_len: int | None = None,
logprobs: bool = False, logprobs: bool = False,
echo: bool = False, echo: bool = False,
print_model_input: bool = False, print_model_input: bool = False,
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, logits_processor: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
) -> Generator[List[GenerationResult], None, None]: ) -> Generator[list[GenerationResult], None, None]:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len: if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1 max_gen_len = self.args.max_seq_len - 1
params = self.model.params params = self.model.params
@ -302,13 +302,13 @@ class Llama3:
def completion( def completion(
self, self,
contents: List[RawContent], contents: list[RawContent],
temperature: float = 0.6, temperature: float = 0.6,
top_p: float = 0.9, top_p: float = 0.9,
max_gen_len: Optional[int] = None, max_gen_len: int | None = None,
logprobs: bool = False, logprobs: bool = False,
echo: bool = False, echo: bool = False,
) -> Generator[List[GenerationResult], None, None]: ) -> Generator[list[GenerationResult], None, None]:
model_inputs = [self.formatter.encode_content(c) for c in contents] model_inputs = [self.formatter.encode_content(c) for c in contents]
for result in self.generate( for result in self.generate(
model_inputs=model_inputs, model_inputs=model_inputs,
@ -324,14 +324,14 @@ class Llama3:
def chat_completion( def chat_completion(
self, self,
messages_batch: List[List[RawMessage]], messages_batch: list[list[RawMessage]],
temperature: float = 0.6, temperature: float = 0.6,
top_p: float = 0.9, top_p: float = 0.9,
max_gen_len: Optional[int] = None, max_gen_len: int | None = None,
logprobs: bool = False, logprobs: bool = False,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json, tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
echo: bool = False, echo: bool = False,
) -> Generator[List[GenerationResult], None, None]: ) -> Generator[list[GenerationResult], None, None]:
model_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch] model_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
for result in self.generate( for result in self.generate(
model_inputs=model_inputs, model_inputs=model_inputs,

View file

@ -12,7 +12,6 @@
# the top-level of this source tree. # the top-level of this source tree.
from pathlib import Path from pathlib import Path
from typing import List, Optional
from termcolor import colored from termcolor import colored
@ -131,7 +130,7 @@ class LLama31Interface:
self.formatter = ChatFormat(self.tokenizer) self.formatter = ChatFormat(self.tokenizer)
self.tool_prompt_format = tool_prompt_format self.tool_prompt_format = tool_prompt_format
def get_tokens(self, messages: List[RawMessage]) -> List[int]: def get_tokens(self, messages: list[RawMessage]) -> list[int]:
model_input = self.formatter.encode_dialog_prompt( model_input = self.formatter.encode_dialog_prompt(
messages, messages,
self.tool_prompt_format, self.tool_prompt_format,
@ -149,10 +148,10 @@ class LLama31Interface:
def system_messages( def system_messages(
self, self,
builtin_tools: List[BuiltinTool], builtin_tools: list[BuiltinTool],
custom_tools: List[ToolDefinition], custom_tools: list[ToolDefinition],
instruction: Optional[str] = None, instruction: str | None = None,
) -> List[RawMessage]: ) -> list[RawMessage]:
messages = [] messages = []
default_gen = SystemDefaultGenerator() default_gen = SystemDefaultGenerator()
@ -194,8 +193,8 @@ class LLama31Interface:
self, self,
content: str, content: str,
stop_reason: StopReason, stop_reason: StopReason,
tool_call: Optional[ToolCall] = None, tool_call: ToolCall | None = None,
) -> List[RawMessage]: ) -> list[RawMessage]:
tool_calls = [] tool_calls = []
if tool_call: if tool_call:
tool_calls.append(tool_call) tool_calls.append(tool_call)
@ -208,7 +207,7 @@ class LLama31Interface:
) )
] ]
def user_message(self, content: str) -> List[RawMessage]: def user_message(self, content: str) -> list[RawMessage]:
return [RawMessage(role="user", content=content)] return [RawMessage(role="user", content=content)]
def display_message_as_tokens(self, message: RawMessage) -> None: def display_message_as_tokens(self, message: RawMessage) -> None:
@ -228,7 +227,7 @@ class LLama31Interface:
print("\n", end="") print("\n", end="")
def list_jinja_templates() -> List[Template]: def list_jinja_templates() -> list[Template]:
return TEMPLATES return TEMPLATES

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import math import math
from typing import Optional, Tuple
import fairscale.nn.model_parallel.initialize as fs_init import fairscale.nn.model_parallel.initialize as fs_init
import torch import torch
@ -80,7 +79,7 @@ def apply_rotary_emb(
xq: torch.Tensor, xq: torch.Tensor,
xk: torch.Tensor, xk: torch.Tensor,
freqs_cis: torch.Tensor, freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_) freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
@ -162,7 +161,7 @@ class Attention(nn.Module):
x: torch.Tensor, x: torch.Tensor,
start_pos: int, start_pos: int,
freqs_cis: torch.Tensor, freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor], mask: torch.Tensor | None,
): ):
bsz, seqlen, _ = x.shape bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
@ -204,7 +203,7 @@ class FeedForward(nn.Module):
dim: int, dim: int,
hidden_dim: int, hidden_dim: int,
multiple_of: int, multiple_of: int,
ffn_dim_multiplier: Optional[float], ffn_dim_multiplier: float | None,
): ):
super().__init__() super().__init__()
hidden_dim = int(2 * hidden_dim / 3) hidden_dim = int(2 * hidden_dim / 3)
@ -243,7 +242,7 @@ class TransformerBlock(nn.Module):
x: torch.Tensor, x: torch.Tensor,
start_pos: int, start_pos: int,
freqs_cis: torch.Tensor, freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor], mask: torch.Tensor | None,
): ):
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask) h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
out = h + self.feed_forward(self.ffn_norm(h)) out = h + self.feed_forward(self.ffn_norm(h))

View file

@ -14,7 +14,7 @@
import math import math
from collections import defaultdict from collections import defaultdict
from logging import getLogger from logging import getLogger
from typing import Any, Optional, Set, Tuple from typing import Any
import torch import torch
import torchvision.transforms as tv import torchvision.transforms as tv
@ -26,7 +26,7 @@ IMAGE_RES = 224
logger = getLogger() logger = getLogger()
class VariableSizeImageTransform(object): class VariableSizeImageTransform:
""" """
This class accepts images of any size and dynamically resize, pads and chunks it This class accepts images of any size and dynamically resize, pads and chunks it
based on the image aspect ratio and the number of image chunks we allow. based on the image aspect ratio and the number of image chunks we allow.
@ -75,7 +75,7 @@ class VariableSizeImageTransform(object):
self.resample = tv.InterpolationMode.BILINEAR self.resample = tv.InterpolationMode.BILINEAR
@staticmethod @staticmethod
def get_factors(n: int) -> Set[int]: def get_factors(n: int) -> set[int]:
""" """
Calculate all factors of a given number, i.e. a dividor that leaves Calculate all factors of a given number, i.e. a dividor that leaves
no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}. no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}.
@ -145,9 +145,9 @@ class VariableSizeImageTransform(object):
@staticmethod @staticmethod
def get_max_res_without_distortion( def get_max_res_without_distortion(
image_size: Tuple[int, int], image_size: tuple[int, int],
target_size: Tuple[int, int], target_size: tuple[int, int],
) -> Tuple[int, int]: ) -> tuple[int, int]:
""" """
Determines the maximum resolution to which an image can be resized to without distorting its Determines the maximum resolution to which an image can be resized to without distorting its
aspect ratio, based on the target resolution. aspect ratio, based on the target resolution.
@ -198,8 +198,8 @@ class VariableSizeImageTransform(object):
def resize_without_distortion( def resize_without_distortion(
self, self,
image: torch.Tensor, image: torch.Tensor,
target_size: Tuple[int, int], target_size: tuple[int, int],
max_upscaling_size: Optional[int], max_upscaling_size: int | None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Used to resize an image to target_resolution, without distortion. Used to resize an image to target_resolution, without distortion.
@ -261,10 +261,10 @@ class VariableSizeImageTransform(object):
def get_best_fit( def get_best_fit(
self, self,
image_size: Tuple[int, int], image_size: tuple[int, int],
possible_resolutions: torch.Tensor, possible_resolutions: torch.Tensor,
resize_to_max_canvas: bool = False, resize_to_max_canvas: bool = False,
) -> Tuple[int, int]: ) -> tuple[int, int]:
""" """
Determines the best canvas possible from a list of possible resolutions to, without distortion, Determines the best canvas possible from a list of possible resolutions to, without distortion,
resize an image to. resize an image to.
@ -364,7 +364,7 @@ class VariableSizeImageTransform(object):
max_num_chunks: int, max_num_chunks: int,
normalize_img: bool = True, normalize_img: bool = True,
resize_to_max_canvas: bool = False, resize_to_max_canvas: bool = False,
) -> Tuple[Any, Any]: ) -> tuple[Any, Any]:
""" """
Args: Args:
image (PIL.Image): Image to be resized. image (PIL.Image): Image to be resized.

View file

@ -6,8 +6,9 @@
import logging import logging
import math import math
from collections.abc import Callable
from functools import partial from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any
import fairscale.nn.model_parallel.initialize as fs_init import fairscale.nn.model_parallel.initialize as fs_init
import torch import torch
@ -104,9 +105,9 @@ class ColumnParallelConv2dPatch(torch.nn.Module):
self, self,
in_channels: int, in_channels: int,
out_channels: int, out_channels: int,
kernel_size: Union[int, Tuple[int, int]], kernel_size: int | tuple[int, int],
stride: Union[int, Tuple[int, int]], stride: int | tuple[int, int],
bias: Optional[bool] = False, bias: bool | None = False,
) -> None: ) -> None:
super().__init__() super().__init__()
if isinstance(kernel_size, int): if isinstance(kernel_size, int):
@ -390,13 +391,13 @@ class VisionEncoder(nn.Module):
def load_hook( def load_hook(
self, self,
state_dict: Dict[str, Any], state_dict: dict[str, Any],
prefix: str, prefix: str,
local_metadata: Dict[str, Any], local_metadata: dict[str, Any],
strict: bool = True, strict: bool = True,
missing_keys: List[str] = None, missing_keys: list[str] = None,
unexpected_keys: List[str] = None, unexpected_keys: list[str] = None,
error_msgs: List[str] = None, error_msgs: list[str] = None,
return_state_dict: bool = False, return_state_dict: bool = False,
) -> None: ) -> None:
orig_pos_embed = state_dict.get(prefix + "positional_embedding") orig_pos_embed = state_dict.get(prefix + "positional_embedding")
@ -641,7 +642,7 @@ class FeedForward(nn.Module):
dim: int, dim: int,
hidden_dim: int, hidden_dim: int,
multiple_of: int, multiple_of: int,
ffn_dim_multiplier: Optional[float], ffn_dim_multiplier: float | None,
): ):
""" """
Initialize the FeedForward module. Initialize the FeedForward module.
@ -983,7 +984,7 @@ class CrossAttentionTransformerBlock(torch.nn.Module):
self, self,
x: torch.Tensor, x: torch.Tensor,
xattn_mask: torch.Tensor, xattn_mask: torch.Tensor,
full_text_row_masked_out_mask: Tuple[torch.Tensor, torch.Tensor], full_text_row_masked_out_mask: tuple[torch.Tensor, torch.Tensor],
xattn_cache: torch.Tensor, xattn_cache: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
_attn_out = self.attention( _attn_out = self.attention(
@ -1144,7 +1145,7 @@ class CrossAttentionTransformerText(torch.nn.Module):
def _init_fusion_schedule( def _init_fusion_schedule(
self, self,
num_layers: int, num_layers: int,
) -> List[int]: ) -> list[int]:
llama_layers = list(range(self.n_llama_layers)) llama_layers = list(range(self.n_llama_layers))
# uniformly spread the layers # uniformly spread the layers
@ -1231,7 +1232,7 @@ class CrossAttentionTransformerText(torch.nn.Module):
text_dtype, text_dtype,
vision_tokens, vision_tokens,
cross_attention_masks, cross_attention_masks,
) -> Tuple[Tensor, Tensor]: ) -> tuple[Tensor, Tensor]:
assert vision_tokens is not None, "Vision tokens must be provided" assert vision_tokens is not None, "Vision tokens must be provided"
vision_seqlen = vision_tokens.shape[3] vision_seqlen = vision_tokens.shape[3]
assert vision_tokens.shape[1] == cross_attention_masks.shape[2], ( assert vision_tokens.shape[1] == cross_attention_masks.shape[2], (
@ -1280,11 +1281,11 @@ class CrossAttentionTransformer(torch.nn.Module):
def compute_vision_tokens_masks( def compute_vision_tokens_masks(
self, self,
batch_images: List[List[PIL_Image.Image]], batch_images: list[list[PIL_Image.Image]],
batch_masks: List[List[List[int]]], batch_masks: list[list[list[int]]],
total_len: int, total_len: int,
device: torch.device, device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
skip_vision_encoder = False skip_vision_encoder = False
assert len(batch_images) == len(batch_masks), "Images and masks must have the same length" assert len(batch_images) == len(batch_masks), "Images and masks must have the same length"
@ -1371,11 +1372,11 @@ class CrossAttentionTransformer(torch.nn.Module):
def _stack_images( def _stack_images(
images: List[List[PIL_Image.Image]], images: list[list[PIL_Image.Image]],
max_num_chunks: int, max_num_chunks: int,
image_res: int, image_res: int,
max_num_images: int, max_num_images: int,
) -> Tuple[torch.Tensor, List[int]]: ) -> tuple[torch.Tensor, list[int]]:
""" """
Takes a list of list of images and stacks them into a tensor. Takes a list of list of images and stacks them into a tensor.
This function is needed since images can be of completely This function is needed since images can be of completely
@ -1400,8 +1401,8 @@ def _stack_images(
def _pad_masks( def _pad_masks(
all_masks: List[List[List[int]]], all_masks: list[list[list[int]]],
all_num_chunks: List[List[int]], all_num_chunks: list[list[int]],
total_len: int, total_len: int,
max_num_chunks: int, max_num_chunks: int,
) -> torch.Tensor: ) -> torch.Tensor:

View file

@ -12,7 +12,7 @@
# the top-level of this source tree. # the top-level of this source tree.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List from typing import Any
from jinja2 import Template from jinja2 import Template
@ -20,7 +20,7 @@ from jinja2 import Template
@dataclass @dataclass
class PromptTemplate: class PromptTemplate:
template: str template: str
data: Dict[str, Any] data: dict[str, Any]
def render(self): def render(self):
template = Template(self.template) template = Template(self.template)
@ -35,5 +35,5 @@ class PromptTemplateGeneratorBase:
def gen(self, *args, **kwargs) -> PromptTemplate: def gen(self, *args, **kwargs) -> PromptTemplate:
raise NotImplementedError() raise NotImplementedError()
def data_examples(self) -> List[Any]: def data_examples(self) -> list[Any]:
raise NotImplementedError() raise NotImplementedError()

View file

@ -13,7 +13,7 @@
import textwrap import textwrap
from datetime import datetime from datetime import datetime
from typing import Any, List, Optional from typing import Any
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
BuiltinTool, BuiltinTool,
@ -39,12 +39,12 @@ class SystemDefaultGenerator(PromptTemplateGeneratorBase):
}, },
) )
def data_examples(self) -> List[Any]: def data_examples(self) -> list[Any]:
return [None] return [None]
class BuiltinToolGenerator(PromptTemplateGeneratorBase): class BuiltinToolGenerator(PromptTemplateGeneratorBase):
def _tool_breakdown(self, tools: List[ToolDefinition]): def _tool_breakdown(self, tools: list[ToolDefinition]):
builtin_tools, custom_tools = [], [] builtin_tools, custom_tools = [], []
for dfn in tools: for dfn in tools:
if isinstance(dfn.tool_name, BuiltinTool): if isinstance(dfn.tool_name, BuiltinTool):
@ -54,7 +54,7 @@ class BuiltinToolGenerator(PromptTemplateGeneratorBase):
return builtin_tools, custom_tools return builtin_tools, custom_tools
def gen(self, tools: List[ToolDefinition]) -> PromptTemplate: def gen(self, tools: list[ToolDefinition]) -> PromptTemplate:
builtin_tools, custom_tools = self._tool_breakdown(tools) builtin_tools, custom_tools = self._tool_breakdown(tools)
template_str = textwrap.dedent( template_str = textwrap.dedent(
""" """
@ -75,7 +75,7 @@ class BuiltinToolGenerator(PromptTemplateGeneratorBase):
}, },
) )
def data_examples(self) -> List[List[ToolDefinition]]: def data_examples(self) -> list[list[ToolDefinition]]:
return [ return [
# builtin tools # builtin tools
[ [
@ -91,7 +91,7 @@ class BuiltinToolGenerator(PromptTemplateGeneratorBase):
class JsonCustomToolGenerator(PromptTemplateGeneratorBase): class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate: def gen(self, custom_tools: list[ToolDefinition]) -> PromptTemplate:
template_str = textwrap.dedent( template_str = textwrap.dedent(
""" """
Answer the user's question by making use of the following functions if needed. Answer the user's question by making use of the following functions if needed.
@ -137,7 +137,7 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
{"custom_tools": [t.model_dump() for t in custom_tools]}, {"custom_tools": [t.model_dump() for t in custom_tools]},
) )
def data_examples(self) -> List[List[ToolDefinition]]: def data_examples(self) -> list[list[ToolDefinition]]:
return [ return [
[ [
ToolDefinition( ToolDefinition(
@ -161,7 +161,7 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase): class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate: def gen(self, custom_tools: list[ToolDefinition]) -> PromptTemplate:
template_str = textwrap.dedent( template_str = textwrap.dedent(
""" """
You have access to the following functions: You have access to the following functions:
@ -199,7 +199,7 @@ class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
{"custom_tools": [t.model_dump() for t in custom_tools]}, {"custom_tools": [t.model_dump() for t in custom_tools]},
) )
def data_examples(self) -> List[List[ToolDefinition]]: def data_examples(self) -> list[list[ToolDefinition]]:
return [ return [
[ [
ToolDefinition( ToolDefinition(
@ -238,14 +238,14 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
""".strip("\n") """.strip("\n")
) )
def gen(self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None) -> PromptTemplate: def gen(self, custom_tools: list[ToolDefinition], system_prompt: str | None = None) -> PromptTemplate:
system_prompt = system_prompt or self.DEFAULT_PROMPT system_prompt = system_prompt or self.DEFAULT_PROMPT
return PromptTemplate( return PromptTemplate(
system_prompt, system_prompt,
{"function_description": self._gen_function_description(custom_tools)}, {"function_description": self._gen_function_description(custom_tools)},
) )
def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate: def _gen_function_description(self, custom_tools: list[ToolDefinition]) -> PromptTemplate:
template_str = textwrap.dedent( template_str = textwrap.dedent(
""" """
Here is a list of functions in JSON format that you can invoke. Here is a list of functions in JSON format that you can invoke.
@ -291,7 +291,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
{"tools": [t.model_dump() for t in custom_tools]}, {"tools": [t.model_dump() for t in custom_tools]},
).render() ).render()
def data_examples(self) -> List[List[ToolDefinition]]: def data_examples(self) -> list[list[ToolDefinition]]:
return [ return [
[ [
ToolDefinition( ToolDefinition(

View file

@ -12,7 +12,6 @@
# the top-level of this source tree. # the top-level of this source tree.
import textwrap import textwrap
from typing import Optional
from .base import PromptTemplate, PromptTemplateGeneratorBase from .base import PromptTemplate, PromptTemplateGeneratorBase
@ -21,8 +20,8 @@ class ToolResponseGenerator(PromptTemplateGeneratorBase):
def gen( def gen(
self, self,
status: str, status: str,
stdout: Optional[str] = None, stdout: str | None = None,
stderr: Optional[str] = None, stderr: str | None = None,
): ):
assert status in [ assert status in [
"success", "success",

View file

@ -6,7 +6,7 @@
# type: ignore # type: ignore
import os import os
from typing import Any, Dict, List, Optional, cast from typing import Any, cast
import torch import torch
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
@ -37,9 +37,9 @@ def swiglu_wrapper(
def convert_to_quantized_model( def convert_to_quantized_model(
model: Transformer | CrossAttentionTransformer, model: Transformer | CrossAttentionTransformer,
checkpoint_dir: str, checkpoint_dir: str,
quantization_mode: Optional[str] = None, quantization_mode: str | None = None,
fp8_activation_scale_ub: Optional[float] = 1200.0, fp8_activation_scale_ub: float | None = 1200.0,
device: Optional[torch.device] = None, device: torch.device | None = None,
) -> Transformer | CrossAttentionTransformer: ) -> Transformer | CrossAttentionTransformer:
if quantization_mode == QuantizationMode.fp8_mixed: if quantization_mode == QuantizationMode.fp8_mixed:
return convert_to_fp8_quantized_model(model, checkpoint_dir, fp8_activation_scale_ub, device) return convert_to_fp8_quantized_model(model, checkpoint_dir, fp8_activation_scale_ub, device)
@ -52,8 +52,8 @@ def convert_to_quantized_model(
def convert_to_fp8_quantized_model( def convert_to_fp8_quantized_model(
model: Transformer, model: Transformer,
checkpoint_dir: str, checkpoint_dir: str,
fp8_activation_scale_ub: Optional[float] = 1200.0, fp8_activation_scale_ub: float | None = 1200.0,
device: Optional[torch.device] = None, device: torch.device | None = None,
) -> Transformer: ) -> Transformer:
# Move weights to GPU with quantization # Move weights to GPU with quantization
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt") fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
@ -122,8 +122,8 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
precision: torch.dtype = torch.float32, precision: torch.dtype = torch.float32,
scales_precision: torch.dtype = torch.float32, scales_precision: torch.dtype = torch.float32,
# LoRA parameters # LoRA parameters
lora_rank: Optional[int] = None, lora_rank: int | None = None,
lora_scale: Optional[float] = None, lora_scale: float | None = None,
) -> None: ) -> None:
super().__init__( super().__init__(
in_features, in_features,
@ -134,8 +134,8 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
precision=precision, precision=precision,
scales_precision=scales_precision, scales_precision=scales_precision,
) )
self.lora_scale: Optional[float] = None self.lora_scale: float | None = None
self.adaptor: Optional[nn.Sequential] = None self.adaptor: nn.Sequential | None = None
if lora_rank is not None: if lora_rank is not None:
assert lora_scale is not None, "Please specify lora scale for LoRA." assert lora_scale is not None, "Please specify lora scale for LoRA."
# Low-rank adaptation. See paper for more details: https://arxiv.org/abs/2106.09685 # Low-rank adaptation. See paper for more details: https://arxiv.org/abs/2106.09685
@ -147,13 +147,13 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
def load_hook( def load_hook(
self, self,
state_dict: Dict[str, Any], state_dict: dict[str, Any],
prefix: str, prefix: str,
local_metadata: Dict[str, Any], local_metadata: dict[str, Any],
strict: bool, strict: bool,
missing_keys: List[str], missing_keys: list[str],
unexpected_keys: List[str], unexpected_keys: list[str],
error_msgs: List[str], error_msgs: list[str],
) -> None: ) -> None:
"""A hook to load the quantized weights from the state dict.""" """A hook to load the quantized weights from the state dict."""
if prefix + "zeros" not in state_dict: if prefix + "zeros" not in state_dict:
@ -191,13 +191,13 @@ class Int8WeightEmbedding(torch.nn.Embedding):
def load_hook( def load_hook(
self, self,
state_dict: Dict[str, Any], state_dict: dict[str, Any],
prefix: str, prefix: str,
local_metadata: Dict[str, Any], local_metadata: dict[str, Any],
strict: bool, strict: bool,
missing_keys: List[str], missing_keys: list[str],
unexpected_keys: List[str], unexpected_keys: list[str],
error_msgs: List[str], error_msgs: list[str],
) -> None: ) -> None:
"""A hook to load the quantized embedding weight and scales from the state dict.""" """A hook to load the quantized embedding weight and scales from the state dict."""
weights = state_dict.pop(prefix + "weight") weights = state_dict.pop(prefix + "weight")
@ -221,13 +221,13 @@ class Int8WeightLinear(torch.nn.Linear):
def load_hook( def load_hook(
self, self,
state_dict: Dict[str, Any], state_dict: dict[str, Any],
prefix: str, prefix: str,
local_metadata: Dict[str, Any], local_metadata: dict[str, Any],
strict: bool, strict: bool,
missing_keys: List[str], missing_keys: list[str],
unexpected_keys: List[str], unexpected_keys: list[str],
error_msgs: List[str], error_msgs: list[str],
) -> None: ) -> None:
"""A hook to load the quantized linear weight and scales from the state dict.""" """A hook to load the quantized linear weight and scales from the state dict."""
weights = state_dict.pop(prefix + "weight") weights = state_dict.pop(prefix + "weight")
@ -238,8 +238,8 @@ class Int8WeightLinear(torch.nn.Linear):
def _prepare_model_int4_weight_int8_dynamic_activation( def _prepare_model_int4_weight_int8_dynamic_activation(
model: torch.nn.Module, model: torch.nn.Module,
group_size: int, group_size: int,
lora_rank: Optional[int], lora_rank: int | None,
lora_scale: Optional[float], lora_scale: float | None,
): ):
"""Prepare the model for int4 weight and int8 dynamic activation quantization. """Prepare the model for int4 weight and int8 dynamic activation quantization.
@ -265,7 +265,7 @@ def _prepare_model_int4_weight_int8_dynamic_activation(
) )
del module del module
setattr(model, module_name, quantized_module) setattr(model, module_name, quantized_module)
elif isinstance(module, (ColumnParallelLinear, RowParallelLinear, nn.Linear)): elif isinstance(module, ColumnParallelLinear | RowParallelLinear | nn.Linear):
quantized_module = Int8DynActInt4WeightLinearLoRA( quantized_module = Int8DynActInt4WeightLinearLoRA(
in_features=module.in_features, in_features=module.in_features,
out_features=module.out_features, out_features=module.out_features,
@ -286,7 +286,7 @@ def _prepare_model_int4_weight_int8_dynamic_activation(
def convert_to_int4_quantized_model( def convert_to_int4_quantized_model(
model: Transformer | CrossAttentionTransformer, model: Transformer | CrossAttentionTransformer,
checkpoint_dir: str, checkpoint_dir: str,
device: Optional[torch.device] = None, device: torch.device | None = None,
) -> Transformer | CrossAttentionTransformer: ) -> Transformer | CrossAttentionTransformer:
"""Convert the model to int4 quantized model.""" """Convert the model to int4 quantized model."""
model_args = model.params model_args = model.params

View file

@ -5,18 +5,11 @@
# the root directory of this source tree. # the root directory of this source tree.
import os import os
from collections.abc import Collection, Iterator, Sequence, Set
from logging import getLogger from logging import getLogger
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
AbstractSet,
Collection,
Dict,
Iterator,
List,
Literal, Literal,
Optional,
Sequence,
Union,
cast, cast,
) )
@ -44,7 +37,7 @@ class Tokenizer:
Tokenizing and encoding/decoding text using the Tiktoken tokenizer. Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
""" """
special_tokens: Dict[str, int] special_tokens: dict[str, int]
num_reserved_special_tokens = 256 num_reserved_special_tokens = 256
@ -116,9 +109,9 @@ class Tokenizer:
*, *,
bos: bool, bos: bool,
eos: bool, eos: bool,
allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None, allowed_special: Literal["all"] | Set[str] | None = None,
disallowed_special: Union[Literal["all"], Collection[str]] = (), disallowed_special: Literal["all"] | Collection[str] = (),
) -> List[int]: ) -> list[int]:
""" """
Encodes a string into a list of token IDs. Encodes a string into a list of token IDs.
@ -151,7 +144,7 @@ class Tokenizer:
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
) )
) )
t: List[int] = [] t: list[int] = []
for substr in substrs: for substr in substrs:
t.extend( t.extend(
self.model.encode( self.model.encode(
@ -177,7 +170,7 @@ class Tokenizer:
str: The decoded string. str: The decoded string.
""" """
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
return self.model.decode(cast(List[int], t)) return self.model.decode(cast(list[int], t))
@staticmethod @staticmethod
def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]: def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]:

View file

@ -6,7 +6,6 @@
import json import json
import re import re
from typing import Optional, Tuple
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -172,7 +171,7 @@ class ToolUtils:
return match is not None return match is not None
@staticmethod @staticmethod
def maybe_extract_builtin_tool_call(message_body: str) -> Optional[Tuple[str, str]]: def maybe_extract_builtin_tool_call(message_body: str) -> tuple[str, str] | None:
# Find the first match in the text # Find the first match in the text
match = re.search(BUILTIN_TOOL_PATTERN, message_body) match = re.search(BUILTIN_TOOL_PATTERN, message_body)
@ -185,7 +184,7 @@ class ToolUtils:
return None return None
@staticmethod @staticmethod
def maybe_extract_custom_tool_call(message_body: str) -> Optional[Tuple[str, str]]: def maybe_extract_custom_tool_call(message_body: str) -> tuple[str, str] | None:
# NOTE: Custom function too calls are still experimental # NOTE: Custom function too calls are still experimental
# Sometimes, response is of the form # Sometimes, response is of the form
# {"type": "function", "name": "function_name", "parameters": {...} # {"type": "function", "name": "function_name", "parameters": {...}
@ -252,7 +251,7 @@ class ToolUtils:
def format_value(value: RecursiveType) -> str: def format_value(value: RecursiveType) -> str:
if isinstance(value, str): if isinstance(value, str):
return f'"{value}"' return f'"{value}"'
elif isinstance(value, (int, float, bool)) or value is None: elif isinstance(value, int | float | bool) or value is None:
return str(value) return str(value)
elif isinstance(value, list): elif isinstance(value, list):
return f"[{', '.join(format_value(v) for v in value)}]" return f"[{', '.join(format_value(v) for v in value)}]"

View file

@ -12,7 +12,6 @@
# the top-level of this source tree. # the top-level of this source tree.
import textwrap import textwrap
from typing import List
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
BuiltinTool, BuiltinTool,
@ -73,7 +72,7 @@ def wolfram_alpha_response():
) )
def usecases() -> List[UseCase | str]: def usecases() -> list[UseCase | str]:
return [ return [
textwrap.dedent( textwrap.dedent(
""" """

View file

@ -12,7 +12,6 @@
# the top-level of this source tree. # the top-level of this source tree.
import textwrap import textwrap
from typing import List
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
BuiltinTool, BuiltinTool,
@ -74,7 +73,7 @@ def wolfram_alpha_response():
) )
def usecases() -> List[UseCase | str]: def usecases() -> list[UseCase | str]:
return [ return [
textwrap.dedent( textwrap.dedent(
""" """

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Optional
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
@ -15,8 +14,8 @@ class QuantizationScheme(Enum):
class QuantizationArgs(BaseModel): class QuantizationArgs(BaseModel):
scheme: Optional[QuantizationScheme] = None scheme: QuantizationScheme | None = None
group_size: Optional[int] = None group_size: int | None = None
spinquant: bool = False spinquant: bool = False
@ -58,32 +57,32 @@ class ModelArgs(BaseModel):
dim: int = -1 dim: int = -1
n_layers: int = -1 n_layers: int = -1
n_heads: int = -1 n_heads: int = -1
n_kv_heads: Optional[int] = None n_kv_heads: int | None = None
head_dim: Optional[int] = None head_dim: int | None = None
vocab_size: int = -1 vocab_size: int = -1
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
ffn_dim_multiplier: Optional[float] = None ffn_dim_multiplier: float | None = None
ffn_exp: Optional[float] = None ffn_exp: float | None = None
norm_eps: float = 1e-5 norm_eps: float = 1e-5
attention_chunk_size: Optional[int] = None attention_chunk_size: int | None = None
rope_theta: float = 500000 rope_theta: float = 500000
use_scaled_rope: bool = False use_scaled_rope: bool = False
rope_scaling_factor: Optional[float] = None rope_scaling_factor: float | None = None
rope_high_freq_factor: Optional[float] = None rope_high_freq_factor: float | None = None
nope_layer_interval: Optional[int] = None # No position encoding in every n layers nope_layer_interval: int | None = None # No position encoding in every n layers
use_qk_norm: bool = False use_qk_norm: bool = False
# Set to True to enable inference-time temperature tuning (useful for very long context) # Set to True to enable inference-time temperature tuning (useful for very long context)
attn_temperature_tuning: bool = False attn_temperature_tuning: bool = False
floor_scale: float = 8192.0 floor_scale: float = 8192.0
attn_scale: float = 0.1 attn_scale: float = 0.1
vision_args: Optional[VisionArgs] = None vision_args: VisionArgs | None = None
moe_args: Optional[MoEArgs] = None moe_args: MoEArgs | None = None
quantization_args: Optional[QuantizationArgs] = None quantization_args: QuantizationArgs | None = None
lora_args: Optional[LoRAArgs] = None lora_args: LoRAArgs | None = None
max_batch_size: int = 32 max_batch_size: int = 32
max_seq_len: int = 2048 max_seq_len: int = 2048

View file

@ -8,7 +8,6 @@ import io
import json import json
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import torch import torch
from PIL import Image as PIL_Image from PIL import Image as PIL_Image
@ -46,10 +45,10 @@ def role_str(role: Role) -> str:
class TransformedImage: class TransformedImage:
image_tiles: torch.Tensor image_tiles: torch.Tensor
# is the aspect ratio needed anywhere? # is the aspect ratio needed anywhere?
aspect_ratio: Tuple[int, int] aspect_ratio: tuple[int, int]
def convert_image_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image: def convert_image_to_rgb(image: PIL_Image.Image, bg: tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
if image.mode == "RGBA": if image.mode == "RGBA":
image.load() # for png.split() image.load() # for png.split()
new_img = PIL_Image.new("RGB", image.size, bg) new_img = PIL_Image.new("RGB", image.size, bg)
@ -59,12 +58,12 @@ def convert_image_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255
class ChatFormat: class ChatFormat:
possible_headers: Dict[Role, str] possible_headers: dict[Role, str]
def __init__( def __init__(
self, self,
tokenizer: Tokenizer, tokenizer: Tokenizer,
vision_args: Optional[VisionArgs] = None, vision_args: VisionArgs | None = None,
max_num_chunks: int = 16, max_num_chunks: int = 16,
): ):
self.tokenizer = tokenizer self.tokenizer = tokenizer
@ -81,7 +80,7 @@ class ChatFormat:
vision_args.image_size.width, vision_args.image_size.height vision_args.image_size.width, vision_args.image_size.height
) )
def _encode_header(self, role: str) -> List[int]: def _encode_header(self, role: str) -> list[int]:
tokens = [] tokens = []
tokens.append(self.tokenizer.special_tokens["<|header_start|>"]) tokens.append(self.tokenizer.special_tokens["<|header_start|>"])
@ -98,7 +97,7 @@ class ChatFormat:
def _encode_image( def _encode_image(
self, self,
transformed_image: TransformedImage, transformed_image: TransformedImage,
) -> List[int]: ) -> list[int]:
assert self.vision_args is not None, "The model is not vision-enabled" assert self.vision_args is not None, "The model is not vision-enabled"
image_tensor = transformed_image.image_tiles image_tensor = transformed_image.image_tiles
@ -140,7 +139,7 @@ class ChatFormat:
return tokens return tokens
def _encode_content(self, content: RawContent, bos: bool = False) -> Tuple[List[int], List[TransformedImage]]: def _encode_content(self, content: RawContent, bos: bool = False) -> tuple[list[int], list[TransformedImage]]:
tokens = [] tokens = []
tranformed_images = [] tranformed_images = []
@ -189,7 +188,7 @@ class ChatFormat:
def encode_message( def encode_message(
self, message: RawMessage, tool_prompt_format: ToolPromptFormat self, message: RawMessage, tool_prompt_format: ToolPromptFormat
) -> Tuple[List[int], List[TransformedImage]]: ) -> tuple[list[int], list[TransformedImage]]:
tokens = self._encode_header(message.role) tokens = self._encode_header(message.role)
images = [] images = []
@ -223,7 +222,7 @@ class ChatFormat:
def encode_dialog_prompt( def encode_dialog_prompt(
self, self,
messages: List[RawMessage], messages: list[RawMessage],
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json, tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
) -> LLMInput: ) -> LLMInput:
tokens = [] tokens = []
@ -240,7 +239,7 @@ class ChatFormat:
return self._model_input_from_tokens_images(tokens, images) return self._model_input_from_tokens_images(tokens, images)
# TODO(this should be generic, not only for assistant messages) # TODO(this should be generic, not only for assistant messages)
def decode_assistant_message(self, tokens: List[int], stop_reason: StopReason) -> RawMessage: def decode_assistant_message(self, tokens: list[int], stop_reason: StopReason) -> RawMessage:
content = self.tokenizer.decode(tokens) content = self.tokenizer.decode(tokens)
return self.decode_assistant_message_from_content(content, stop_reason) return self.decode_assistant_message_from_content(content, stop_reason)
@ -312,7 +311,7 @@ class ChatFormat:
tool_calls=tool_calls, tool_calls=tool_calls,
) )
def _model_input_from_tokens_images(self, tokens: List[int], images: List[TransformedImage]) -> LLMInput: def _model_input_from_tokens_images(self, tokens: list[int], images: list[TransformedImage]) -> LLMInput:
return LLMInput( return LLMInput(
tokens=tokens, tokens=tokens,
images=[x.image_tiles for x in images] if len(images) > 0 else None, images=[x.image_tiles for x in images] if len(images) > 0 else None,

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Union
import torch import torch
@ -30,7 +29,7 @@ class LLMInput:
tokens: torch.Tensor tokens: torch.Tensor
# images are already pre-processed (resized, tiled, etc.) # images are already pre-processed (resized, tiled, etc.)
images: Optional[List[torch.Tensor]] = None images: list[torch.Tensor] | None = None
@dataclass @dataclass
@ -45,8 +44,8 @@ class TransformerInput:
# tokens_position defines the position of the tokens in each batch, # tokens_position defines the position of the tokens in each batch,
# - when it is a tensor ([batch_size,]), it is the start position of the tokens in each batch # - when it is a tensor ([batch_size,]), it is the start position of the tokens in each batch
# - when it is an int, the start position are the same for all batches # - when it is an int, the start position are the same for all batches
tokens_position: Union[torch.Tensor, int] tokens_position: torch.Tensor | int
image_embedding: Optional[MaskedEmbedding] = None image_embedding: MaskedEmbedding | None = None
@dataclass @dataclass

View file

@ -11,7 +11,7 @@
# top-level folder for each specific model found within the models/ directory at # top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree. # the top-level of this source tree.
from typing import Any, Dict, List from typing import Any
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
@ -36,13 +36,13 @@ class FeedForward(nn.Module):
def load_hook( def load_hook(
self, self,
state_dict: Dict[str, Any], state_dict: dict[str, Any],
prefix: str, prefix: str,
local_metadata: Dict[str, Any], local_metadata: dict[str, Any],
strict: bool, strict: bool,
missing_keys: List[str], missing_keys: list[str],
unexpected_keys: List[str], unexpected_keys: list[str],
error_msgs: List[str], error_msgs: list[str],
) -> None: ) -> None:
if prefix + "mlp.fc1_weight" in state_dict: if prefix + "mlp.fc1_weight" in state_dict:
w1, w3 = state_dict.pop(prefix + "mlp.fc1_weight").chunk(2, dim=0) w1, w3 = state_dict.pop(prefix + "mlp.fc1_weight").chunk(2, dim=0)

View file

@ -10,8 +10,8 @@ import json
import os import os
import sys import sys
import time import time
from collections.abc import Callable, Generator
from pathlib import Path from pathlib import Path
from typing import Callable, Generator, List, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -38,8 +38,8 @@ class Llama4:
ckpt_dir: str, ckpt_dir: str,
max_seq_len: int, max_seq_len: int,
max_batch_size: int, max_batch_size: int,
world_size: Optional[int] = None, world_size: int | None = None,
quantization_mode: Optional[QuantizationMode] = None, quantization_mode: QuantizationMode | None = None,
seed: int = 1, seed: int = 1,
): ):
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
@ -63,7 +63,7 @@ class Llama4:
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth")) ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}" assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})") print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
with open(Path(ckpt_dir) / "params.json", "r") as f: with open(Path(ckpt_dir) / "params.json") as f:
params = json.loads(f.read()) params = json.loads(f.read())
model_args: ModelArgs = ModelArgs( model_args: ModelArgs = ModelArgs(
@ -117,15 +117,15 @@ class Llama4:
@torch.inference_mode() @torch.inference_mode()
def generate( def generate(
self, self,
llm_inputs: List[LLMInput], llm_inputs: list[LLMInput],
temperature: float = 0.6, temperature: float = 0.6,
top_p: float = 0.9, top_p: float = 0.9,
max_gen_len: Optional[int] = None, max_gen_len: int | None = None,
logprobs: bool = False, logprobs: bool = False,
echo: bool = False, echo: bool = False,
print_model_input: bool = False, print_model_input: bool = False,
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, logits_processor: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
) -> Generator[List[GenerationResult], None, None]: ) -> Generator[list[GenerationResult], None, None]:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.args.max_seq_len: if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.args.max_seq_len:
max_gen_len = self.model.args.max_seq_len - 1 max_gen_len = self.model.args.max_seq_len - 1
@ -245,13 +245,13 @@ class Llama4:
def completion( def completion(
self, self,
contents: List[RawContent], contents: list[RawContent],
temperature: float = 0.6, temperature: float = 0.6,
top_p: float = 0.9, top_p: float = 0.9,
max_gen_len: Optional[int] = None, max_gen_len: int | None = None,
logprobs: bool = False, logprobs: bool = False,
echo: bool = False, echo: bool = False,
) -> Generator[List[GenerationResult], None, None]: ) -> Generator[list[GenerationResult], None, None]:
llm_inputs = [self.formatter.encode_content(c) for c in contents] llm_inputs = [self.formatter.encode_content(c) for c in contents]
for result in self.generate( for result in self.generate(
llm_inputs=llm_inputs, llm_inputs=llm_inputs,
@ -267,13 +267,13 @@ class Llama4:
def chat_completion( def chat_completion(
self, self,
messages_batch: List[List[RawMessage]], messages_batch: list[list[RawMessage]],
temperature: float = 0.6, temperature: float = 0.6,
top_p: float = 0.9, top_p: float = 0.9,
max_gen_len: Optional[int] = None, max_gen_len: int | None = None,
logprobs: bool = False, logprobs: bool = False,
echo: bool = False, echo: bool = False,
) -> Generator[List[GenerationResult], None, None]: ) -> Generator[list[GenerationResult], None, None]:
llm_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch] llm_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
for result in self.generate( for result in self.generate(
llm_inputs=llm_inputs, llm_inputs=llm_inputs,

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import math import math
from typing import Any, Dict, List, Optional, Tuple from typing import Any
import fairscale.nn.model_parallel.initialize as fs_init import fairscale.nn.model_parallel.initialize as fs_init
import torch import torch
@ -89,7 +89,7 @@ def apply_rotary_emb(
xq: torch.Tensor, xq: torch.Tensor,
xk: torch.Tensor, xk: torch.Tensor,
freqs_cis: torch.Tensor, freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_) freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
@ -174,13 +174,13 @@ class Attention(nn.Module):
def load_hook( def load_hook(
self, self,
state_dict: Dict[str, Any], state_dict: dict[str, Any],
prefix: str, prefix: str,
local_metadata: Dict[str, Any], local_metadata: dict[str, Any],
strict: bool, strict: bool,
missing_keys: List[str], missing_keys: list[str],
unexpected_keys: List[str], unexpected_keys: list[str],
error_msgs: List[str], error_msgs: list[str],
) -> None: ) -> None:
if prefix + "wqkv.weight" in state_dict: if prefix + "wqkv.weight" in state_dict:
wqkv = state_dict.pop(prefix + "wqkv.weight") wqkv = state_dict.pop(prefix + "wqkv.weight")
@ -200,7 +200,7 @@ class Attention(nn.Module):
x: torch.Tensor, x: torch.Tensor,
start_pos: int, start_pos: int,
freqs_cis: torch.Tensor, freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor] = None, mask: torch.Tensor | None = None,
): ):
bsz, seqlen, _ = x.shape bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
@ -288,13 +288,13 @@ class TransformerBlock(nn.Module):
def load_hook( def load_hook(
self, self,
state_dict: Dict[str, Any], state_dict: dict[str, Any],
prefix: str, prefix: str,
local_metadata: Dict[str, Any], local_metadata: dict[str, Any],
strict: bool, strict: bool,
missing_keys: List[str], missing_keys: list[str],
unexpected_keys: List[str], unexpected_keys: list[str],
error_msgs: List[str], error_msgs: list[str],
) -> None: ) -> None:
if prefix + "attention.wqkv.layer_norm_weight" in state_dict: if prefix + "attention.wqkv.layer_norm_weight" in state_dict:
state_dict[prefix + "attention_norm.weight"] = state_dict.pop(prefix + "attention.wqkv.layer_norm_weight") state_dict[prefix + "attention_norm.weight"] = state_dict.pop(prefix + "attention.wqkv.layer_norm_weight")
@ -318,8 +318,8 @@ class TransformerBlock(nn.Module):
x: torch.Tensor, x: torch.Tensor,
start_pos: int, start_pos: int,
freqs_cis: torch.Tensor, freqs_cis: torch.Tensor,
global_attn_mask: Optional[torch.Tensor], global_attn_mask: torch.Tensor | None,
local_attn_mask: Optional[torch.Tensor], local_attn_mask: torch.Tensor | None,
): ):
# The iRoPE architecture uses global attention mask for NoPE layers or # The iRoPE architecture uses global attention mask for NoPE layers or
# if chunked local attention is not used # if chunked local attention is not used
@ -374,13 +374,13 @@ class Transformer(nn.Module):
def load_hook( def load_hook(
self, self,
state_dict: Dict[str, Any], state_dict: dict[str, Any],
prefix: str, prefix: str,
local_metadata: Dict[str, Any], local_metadata: dict[str, Any],
strict: bool, strict: bool,
missing_keys: List[str], missing_keys: list[str],
unexpected_keys: List[str], unexpected_keys: list[str],
error_msgs: List[str], error_msgs: list[str],
) -> None: ) -> None:
if prefix + "rope.freqs" in state_dict: if prefix + "rope.freqs" in state_dict:
state_dict.pop(prefix + "rope.freqs") state_dict.pop(prefix + "rope.freqs")

View file

@ -6,7 +6,7 @@
# ruff: noqa: N806 # ruff: noqa: N806
# pyre-strict # pyre-strict
from typing import Any, Dict, List from typing import Any
import fairscale.nn.model_parallel.initialize as fs_init import fairscale.nn.model_parallel.initialize as fs_init
import torch import torch
@ -63,13 +63,13 @@ class Experts(nn.Module):
def load_hook( def load_hook(
self, self,
state_dict: Dict[str, Any], state_dict: dict[str, Any],
prefix: str, prefix: str,
local_metadata: Dict[str, Any], local_metadata: dict[str, Any],
strict: bool, strict: bool,
missing_keys: List[str], missing_keys: list[str],
unexpected_keys: List[str], unexpected_keys: list[str],
error_msgs: List[str], error_msgs: list[str],
) -> None: ) -> None:
self.prefix = prefix self.prefix = prefix
if prefix + "moe_w_in_eD_F" in state_dict: if prefix + "moe_w_in_eD_F" in state_dict:
@ -158,13 +158,13 @@ class MoE(torch.nn.Module):
def load_hook( def load_hook(
self, self,
state_dict: Dict[str, Any], state_dict: dict[str, Any],
prefix: str, prefix: str,
local_metadata: Dict[str, Any], local_metadata: dict[str, Any],
strict: bool, strict: bool,
missing_keys: List[str], missing_keys: list[str],
unexpected_keys: List[str], unexpected_keys: list[str],
error_msgs: List[str], error_msgs: list[str],
) -> None: ) -> None:
if prefix + "w_in_shared_FD.weight" in state_dict: if prefix + "w_in_shared_FD.weight" in state_dict:
state_dict[prefix + "shared_expert.w1.weight"] = state_dict.pop(prefix + "w_in_shared_FD.weight") state_dict[prefix + "shared_expert.w1.weight"] = state_dict.pop(prefix + "w_in_shared_FD.weight")
@ -210,5 +210,5 @@ class MoE(torch.nn.Module):
def divide_exact(numerator: int, denominator: int) -> int: def divide_exact(numerator: int, denominator: int) -> int:
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator) assert numerator % denominator == 0, f"{numerator} is not divisible by {denominator}"
return numerator // denominator return numerator // denominator

View file

@ -13,7 +13,6 @@
import math import math
from collections import defaultdict from collections import defaultdict
from typing import Optional, Set, Tuple
import torch import torch
import torchvision.transforms as tv import torchvision.transforms as tv
@ -52,7 +51,7 @@ class ResizeNormalizeImageTransform:
return self.tv_transform(image) return self.tv_transform(image)
class VariableSizeImageTransform(object): class VariableSizeImageTransform:
""" """
This class accepts images of any size and dynamically resize, pads and chunks it This class accepts images of any size and dynamically resize, pads and chunks it
based on the image aspect ratio and the number of image chunks we allow. based on the image aspect ratio and the number of image chunks we allow.
@ -100,7 +99,7 @@ class VariableSizeImageTransform(object):
self.resample = tv.InterpolationMode.BILINEAR self.resample = tv.InterpolationMode.BILINEAR
@staticmethod @staticmethod
def get_factors(n: int) -> Set[int]: def get_factors(n: int) -> set[int]:
""" """
Calculate all factors of a given number, i.e. a dividor that leaves Calculate all factors of a given number, i.e. a dividor that leaves
no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}. no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}.
@ -170,9 +169,9 @@ class VariableSizeImageTransform(object):
@staticmethod @staticmethod
def get_max_res_without_distortion( def get_max_res_without_distortion(
image_size: Tuple[int, int], image_size: tuple[int, int],
target_size: Tuple[int, int], target_size: tuple[int, int],
) -> Tuple[int, int]: ) -> tuple[int, int]:
""" """
Determines the maximum resolution to which an image can be resized to without distorting its Determines the maximum resolution to which an image can be resized to without distorting its
aspect ratio, based on the target resolution. aspect ratio, based on the target resolution.
@ -223,8 +222,8 @@ class VariableSizeImageTransform(object):
def resize_without_distortion( def resize_without_distortion(
self, self,
image: torch.Tensor, image: torch.Tensor,
target_size: Tuple[int, int], target_size: tuple[int, int],
max_upscaling_size: Optional[int], max_upscaling_size: int | None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Used to resize an image to target_resolution, without distortion. Used to resize an image to target_resolution, without distortion.
@ -289,10 +288,10 @@ class VariableSizeImageTransform(object):
def get_best_fit( def get_best_fit(
self, self,
image_size: Tuple[int, int], image_size: tuple[int, int],
possible_resolutions: torch.Tensor, possible_resolutions: torch.Tensor,
resize_to_max_canvas: bool = False, resize_to_max_canvas: bool = False,
) -> Tuple[int, int]: ) -> tuple[int, int]:
""" """
Determines the best canvas possible from a list of possible resolutions to, without distortion, Determines the best canvas possible from a list of possible resolutions to, without distortion,
resize an image to. resize an image to.
@ -392,7 +391,7 @@ class VariableSizeImageTransform(object):
max_num_chunks: int, max_num_chunks: int,
normalize_img: bool = True, normalize_img: bool = True,
resize_to_max_canvas: bool = False, resize_to_max_canvas: bool = False,
) -> Tuple[torch.Tensor, Tuple[int, int]]: ) -> tuple[torch.Tensor, tuple[int, int]]:
""" """
Args: Args:
image (PIL.Image): Image to be resized. image (PIL.Image): Image to be resized.

View file

@ -12,7 +12,6 @@
# the top-level of this source tree. # the top-level of this source tree.
import textwrap import textwrap
from typing import List, Optional
from llama_stack.apis.inference import ToolDefinition, ToolParamDefinition from llama_stack.apis.inference import ToolDefinition, ToolParamDefinition
from llama_stack.models.llama.llama3.prompt_templates.base import ( from llama_stack.models.llama.llama3.prompt_templates.base import (
@ -67,14 +66,14 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
""".strip("\n") """.strip("\n")
) )
def gen(self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None) -> PromptTemplate: def gen(self, custom_tools: list[ToolDefinition], system_prompt: str | None = None) -> PromptTemplate:
system_prompt = system_prompt or self.DEFAULT_PROMPT system_prompt = system_prompt or self.DEFAULT_PROMPT
return PromptTemplate( return PromptTemplate(
system_prompt, system_prompt,
{"function_description": self._gen_function_description(custom_tools)}, {"function_description": self._gen_function_description(custom_tools)},
) )
def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate: def _gen_function_description(self, custom_tools: list[ToolDefinition]) -> PromptTemplate:
template_str = textwrap.dedent( template_str = textwrap.dedent(
""" """
Here is a list of functions in JSON format that you can invoke. Here is a list of functions in JSON format that you can invoke.
@ -120,7 +119,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
{"tools": [t.model_dump() for t in custom_tools]}, {"tools": [t.model_dump() for t in custom_tools]},
).render() ).render()
def data_examples(self) -> List[List[ToolDefinition]]: def data_examples(self) -> list[list[ToolDefinition]]:
return [ return [
[ [
ToolDefinition( ToolDefinition(

View file

@ -7,7 +7,6 @@
import textwrap import textwrap
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import List
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import ( from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
PythonListCustomToolGenerator, PythonListCustomToolGenerator,
@ -23,7 +22,7 @@ from ..prompt_format import (
THIS_DIR = Path(__file__).parent THIS_DIR = Path(__file__).parent
def usecases(base_model: bool = False) -> List[UseCase | str]: def usecases(base_model: bool = False) -> list[UseCase | str]:
with open(THIS_DIR.parent / "resources/small_dog.jpg", "rb") as f: with open(THIS_DIR.parent / "resources/small_dog.jpg", "rb") as f:
img_small_dog = f.read() img_small_dog = f.read()
with open(THIS_DIR.parent / "resources/dog.jpg", "rb") as f: with open(THIS_DIR.parent / "resources/dog.jpg", "rb") as f:

View file

@ -6,7 +6,7 @@
import logging import logging
import os import os
from typing import Callable, Optional from collections.abc import Callable
import torch import torch
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
@ -45,8 +45,8 @@ def experts_batched_swiglu_wrapper(
def convert_to_quantized_model( def convert_to_quantized_model(
model: Transformer, model: Transformer,
checkpoint_dir: str, checkpoint_dir: str,
quantization_mode: Optional[str] = None, quantization_mode: str | None = None,
fp8_activation_scale_ub: Optional[float] = 1200.0, fp8_activation_scale_ub: float | None = 1200.0,
use_rich_progress: bool = True, use_rich_progress: bool = True,
) -> Transformer: ) -> Transformer:
from ...quantize_impls import ( from ...quantize_impls import (
@ -213,7 +213,7 @@ def logging_callbacks(
) )
task_id = progress.add_task("[blue]Converting layers...", total=total_blocks, status="Starting") task_id = progress.add_task("[blue]Converting layers...", total=total_blocks, status="Starting")
def update_status(message: Optional[str], completed: Optional[int] = None) -> None: def update_status(message: str | None, completed: int | None = None) -> None:
if use_rich_progress: if use_rich_progress:
if message is not None: if message is not None:
progress.update(task_id, status=message) progress.update(task_id, status=message)

View file

@ -5,18 +5,11 @@
# the root directory of this source tree. # the root directory of this source tree.
import os import os
from collections.abc import Collection, Iterator, Sequence, Set
from logging import getLogger from logging import getLogger
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
AbstractSet,
Collection,
Dict,
Iterator,
List,
Literal, Literal,
Optional,
Sequence,
Union,
cast, cast,
) )
@ -114,7 +107,7 @@ class Tokenizer:
Tokenizing and encoding/decoding text using the Tiktoken tokenizer. Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
""" """
special_tokens: Dict[str, int] special_tokens: dict[str, int]
num_reserved_special_tokens = 2048 num_reserved_special_tokens = 2048
@ -182,9 +175,9 @@ class Tokenizer:
*, *,
bos: bool, bos: bool,
eos: bool, eos: bool,
allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None, allowed_special: Literal["all"] | Set[str] | None = None,
disallowed_special: Union[Literal["all"], Collection[str]] = (), disallowed_special: Literal["all"] | Collection[str] = (),
) -> List[int]: ) -> list[int]:
""" """
Encodes a string into a list of token IDs. Encodes a string into a list of token IDs.
@ -217,7 +210,7 @@ class Tokenizer:
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
) )
) )
t: List[int] = [] t: list[int] = []
for substr in substrs: for substr in substrs:
t.extend( t.extend(
self.model.encode( self.model.encode(
@ -243,7 +236,7 @@ class Tokenizer:
str: The decoded string. str: The decoded string.
""" """
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
return self.model.decode(cast(List[int], t)) return self.model.decode(cast(list[int], t))
@staticmethod @staticmethod
def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]: def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]:

View file

@ -5,7 +5,8 @@
# the root directory of this source tree. # the root directory of this source tree.
import math import math
from typing import Any, Callable, Dict, List from collections.abc import Callable
from typing import Any
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -136,13 +137,13 @@ class VisionEmbeddings(torch.nn.Module):
def load_hook( def load_hook(
self, self,
state_dict: Dict[str, Any], state_dict: dict[str, Any],
prefix: str, prefix: str,
local_metadata: Dict[str, Any], local_metadata: dict[str, Any],
strict: bool = True, strict: bool = True,
missing_keys: List[str] = None, missing_keys: list[str] = None,
unexpected_keys: List[str] = None, unexpected_keys: list[str] = None,
error_msgs: List[str] = None, error_msgs: list[str] = None,
return_state_dict: bool = False, return_state_dict: bool = False,
) -> None: ) -> None:
original_sd = self.state_dict() original_sd = self.state_dict()
@ -163,7 +164,7 @@ class VisionEmbeddings(torch.nn.Module):
# each image is a tensor of shape [num_tiles, C, H, W] # each image is a tensor of shape [num_tiles, C, H, W]
def forward( def forward(
self, self,
image_batch: List[List[torch.Tensor]], image_batch: list[list[torch.Tensor]],
image_mask: torch.Tensor, image_mask: torch.Tensor,
h_ref: torch.Tensor, h_ref: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:

View file

@ -4,7 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from collections.abc import Callable
from typing import Any
import fairscale.nn.model_parallel.initialize as fs_init import fairscale.nn.model_parallel.initialize as fs_init
import torch import torch
@ -42,9 +43,9 @@ class ColumnParallelConv2dPatch(torch.nn.Module):
self, self,
in_channels: int, in_channels: int,
out_channels: int, out_channels: int,
kernel_size: Union[int, Tuple[int, int]], kernel_size: int | tuple[int, int],
stride: Union[int, Tuple[int, int]], stride: int | tuple[int, int],
bias: Optional[bool] = False, bias: bool | None = False,
) -> None: ) -> None:
super().__init__() super().__init__()
if isinstance(kernel_size, int): if isinstance(kernel_size, int):
@ -134,15 +135,15 @@ class _TransformerBlock(nn.Module):
def attention( def attention(
self, self,
x: torch.Tensor, x: torch.Tensor,
freq_cis: Optional[torch.Tensor] = None, freq_cis: torch.Tensor | None = None,
): ):
return self.attn(x=x, start_pos=0, freqs_cis=freq_cis) return self.attn(x=x, start_pos=0, freqs_cis=freq_cis)
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
mask: Optional[torch.Tensor] = None, mask: torch.Tensor | None = None,
freq_cis: Optional[torch.Tensor] = None, freq_cis: torch.Tensor | None = None,
): ):
_gate_attn = 1 if not self.gated else self.gate_attn.tanh() _gate_attn = 1 if not self.gated else self.gate_attn.tanh()
_gate_ffn = 1 if not self.gated else self.gate_ffn.tanh() _gate_ffn = 1 if not self.gated else self.gate_ffn.tanh()
@ -210,8 +211,8 @@ class PackingIndex:
class VisionEncoder(nn.Module): class VisionEncoder(nn.Module):
def __init__( def __init__(
self, self,
image_size: Tuple[int, int], image_size: tuple[int, int],
patch_size: Tuple[int, int], patch_size: tuple[int, int],
dim: int, dim: int,
layers: int, layers: int,
heads: int, heads: int,
@ -299,13 +300,13 @@ class VisionEncoder(nn.Module):
def load_hook( def load_hook(
self, self,
state_dict: Dict[str, Any], state_dict: dict[str, Any],
prefix: str, prefix: str,
local_metadata: Dict[str, Any], local_metadata: dict[str, Any],
strict: bool = True, strict: bool = True,
missing_keys: List[str] = None, missing_keys: list[str] = None,
unexpected_keys: List[str] = None, unexpected_keys: list[str] = None,
error_msgs: List[str] = None, error_msgs: list[str] = None,
return_state_dict: bool = False, return_state_dict: bool = False,
) -> None: ) -> None:
orig_pos_embed = state_dict.get(prefix + "positional_embedding") orig_pos_embed = state_dict.get(prefix + "positional_embedding")

View file

@ -14,7 +14,6 @@
import json import json
import textwrap import textwrap
from pathlib import Path from pathlib import Path
from typing import List
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -44,7 +43,7 @@ class TextCompletionContent(BaseModel):
class UseCase(BaseModel): class UseCase(BaseModel):
title: str = "" title: str = ""
description: str = "" description: str = ""
dialogs: List[List[RawMessage] | TextCompletionContent | str] = Field(default_factory=list) dialogs: list[list[RawMessage] | TextCompletionContent | str] = Field(default_factory=list)
notes: str = "" notes: str = ""
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json
max_gen_len: int = 512 max_gen_len: int = 512

View file

@ -7,7 +7,6 @@
# type: ignore # type: ignore
import collections import collections
import logging import logging
from typing import Optional, Tuple, Type, Union
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -27,7 +26,7 @@ class Fp8ScaledWeights:
# TODO: Ugly trick so torch allows us to replace parameters # TODO: Ugly trick so torch allows us to replace parameters
# with our custom Fp8Weights instance. Do this properly. # with our custom Fp8Weights instance. Do this properly.
@property @property
def __class__(self) -> Type[nn.parameter.Parameter]: def __class__(self) -> type[nn.parameter.Parameter]:
return nn.Parameter return nn.Parameter
@property @property
@ -51,7 +50,7 @@ class Int4ScaledWeights:
# TODO: Ugly trick so torch allows us to replace parameters # TODO: Ugly trick so torch allows us to replace parameters
# with our custom Int4Weights instance. Do this properly. # with our custom Int4Weights instance. Do this properly.
@property @property
def __class__(self) -> Type[nn.parameter.Parameter]: def __class__(self) -> type[nn.parameter.Parameter]:
return nn.Parameter return nn.Parameter
@property @property
@ -74,7 +73,7 @@ class Int4Weights(
def int4_row_quantize( def int4_row_quantize(
x: torch.Tensor, x: torch.Tensor,
group_size: int = 128, group_size: int = 128,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
n_bit = 4 # Number of target bits. n_bit = 4 # Number of target bits.
to_quant = x.reshape(-1, group_size).to(torch.float) to_quant = x.reshape(-1, group_size).to(torch.float)
@ -115,8 +114,8 @@ def pack_int4(x: torch.Tensor) -> torch.Tensor:
def bmm_nt( def bmm_nt(
x: Tensor, x: Tensor,
w: Union[Fp8RowwiseWeights, Int4Weights], w: Fp8RowwiseWeights | Int4Weights,
num_tokens: Optional[Tensor] = None, num_tokens: Tensor | None = None,
) -> Tensor: ) -> Tensor:
if isinstance(w, Fp8ScaledWeights): if isinstance(w, Fp8ScaledWeights):
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, w.activation_scale_ub) xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, w.activation_scale_ub)
@ -129,10 +128,10 @@ def bmm_nt(
def ffn_swiglu( def ffn_swiglu(
x: Tensor, x: Tensor,
w1: Union[Fp8RowwiseWeights, Int4Weights], w1: Fp8RowwiseWeights | Int4Weights,
w3: Union[Fp8RowwiseWeights, Int4Weights], w3: Fp8RowwiseWeights | Int4Weights,
w2: Union[Fp8RowwiseWeights, Int4Weights], w2: Fp8RowwiseWeights | Int4Weights,
num_tokens: Optional[Tensor] = None, num_tokens: Tensor | None = None,
is_memory_bounded: bool = False, is_memory_bounded: bool = False,
) -> Tensor: ) -> Tensor:
if (isinstance(w1, Fp8ScaledWeights) and isinstance(w3, Fp8ScaledWeights) and isinstance(w2, Fp8ScaledWeights)) or ( if (isinstance(w1, Fp8ScaledWeights) and isinstance(w3, Fp8ScaledWeights) and isinstance(w2, Fp8ScaledWeights)) or (
@ -158,7 +157,7 @@ def ffn_swiglu(
def quantize_fp8( def quantize_fp8(
w: Tensor, w: Tensor,
fp8_activation_scale_ub: float, fp8_activation_scale_ub: float,
output_device: Optional[torch.device] = None, output_device: torch.device | None = None,
) -> Fp8RowwiseWeights: ) -> Fp8RowwiseWeights:
"""Quantize [n, k] weight tensor. """Quantize [n, k] weight tensor.
@ -184,7 +183,7 @@ def quantize_fp8(
@torch.inference_mode() @torch.inference_mode()
def quantize_int4( def quantize_int4(
w: Tensor, w: Tensor,
output_device: Optional[torch.device] = None, output_device: torch.device | None = None,
) -> Int4Weights: ) -> Int4Weights:
"""Quantize [n, k/2] weight tensor. """Quantize [n, k/2] weight tensor.
@ -213,7 +212,7 @@ def load_fp8(
w: Tensor, w: Tensor,
w_scale: Tensor, w_scale: Tensor,
fp8_activation_scale_ub: float, fp8_activation_scale_ub: float,
output_device: Optional[torch.device] = None, output_device: torch.device | None = None,
) -> Fp8RowwiseWeights: ) -> Fp8RowwiseWeights:
"""Load FP8 [n, k] weight tensor. """Load FP8 [n, k] weight tensor.
@ -239,7 +238,7 @@ def load_int4(
w: Tensor, w: Tensor,
scale: Tensor, scale: Tensor,
zero_point: Tensor, zero_point: Tensor,
output_device: Optional[torch.device] = None, output_device: torch.device | None = None,
) -> Int4Weights: ) -> Int4Weights:
"""Load INT4 [n, k/2] weight tensor. """Load INT4 [n, k/2] weight tensor.
@ -256,9 +255,9 @@ def load_int4(
def fc_dynamic( def fc_dynamic(
x: Tensor, x: Tensor,
w: Union[Fp8RowwiseWeights, Int4Weights], w: Fp8RowwiseWeights | Int4Weights,
activation_scale_ub: Optional[Tensor] = None, activation_scale_ub: Tensor | None = None,
num_tokens: Optional[Tensor] = None, num_tokens: Tensor | None = None,
is_memory_bounded: bool = False, is_memory_bounded: bool = False,
) -> Tensor: ) -> Tensor:
""" """
@ -275,11 +274,11 @@ def fc_dynamic(
def ffn_swiglu_dynamic( def ffn_swiglu_dynamic(
x: Tensor, x: Tensor,
w1: Union[Fp8RowwiseWeights, Int4Weights], w1: Fp8RowwiseWeights | Int4Weights,
w3: Union[Fp8RowwiseWeights, Int4Weights], w3: Fp8RowwiseWeights | Int4Weights,
w2: Union[Fp8RowwiseWeights, Int4Weights], w2: Fp8RowwiseWeights | Int4Weights,
activation_scale_ub: Optional[Tensor] = None, activation_scale_ub: Tensor | None = None,
num_tokens: Optional[Tensor] = None, num_tokens: Tensor | None = None,
is_memory_bounded: bool = False, is_memory_bounded: bool = False,
) -> Tensor: ) -> Tensor:
assert x.dim() == 3 or x.dim() == 2 assert x.dim() == 3 or x.dim() == 2

View file

@ -6,7 +6,6 @@
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
from typing import List, Optional
from .sku_types import ( from .sku_types import (
CheckpointQuantizationFormat, CheckpointQuantizationFormat,
@ -19,14 +18,14 @@ LLAMA2_VOCAB_SIZE = 32000
LLAMA3_VOCAB_SIZE = 128256 LLAMA3_VOCAB_SIZE = 128256
def resolve_model(descriptor: str) -> Optional[Model]: def resolve_model(descriptor: str) -> Model | None:
for m in all_registered_models(): for m in all_registered_models():
if descriptor in (m.descriptor(), m.huggingface_repo): if descriptor in (m.descriptor(), m.huggingface_repo):
return m return m
return None return None
def all_registered_models() -> List[Model]: def all_registered_models() -> list[Model]:
return ( return (
llama2_family() llama2_family()
+ llama3_family() + llama3_family()
@ -38,48 +37,48 @@ def all_registered_models() -> List[Model]:
) )
def llama2_family() -> List[Model]: def llama2_family() -> list[Model]:
return [ return [
*llama2_base_models(), *llama2_base_models(),
*llama2_instruct_models(), *llama2_instruct_models(),
] ]
def llama3_family() -> List[Model]: def llama3_family() -> list[Model]:
return [ return [
*llama3_base_models(), *llama3_base_models(),
*llama3_instruct_models(), *llama3_instruct_models(),
] ]
def llama3_1_family() -> List[Model]: def llama3_1_family() -> list[Model]:
return [ return [
*llama3_1_base_models(), *llama3_1_base_models(),
*llama3_1_instruct_models(), *llama3_1_instruct_models(),
] ]
def llama3_2_family() -> List[Model]: def llama3_2_family() -> list[Model]:
return [ return [
*llama3_2_base_models(), *llama3_2_base_models(),
*llama3_2_instruct_models(), *llama3_2_instruct_models(),
] ]
def llama3_3_family() -> List[Model]: def llama3_3_family() -> list[Model]:
return [ return [
*llama3_3_instruct_models(), *llama3_3_instruct_models(),
] ]
def llama4_family() -> List[Model]: def llama4_family() -> list[Model]:
return [ return [
*llama4_base_models(), *llama4_base_models(),
*llama4_instruct_models(), *llama4_instruct_models(),
] ]
def llama4_base_models() -> List[Model]: def llama4_base_models() -> list[Model]:
return [ return [
Model( Model(
core_model_id=CoreModelId.llama4_scout_17b_16e, core_model_id=CoreModelId.llama4_scout_17b_16e,
@ -98,7 +97,7 @@ def llama4_base_models() -> List[Model]:
] ]
def llama4_instruct_models() -> List[Model]: def llama4_instruct_models() -> list[Model]:
return [ return [
Model( Model(
core_model_id=CoreModelId.llama4_scout_17b_16e_instruct, core_model_id=CoreModelId.llama4_scout_17b_16e_instruct,
@ -126,7 +125,7 @@ def llama4_instruct_models() -> List[Model]:
] ]
def llama2_base_models() -> List[Model]: def llama2_base_models() -> list[Model]:
return [ return [
Model( Model(
core_model_id=CoreModelId.llama2_7b, core_model_id=CoreModelId.llama2_7b,
@ -185,7 +184,7 @@ def llama2_base_models() -> List[Model]:
] ]
def llama3_base_models() -> List[Model]: def llama3_base_models() -> list[Model]:
return [ return [
Model( Model(
core_model_id=CoreModelId.llama3_8b, core_model_id=CoreModelId.llama3_8b,
@ -226,7 +225,7 @@ def llama3_base_models() -> List[Model]:
] ]
def llama3_1_base_models() -> List[Model]: def llama3_1_base_models() -> list[Model]:
return [ return [
Model( Model(
core_model_id=CoreModelId.llama3_1_8b, core_model_id=CoreModelId.llama3_1_8b,
@ -324,7 +323,7 @@ def llama3_1_base_models() -> List[Model]:
] ]
def llama3_2_base_models() -> List[Model]: def llama3_2_base_models() -> list[Model]:
return [ return [
Model( Model(
core_model_id=CoreModelId.llama3_2_1b, core_model_id=CoreModelId.llama3_2_1b,
@ -407,7 +406,7 @@ def llama3_2_base_models() -> List[Model]:
] ]
def llama2_instruct_models() -> List[Model]: def llama2_instruct_models() -> list[Model]:
return [ return [
Model( Model(
core_model_id=CoreModelId.llama2_7b_chat, core_model_id=CoreModelId.llama2_7b_chat,
@ -466,7 +465,7 @@ def llama2_instruct_models() -> List[Model]:
] ]
def llama3_instruct_models() -> List[Model]: def llama3_instruct_models() -> list[Model]:
return [ return [
Model( Model(
core_model_id=CoreModelId.llama3_8b_instruct, core_model_id=CoreModelId.llama3_8b_instruct,
@ -507,7 +506,7 @@ def llama3_instruct_models() -> List[Model]:
] ]
def llama3_1_instruct_models() -> List[Model]: def llama3_1_instruct_models() -> list[Model]:
return [ return [
Model( Model(
core_model_id=CoreModelId.llama3_1_8b_instruct, core_model_id=CoreModelId.llama3_1_8b_instruct,
@ -635,7 +634,7 @@ def arch_args_3b() -> dict:
} }
def llama3_2_quantized_models() -> List[Model]: def llama3_2_quantized_models() -> list[Model]:
return [ return [
Model( Model(
core_model_id=CoreModelId.llama3_2_1b_instruct, core_model_id=CoreModelId.llama3_2_1b_instruct,
@ -704,7 +703,7 @@ def llama3_2_quantized_models() -> List[Model]:
] ]
def llama3_2_instruct_models() -> List[Model]: def llama3_2_instruct_models() -> list[Model]:
return [ return [
Model( Model(
core_model_id=CoreModelId.llama3_2_1b_instruct, core_model_id=CoreModelId.llama3_2_1b_instruct,
@ -766,7 +765,7 @@ def llama3_2_instruct_models() -> List[Model]:
] ]
def llama3_3_instruct_models() -> List[Model]: def llama3_3_instruct_models() -> list[Model]:
return [ return [
Model( Model(
core_model_id=CoreModelId.llama3_3_70b_instruct, core_model_id=CoreModelId.llama3_3_70b_instruct,
@ -790,7 +789,7 @@ def llama3_3_instruct_models() -> List[Model]:
@lru_cache @lru_cache
def safety_models() -> List[Model]: def safety_models() -> list[Model]:
return [ return [
Model( Model(
core_model_id=CoreModelId.llama_guard_4_12b, core_model_id=CoreModelId.llama_guard_4_12b,
@ -919,7 +918,7 @@ def safety_models() -> List[Model]:
@dataclass @dataclass
class LlamaDownloadInfo: class LlamaDownloadInfo:
folder: str folder: str
files: List[str] files: list[str]
pth_size: int pth_size: int

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Any, Dict, Optional from typing import Any
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
@ -159,13 +159,13 @@ def model_family(model_id) -> ModelFamily:
class Model(BaseModel): class Model(BaseModel):
core_model_id: CoreModelId core_model_id: CoreModelId
description: str description: str
huggingface_repo: Optional[str] = None huggingface_repo: str | None = None
arch_args: Dict[str, Any] arch_args: dict[str, Any]
variant: str = "" variant: str = ""
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16 quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
pth_file_count: int pth_file_count: int
metadata: Dict[str, Any] = Field(default_factory=dict) metadata: dict[str, Any] = Field(default_factory=dict)
# silence pydantic until we remove the `model_` fields # silence pydantic until we remove the `model_` fields
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Any, List, Optional, Protocol from typing import Any, Protocol
from urllib.parse import urlparse from urllib.parse import urlparse
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -65,7 +65,7 @@ class DatasetsProtocolPrivate(Protocol):
class ScoringFunctionsProtocolPrivate(Protocol): class ScoringFunctionsProtocolPrivate(Protocol):
async def list_scoring_functions(self) -> List[ScoringFn]: ... async def list_scoring_functions(self) -> list[ScoringFn]: ...
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None: ... async def register_scoring_function(self, scoring_fn: ScoringFn) -> None: ...
@ -88,24 +88,24 @@ class ProviderSpec(BaseModel):
..., ...,
description="Fully-qualified classname of the config for this provider", description="Fully-qualified classname of the config for this provider",
) )
api_dependencies: List[Api] = Field( api_dependencies: list[Api] = Field(
default_factory=list, default_factory=list,
description="Higher-level API surfaces may depend on other providers to provide their functionality", description="Higher-level API surfaces may depend on other providers to provide their functionality",
) )
optional_api_dependencies: List[Api] = Field( optional_api_dependencies: list[Api] = Field(
default_factory=list, default_factory=list,
) )
deprecation_warning: Optional[str] = Field( deprecation_warning: str | None = Field(
default=None, default=None,
description="If this provider is deprecated, specify the warning message here", description="If this provider is deprecated, specify the warning message here",
) )
deprecation_error: Optional[str] = Field( deprecation_error: str | None = Field(
default=None, default=None,
description="If this provider is deprecated and does NOT work, specify the error message here", description="If this provider is deprecated and does NOT work, specify the error message here",
) )
# used internally by the resolver; this is a hack for now # used internally by the resolver; this is a hack for now
deps__: List[str] = Field(default_factory=list) deps__: list[str] = Field(default_factory=list)
@property @property
def is_sample(self) -> bool: def is_sample(self) -> bool:
@ -131,25 +131,25 @@ Fully-qualified name of the module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation - `get_adapter_impl(config, deps)`: returns the adapter implementation
""", """,
) )
pip_packages: List[str] = Field( pip_packages: list[str] = Field(
default_factory=list, default_factory=list,
description="The pip dependencies needed for this implementation", description="The pip dependencies needed for this implementation",
) )
config_class: str = Field( config_class: str = Field(
description="Fully-qualified classname of the config for this provider", description="Fully-qualified classname of the config for this provider",
) )
provider_data_validator: Optional[str] = Field( provider_data_validator: str | None = Field(
default=None, default=None,
) )
@json_schema_type @json_schema_type
class InlineProviderSpec(ProviderSpec): class InlineProviderSpec(ProviderSpec):
pip_packages: List[str] = Field( pip_packages: list[str] = Field(
default_factory=list, default_factory=list,
description="The pip dependencies needed for this implementation", description="The pip dependencies needed for this implementation",
) )
container_image: Optional[str] = Field( container_image: str | None = Field(
default=None, default=None,
description=""" description="""
The container image to use for this implementation. If one is provided, pip_packages will be ignored. The container image to use for this implementation. If one is provided, pip_packages will be ignored.
@ -164,14 +164,14 @@ Fully-qualified name of the module to import. The module is expected to have:
- `get_provider_impl(config, deps)`: returns the local implementation - `get_provider_impl(config, deps)`: returns the local implementation
""", """,
) )
provider_data_validator: Optional[str] = Field( provider_data_validator: str | None = Field(
default=None, default=None,
) )
class RemoteProviderConfig(BaseModel): class RemoteProviderConfig(BaseModel):
host: str = "localhost" host: str = "localhost"
port: Optional[int] = None port: int | None = None
protocol: str = "http" protocol: str = "http"
@property @property
@ -197,7 +197,7 @@ API responses, specify the adapter here.
) )
@property @property
def container_image(self) -> Optional[str]: def container_image(self) -> str | None:
return None return None
@property @property
@ -205,16 +205,16 @@ API responses, specify the adapter here.
return self.adapter.module return self.adapter.module
@property @property
def pip_packages(self) -> List[str]: def pip_packages(self) -> list[str]:
return self.adapter.pip_packages return self.adapter.pip_packages
@property @property
def provider_data_validator(self) -> Optional[str]: def provider_data_validator(self) -> str | None:
return self.adapter.provider_data_validator return self.adapter.provider_data_validator
def remote_provider_spec( def remote_provider_spec(
api: Api, adapter: AdapterSpec, api_dependencies: Optional[List[Api]] = None api: Api, adapter: AdapterSpec, api_dependencies: list[Api] | None = None
) -> RemoteProviderSpec: ) -> RemoteProviderSpec:
return RemoteProviderSpec( return RemoteProviderSpec(
api=api, api=api,

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict from typing import Any
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from .config import MetaReferenceAgentsImplConfig from .config import MetaReferenceAgentsImplConfig
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, Any]): async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any]):
from .agents import MetaReferenceAgentsImpl from .agents import MetaReferenceAgentsImpl
impl = MetaReferenceAgentsImpl( impl = MetaReferenceAgentsImpl(

View file

@ -10,8 +10,8 @@ import re
import secrets import secrets
import string import string
import uuid import uuid
from collections.abc import AsyncGenerator
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import AsyncGenerator, List, Optional, Union
import httpx import httpx
@ -112,7 +112,7 @@ class ChatAgent(ShieldRunnerMixin):
output_shields=agent_config.output_shields, output_shields=agent_config.output_shields,
) )
def turn_to_messages(self, turn: Turn) -> List[Message]: def turn_to_messages(self, turn: Turn) -> list[Message]:
messages = [] messages = []
# NOTE: if a toolcall response is in a step, we do not add it when processing the input messages # NOTE: if a toolcall response is in a step, we do not add it when processing the input messages
@ -161,7 +161,7 @@ class ChatAgent(ShieldRunnerMixin):
async def create_session(self, name: str) -> str: async def create_session(self, name: str) -> str:
return await self.storage.create_session(name) return await self.storage.create_session(name)
async def get_messages_from_turns(self, turns: List[Turn]) -> List[Message]: async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
messages = [] messages = []
if self.agent_config.instructions != "": if self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions)) messages.append(SystemMessage(content=self.agent_config.instructions))
@ -201,8 +201,8 @@ class ChatAgent(ShieldRunnerMixin):
async def _run_turn( async def _run_turn(
self, self,
request: Union[AgentTurnCreateRequest, AgentTurnResumeRequest], request: AgentTurnCreateRequest | AgentTurnResumeRequest,
turn_id: Optional[str] = None, turn_id: str | None = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
assert request.stream is True, "Non-streaming not supported" assert request.stream is True, "Non-streaming not supported"
@ -321,10 +321,10 @@ class ChatAgent(ShieldRunnerMixin):
self, self,
session_id: str, session_id: str,
turn_id: str, turn_id: str,
input_messages: List[Message], input_messages: list[Message],
sampling_params: SamplingParams, sampling_params: SamplingParams,
stream: bool = False, stream: bool = False,
documents: Optional[List[Document]] = None, documents: list[Document] | None = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
# Doing async generators makes downstream code much simpler and everything amenable to # Doing async generators makes downstream code much simpler and everything amenable to
# streaming. However, it also makes things complicated here because AsyncGenerators cannot # streaming. However, it also makes things complicated here because AsyncGenerators cannot
@ -374,8 +374,8 @@ class ChatAgent(ShieldRunnerMixin):
async def run_multiple_shields_wrapper( async def run_multiple_shields_wrapper(
self, self,
turn_id: str, turn_id: str,
messages: List[Message], messages: list[Message],
shields: List[str], shields: list[str],
touchpoint: str, touchpoint: str,
) -> AsyncGenerator: ) -> AsyncGenerator:
async with tracing.span("run_shields") as span: async with tracing.span("run_shields") as span:
@ -443,10 +443,10 @@ class ChatAgent(ShieldRunnerMixin):
self, self,
session_id: str, session_id: str,
turn_id: str, turn_id: str,
input_messages: List[Message], input_messages: list[Message],
sampling_params: SamplingParams, sampling_params: SamplingParams,
stream: bool = False, stream: bool = False,
documents: Optional[List[Document]] = None, documents: list[Document] | None = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
# if document is passed in a turn, we parse the raw text of the document # if document is passed in a turn, we parse the raw text of the document
# and sent it as a user message # and sent it as a user message
@ -760,7 +760,7 @@ class ChatAgent(ShieldRunnerMixin):
async def _initialize_tools( async def _initialize_tools(
self, self,
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None, toolgroups_for_turn: list[AgentToolGroup] | None = None,
) -> None: ) -> None:
toolgroup_to_args = {} toolgroup_to_args = {}
for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []): for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []):
@ -847,7 +847,7 @@ class ChatAgent(ShieldRunnerMixin):
tool_name_to_args, tool_name_to_args,
) )
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]: def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, str | None]:
"""Parse a toolgroup name into its components. """Parse a toolgroup name into its components.
Args: Args:
@ -921,7 +921,7 @@ async def get_raw_document_text(document: Document) -> str:
def _interpret_content_as_attachment( def _interpret_content_as_attachment(
content: str, content: str,
) -> Optional[Attachment]: ) -> Attachment | None:
match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content) match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content)
if match: if match:
snippet = match.group(1) snippet = match.group(1)

View file

@ -8,7 +8,7 @@ import json
import logging import logging
import shutil import shutil
import uuid import uuid
from typing import AsyncGenerator, List, Optional, Union from collections.abc import AsyncGenerator
from llama_stack.apis.agents import ( from llama_stack.apis.agents import (
Agent, Agent,
@ -142,16 +142,11 @@ class MetaReferenceAgentsImpl(Agents):
self, self,
agent_id: str, agent_id: str,
session_id: str, session_id: str,
messages: List[ messages: list[UserMessage | ToolResponseMessage],
Union[ toolgroups: list[AgentToolGroup] | None = None,
UserMessage, documents: list[Document] | None = None,
ToolResponseMessage, stream: bool | None = False,
] tool_config: ToolConfig | None = None,
],
toolgroups: Optional[List[AgentToolGroup]] = None,
documents: Optional[List[Document]] = None,
stream: Optional[bool] = False,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
request = AgentTurnCreateRequest( request = AgentTurnCreateRequest(
agent_id=agent_id, agent_id=agent_id,
@ -180,8 +175,8 @@ class MetaReferenceAgentsImpl(Agents):
agent_id: str, agent_id: str,
session_id: str, session_id: str,
turn_id: str, turn_id: str,
tool_responses: List[ToolResponse], tool_responses: list[ToolResponse],
stream: Optional[bool] = False, stream: bool | None = False,
) -> AsyncGenerator: ) -> AsyncGenerator:
request = AgentTurnResumeRequest( request = AgentTurnResumeRequest(
agent_id=agent_id, agent_id=agent_id,
@ -219,7 +214,7 @@ class MetaReferenceAgentsImpl(Agents):
self, self,
agent_id: str, agent_id: str,
session_id: str, session_id: str,
turn_ids: Optional[List[str]] = None, turn_ids: list[str] | None = None,
) -> Session: ) -> Session:
agent = await self._get_agent_impl(agent_id) agent = await self._get_agent_impl(agent_id)
session_info = await agent.storage.get_session_info(session_id) session_info = await agent.storage.get_session_info(session_id)
@ -265,13 +260,13 @@ class MetaReferenceAgentsImpl(Agents):
async def create_openai_response( async def create_openai_response(
self, self,
input: Union[str, List[OpenAIResponseInputMessage]], input: str | list[OpenAIResponseInputMessage],
model: str, model: str,
previous_response_id: Optional[str] = None, previous_response_id: str | None = None,
store: Optional[bool] = True, store: bool | None = True,
stream: Optional[bool] = False, stream: bool | None = False,
temperature: Optional[float] = None, temperature: float | None = None,
tools: Optional[List[OpenAIResponseInputTool]] = None, tools: list[OpenAIResponseInputTool] | None = None,
) -> OpenAIResponseObject: ) -> OpenAIResponseObject:
return await self.openai_responses_impl.create_openai_response( return await self.openai_responses_impl.create_openai_response(
input, model, previous_response_id, store, stream, temperature, tools input, model, previous_response_id, store, stream, temperature, tools

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
@ -16,7 +16,7 @@ class MetaReferenceAgentsImplConfig(BaseModel):
persistence_store: KVStoreConfig persistence_store: KVStoreConfig
@classmethod @classmethod
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]: def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
return { return {
"persistence_store": SqliteKVStoreConfig.sample_run_config( "persistence_store": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__, __distro_dir__=__distro_dir__,

View file

@ -6,7 +6,8 @@
import json import json
import uuid import uuid
from typing import AsyncIterator, List, Optional, Union, cast from collections.abc import AsyncIterator
from typing import cast
from openai.types.chat import ChatCompletionToolParam from openai.types.chat import ChatCompletionToolParam
@ -49,15 +50,15 @@ logger = get_logger(name=__name__, category="openai_responses")
OPENAI_RESPONSES_PREFIX = "openai_responses:" OPENAI_RESPONSES_PREFIX = "openai_responses:"
async def _previous_response_to_messages(previous_response: OpenAIResponseObject) -> List[OpenAIMessageParam]: async def _previous_response_to_messages(previous_response: OpenAIResponseObject) -> list[OpenAIMessageParam]:
messages: List[OpenAIMessageParam] = [] messages: list[OpenAIMessageParam] = []
for output_message in previous_response.output: for output_message in previous_response.output:
if isinstance(output_message, OpenAIResponseOutputMessage): if isinstance(output_message, OpenAIResponseOutputMessage):
messages.append(OpenAIAssistantMessageParam(content=output_message.content[0].text)) messages.append(OpenAIAssistantMessageParam(content=output_message.content[0].text))
return messages return messages
async def _openai_choices_to_output_messages(choices: List[OpenAIChoice]) -> List[OpenAIResponseOutputMessage]: async def _openai_choices_to_output_messages(choices: list[OpenAIChoice]) -> list[OpenAIResponseOutputMessage]:
output_messages = [] output_messages = []
for choice in choices: for choice in choices:
output_content = "" output_content = ""
@ -101,22 +102,22 @@ class OpenAIResponsesImpl:
async def create_openai_response( async def create_openai_response(
self, self,
input: Union[str, List[OpenAIResponseInputMessage]], input: str | list[OpenAIResponseInputMessage],
model: str, model: str,
previous_response_id: Optional[str] = None, previous_response_id: str | None = None,
store: Optional[bool] = True, store: bool | None = True,
stream: Optional[bool] = False, stream: bool | None = False,
temperature: Optional[float] = None, temperature: float | None = None,
tools: Optional[List[OpenAIResponseInputTool]] = None, tools: list[OpenAIResponseInputTool] | None = None,
): ):
stream = False if stream is None else stream stream = False if stream is None else stream
messages: List[OpenAIMessageParam] = [] messages: list[OpenAIMessageParam] = []
if previous_response_id: if previous_response_id:
previous_response = await self.get_openai_response(previous_response_id) previous_response = await self.get_openai_response(previous_response_id)
messages.extend(await _previous_response_to_messages(previous_response)) messages.extend(await _previous_response_to_messages(previous_response))
# TODO: refactor this user_content parsing out into a separate method # TODO: refactor this user_content parsing out into a separate method
user_content: Union[str, List[OpenAIChatCompletionContentPartParam]] = "" user_content: str | list[OpenAIChatCompletionContentPartParam] = ""
if isinstance(input, list): if isinstance(input, list):
user_content = [] user_content = []
for user_input in input: for user_input in input:
@ -179,7 +180,7 @@ class OpenAIResponsesImpl:
# dump and reload to map to our pydantic types # dump and reload to map to our pydantic types
chat_response = OpenAIChatCompletion(**chat_response.model_dump()) chat_response = OpenAIChatCompletion(**chat_response.model_dump())
output_messages: List[OpenAIResponseOutput] = [] output_messages: list[OpenAIResponseOutput] = []
if chat_response.choices[0].message.tool_calls: if chat_response.choices[0].message.tool_calls:
output_messages.extend( output_messages.extend(
await self._execute_tool_and_return_final_output(model, stream, chat_response, messages, temperature) await self._execute_tool_and_return_final_output(model, stream, chat_response, messages, temperature)
@ -215,9 +216,9 @@ class OpenAIResponsesImpl:
return response return response
async def _convert_response_tools_to_chat_tools( async def _convert_response_tools_to_chat_tools(
self, tools: List[OpenAIResponseInputTool] self, tools: list[OpenAIResponseInputTool]
) -> List[ChatCompletionToolParam]: ) -> list[ChatCompletionToolParam]:
chat_tools: List[ChatCompletionToolParam] = [] chat_tools: list[ChatCompletionToolParam] = []
for input_tool in tools: for input_tool in tools:
# TODO: Handle other tool types # TODO: Handle other tool types
if input_tool.type == "web_search": if input_tool.type == "web_search":
@ -247,10 +248,10 @@ class OpenAIResponsesImpl:
model_id: str, model_id: str,
stream: bool, stream: bool,
chat_response: OpenAIChatCompletion, chat_response: OpenAIChatCompletion,
messages: List[OpenAIMessageParam], messages: list[OpenAIMessageParam],
temperature: float, temperature: float,
) -> List[OpenAIResponseOutput]: ) -> list[OpenAIResponseOutput]:
output_messages: List[OpenAIResponseOutput] = [] output_messages: list[OpenAIResponseOutput] = []
choice = chat_response.choices[0] choice = chat_response.choices[0]
# If the choice is not an assistant message, we don't need to execute any tools # If the choice is not an assistant message, we don't need to execute any tools
@ -314,7 +315,7 @@ class OpenAIResponsesImpl:
async def _execute_tool_call( async def _execute_tool_call(
self, self,
function: OpenAIChatCompletionToolCallFunction, function: OpenAIChatCompletionToolCallFunction,
) -> Optional[ToolInvocationResult]: ) -> ToolInvocationResult | None:
if not function.name: if not function.name:
return None return None
function_args = json.loads(function.arguments) if function.arguments else {} function_args = json.loads(function.arguments) if function.arguments else {}

View file

@ -8,7 +8,6 @@ import json
import logging import logging
import uuid import uuid
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel
@ -25,9 +24,9 @@ class AgentSessionInfo(BaseModel):
session_id: str session_id: str
session_name: str session_name: str
# TODO: is this used anywhere? # TODO: is this used anywhere?
vector_db_id: Optional[str] = None vector_db_id: str | None = None
started_at: datetime started_at: datetime
access_attributes: Optional[AccessAttributes] = None access_attributes: AccessAttributes | None = None
class AgentPersistence: class AgentPersistence:
@ -55,7 +54,7 @@ class AgentPersistence:
) )
return session_id return session_id
async def get_session_info(self, session_id: str) -> Optional[AgentSessionInfo]: async def get_session_info(self, session_id: str) -> AgentSessionInfo | None:
value = await self.kvstore.get( value = await self.kvstore.get(
key=f"session:{self.agent_id}:{session_id}", key=f"session:{self.agent_id}:{session_id}",
) )
@ -78,7 +77,7 @@ class AgentPersistence:
return check_access(session_info.session_id, session_info.access_attributes, get_auth_attributes()) return check_access(session_info.session_id, session_info.access_attributes, get_auth_attributes())
async def get_session_if_accessible(self, session_id: str) -> Optional[AgentSessionInfo]: async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
"""Get session info if the user has access to it. For internal use by sub-session methods.""" """Get session info if the user has access to it. For internal use by sub-session methods."""
session_info = await self.get_session_info(session_id) session_info = await self.get_session_info(session_id)
if not session_info: if not session_info:
@ -106,7 +105,7 @@ class AgentPersistence:
value=turn.model_dump_json(), value=turn.model_dump_json(),
) )
async def get_session_turns(self, session_id: str) -> List[Turn]: async def get_session_turns(self, session_id: str) -> list[Turn]:
if not await self.get_session_if_accessible(session_id): if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied") raise ValueError(f"Session {session_id} not found or access denied")
@ -125,7 +124,7 @@ class AgentPersistence:
turns.sort(key=lambda x: (x.completed_at or datetime.min)) turns.sort(key=lambda x: (x.completed_at or datetime.min))
return turns return turns
async def get_session_turn(self, session_id: str, turn_id: str) -> Optional[Turn]: async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None:
if not await self.get_session_if_accessible(session_id): if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied") raise ValueError(f"Session {session_id} not found or access denied")
@ -145,7 +144,7 @@ class AgentPersistence:
value=step.model_dump_json(), value=step.model_dump_json(),
) )
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]: async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> ToolExecutionStep | None:
if not await self.get_session_if_accessible(session_id): if not await self.get_session_if_accessible(session_id):
return None return None
@ -163,7 +162,7 @@ class AgentPersistence:
value=str(num_infer_iters), value=str(num_infer_iters),
) )
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]: async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> int | None:
if not await self.get_session_if_accessible(session_id): if not await self.get_session_if_accessible(session_id):
return None return None

View file

@ -6,7 +6,6 @@
import asyncio import asyncio
import logging import logging
from typing import List
from llama_stack.apis.inference import Message from llama_stack.apis.inference import Message
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
@ -25,14 +24,14 @@ class ShieldRunnerMixin:
def __init__( def __init__(
self, self,
safety_api: Safety, safety_api: Safety,
input_shields: List[str] = None, input_shields: list[str] = None,
output_shields: List[str] = None, output_shields: list[str] = None,
): ):
self.safety_api = safety_api self.safety_api = safety_api
self.input_shields = input_shields self.input_shields = input_shields
self.output_shields = output_shields self.output_shields = output_shields
async def run_multiple_shields(self, messages: List[Message], identifiers: List[str]) -> None: async def run_multiple_shields(self, messages: list[Message], identifiers: list[str]) -> None:
async def run_shield_with_span(identifier: str): async def run_shield_with_span(identifier: str):
async with tracing.span(f"run_shield_{identifier}"): async with tracing.span(f"run_shield_{identifier}"):
return await self.safety_api.run_shield( return await self.safety_api.run_shield(

Some files were not shown because too many files have changed in this diff Show more