chore: enable pyupgrade fixes (#1806)

# What does this PR do?

The goal of this PR is code base modernization.

Schema reflection code needed a minor adjustment to handle UnionTypes
and collections.abc.AsyncIterator. (Both are preferred for latest Python
releases.)

Note to reviewers: almost all changes here are automatically generated
by pyupgrade. Some additional unused imports were cleaned up. The only
change worth of note can be found under `docs/openapi_generator` and
`llama_stack/strong_typing/schema.py` where reflection code was updated
to deal with "newer" types.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-05-01 17:23:50 -04:00 committed by GitHub
parent ffe3d0b2cd
commit 9e6561a1ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
319 changed files with 2843 additions and 3033 deletions

View file

@ -4,20 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncIterator
from datetime import datetime
from enum import Enum
from typing import (
Annotated,
Any,
AsyncIterator,
Dict,
List,
Literal,
Optional,
Protocol,
Union,
runtime_checkable,
)
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, ConfigDict, Field
@ -79,8 +69,8 @@ class StepCommon(BaseModel):
turn_id: str
step_id: str
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
started_at: datetime | None = None
completed_at: datetime | None = None
class StepType(Enum):
@ -120,8 +110,8 @@ class ToolExecutionStep(StepCommon):
"""
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
tool_calls: List[ToolCall]
tool_responses: List[ToolResponse]
tool_calls: list[ToolCall]
tool_responses: list[ToolResponse]
@json_schema_type
@ -132,7 +122,7 @@ class ShieldCallStep(StepCommon):
"""
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
violation: Optional[SafetyViolation]
violation: SafetyViolation | None
@json_schema_type
@ -150,12 +140,7 @@ class MemoryRetrievalStep(StepCommon):
Step = Annotated[
Union[
InferenceStep,
ToolExecutionStep,
ShieldCallStep,
MemoryRetrievalStep,
],
InferenceStep | ToolExecutionStep | ShieldCallStep | MemoryRetrievalStep,
Field(discriminator="step_type"),
]
@ -166,18 +151,13 @@ class Turn(BaseModel):
turn_id: str
session_id: str
input_messages: List[
Union[
UserMessage,
ToolResponseMessage,
]
]
steps: List[Step]
input_messages: list[UserMessage | ToolResponseMessage]
steps: list[Step]
output_message: CompletionMessage
output_attachments: Optional[List[Attachment]] = Field(default_factory=list)
output_attachments: list[Attachment] | None = Field(default_factory=list)
started_at: datetime
completed_at: Optional[datetime] = None
completed_at: datetime | None = None
@json_schema_type
@ -186,34 +166,31 @@ class Session(BaseModel):
session_id: str
session_name: str
turns: List[Turn]
turns: list[Turn]
started_at: datetime
class AgentToolGroupWithArgs(BaseModel):
name: str
args: Dict[str, Any]
args: dict[str, Any]
AgentToolGroup = Union[
str,
AgentToolGroupWithArgs,
]
AgentToolGroup = str | AgentToolGroupWithArgs
register_schema(AgentToolGroup, name="AgentTool")
class AgentConfigCommon(BaseModel):
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
input_shields: Optional[List[str]] = Field(default_factory=list)
output_shields: Optional[List[str]] = Field(default_factory=list)
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=None, deprecated="use tool_config instead")
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None, deprecated="use tool_config instead")
tool_config: Optional[ToolConfig] = Field(default=None)
input_shields: list[str] | None = Field(default_factory=list)
output_shields: list[str] | None = Field(default_factory=list)
toolgroups: list[AgentToolGroup] | None = Field(default_factory=list)
client_tools: list[ToolDef] | None = Field(default_factory=list)
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead")
tool_config: ToolConfig | None = Field(default=None)
max_infer_iters: Optional[int] = 10
max_infer_iters: int | None = 10
def model_post_init(self, __context):
if self.tool_config:
@ -243,9 +220,9 @@ class AgentConfig(AgentConfigCommon):
model: str
instructions: str
name: Optional[str] = None
enable_session_persistence: Optional[bool] = False
response_format: Optional[ResponseFormat] = None
name: str | None = None
enable_session_persistence: bool | None = False
response_format: ResponseFormat | None = None
@json_schema_type
@ -257,16 +234,16 @@ class Agent(BaseModel):
@json_schema_type
class ListAgentsResponse(BaseModel):
data: List[Agent]
data: list[Agent]
@json_schema_type
class ListAgentSessionsResponse(BaseModel):
data: List[Session]
data: list[Session]
class AgentConfigOverridablePerTurn(AgentConfigCommon):
instructions: Optional[str] = None
instructions: str | None = None
class AgentTurnResponseEventType(Enum):
@ -284,7 +261,7 @@ class AgentTurnResponseStepStartPayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.step_start.value] = AgentTurnResponseEventType.step_start.value
step_type: StepType
step_id: str
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
metadata: dict[str, Any] | None = Field(default_factory=dict)
@json_schema_type
@ -327,14 +304,12 @@ class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
AgentTurnResponseEventPayload = Annotated[
Union[
AgentTurnResponseStepStartPayload,
AgentTurnResponseStepProgressPayload,
AgentTurnResponseStepCompletePayload,
AgentTurnResponseTurnStartPayload,
AgentTurnResponseTurnCompletePayload,
AgentTurnResponseTurnAwaitingInputPayload,
],
AgentTurnResponseStepStartPayload
| AgentTurnResponseStepProgressPayload
| AgentTurnResponseStepCompletePayload
| AgentTurnResponseTurnStartPayload
| AgentTurnResponseTurnCompletePayload
| AgentTurnResponseTurnAwaitingInputPayload,
Field(discriminator="event_type"),
]
register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
@ -363,18 +338,13 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
# TODO: figure out how we can simplify this and make why
# ToolResponseMessage needs to be here (it is function call
# execution from outside the system)
messages: List[
Union[
UserMessage,
ToolResponseMessage,
]
]
messages: list[UserMessage | ToolResponseMessage]
documents: Optional[List[Document]] = None
toolgroups: Optional[List[AgentToolGroup]] = None
documents: list[Document] | None = None
toolgroups: list[AgentToolGroup] | None = None
stream: Optional[bool] = False
tool_config: Optional[ToolConfig] = None
stream: bool | None = False
tool_config: ToolConfig | None = None
@json_schema_type
@ -382,8 +352,8 @@ class AgentTurnResumeRequest(BaseModel):
agent_id: str
session_id: str
turn_id: str
tool_responses: List[ToolResponse]
stream: Optional[bool] = False
tool_responses: list[ToolResponse]
stream: bool | None = False
@json_schema_type
@ -429,17 +399,12 @@ class Agents(Protocol):
self,
agent_id: str,
session_id: str,
messages: List[
Union[
UserMessage,
ToolResponseMessage,
]
],
stream: Optional[bool] = False,
documents: Optional[List[Document]] = None,
toolgroups: Optional[List[AgentToolGroup]] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
messages: list[UserMessage | ToolResponseMessage],
stream: bool | None = False,
documents: list[Document] | None = None,
toolgroups: list[AgentToolGroup] | None = None,
tool_config: ToolConfig | None = None,
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
"""Create a new turn for an agent.
:param agent_id: The ID of the agent to create the turn for.
@ -463,9 +428,9 @@ class Agents(Protocol):
agent_id: str,
session_id: str,
turn_id: str,
tool_responses: List[ToolResponse],
stream: Optional[bool] = False,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
tool_responses: list[ToolResponse],
stream: bool | None = False,
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
"""Resume an agent turn with executed tool call responses.
When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.
@ -538,7 +503,7 @@ class Agents(Protocol):
self,
session_id: str,
agent_id: str,
turn_ids: Optional[List[str]] = None,
turn_ids: list[str] | None = None,
) -> Session:
"""Retrieve an agent session by its ID.
@ -623,14 +588,14 @@ class Agents(Protocol):
@webmethod(route="/openai/v1/responses", method="POST")
async def create_openai_response(
self,
input: Union[str, List[OpenAIResponseInputMessage]],
input: str | list[OpenAIResponseInputMessage],
model: str,
previous_response_id: Optional[str] = None,
store: Optional[bool] = True,
stream: Optional[bool] = False,
temperature: Optional[float] = None,
tools: Optional[List[OpenAIResponseInputTool]] = None,
) -> Union[OpenAIResponseObject, AsyncIterator[OpenAIResponseObjectStream]]:
previous_response_id: str | None = None,
store: bool | None = True,
stream: bool | None = False,
temperature: float | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
"""Create a new OpenAI response.
:param input: Input message(s) to create the response.

View file

@ -4,10 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Literal, Optional, Union
from typing import Annotated, Literal
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.schema_utils import json_schema_type, register_schema
@ -25,7 +24,7 @@ class OpenAIResponseOutputMessageContentOutputText(BaseModel):
OpenAIResponseOutputMessageContent = Annotated[
Union[OpenAIResponseOutputMessageContentOutputText,],
OpenAIResponseOutputMessageContentOutputText,
Field(discriminator="type"),
]
register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMessageContent")
@ -34,7 +33,7 @@ register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMe
@json_schema_type
class OpenAIResponseOutputMessage(BaseModel):
id: str
content: List[OpenAIResponseOutputMessageContent]
content: list[OpenAIResponseOutputMessageContent]
role: Literal["assistant"] = "assistant"
status: str
type: Literal["message"] = "message"
@ -48,10 +47,7 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
OpenAIResponseOutput = Annotated[
Union[
OpenAIResponseOutputMessage,
OpenAIResponseOutputMessageWebSearchToolCall,
],
OpenAIResponseOutputMessage | OpenAIResponseOutputMessageWebSearchToolCall,
Field(discriminator="type"),
]
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
@ -60,18 +56,18 @@ register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
@json_schema_type
class OpenAIResponseObject(BaseModel):
created_at: int
error: Optional[OpenAIResponseError] = None
error: OpenAIResponseError | None = None
id: str
model: str
object: Literal["response"] = "response"
output: List[OpenAIResponseOutput]
output: list[OpenAIResponseOutput]
parallel_tool_calls: bool = False
previous_response_id: Optional[str] = None
previous_response_id: str | None = None
status: str
temperature: Optional[float] = None
top_p: Optional[float] = None
truncation: Optional[str] = None
user: Optional[str] = None
temperature: float | None = None
top_p: float | None = None
truncation: str | None = None
user: str | None = None
@json_schema_type
@ -87,10 +83,7 @@ class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
OpenAIResponseObjectStream = Annotated[
Union[
OpenAIResponseObjectStreamResponseCreated,
OpenAIResponseObjectStreamResponseCompleted,
],
OpenAIResponseObjectStreamResponseCreated | OpenAIResponseObjectStreamResponseCompleted,
Field(discriminator="type"),
]
register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream")
@ -107,12 +100,12 @@ class OpenAIResponseInputMessageContentImage(BaseModel):
detail: Literal["low"] | Literal["high"] | Literal["auto"] = "auto"
type: Literal["input_image"] = "input_image"
# TODO: handle file_id
image_url: Optional[str] = None
image_url: str | None = None
# TODO: handle file content types
OpenAIResponseInputMessageContent = Annotated[
Union[OpenAIResponseInputMessageContentText, OpenAIResponseInputMessageContentImage],
OpenAIResponseInputMessageContentText | OpenAIResponseInputMessageContentImage,
Field(discriminator="type"),
]
register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent")
@ -120,21 +113,21 @@ register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMess
@json_schema_type
class OpenAIResponseInputMessage(BaseModel):
content: Union[str, List[OpenAIResponseInputMessageContent]]
content: str | list[OpenAIResponseInputMessageContent]
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
type: Optional[Literal["message"]] = "message"
type: Literal["message"] | None = "message"
@json_schema_type
class OpenAIResponseInputToolWebSearch(BaseModel):
type: Literal["web_search"] | Literal["web_search_preview_2025_03_11"] = "web_search"
# TODO: actually use search_context_size somewhere...
search_context_size: Optional[str] = Field(default="medium", pattern="^low|medium|high$")
search_context_size: str | None = Field(default="medium", pattern="^low|medium|high$")
# TODO: add user_location
OpenAIResponseInputTool = Annotated[
Union[OpenAIResponseInputToolWebSearch,],
OpenAIResponseInputToolWebSearch,
Field(discriminator="type"),
]
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional, Protocol, runtime_checkable
from typing import Protocol, runtime_checkable
from llama_stack.apis.common.job_types import Job
from llama_stack.apis.inference import (
@ -34,22 +34,22 @@ class BatchInference(Protocol):
async def completion(
self,
model: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> Job: ...
@webmethod(route="/batch-inference/chat-completion", method="POST")
async def chat_completion(
self,
model: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> Job: ...

View file

@ -3,7 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from typing import Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, Field
@ -13,8 +13,8 @@ from llama_stack.schema_utils import json_schema_type, webmethod
class CommonBenchmarkFields(BaseModel):
dataset_id: str
scoring_functions: List[str]
metadata: Dict[str, Any] = Field(
scoring_functions: list[str]
metadata: dict[str, Any] = Field(
default_factory=dict,
description="Metadata for this evaluation task",
)
@ -35,12 +35,12 @@ class Benchmark(CommonBenchmarkFields, Resource):
class BenchmarkInput(CommonBenchmarkFields, BaseModel):
benchmark_id: str
provider_id: Optional[str] = None
provider_benchmark_id: Optional[str] = None
provider_id: str | None = None
provider_benchmark_id: str | None = None
class ListBenchmarksResponse(BaseModel):
data: List[Benchmark]
data: list[Benchmark]
@runtime_checkable
@ -59,8 +59,8 @@ class Benchmarks(Protocol):
self,
benchmark_id: str,
dataset_id: str,
scoring_functions: List[str],
provider_benchmark_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
scoring_functions: list[str],
provider_benchmark_id: str | None = None,
provider_id: str | None = None,
metadata: dict[str, Any] | None = None,
) -> None: ...

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from enum import Enum
from typing import Annotated, List, Literal, Optional, Union
from typing import Annotated, Literal
from pydantic import BaseModel, Field, model_validator
@ -26,9 +26,9 @@ class _URLOrData(BaseModel):
:param data: base64 encoded image data as string
"""
url: Optional[URL] = None
url: URL | None = None
# data is a base64 encoded string, hint with contentEncoding=base64
data: Optional[str] = Field(contentEncoding="base64", default=None)
data: str | None = Field(contentEncoding="base64", default=None)
@model_validator(mode="before")
@classmethod
@ -64,13 +64,13 @@ class TextContentItem(BaseModel):
# other modalities can be added here
InterleavedContentItem = Annotated[
Union[ImageContentItem, TextContentItem],
ImageContentItem | TextContentItem,
Field(discriminator="type"),
]
register_schema(InterleavedContentItem, name="InterleavedContentItem")
# accept a single "str" as a special case since it is common
InterleavedContent = Union[str, InterleavedContentItem, List[InterleavedContentItem]]
InterleavedContent = str | InterleavedContentItem | list[InterleavedContentItem]
register_schema(InterleavedContent, name="InterleavedContent")
@ -100,13 +100,13 @@ class ToolCallDelta(BaseModel):
# you either send an in-progress tool call so the client can stream a long
# code generation or you send the final parsed tool call at the end of the
# stream
tool_call: Union[str, ToolCall]
tool_call: str | ToolCall
parse_status: ToolCallParseStatus
# streaming completions send a stream of ContentDeltas
ContentDelta = Annotated[
Union[TextDelta, ImageDelta, ToolCallDelta],
TextDelta | ImageDelta | ToolCallDelta,
Field(discriminator="type"),
]
register_schema(ContentDelta, name="ContentDelta")

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel
@ -25,6 +25,6 @@ class RestAPIMethod(Enum):
class RestAPIExecutionConfig(BaseModel):
url: URL
method: RestAPIMethod
params: Optional[Dict[str, Any]] = None
headers: Optional[Dict[str, Any]] = None
body: Optional[Dict[str, Any]] = None
params: dict[str, Any] | None = None
headers: dict[str, Any] | None = None
body: dict[str, Any] | None = None

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
from typing import Any
from pydantic import BaseModel
@ -19,5 +19,5 @@ class PaginatedResponse(BaseModel):
:param has_more: Whether there are more items available after this set
"""
data: List[Dict[str, Any]]
data: list[dict[str, Any]]
has_more: bool

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
from datetime import datetime
from typing import Optional
from pydantic import BaseModel
@ -27,4 +26,4 @@ class Checkpoint(BaseModel):
epoch: int
post_training_job_id: str
path: str
training_metrics: Optional[PostTrainingMetric] = None
training_metrics: PostTrainingMetric | None = None

View file

@ -4,10 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Literal, Union
from typing import Annotated, Literal
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.schema_utils import json_schema_type, register_schema
@ -73,18 +72,16 @@ class DialogType(BaseModel):
ParamType = Annotated[
Union[
StringType,
NumberType,
BooleanType,
ArrayType,
ObjectType,
JsonType,
UnionType,
ChatCompletionInputType,
CompletionInputType,
AgentTurnInputType,
],
StringType
| NumberType
| BooleanType
| ArrayType
| ObjectType
| JsonType
| UnionType
| ChatCompletionInputType
| CompletionInputType
| AgentTurnInputType,
Field(discriminator="type"),
]
register_schema(ParamType, name="ParamType")

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from typing import Any, Protocol, runtime_checkable
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasets import Dataset
@ -24,8 +24,8 @@ class DatasetIO(Protocol):
async def iterrows(
self,
dataset_id: str,
start_index: Optional[int] = None,
limit: Optional[int] = None,
start_index: int | None = None,
limit: int | None = None,
) -> PaginatedResponse:
"""Get a paginated list of rows from a dataset.
@ -44,4 +44,4 @@ class DatasetIO(Protocol):
...
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ...
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: ...

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from enum import Enum
from typing import Annotated, Any, Dict, List, Literal, Optional, Protocol, Union
from typing import Annotated, Any, Literal, Protocol
from pydantic import BaseModel, Field
@ -81,11 +81,11 @@ class RowsDataSource(BaseModel):
"""
type: Literal["rows"] = "rows"
rows: List[Dict[str, Any]]
rows: list[dict[str, Any]]
DataSource = Annotated[
Union[URIDataSource, RowsDataSource],
URIDataSource | RowsDataSource,
Field(discriminator="type"),
]
register_schema(DataSource, name="DataSource")
@ -98,7 +98,7 @@ class CommonDatasetFields(BaseModel):
purpose: DatasetPurpose
source: DataSource
metadata: Dict[str, Any] = Field(
metadata: dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this dataset",
)
@ -122,7 +122,7 @@ class DatasetInput(CommonDatasetFields, BaseModel):
class ListDatasetsResponse(BaseModel):
data: List[Dataset]
data: list[Dataset]
class Datasets(Protocol):
@ -131,8 +131,8 @@ class Datasets(Protocol):
self,
purpose: DatasetPurpose,
source: DataSource,
metadata: Optional[Dict[str, Any]] = None,
dataset_id: Optional[str] = None,
metadata: dict[str, Any] | None = None,
dataset_id: str | None = None,
) -> Dataset:
"""
Register a new dataset.

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
from enum import Enum
from typing import Optional
from pydantic import BaseModel
@ -54,4 +53,4 @@ class Error(BaseModel):
status: int
title: str
detail: str
instance: Optional[str] = None
instance: str | None = None

View file

@ -4,10 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from typing import Annotated, Any, Literal, Protocol
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.agents import AgentConfig
from llama_stack.apis.common.job_types import Job
@ -29,7 +28,7 @@ class ModelCandidate(BaseModel):
type: Literal["model"] = "model"
model: str
sampling_params: SamplingParams
system_message: Optional[SystemMessage] = None
system_message: SystemMessage | None = None
@json_schema_type
@ -43,7 +42,7 @@ class AgentCandidate(BaseModel):
config: AgentConfig
EvalCandidate = Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")]
EvalCandidate = Annotated[ModelCandidate | AgentCandidate, Field(discriminator="type")]
register_schema(EvalCandidate, name="EvalCandidate")
@ -57,11 +56,11 @@ class BenchmarkConfig(BaseModel):
"""
eval_candidate: EvalCandidate
scoring_params: Dict[str, ScoringFnParams] = Field(
scoring_params: dict[str, ScoringFnParams] = Field(
description="Map between scoring function id and parameters for each scoring function you want to run",
default_factory=dict,
)
num_examples: Optional[int] = Field(
num_examples: int | None = Field(
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
default=None,
)
@ -76,9 +75,9 @@ class EvaluateResponse(BaseModel):
:param scores: The scores from the evaluation.
"""
generations: List[Dict[str, Any]]
generations: list[dict[str, Any]]
# each key in the dict is a scoring function name
scores: Dict[str, ScoringResult]
scores: dict[str, ScoringResult]
class Eval(Protocol):
@ -101,8 +100,8 @@ class Eval(Protocol):
async def evaluate_rows(
self,
benchmark_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
input_rows: list[dict[str, Any]],
scoring_functions: list[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
"""Evaluate a list of rows on a benchmark.

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional, Protocol, runtime_checkable
from typing import Protocol, runtime_checkable
from pydantic import BaseModel
@ -42,7 +42,7 @@ class ListBucketResponse(BaseModel):
:param data: List of FileResponse entries
"""
data: List[BucketResponse]
data: list[BucketResponse]
@json_schema_type
@ -74,7 +74,7 @@ class ListFileResponse(BaseModel):
:param data: List of FileResponse entries
"""
data: List[FileResponse]
data: list[FileResponse]
@runtime_checkable
@ -102,7 +102,7 @@ class Files(Protocol):
async def upload_content_to_session(
self,
upload_id: str,
) -> Optional[FileResponse]:
) -> FileResponse | None:
"""
Upload file content to an existing upload session.
On the server, request body will have the raw bytes that are uploaded.

View file

@ -4,21 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncIterator
from enum import Enum
from typing import (
Annotated,
Any,
AsyncIterator,
Dict,
List,
Literal,
Optional,
Protocol,
Union,
runtime_checkable,
)
from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated, TypedDict
from typing_extensions import TypedDict
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
from llama_stack.apis.models import Model
@ -47,8 +44,8 @@ class GreedySamplingStrategy(BaseModel):
@json_schema_type
class TopPSamplingStrategy(BaseModel):
type: Literal["top_p"] = "top_p"
temperature: Optional[float] = Field(..., gt=0.0)
top_p: Optional[float] = 0.95
temperature: float | None = Field(..., gt=0.0)
top_p: float | None = 0.95
@json_schema_type
@ -58,7 +55,7 @@ class TopKSamplingStrategy(BaseModel):
SamplingStrategy = Annotated[
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
GreedySamplingStrategy | TopPSamplingStrategy | TopKSamplingStrategy,
Field(discriminator="type"),
]
register_schema(SamplingStrategy, name="SamplingStrategy")
@ -79,9 +76,9 @@ class SamplingParams(BaseModel):
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
max_tokens: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0
stop: Optional[List[str]] = None
max_tokens: int | None = 0
repetition_penalty: float | None = 1.0
stop: list[str] | None = None
class LogProbConfig(BaseModel):
@ -90,7 +87,7 @@ class LogProbConfig(BaseModel):
:param top_k: How many tokens (for each position) to return log probabilities for.
"""
top_k: Optional[int] = 0
top_k: int | None = 0
class QuantizationType(Enum):
@ -125,11 +122,11 @@ class Int4QuantizationConfig(BaseModel):
"""
type: Literal["int4_mixed"] = "int4_mixed"
scheme: Optional[str] = "int4_weight_int8_dynamic_activation"
scheme: str | None = "int4_weight_int8_dynamic_activation"
QuantizationConfig = Annotated[
Union[Bf16QuantizationConfig, Fp8QuantizationConfig, Int4QuantizationConfig],
Bf16QuantizationConfig | Fp8QuantizationConfig | Int4QuantizationConfig,
Field(discriminator="type"),
]
@ -145,7 +142,7 @@ class UserMessage(BaseModel):
role: Literal["user"] = "user"
content: InterleavedContent
context: Optional[InterleavedContent] = None
context: InterleavedContent | None = None
@json_schema_type
@ -190,16 +187,11 @@ class CompletionMessage(BaseModel):
role: Literal["assistant"] = "assistant"
content: InterleavedContent
stop_reason: StopReason
tool_calls: Optional[List[ToolCall]] = Field(default_factory=list)
tool_calls: list[ToolCall] | None = Field(default_factory=list)
Message = Annotated[
Union[
UserMessage,
SystemMessage,
ToolResponseMessage,
CompletionMessage,
],
UserMessage | SystemMessage | ToolResponseMessage | CompletionMessage,
Field(discriminator="role"),
]
register_schema(Message, name="Message")
@ -208,9 +200,9 @@ register_schema(Message, name="Message")
@json_schema_type
class ToolResponse(BaseModel):
call_id: str
tool_name: Union[BuiltinTool, str]
tool_name: BuiltinTool | str
content: InterleavedContent
metadata: Optional[Dict[str, Any]] = None
metadata: dict[str, Any] | None = None
@field_validator("tool_name", mode="before")
@classmethod
@ -243,7 +235,7 @@ class TokenLogProbs(BaseModel):
:param logprobs_by_token: Dictionary mapping tokens to their log probabilities
"""
logprobs_by_token: Dict[str, float]
logprobs_by_token: dict[str, float]
class ChatCompletionResponseEventType(Enum):
@ -271,8 +263,8 @@ class ChatCompletionResponseEvent(BaseModel):
event_type: ChatCompletionResponseEventType
delta: ContentDelta
logprobs: Optional[List[TokenLogProbs]] = None
stop_reason: Optional[StopReason] = None
logprobs: list[TokenLogProbs] | None = None
stop_reason: StopReason | None = None
class ResponseFormatType(Enum):
@ -295,7 +287,7 @@ class JsonSchemaResponseFormat(BaseModel):
"""
type: Literal[ResponseFormatType.json_schema.value] = ResponseFormatType.json_schema.value
json_schema: Dict[str, Any]
json_schema: dict[str, Any]
@json_schema_type
@ -307,11 +299,11 @@ class GrammarResponseFormat(BaseModel):
"""
type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value
bnf: Dict[str, Any]
bnf: dict[str, Any]
ResponseFormat = Annotated[
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
JsonSchemaResponseFormat | GrammarResponseFormat,
Field(discriminator="type"),
]
register_schema(ResponseFormat, name="ResponseFormat")
@ -321,10 +313,10 @@ register_schema(ResponseFormat, name="ResponseFormat")
class CompletionRequest(BaseModel):
model: str
content: InterleavedContent
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
response_format: Optional[ResponseFormat] = None
stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
response_format: ResponseFormat | None = None
stream: bool | None = False
logprobs: LogProbConfig | None = None
@json_schema_type
@ -338,7 +330,7 @@ class CompletionResponse(MetricResponseMixin):
content: str
stop_reason: StopReason
logprobs: Optional[List[TokenLogProbs]] = None
logprobs: list[TokenLogProbs] | None = None
@json_schema_type
@ -351,8 +343,8 @@ class CompletionResponseStreamChunk(MetricResponseMixin):
"""
delta: str
stop_reason: Optional[StopReason] = None
logprobs: Optional[List[TokenLogProbs]] = None
stop_reason: StopReason | None = None
logprobs: list[TokenLogProbs] | None = None
class SystemMessageBehavior(Enum):
@ -383,9 +375,9 @@ class ToolConfig(BaseModel):
'{{function_definitions}}' to indicate where the function definitions should be inserted.
"""
tool_choice: Optional[ToolChoice | str] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
system_message_behavior: Optional[SystemMessageBehavior] = Field(default=SystemMessageBehavior.append)
tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto)
tool_prompt_format: ToolPromptFormat | None = Field(default=None)
system_message_behavior: SystemMessageBehavior | None = Field(default=SystemMessageBehavior.append)
def model_post_init(self, __context: Any) -> None:
if isinstance(self.tool_choice, str):
@ -399,15 +391,15 @@ class ToolConfig(BaseModel):
@json_schema_type
class ChatCompletionRequest(BaseModel):
model: str
messages: List[Message]
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
messages: list[Message]
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tool_config: Optional[ToolConfig] = Field(default_factory=ToolConfig)
tools: list[ToolDefinition] | None = Field(default_factory=list)
tool_config: ToolConfig | None = Field(default_factory=ToolConfig)
response_format: Optional[ResponseFormat] = None
stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None
response_format: ResponseFormat | None = None
stream: bool | None = False
logprobs: LogProbConfig | None = None
@json_schema_type
@ -429,7 +421,7 @@ class ChatCompletionResponse(MetricResponseMixin):
"""
completion_message: CompletionMessage
logprobs: Optional[List[TokenLogProbs]] = None
logprobs: list[TokenLogProbs] | None = None
@json_schema_type
@ -439,7 +431,7 @@ class EmbeddingsResponse(BaseModel):
:param embeddings: List of embedding vectors, one per input content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
"""
embeddings: List[List[float]]
embeddings: list[list[float]]
@json_schema_type
@ -451,7 +443,7 @@ class OpenAIChatCompletionContentPartTextParam(BaseModel):
@json_schema_type
class OpenAIImageURL(BaseModel):
url: str
detail: Optional[str] = None
detail: str | None = None
@json_schema_type
@ -461,16 +453,13 @@ class OpenAIChatCompletionContentPartImageParam(BaseModel):
OpenAIChatCompletionContentPartParam = Annotated[
Union[
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionContentPartImageParam,
],
OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
Field(discriminator="type"),
]
register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam")
OpenAIChatCompletionMessageContent = Union[str, List[OpenAIChatCompletionContentPartParam]]
OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam]
@json_schema_type
@ -484,7 +473,7 @@ class OpenAIUserMessageParam(BaseModel):
role: Literal["user"] = "user"
content: OpenAIChatCompletionMessageContent
name: Optional[str] = None
name: str | None = None
@json_schema_type
@ -498,21 +487,21 @@ class OpenAISystemMessageParam(BaseModel):
role: Literal["system"] = "system"
content: OpenAIChatCompletionMessageContent
name: Optional[str] = None
name: str | None = None
@json_schema_type
class OpenAIChatCompletionToolCallFunction(BaseModel):
name: Optional[str] = None
arguments: Optional[str] = None
name: str | None = None
arguments: str | None = None
@json_schema_type
class OpenAIChatCompletionToolCall(BaseModel):
index: Optional[int] = None
id: Optional[str] = None
index: int | None = None
id: str | None = None
type: Literal["function"] = "function"
function: Optional[OpenAIChatCompletionToolCallFunction] = None
function: OpenAIChatCompletionToolCallFunction | None = None
@json_schema_type
@ -526,9 +515,9 @@ class OpenAIAssistantMessageParam(BaseModel):
"""
role: Literal["assistant"] = "assistant"
content: Optional[OpenAIChatCompletionMessageContent] = None
name: Optional[str] = None
tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = None
content: OpenAIChatCompletionMessageContent | None = None
name: str | None = None
tool_calls: list[OpenAIChatCompletionToolCall] | None = None
@json_schema_type
@ -556,17 +545,15 @@ class OpenAIDeveloperMessageParam(BaseModel):
role: Literal["developer"] = "developer"
content: OpenAIChatCompletionMessageContent
name: Optional[str] = None
name: str | None = None
OpenAIMessageParam = Annotated[
Union[
OpenAIUserMessageParam,
OpenAISystemMessageParam,
OpenAIAssistantMessageParam,
OpenAIToolMessageParam,
OpenAIDeveloperMessageParam,
],
OpenAIUserMessageParam
| OpenAISystemMessageParam
| OpenAIAssistantMessageParam
| OpenAIToolMessageParam
| OpenAIDeveloperMessageParam,
Field(discriminator="role"),
]
register_schema(OpenAIMessageParam, name="OpenAIMessageParam")
@ -580,14 +567,14 @@ class OpenAIResponseFormatText(BaseModel):
@json_schema_type
class OpenAIJSONSchema(TypedDict, total=False):
name: str
description: Optional[str] = None
strict: Optional[bool] = None
description: str | None = None
strict: bool | None = None
# Pydantic BaseModel cannot be used with a schema param, since it already
# has one. And, we don't want to alias here because then have to handle
# that alias when converting to OpenAI params. So, to support schema,
# we use a TypedDict.
schema: Optional[Dict[str, Any]] = None
schema: dict[str, Any] | None = None
@json_schema_type
@ -602,11 +589,7 @@ class OpenAIResponseFormatJSONObject(BaseModel):
OpenAIResponseFormatParam = Annotated[
Union[
OpenAIResponseFormatText,
OpenAIResponseFormatJSONSchema,
OpenAIResponseFormatJSONObject,
],
OpenAIResponseFormatText | OpenAIResponseFormatJSONSchema | OpenAIResponseFormatJSONObject,
Field(discriminator="type"),
]
register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam")
@ -622,7 +605,7 @@ class OpenAITopLogProb(BaseModel):
"""
token: str
bytes: Optional[List[int]] = None
bytes: list[int] | None = None
logprob: float
@ -637,9 +620,9 @@ class OpenAITokenLogProb(BaseModel):
"""
token: str
bytes: Optional[List[int]] = None
bytes: list[int] | None = None
logprob: float
top_logprobs: List[OpenAITopLogProb]
top_logprobs: list[OpenAITopLogProb]
@json_schema_type
@ -650,8 +633,8 @@ class OpenAIChoiceLogprobs(BaseModel):
:param refusal: (Optional) The log probabilities for the tokens in the message
"""
content: Optional[List[OpenAITokenLogProb]] = None
refusal: Optional[List[OpenAITokenLogProb]] = None
content: list[OpenAITokenLogProb] | None = None
refusal: list[OpenAITokenLogProb] | None = None
@json_schema_type
@ -664,10 +647,10 @@ class OpenAIChoiceDelta(BaseModel):
:param tool_calls: (Optional) The tool calls of the delta
"""
content: Optional[str] = None
refusal: Optional[str] = None
role: Optional[str] = None
tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = None
content: str | None = None
refusal: str | None = None
role: str | None = None
tool_calls: list[OpenAIChatCompletionToolCall] | None = None
@json_schema_type
@ -683,7 +666,7 @@ class OpenAIChunkChoice(BaseModel):
delta: OpenAIChoiceDelta
finish_reason: str
index: int
logprobs: Optional[OpenAIChoiceLogprobs] = None
logprobs: OpenAIChoiceLogprobs | None = None
@json_schema_type
@ -699,7 +682,7 @@ class OpenAIChoice(BaseModel):
message: OpenAIMessageParam
finish_reason: str
index: int
logprobs: Optional[OpenAIChoiceLogprobs] = None
logprobs: OpenAIChoiceLogprobs | None = None
@json_schema_type
@ -714,7 +697,7 @@ class OpenAIChatCompletion(BaseModel):
"""
id: str
choices: List[OpenAIChoice]
choices: list[OpenAIChoice]
object: Literal["chat.completion"] = "chat.completion"
created: int
model: str
@ -732,7 +715,7 @@ class OpenAIChatCompletionChunk(BaseModel):
"""
id: str
choices: List[OpenAIChunkChoice]
choices: list[OpenAIChunkChoice]
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int
model: str
@ -748,10 +731,10 @@ class OpenAICompletionLogprobs(BaseModel):
:top_logprobs: (Optional) The top log probabilities for the tokens
"""
text_offset: Optional[List[int]] = None
token_logprobs: Optional[List[float]] = None
tokens: Optional[List[str]] = None
top_logprobs: Optional[List[Dict[str, float]]] = None
text_offset: list[int] | None = None
token_logprobs: list[float] | None = None
tokens: list[str] | None = None
top_logprobs: list[dict[str, float]] | None = None
@json_schema_type
@ -767,7 +750,7 @@ class OpenAICompletionChoice(BaseModel):
finish_reason: str
text: str
index: int
logprobs: Optional[OpenAIChoiceLogprobs] = None
logprobs: OpenAIChoiceLogprobs | None = None
@json_schema_type
@ -782,7 +765,7 @@ class OpenAICompletion(BaseModel):
"""
id: str
choices: List[OpenAICompletionChoice]
choices: list[OpenAICompletionChoice]
created: int
model: str
object: Literal["text_completion"] = "text_completion"
@ -818,12 +801,12 @@ class EmbeddingTaskType(Enum):
@json_schema_type
class BatchCompletionResponse(BaseModel):
batch: List[CompletionResponse]
batch: list[CompletionResponse]
@json_schema_type
class BatchChatCompletionResponse(BaseModel):
batch: List[ChatCompletionResponse]
batch: list[ChatCompletionResponse]
@runtime_checkable
@ -843,11 +826,11 @@ class Inference(Protocol):
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
"""Generate a completion for the given content using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
@ -865,10 +848,10 @@ class Inference(Protocol):
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> BatchCompletionResponse:
raise NotImplementedError("Batch completion is not implemented")
@ -876,16 +859,16 @@ class Inference(Protocol):
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
"""Generate a chat completion for the given messages using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
@ -916,12 +899,12 @@ class Inference(Protocol):
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_config: ToolConfig | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> BatchChatCompletionResponse:
raise NotImplementedError("Batch chat completion is not implemented")
@ -929,10 +912,10 @@ class Inference(Protocol):
async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
"""Generate embeddings for content pieces using the specified model.
@ -950,25 +933,25 @@ class Inference(Protocol):
self,
# Standard OpenAI completion parameters
model: str,
prompt: Union[str, List[str], List[int], List[List[int]]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
# vLLM-specific parameters
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
) -> OpenAICompletion:
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
@ -996,29 +979,29 @@ class Inference(Protocol):
async def openai_chat_completion(
self,
model: str,
messages: List[OpenAIMessageParam],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Protocol, runtime_checkable
from typing import Protocol, runtime_checkable
from pydantic import BaseModel
@ -16,7 +16,7 @@ from llama_stack.schema_utils import json_schema_type, webmethod
class RouteInfo(BaseModel):
route: str
method: str
provider_types: List[str]
provider_types: list[str]
@json_schema_type
@ -30,7 +30,7 @@ class VersionInfo(BaseModel):
class ListRoutesResponse(BaseModel):
data: List[RouteInfo]
data: list[RouteInfo]
@runtime_checkable

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from typing import Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, ConfigDict, Field
@ -15,7 +15,7 @@ from llama_stack.schema_utils import json_schema_type, webmethod
class CommonModelFields(BaseModel):
metadata: Dict[str, Any] = Field(
metadata: dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this model",
)
@ -46,14 +46,14 @@ class Model(CommonModelFields, Resource):
class ModelInput(CommonModelFields):
model_id: str
provider_id: Optional[str] = None
provider_model_id: Optional[str] = None
model_type: Optional[ModelType] = ModelType.llm
provider_id: str | None = None
provider_model_id: str | None = None
model_type: ModelType | None = ModelType.llm
model_config = ConfigDict(protected_namespaces=())
class ListModelsResponse(BaseModel):
data: List[Model]
data: list[Model]
@json_schema_type
@ -73,7 +73,7 @@ class OpenAIModel(BaseModel):
class OpenAIListModelsResponse(BaseModel):
data: List[OpenAIModel]
data: list[OpenAIModel]
@runtime_checkable
@ -95,10 +95,10 @@ class Models(Protocol):
async def register_model(
self,
model_id: str,
provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
provider_model_id: str | None = None,
provider_id: str | None = None,
metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None,
) -> Model: ...
@webmethod(route="/models/{model_id:path}", method="DELETE")

View file

@ -6,10 +6,9 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from typing import Annotated, Any, Literal, Protocol
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.job_types import JobStatus
@ -36,9 +35,9 @@ class DataConfig(BaseModel):
batch_size: int
shuffle: bool
data_format: DatasetFormat
validation_dataset_id: Optional[str] = None
packed: Optional[bool] = False
train_on_input: Optional[bool] = False
validation_dataset_id: str | None = None
packed: bool | None = False
train_on_input: bool | None = False
@json_schema_type
@ -51,10 +50,10 @@ class OptimizerConfig(BaseModel):
@json_schema_type
class EfficiencyConfig(BaseModel):
enable_activation_checkpointing: Optional[bool] = False
enable_activation_offloading: Optional[bool] = False
memory_efficient_fsdp_wrap: Optional[bool] = False
fsdp_cpu_offload: Optional[bool] = False
enable_activation_checkpointing: bool | None = False
enable_activation_offloading: bool | None = False
memory_efficient_fsdp_wrap: bool | None = False
fsdp_cpu_offload: bool | None = False
@json_schema_type
@ -62,23 +61,23 @@ class TrainingConfig(BaseModel):
n_epochs: int
max_steps_per_epoch: int = 1
gradient_accumulation_steps: int = 1
max_validation_steps: Optional[int] = 1
data_config: Optional[DataConfig] = None
optimizer_config: Optional[OptimizerConfig] = None
efficiency_config: Optional[EfficiencyConfig] = None
dtype: Optional[str] = "bf16"
max_validation_steps: int | None = 1
data_config: DataConfig | None = None
optimizer_config: OptimizerConfig | None = None
efficiency_config: EfficiencyConfig | None = None
dtype: str | None = "bf16"
@json_schema_type
class LoraFinetuningConfig(BaseModel):
type: Literal["LoRA"] = "LoRA"
lora_attn_modules: List[str]
lora_attn_modules: list[str]
apply_lora_to_mlp: bool
apply_lora_to_output: bool
rank: int
alpha: int
use_dora: Optional[bool] = False
quantize_base: Optional[bool] = False
use_dora: bool | None = False
quantize_base: bool | None = False
@json_schema_type
@ -88,7 +87,7 @@ class QATFinetuningConfig(BaseModel):
group_size: int
AlgorithmConfig = Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")]
AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")]
register_schema(AlgorithmConfig, name="AlgorithmConfig")
@ -97,7 +96,7 @@ class PostTrainingJobLogStream(BaseModel):
"""Stream of logs from a finetuning job."""
job_uuid: str
log_lines: List[str]
log_lines: list[str]
@json_schema_type
@ -131,8 +130,8 @@ class PostTrainingRLHFRequest(BaseModel):
training_config: TrainingConfig
# TODO: define these
hyperparam_search_config: Dict[str, Any]
logger_config: Dict[str, Any]
hyperparam_search_config: dict[str, Any]
logger_config: dict[str, Any]
class PostTrainingJob(BaseModel):
@ -146,17 +145,17 @@ class PostTrainingJobStatusResponse(BaseModel):
job_uuid: str
status: JobStatus
scheduled_at: Optional[datetime] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
scheduled_at: datetime | None = None
started_at: datetime | None = None
completed_at: datetime | None = None
resources_allocated: Optional[Dict[str, Any]] = None
resources_allocated: dict[str, Any] | None = None
checkpoints: List[Checkpoint] = Field(default_factory=list)
checkpoints: list[Checkpoint] = Field(default_factory=list)
class ListPostTrainingJobsResponse(BaseModel):
data: List[PostTrainingJob]
data: list[PostTrainingJob]
@json_schema_type
@ -164,7 +163,7 @@ class PostTrainingJobArtifactsResponse(BaseModel):
"""Artifacts of a finetuning job."""
job_uuid: str
checkpoints: List[Checkpoint] = Field(default_factory=list)
checkpoints: list[Checkpoint] = Field(default_factory=list)
# TODO(ashwin): metrics, evals
@ -175,14 +174,14 @@ class PostTraining(Protocol):
self,
job_uuid: str,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
model: Optional[str] = Field(
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
model: str | None = Field(
default=None,
description="Model descriptor for training if not in provider config`",
),
checkpoint_dir: Optional[str] = None,
algorithm_config: Optional[AlgorithmConfig] = None,
checkpoint_dir: str | None = None,
algorithm_config: AlgorithmConfig | None = None,
) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize", method="POST")
@ -192,8 +191,8 @@ class PostTraining(Protocol):
finetuned_model: str,
algorithm_config: DPOAlignmentConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
) -> PostTrainingJob: ...
@webmethod(route="/post-training/jobs", method="GET")

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Protocol, runtime_checkable
from typing import Any, Protocol, runtime_checkable
from pydantic import BaseModel
@ -17,12 +17,12 @@ class ProviderInfo(BaseModel):
api: str
provider_id: str
provider_type: str
config: Dict[str, Any]
config: dict[str, Any]
health: HealthResponse
class ListProvidersResponse(BaseModel):
data: List[ProviderInfo]
data: list[ProviderInfo]
@runtime_checkable

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from typing import Any, Protocol, runtime_checkable
from pydantic import BaseModel, Field
@ -27,16 +27,16 @@ class SafetyViolation(BaseModel):
violation_level: ViolationLevel
# what message should you convey to the user
user_message: Optional[str] = None
user_message: str | None = None
# additional metadata (including specific violation codes) more for
# debugging, telemetry
metadata: Dict[str, Any] = Field(default_factory=dict)
metadata: dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class RunShieldResponse(BaseModel):
violation: Optional[SafetyViolation] = None
violation: SafetyViolation | None = None
class ShieldStore(Protocol):
@ -52,6 +52,6 @@ class Safety(Protocol):
async def run_shield(
self,
shield_id: str,
messages: List[Message],
params: Dict[str, Any] = None,
messages: list[Message],
params: dict[str, Any] = None,
) -> RunShieldResponse: ...

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from typing import Any, Protocol, runtime_checkable
from pydantic import BaseModel
@ -12,7 +12,7 @@ from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.schema_utils import json_schema_type, webmethod
# mapping of metric to value
ScoringResultRow = Dict[str, Any]
ScoringResultRow = dict[str, Any]
@json_schema_type
@ -24,15 +24,15 @@ class ScoringResult(BaseModel):
:param aggregated_results: Map of metric name to aggregated value
"""
score_rows: List[ScoringResultRow]
score_rows: list[ScoringResultRow]
# aggregated metrics to value
aggregated_results: Dict[str, Any]
aggregated_results: dict[str, Any]
@json_schema_type
class ScoreBatchResponse(BaseModel):
dataset_id: Optional[str] = None
results: Dict[str, ScoringResult]
dataset_id: str | None = None
results: dict[str, ScoringResult]
@json_schema_type
@ -44,7 +44,7 @@ class ScoreResponse(BaseModel):
"""
# each key in the dict is a scoring function name
results: Dict[str, ScoringResult]
results: dict[str, ScoringResult]
class ScoringFunctionStore(Protocol):
@ -59,15 +59,15 @@ class Scoring(Protocol):
async def score_batch(
self,
dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]],
scoring_functions: dict[str, ScoringFnParams | None],
save_results_dataset: bool = False,
) -> ScoreBatchResponse: ...
@webmethod(route="/scoring/score", method="POST")
async def score(
self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]],
input_rows: list[dict[str, Any]],
scoring_functions: dict[str, ScoringFnParams | None],
) -> ScoreResponse:
"""Score a list of rows.

View file

@ -6,18 +6,14 @@
from enum import Enum
from typing import (
Annotated,
Any,
Dict,
List,
Literal,
Optional,
Protocol,
Union,
runtime_checkable,
)
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import Resource, ResourceType
@ -46,12 +42,12 @@ class AggregationFunctionType(Enum):
class LLMAsJudgeScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.llm_as_judge.value] = ScoringFnParamsType.llm_as_judge.value
judge_model: str
prompt_template: Optional[str] = None
judge_score_regexes: Optional[List[str]] = Field(
prompt_template: str | None = None
judge_score_regexes: list[str] | None = Field(
description="Regexes to extract the answer from generated response",
default_factory=list,
)
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
aggregation_functions: list[AggregationFunctionType] | None = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=list,
)
@ -60,11 +56,11 @@ class LLMAsJudgeScoringFnParams(BaseModel):
@json_schema_type
class RegexParserScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.regex_parser.value] = ScoringFnParamsType.regex_parser.value
parsing_regexes: Optional[List[str]] = Field(
parsing_regexes: list[str] | None = Field(
description="Regex to extract the answer from generated response",
default_factory=list,
)
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
aggregation_functions: list[AggregationFunctionType] | None = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=list,
)
@ -73,33 +69,29 @@ class RegexParserScoringFnParams(BaseModel):
@json_schema_type
class BasicScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.basic.value] = ScoringFnParamsType.basic.value
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
aggregation_functions: list[AggregationFunctionType] | None = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=list,
)
ScoringFnParams = Annotated[
Union[
LLMAsJudgeScoringFnParams,
RegexParserScoringFnParams,
BasicScoringFnParams,
],
LLMAsJudgeScoringFnParams | RegexParserScoringFnParams | BasicScoringFnParams,
Field(discriminator="type"),
]
register_schema(ScoringFnParams, name="ScoringFnParams")
class CommonScoringFnFields(BaseModel):
description: Optional[str] = None
metadata: Dict[str, Any] = Field(
description: str | None = None
metadata: dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this definition",
)
return_type: ParamType = Field(
description="The return type of the deterministic function",
)
params: Optional[ScoringFnParams] = Field(
params: ScoringFnParams | None = Field(
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
default=None,
)
@ -120,12 +112,12 @@ class ScoringFn(CommonScoringFnFields, Resource):
class ScoringFnInput(CommonScoringFnFields, BaseModel):
scoring_fn_id: str
provider_id: Optional[str] = None
provider_scoring_fn_id: Optional[str] = None
provider_id: str | None = None
provider_scoring_fn_id: str | None = None
class ListScoringFunctionsResponse(BaseModel):
data: List[ScoringFn]
data: list[ScoringFn]
@runtime_checkable
@ -142,7 +134,7 @@ class ScoringFunctions(Protocol):
scoring_fn_id: str,
description: str,
return_type: ParamType,
provider_scoring_fn_id: Optional[str] = None,
provider_id: Optional[str] = None,
params: Optional[ScoringFnParams] = None,
provider_scoring_fn_id: str | None = None,
provider_id: str | None = None,
params: ScoringFnParams | None = None,
) -> None: ...

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from typing import Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel
@ -14,7 +14,7 @@ from llama_stack.schema_utils import json_schema_type, webmethod
class CommonShieldFields(BaseModel):
params: Optional[Dict[str, Any]] = None
params: dict[str, Any] | None = None
@json_schema_type
@ -34,12 +34,12 @@ class Shield(CommonShieldFields, Resource):
class ShieldInput(CommonShieldFields):
shield_id: str
provider_id: Optional[str] = None
provider_shield_id: Optional[str] = None
provider_id: str | None = None
provider_shield_id: str | None = None
class ListShieldsResponse(BaseModel):
data: List[Shield]
data: list[Shield]
@runtime_checkable
@ -55,7 +55,7 @@ class Shields(Protocol):
async def register_shield(
self,
shield_id: str,
provider_shield_id: Optional[str] = None,
provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
provider_shield_id: str | None = None,
provider_id: str | None = None,
params: dict[str, Any] | None = None,
) -> Shield: ...

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, Union
from typing import Any, Protocol
from pydantic import BaseModel
@ -28,24 +28,24 @@ class FilteringFunction(Enum):
class SyntheticDataGenerationRequest(BaseModel):
"""Request to generate synthetic data. A small batch of prompts and a filtering function"""
dialogs: List[Message]
dialogs: list[Message]
filtering_function: FilteringFunction = FilteringFunction.none
model: Optional[str] = None
model: str | None = None
@json_schema_type
class SyntheticDataGenerationResponse(BaseModel):
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
synthetic_data: List[Dict[str, Any]]
statistics: Optional[Dict[str, Any]] = None
synthetic_data: list[dict[str, Any]]
statistics: dict[str, Any] | None = None
class SyntheticDataGeneration(Protocol):
@webmethod(route="/synthetic-data-generation/generate")
def synthetic_data_generate(
self,
dialogs: List[Message],
dialogs: list[Message],
filtering_function: FilteringFunction = FilteringFunction.none,
model: Optional[str] = None,
) -> Union[SyntheticDataGenerationResponse]: ...
model: str | None = None,
) -> SyntheticDataGenerationResponse: ...

View file

@ -7,18 +7,14 @@
from datetime import datetime
from enum import Enum
from typing import (
Annotated,
Any,
Dict,
List,
Literal,
Optional,
Protocol,
Union,
runtime_checkable,
)
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.models.llama.datatypes import Primitive
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@ -37,11 +33,11 @@ class SpanStatus(Enum):
class Span(BaseModel):
span_id: str
trace_id: str
parent_span_id: Optional[str] = None
parent_span_id: str | None = None
name: str
start_time: datetime
end_time: Optional[datetime] = None
attributes: Optional[Dict[str, Any]] = Field(default_factory=dict)
end_time: datetime | None = None
attributes: dict[str, Any] | None = Field(default_factory=dict)
def set_attribute(self, key: str, value: Any):
if self.attributes is None:
@ -54,7 +50,7 @@ class Trace(BaseModel):
trace_id: str
root_span_id: str
start_time: datetime
end_time: Optional[datetime] = None
end_time: datetime | None = None
@json_schema_type
@ -78,7 +74,7 @@ class EventCommon(BaseModel):
trace_id: str
span_id: str
timestamp: datetime
attributes: Optional[Dict[str, Primitive]] = Field(default_factory=dict)
attributes: dict[str, Primitive] | None = Field(default_factory=dict)
@json_schema_type
@ -92,15 +88,15 @@ class UnstructuredLogEvent(EventCommon):
class MetricEvent(EventCommon):
type: Literal[EventType.METRIC.value] = EventType.METRIC.value
metric: str # this would be an enum
value: Union[int, float]
value: int | float
unit: str
@json_schema_type
class MetricInResponse(BaseModel):
metric: str
value: Union[int, float]
unit: Optional[str] = None
value: int | float
unit: str | None = None
# This is a short term solution to allow inference API to return metrics
@ -124,7 +120,7 @@ class MetricInResponse(BaseModel):
class MetricResponseMixin(BaseModel):
metrics: Optional[List[MetricInResponse]] = None
metrics: list[MetricInResponse] | None = None
@json_schema_type
@ -137,7 +133,7 @@ class StructuredLogType(Enum):
class SpanStartPayload(BaseModel):
type: Literal[StructuredLogType.SPAN_START.value] = StructuredLogType.SPAN_START.value
name: str
parent_span_id: Optional[str] = None
parent_span_id: str | None = None
@json_schema_type
@ -147,10 +143,7 @@ class SpanEndPayload(BaseModel):
StructuredLogPayload = Annotated[
Union[
SpanStartPayload,
SpanEndPayload,
],
SpanStartPayload | SpanEndPayload,
Field(discriminator="type"),
]
register_schema(StructuredLogPayload, name="StructuredLogPayload")
@ -163,11 +156,7 @@ class StructuredLogEvent(EventCommon):
Event = Annotated[
Union[
UnstructuredLogEvent,
MetricEvent,
StructuredLogEvent,
],
UnstructuredLogEvent | MetricEvent | StructuredLogEvent,
Field(discriminator="type"),
]
register_schema(Event, name="Event")
@ -184,7 +173,7 @@ class EvalTrace(BaseModel):
@json_schema_type
class SpanWithStatus(Span):
status: Optional[SpanStatus] = None
status: SpanStatus | None = None
@json_schema_type
@ -203,15 +192,15 @@ class QueryCondition(BaseModel):
class QueryTracesResponse(BaseModel):
data: List[Trace]
data: list[Trace]
class QuerySpansResponse(BaseModel):
data: List[Span]
data: list[Span]
class QuerySpanTreeResponse(BaseModel):
data: Dict[str, SpanWithStatus]
data: dict[str, SpanWithStatus]
@runtime_checkable
@ -222,10 +211,10 @@ class Telemetry(Protocol):
@webmethod(route="/telemetry/traces", method="POST")
async def query_traces(
self,
attribute_filters: Optional[List[QueryCondition]] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
attribute_filters: list[QueryCondition] | None = None,
limit: int | None = 100,
offset: int | None = 0,
order_by: list[str] | None = None,
) -> QueryTracesResponse: ...
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET")
@ -238,23 +227,23 @@ class Telemetry(Protocol):
async def get_span_tree(
self,
span_id: str,
attributes_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
attributes_to_return: list[str] | None = None,
max_depth: int | None = None,
) -> QuerySpanTreeResponse: ...
@webmethod(route="/telemetry/spans", method="POST")
async def query_spans(
self,
attribute_filters: List[QueryCondition],
attributes_to_return: List[str],
max_depth: Optional[int] = None,
attribute_filters: list[QueryCondition],
attributes_to_return: list[str],
max_depth: int | None = None,
) -> QuerySpansResponse: ...
@webmethod(route="/telemetry/spans/export", method="POST")
async def save_spans_to_dataset(
self,
attribute_filters: List[QueryCondition],
attributes_to_save: List[str],
attribute_filters: list[QueryCondition],
attributes_to_save: list[str],
dataset_id: str,
max_depth: Optional[int] = None,
max_depth: int | None = None,
) -> None: ...

View file

@ -5,10 +5,10 @@
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Union
from typing import Annotated, Any, Literal
from pydantic import BaseModel, Field
from typing_extensions import Annotated, Protocol, runtime_checkable
from typing_extensions import Protocol, runtime_checkable
from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@ -29,13 +29,13 @@ class RAGDocument(BaseModel):
document_id: str
content: InterleavedContent | URL
mime_type: str | None = None
metadata: Dict[str, Any] = Field(default_factory=dict)
metadata: dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class RAGQueryResult(BaseModel):
content: Optional[InterleavedContent] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
content: InterleavedContent | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
@json_schema_type
@ -59,10 +59,7 @@ class LLMRAGQueryGeneratorConfig(BaseModel):
RAGQueryGeneratorConfig = Annotated[
Union[
DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig,
],
DefaultRAGQueryGeneratorConfig | LLMRAGQueryGeneratorConfig,
Field(discriminator="type"),
]
register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
@ -83,7 +80,7 @@ class RAGToolRuntime(Protocol):
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST")
async def insert(
self,
documents: List[RAGDocument],
documents: list[RAGDocument],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
@ -94,8 +91,8 @@ class RAGToolRuntime(Protocol):
async def query(
self,
content: InterleavedContent,
vector_db_ids: List[str],
query_config: Optional[RAGQueryConfig] = None,
vector_db_ids: list[str],
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
"""Query the RAG system for context; typically invoked by the agent"""
...

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Literal, Optional
from typing import Any, Literal
from pydantic import BaseModel, Field
from typing_extensions import Protocol, runtime_checkable
@ -24,7 +24,7 @@ class ToolParameter(BaseModel):
parameter_type: str
description: str
required: bool = Field(default=True)
default: Optional[Any] = None
default: Any | None = None
@json_schema_type
@ -40,39 +40,39 @@ class Tool(Resource):
toolgroup_id: str
tool_host: ToolHost
description: str
parameters: List[ToolParameter]
metadata: Optional[Dict[str, Any]] = None
parameters: list[ToolParameter]
metadata: dict[str, Any] | None = None
@json_schema_type
class ToolDef(BaseModel):
name: str
description: Optional[str] = None
parameters: Optional[List[ToolParameter]] = None
metadata: Optional[Dict[str, Any]] = None
description: str | None = None
parameters: list[ToolParameter] | None = None
metadata: dict[str, Any] | None = None
@json_schema_type
class ToolGroupInput(BaseModel):
toolgroup_id: str
provider_id: str
args: Optional[Dict[str, Any]] = None
mcp_endpoint: Optional[URL] = None
args: dict[str, Any] | None = None
mcp_endpoint: URL | None = None
@json_schema_type
class ToolGroup(Resource):
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
mcp_endpoint: Optional[URL] = None
args: Optional[Dict[str, Any]] = None
mcp_endpoint: URL | None = None
args: dict[str, Any] | None = None
@json_schema_type
class ToolInvocationResult(BaseModel):
content: Optional[InterleavedContent] = None
error_message: Optional[str] = None
error_code: Optional[int] = None
metadata: Optional[Dict[str, Any]] = None
content: InterleavedContent | None = None
error_message: str | None = None
error_code: int | None = None
metadata: dict[str, Any] | None = None
class ToolStore(Protocol):
@ -81,11 +81,11 @@ class ToolStore(Protocol):
class ListToolGroupsResponse(BaseModel):
data: List[ToolGroup]
data: list[ToolGroup]
class ListToolsResponse(BaseModel):
data: List[Tool]
data: list[Tool]
class ListToolDefsResponse(BaseModel):
@ -100,8 +100,8 @@ class ToolGroups(Protocol):
self,
toolgroup_id: str,
provider_id: str,
mcp_endpoint: Optional[URL] = None,
args: Optional[Dict[str, Any]] = None,
mcp_endpoint: URL | None = None,
args: dict[str, Any] | None = None,
) -> None:
"""Register a tool group"""
...
@ -118,7 +118,7 @@ class ToolGroups(Protocol):
...
@webmethod(route="/tools", method="GET")
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
"""List tools with optional tool group"""
...
@ -151,10 +151,10 @@ class ToolRuntime(Protocol):
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
@webmethod(route="/tool-runtime/list-tools", method="GET")
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse: ...
@webmethod(route="/tool-runtime/invoke", method="POST")
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
"""Run a tool with the given arguments"""
...

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Literal, Optional, Protocol, runtime_checkable
from typing import Literal, Protocol, runtime_checkable
from pydantic import BaseModel
@ -33,11 +33,11 @@ class VectorDBInput(BaseModel):
vector_db_id: str
embedding_model: str
embedding_dimension: int
provider_vector_db_id: Optional[str] = None
provider_vector_db_id: str | None = None
class ListVectorDBsResponse(BaseModel):
data: List[VectorDB]
data: list[VectorDB]
@runtime_checkable
@ -57,9 +57,9 @@ class VectorDBs(Protocol):
self,
vector_db_id: str,
embedding_model: str,
embedding_dimension: Optional[int] = 384,
provider_id: Optional[str] = None,
provider_vector_db_id: Optional[str] = None,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorDB: ...
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE")

View file

@ -8,7 +8,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from typing import Any, Protocol, runtime_checkable
from pydantic import BaseModel, Field
@ -20,17 +20,17 @@ from llama_stack.schema_utils import json_schema_type, webmethod
class Chunk(BaseModel):
content: InterleavedContent
metadata: Dict[str, Any] = Field(default_factory=dict)
metadata: dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class QueryChunksResponse(BaseModel):
chunks: List[Chunk]
scores: List[float]
chunks: list[Chunk]
scores: list[float]
class VectorDBStore(Protocol):
def get_vector_db(self, vector_db_id: str) -> Optional[VectorDB]: ...
def get_vector_db(self, vector_db_id: str) -> VectorDB | None: ...
@runtime_checkable
@ -44,8 +44,8 @@ class VectorIO(Protocol):
async def insert_chunks(
self,
vector_db_id: str,
chunks: List[Chunk],
ttl_seconds: Optional[int] = None,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None: ...
@webmethod(route="/vector-io/query", method="POST")
@ -53,5 +53,5 @@ class VectorIO(Protocol):
self,
vector_db_id: str,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
params: dict[str, Any] | None = None,
) -> QueryChunksResponse: ...