chore: enable pyupgrade fixes (#1806)

# What does this PR do?

The goal of this PR is code base modernization.

Schema reflection code needed a minor adjustment to handle UnionTypes
and collections.abc.AsyncIterator. (Both are preferred for latest Python
releases.)

Note to reviewers: almost all changes here are automatically generated
by pyupgrade. Some additional unused imports were cleaned up. The only
change worth of note can be found under `docs/openapi_generator` and
`llama_stack/strong_typing/schema.py` where reflection code was updated
to deal with "newer" types.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-05-01 17:23:50 -04:00 committed by GitHub
parent ffe3d0b2cd
commit 9e6561a1ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
319 changed files with 2843 additions and 3033 deletions

View file

@ -3,54 +3,53 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from pydantic import BaseModel, Field
class BedrockBaseConfig(BaseModel):
aws_access_key_id: Optional[str] = Field(
aws_access_key_id: str | None = Field(
default=None,
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",
)
aws_secret_access_key: Optional[str] = Field(
aws_secret_access_key: str | None = Field(
default=None,
description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY",
)
aws_session_token: Optional[str] = Field(
aws_session_token: str | None = Field(
default=None,
description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN",
)
region_name: Optional[str] = Field(
region_name: str | None = Field(
default=None,
description="The default AWS Region to use, for example, us-west-1 or us-west-2."
"Default use environment variable: AWS_DEFAULT_REGION",
)
profile_name: Optional[str] = Field(
profile_name: str | None = Field(
default=None,
description="The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE",
)
total_max_attempts: Optional[int] = Field(
total_max_attempts: int | None = Field(
default=None,
description="An integer representing the maximum number of attempts that will be made for a single request, "
"including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS",
)
retry_mode: Optional[str] = Field(
retry_mode: str | None = Field(
default=None,
description="A string representing the type of retries Boto3 will perform."
"Default use environment variable: AWS_RETRY_MODE",
)
connect_timeout: Optional[float] = Field(
connect_timeout: float | None = Field(
default=60,
description="The time in seconds till a timeout exception is thrown when attempting to make a connection. "
"The default is 60 seconds.",
)
read_timeout: Optional[float] = Field(
read_timeout: float | None = Field(
default=60,
description="The time in seconds till a timeout exception is thrown when attempting to read from a connection."
"The default is 60 seconds.",
)
session_ttl: Optional[int] = Field(
session_ttl: int | None = Field(
default=3600,
description="The time in seconds till a session expires. The default is 3600 seconds (1 hour).",
)

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List
from typing import Any
from llama_stack.apis.common.type_system import (
ChatCompletionInputType,
@ -85,16 +85,16 @@ def get_valid_schemas(api_str: str):
def validate_dataset_schema(
dataset_schema: Dict[str, Any],
expected_schemas: List[Dict[str, Any]],
dataset_schema: dict[str, Any],
expected_schemas: list[dict[str, Any]],
):
if dataset_schema not in expected_schemas:
raise ValueError(f"Dataset {dataset_schema} does not have a correct input schema in {expected_schemas}")
def validate_row_schema(
input_row: Dict[str, Any],
expected_schemas: List[Dict[str, Any]],
input_row: dict[str, Any],
expected_schemas: list[dict[str, Any]],
):
for schema in expected_schemas:
if all(key in input_row for key in schema):

View file

@ -4,13 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
from typing import Any
from llama_stack.apis.common.responses import PaginatedResponse
def paginate_records(
records: List[Dict[str, Any]],
records: list[dict[str, Any]],
start_index: int | None = None,
limit: int | None = None,
) -> PaginatedResponse:

View file

@ -4,8 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.models.llama.sku_types import * # noqa: F403
@ -22,7 +20,7 @@ def is_supported_safety_model(model: Model) -> bool:
]
def supported_inference_models() -> List[Model]:
def supported_inference_models() -> list[Model]:
return [
m
for m in all_registered_models()

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import logging
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from sentence_transformers import SentenceTransformer
@ -31,10 +31,10 @@ class SentenceTransformerEmbeddingMixin:
async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)

View file

@ -4,7 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
import litellm
@ -64,7 +65,7 @@ class LiteLLMOpenAIMixin(
def __init__(
self,
model_entries,
api_key_from_config: Optional[str],
api_key_from_config: str | None,
provider_data_api_key_field: str,
openai_compat_api_base: str | None = None,
):
@ -97,26 +98,26 @@ class LiteLLMOpenAIMixin(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> AsyncGenerator:
raise NotImplementedError("LiteLLM does not support completion requests")
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
if sampling_params is None:
sampling_params = SamplingParams()
@ -243,10 +244,10 @@ class LiteLLMOpenAIMixin(
async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
@ -261,24 +262,24 @@ class LiteLLMOpenAIMixin(
async def openai_completion(
self,
model: str,
prompt: Union[str, List[str], List[int], List[List[int]]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
@ -309,29 +310,29 @@ class LiteLLMOpenAIMixin(
async def openai_chat_completion(
self,
model: str,
messages: List[OpenAIMessageParam],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=self.get_litellm_model_name(model_obj.provider_resource_id),
@ -365,21 +366,21 @@ class LiteLLMOpenAIMixin(
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch completion is not supported for OpenAI Compat")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_config: ToolConfig | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch chat completion is not supported for OpenAI Compat")

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional
from typing import Any
from pydantic import BaseModel, Field
@ -20,13 +20,13 @@ from llama_stack.providers.utils.inference import (
# more closer to the Model class.
class ProviderModelEntry(BaseModel):
provider_model_id: str
aliases: List[str] = Field(default_factory=list)
llama_model: Optional[str] = None
aliases: list[str] = Field(default_factory=list)
llama_model: str | None = None
model_type: ModelType = ModelType.llm
metadata: Dict[str, Any] = Field(default_factory=dict)
metadata: dict[str, Any] = Field(default_factory=dict)
def get_huggingface_repo(model_descriptor: str) -> Optional[str]:
def get_huggingface_repo(model_descriptor: str) -> str | None:
for model in all_registered_models():
if model.descriptor() == model_descriptor:
return model.huggingface_repo
@ -34,7 +34,7 @@ def get_huggingface_repo(model_descriptor: str) -> Optional[str]:
def build_hf_repo_model_entry(
provider_model_id: str, model_descriptor: str, additional_aliases: Optional[List[str]] = None
provider_model_id: str, model_descriptor: str, additional_aliases: list[str] | None = None
) -> ProviderModelEntry:
aliases = [
get_huggingface_repo(model_descriptor),
@ -58,7 +58,7 @@ def build_model_entry(provider_model_id: str, model_descriptor: str) -> Provider
class ModelRegistryHelper(ModelsProtocolPrivate):
def __init__(self, model_entries: List[ProviderModelEntry]):
def __init__(self, model_entries: list[ProviderModelEntry]):
self.alias_to_provider_id_map = {}
self.provider_id_to_llama_model_map = {}
for entry in model_entries:
@ -72,11 +72,11 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
self.alias_to_provider_id_map[entry.llama_model] = entry.provider_model_id
self.provider_id_to_llama_model_map[entry.provider_model_id] = entry.llama_model
def get_provider_model_id(self, identifier: str) -> Optional[str]:
def get_provider_model_id(self, identifier: str) -> str | None:
return self.alias_to_provider_id_map.get(identifier, None)
# TODO: why keep a separate llama model mapping?
def get_llama_model(self, provider_model_id: str) -> Optional[str]:
def get_llama_model(self, provider_model_id: str) -> str | None:
return self.provider_id_to_llama_model_map.get(provider_model_id, None)
async def register_model(self, model: Model) -> Model:

View file

@ -8,16 +8,9 @@ import logging
import time
import uuid
import warnings
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Iterable
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Awaitable,
Dict,
Iterable,
List,
Optional,
Union,
)
from openai import AsyncStream
@ -141,24 +134,24 @@ class OpenAICompatCompletionChoiceDelta(BaseModel):
class OpenAICompatLogprobs(BaseModel):
text_offset: Optional[List[int]] = None
text_offset: list[int] | None = None
token_logprobs: Optional[List[float]] = None
token_logprobs: list[float] | None = None
tokens: Optional[List[str]] = None
tokens: list[str] | None = None
top_logprobs: Optional[List[Dict[str, float]]] = None
top_logprobs: list[dict[str, float]] | None = None
class OpenAICompatCompletionChoice(BaseModel):
finish_reason: Optional[str] = None
text: Optional[str] = None
delta: Optional[OpenAICompatCompletionChoiceDelta] = None
logprobs: Optional[OpenAICompatLogprobs] = None
finish_reason: str | None = None
text: str | None = None
delta: OpenAICompatCompletionChoiceDelta | None = None
logprobs: OpenAICompatLogprobs | None = None
class OpenAICompatCompletionResponse(BaseModel):
choices: List[OpenAICompatCompletionChoice]
choices: list[OpenAICompatCompletionChoice]
def get_sampling_strategy_options(params: SamplingParams) -> dict:
@ -217,8 +210,8 @@ def get_stop_reason(finish_reason: str) -> StopReason:
def convert_openai_completion_logprobs(
logprobs: Optional[OpenAICompatLogprobs],
) -> Optional[List[TokenLogProbs]]:
logprobs: OpenAICompatLogprobs | None,
) -> list[TokenLogProbs] | None:
if not logprobs:
return None
if hasattr(logprobs, "top_logprobs"):
@ -235,7 +228,7 @@ def convert_openai_completion_logprobs(
return None
def convert_openai_completion_logprobs_stream(text: str, logprobs: Optional[Union[float, OpenAICompatLogprobs]]):
def convert_openai_completion_logprobs_stream(text: str, logprobs: float | OpenAICompatLogprobs | None):
if logprobs is None:
return None
if isinstance(logprobs, float):
@ -562,7 +555,7 @@ class UnparseableToolCall(BaseModel):
async def convert_message_to_openai_dict_new(
message: Message | Dict,
message: Message | dict,
) -> OpenAIChatCompletionMessage:
"""
Convert a Message to an OpenAI API-compatible dictionary.
@ -591,14 +584,10 @@ async def convert_message_to_openai_dict_new(
# List[...] -> List[...]
async def _convert_message_content(
content: InterleavedContent,
) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]:
) -> str | Iterable[OpenAIChatCompletionContentPartParam]:
async def impl(
content_: InterleavedContent,
) -> Union[
str,
OpenAIChatCompletionContentPartParam,
List[OpenAIChatCompletionContentPartParam],
]:
) -> str | OpenAIChatCompletionContentPartParam | list[OpenAIChatCompletionContentPartParam]:
# Llama Stack and OpenAI spec match for str and text input
if isinstance(content_, str):
return content_
@ -670,7 +659,7 @@ async def convert_message_to_openai_dict_new(
def convert_tool_call(
tool_call: ChatCompletionMessageToolCall,
) -> Union[ToolCall, UnparseableToolCall]:
) -> ToolCall | UnparseableToolCall:
"""
Convert a ChatCompletionMessageToolCall tool call to either a
ToolCall or UnparseableToolCall. Returns an UnparseableToolCall
@ -846,7 +835,7 @@ def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
}.get(finish_reason, StopReason.end_of_turn)
def _convert_openai_request_tool_config(tool_choice: Optional[Union[str, Dict[str, Any]]] = None) -> ToolConfig:
def _convert_openai_request_tool_config(tool_choice: str | dict[str, Any] | None = None) -> ToolConfig:
tool_config = ToolConfig()
if tool_choice:
try:
@ -857,7 +846,7 @@ def _convert_openai_request_tool_config(tool_choice: Optional[Union[str, Dict[st
return tool_config
def _convert_openai_request_tools(tools: Optional[List[Dict[str, Any]]] = None) -> List[ToolDefinition]:
def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) -> list[ToolDefinition]:
lls_tools = []
if not tools:
return lls_tools
@ -903,8 +892,8 @@ def _convert_openai_request_response_format(
def _convert_openai_tool_calls(
tool_calls: List[OpenAIChatCompletionMessageToolCall],
) -> List[ToolCall]:
tool_calls: list[OpenAIChatCompletionMessageToolCall],
) -> list[ToolCall]:
"""
Convert an OpenAI ChatCompletionMessageToolCall list into a list of ToolCall.
@ -940,7 +929,7 @@ def _convert_openai_tool_calls(
def _convert_openai_logprobs(
logprobs: OpenAIChoiceLogprobs,
) -> Optional[List[TokenLogProbs]]:
) -> list[TokenLogProbs] | None:
"""
Convert an OpenAI ChoiceLogprobs into a list of TokenLogProbs.
@ -973,9 +962,9 @@ def _convert_openai_logprobs(
def _convert_openai_sampling_params(
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
max_tokens: int | None = None,
temperature: float | None = None,
top_p: float | None = None,
) -> SamplingParams:
sampling_params = SamplingParams()
@ -998,8 +987,8 @@ def _convert_openai_sampling_params(
def openai_messages_to_messages(
messages: List[OpenAIChatCompletionMessage],
) -> List[Message]:
messages: list[OpenAIChatCompletionMessage],
) -> list[Message]:
"""
Convert a list of OpenAIChatCompletionMessage into a list of Message.
"""
@ -1027,7 +1016,7 @@ def openai_messages_to_messages(
return converted_messages
def openai_content_to_content(content: Union[str, Iterable[OpenAIChatCompletionContentPartParam]]):
def openai_content_to_content(content: str | Iterable[OpenAIChatCompletionContentPartParam]):
if isinstance(content, str):
return content
elif isinstance(content, list):
@ -1273,24 +1262,24 @@ class OpenAICompletionToLlamaStackMixin:
async def openai_completion(
self,
model: str,
prompt: Union[str, List[str], List[int], List[List[int]]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
) -> OpenAICompletion:
if stream:
raise ValueError(f"{self.__class__.__name__} doesn't support streaming openai completions")
@ -1342,29 +1331,29 @@ class OpenAIChatCompletionToLlamaStackMixin:
async def openai_chat_completion(
self,
model: str,
messages: List[OpenAIChatCompletionMessage],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
messages: list[OpenAIChatCompletionMessage],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
messages = openai_messages_to_messages(messages)
response_format = _convert_openai_request_response_format(response_format)
sampling_params = _convert_openai_sampling_params(
@ -1403,7 +1392,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
async def _process_stream_response(
self,
model: str,
outstanding_responses: List[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]],
outstanding_responses: list[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]],
):
id = f"chatcmpl-{uuid.uuid4()}"
for outstanding_response in outstanding_responses:
@ -1466,7 +1455,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
i = i + 1
async def _process_non_stream_response(
self, model: str, outstanding_responses: List[Awaitable[ChatCompletionResponse]]
self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]]
) -> OpenAIChatCompletion:
choices = []
for outstanding_response in outstanding_responses:

View file

@ -9,7 +9,6 @@ import base64
import io
import json
import re
from typing import List, Optional, Tuple, Union
import httpx
from PIL import Image as PIL_Image
@ -63,7 +62,7 @@ log = get_logger(name=__name__, category="inference")
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):
messages: List[RawMessage]
messages: list[RawMessage]
class CompletionRequestWithRawContent(CompletionRequest):
@ -93,8 +92,8 @@ def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> s
async def convert_request_to_raw(
request: Union[ChatCompletionRequest, CompletionRequest],
) -> Union[ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent]:
request: ChatCompletionRequest | CompletionRequest,
) -> ChatCompletionRequestWithRawContent | CompletionRequestWithRawContent:
if isinstance(request, ChatCompletionRequest):
messages = []
for m in request.messages:
@ -170,18 +169,18 @@ def content_has_media(content: InterleavedContent):
return _has_media_content(content)
def messages_have_media(messages: List[Message]):
def messages_have_media(messages: list[Message]):
return any(content_has_media(m.content) for m in messages)
def request_has_media(request: Union[ChatCompletionRequest, CompletionRequest]):
def request_has_media(request: ChatCompletionRequest | CompletionRequest):
if isinstance(request, ChatCompletionRequest):
return messages_have_media(request.messages)
else:
return content_has_media(request.content)
async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
async def localize_image_content(media: ImageContentItem) -> tuple[bytes, str]:
image = media.image
if image.url and image.url.uri.startswith("http"):
async with httpx.AsyncClient() as client:
@ -228,7 +227,7 @@ async def completion_request_to_prompt(request: CompletionRequest) -> str:
async def completion_request_to_prompt_model_input_info(
request: CompletionRequest,
) -> Tuple[str, int]:
) -> tuple[str, int]:
content = augment_content_with_response_format_prompt(request.response_format, request.content)
request.content = content
request = await convert_request_to_raw(request)
@ -265,7 +264,7 @@ async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llam
async def chat_completion_request_to_model_input_info(
request: ChatCompletionRequest, llama_model: str
) -> Tuple[str, int]:
) -> tuple[str, int]:
messages = chat_completion_request_to_messages(request, llama_model)
request.messages = messages
request = await convert_request_to_raw(request)
@ -284,7 +283,7 @@ async def chat_completion_request_to_model_input_info(
def chat_completion_request_to_messages(
request: ChatCompletionRequest,
llama_model: str,
) -> List[Message]:
) -> list[Message]:
"""Reads chat completion request and augments the messages to handle tools.
For eg. for llama_3_1, add system message with the appropriate tools or
add user messsage for custom tools, etc.
@ -323,7 +322,7 @@ def chat_completion_request_to_messages(
return messages
def response_format_prompt(fmt: Optional[ResponseFormat]):
def response_format_prompt(fmt: ResponseFormat | None):
if not fmt:
return None
@ -337,7 +336,7 @@ def response_format_prompt(fmt: Optional[ResponseFormat]):
def augment_messages_for_tools_llama_3_1(
request: ChatCompletionRequest,
) -> List[Message]:
) -> list[Message]:
existing_messages = request.messages
existing_system_message = None
if existing_messages[0].role == Role.system.value:
@ -406,7 +405,7 @@ def augment_messages_for_tools_llama_3_1(
def augment_messages_for_tools_llama(
request: ChatCompletionRequest,
custom_tool_prompt_generator,
) -> List[Message]:
) -> list[Message]:
existing_messages = request.messages
existing_system_message = None
if existing_messages[0].role == Role.system.value:
@ -457,7 +456,7 @@ def augment_messages_for_tools_llama(
return messages
def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefinition]) -> str:
def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: list[ToolDefinition]) -> str:
if tool_choice == ToolChoice.auto:
return ""
elif tool_choice == ToolChoice.required:

View file

@ -5,15 +5,15 @@
# the root directory of this source tree.
from datetime import datetime
from typing import List, Optional, Protocol
from typing import Protocol
class KVStore(Protocol):
# TODO: make the value type bytes instead of str
async def set(self, key: str, value: str, expiration: Optional[datetime] = None) -> None: ...
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None: ...
async def get(self, key: str) -> Optional[str]: ...
async def get(self, key: str) -> str | None: ...
async def delete(self, key: str) -> None: ...
async def range(self, start_key: str, end_key: str) -> List[str]: ...
async def range(self, start_key: str, end_key: str) -> list[str]: ...

View file

@ -6,10 +6,9 @@
import re
from enum import Enum
from typing import Literal, Optional, Union
from typing import Annotated, Literal
from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
@ -22,7 +21,7 @@ class KVStoreType(Enum):
class CommonConfig(BaseModel):
namespace: Optional[str] = Field(
namespace: str | None = Field(
default=None,
description="All keys will be prefixed with this namespace",
)
@ -69,7 +68,7 @@ class PostgresKVStoreConfig(CommonConfig):
port: int = 5432
db: str = "llamastack"
user: str
password: Optional[str] = None
password: str | None = None
table_name: str = "llamastack_kvstore"
@classmethod
@ -108,7 +107,7 @@ class MongoDBKVStoreConfig(CommonConfig):
port: int = 27017
db: str = "llamastack"
user: str = None
password: Optional[str] = None
password: str | None = None
collection_name: str = "llamastack_kvstore"
@classmethod
@ -126,6 +125,6 @@ class MongoDBKVStoreConfig(CommonConfig):
KVStoreConfig = Annotated[
Union[RedisKVStoreConfig, SqliteKVStoreConfig, PostgresKVStoreConfig, MongoDBKVStoreConfig],
RedisKVStoreConfig | SqliteKVStoreConfig | PostgresKVStoreConfig | MongoDBKVStoreConfig,
Field(discriminator="type", default=KVStoreType.sqlite.value),
]

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional
from .api import KVStore
from .config import KVStoreConfig, KVStoreType
@ -21,13 +20,13 @@ class InmemoryKVStoreImpl(KVStore):
async def initialize(self) -> None:
pass
async def get(self, key: str) -> Optional[str]:
async def get(self, key: str) -> str | None:
return self._store.get(key)
async def set(self, key: str, value: str) -> None:
self._store[key] = value
async def range(self, start_key: str, end_key: str) -> List[str]:
async def range(self, start_key: str, end_key: str) -> list[str]:
return [self._store[key] for key in self._store.keys() if key >= start_key and key < end_key]

View file

@ -6,7 +6,6 @@
import logging
from datetime import datetime
from typing import List, Optional
from pymongo import AsyncMongoClient
@ -43,12 +42,12 @@ class MongoDBKVStoreImpl(KVStore):
return key
return f"{self.config.namespace}:{key}"
async def set(self, key: str, value: str, expiration: Optional[datetime] = None) -> None:
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
key = self._namespaced_key(key)
update_query = {"$set": {"value": value, "expiration": expiration}}
await self.collection.update_one({"key": key}, update_query, upsert=True)
async def get(self, key: str) -> Optional[str]:
async def get(self, key: str) -> str | None:
key = self._namespaced_key(key)
query = {"key": key}
result = await self.collection.find_one(query, {"value": 1, "_id": 0})
@ -58,7 +57,7 @@ class MongoDBKVStoreImpl(KVStore):
key = self._namespaced_key(key)
await self.collection.delete_one({"key": key})
async def range(self, start_key: str, end_key: str) -> List[str]:
async def range(self, start_key: str, end_key: str) -> list[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
query = {

View file

@ -6,7 +6,6 @@
import logging
from datetime import datetime
from typing import List, Optional
import psycopg2
from psycopg2.extras import DictCursor
@ -54,7 +53,7 @@ class PostgresKVStoreImpl(KVStore):
return key
return f"{self.config.namespace}:{key}"
async def set(self, key: str, value: str, expiration: Optional[datetime] = None) -> None:
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
key = self._namespaced_key(key)
self.cursor.execute(
f"""
@ -66,7 +65,7 @@ class PostgresKVStoreImpl(KVStore):
(key, value, expiration),
)
async def get(self, key: str) -> Optional[str]:
async def get(self, key: str) -> str | None:
key = self._namespaced_key(key)
self.cursor.execute(
f"""
@ -86,7 +85,7 @@ class PostgresKVStoreImpl(KVStore):
(key,),
)
async def range(self, start_key: str, end_key: str) -> List[str]:
async def range(self, start_key: str, end_key: str) -> list[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
from datetime import datetime
from typing import List, Optional
from redis.asyncio import Redis
@ -25,13 +24,13 @@ class RedisKVStoreImpl(KVStore):
return key
return f"{self.config.namespace}:{key}"
async def set(self, key: str, value: str, expiration: Optional[datetime] = None) -> None:
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
key = self._namespaced_key(key)
await self.redis.set(key, value)
if expiration:
await self.redis.expireat(key, expiration)
async def get(self, key: str) -> Optional[str]:
async def get(self, key: str) -> str | None:
key = self._namespaced_key(key)
value = await self.redis.get(key)
if value is None:
@ -43,7 +42,7 @@ class RedisKVStoreImpl(KVStore):
key = self._namespaced_key(key)
await self.redis.delete(key)
async def range(self, start_key: str, end_key: str) -> List[str]:
async def range(self, start_key: str, end_key: str) -> list[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
cursor = 0

View file

@ -6,7 +6,6 @@
import os
from datetime import datetime
from typing import List, Optional
import aiosqlite
@ -33,7 +32,7 @@ class SqliteKVStoreImpl(KVStore):
)
await db.commit()
async def set(self, key: str, value: str, expiration: Optional[datetime] = None) -> None:
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
async with aiosqlite.connect(self.db_path) as db:
await db.execute(
f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)",
@ -41,7 +40,7 @@ class SqliteKVStoreImpl(KVStore):
)
await db.commit()
async def get(self, key: str) -> Optional[str]:
async def get(self, key: str) -> str | None:
async with aiosqlite.connect(self.db_path) as db:
async with db.execute(f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)) as cursor:
row = await cursor.fetchone()
@ -55,7 +54,7 @@ class SqliteKVStoreImpl(KVStore):
await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
await db.commit()
async def range(self, start_key: str, end_key: str) -> List[str]:
async def range(self, start_key: str, end_key: str) -> list[str]:
async with aiosqlite.connect(self.db_path) as db:
async with db.execute(
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",

View file

@ -9,7 +9,7 @@ import logging
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from typing import Any
from urllib.parse import unquote
import httpx
@ -94,7 +94,7 @@ def content_from_data(data_url: str) -> str:
return ""
def concat_interleaved_content(content: List[InterleavedContent]) -> InterleavedContent:
def concat_interleaved_content(content: list[InterleavedContent]) -> InterleavedContent:
"""concatenate interleaved content into a single list. ensure that 'str's are converted to TextContentItem when in a list"""
ret = []
@ -141,7 +141,7 @@ async def content_from_doc(doc: RAGDocument) -> str:
return interleaved_content_as_str(doc.content)
def make_overlapped_chunks(document_id: str, text: str, window_len: int, overlap_len: int) -> List[Chunk]:
def make_overlapped_chunks(document_id: str, text: str, window_len: int, overlap_len: int) -> list[Chunk]:
tokenizer = Tokenizer.get_instance()
tokens = tokenizer.encode(text, bos=False, eos=False)
@ -165,7 +165,7 @@ def make_overlapped_chunks(document_id: str, text: str, window_len: int, overlap
class EmbeddingIndex(ABC):
@abstractmethod
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
raise NotImplementedError()
@abstractmethod
@ -185,7 +185,7 @@ class VectorDBWithIndex:
async def insert_chunks(
self,
chunks: List[Chunk],
chunks: list[Chunk],
) -> None:
embeddings_response = await self.inference_api.embeddings(
self.vector_db.embedding_model, [x.content for x in chunks]
@ -197,7 +197,7 @@ class VectorDBWithIndex:
async def query_chunks(
self,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
if params is None:
params = {}

View file

@ -8,9 +8,10 @@ import abc
import asyncio
import functools
import threading
from collections.abc import Callable, Coroutine, Iterable
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Callable, Coroutine, Dict, Iterable, Tuple, TypeAlias
from typing import Any, TypeAlias
from pydantic import BaseModel
@ -38,7 +39,7 @@ class JobArtifact(BaseModel):
name: str
# TODO: uri should be a reference to /files API; revisit when /files is implemented
uri: str | None = None
metadata: Dict[str, Any]
metadata: dict[str, Any]
JobHandler = Callable[
@ -46,7 +47,7 @@ JobHandler = Callable[
]
LogMessage: TypeAlias = Tuple[datetime, str]
LogMessage: TypeAlias = tuple[datetime, str]
_COMPLETED_STATUSES = {JobStatus.completed, JobStatus.failed}
@ -60,7 +61,7 @@ class Job:
self._handler = handler
self._artifacts: list[JobArtifact] = []
self._logs: list[LogMessage] = []
self._state_transitions: list[Tuple[datetime, JobStatus]] = [(datetime.now(timezone.utc), JobStatus.new)]
self._state_transitions: list[tuple[datetime, JobStatus]] = [(datetime.now(timezone.utc), JobStatus.new)]
@property
def handler(self) -> JobHandler:

View file

@ -4,13 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import statistics
from typing import Any, Dict, List
from typing import Any
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import AggregationFunctionType
def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
def aggregate_accuracy(scoring_results: list[ScoringResultRow]) -> dict[str, Any]:
num_correct = sum(result["score"] for result in scoring_results)
avg_score = num_correct / len(scoring_results)
@ -21,14 +21,14 @@ def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any
}
def aggregate_average(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
def aggregate_average(scoring_results: list[ScoringResultRow]) -> dict[str, Any]:
return {
"average": sum(result["score"] for result in scoring_results if result["score"] is not None)
/ len([_ for _ in scoring_results if _["score"] is not None]),
}
def aggregate_weighted_average(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
def aggregate_weighted_average(scoring_results: list[ScoringResultRow]) -> dict[str, Any]:
return {
"weighted_average": sum(
result["score"] * result["weight"]
@ -40,14 +40,14 @@ def aggregate_weighted_average(scoring_results: List[ScoringResultRow]) -> Dict[
def aggregate_categorical_count(
scoring_results: List[ScoringResultRow],
) -> Dict[str, Any]:
scoring_results: list[ScoringResultRow],
) -> dict[str, Any]:
scores = [str(r["score"]) for r in scoring_results]
unique_scores = sorted(set(scores))
return {"categorical_count": {s: scores.count(s) for s in unique_scores}}
def aggregate_median(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
def aggregate_median(scoring_results: list[ScoringResultRow]) -> dict[str, Any]:
scores = [r["score"] for r in scoring_results if r["score"] is not None]
median = statistics.median(scores) if scores else None
return {"median": median}
@ -64,8 +64,8 @@ AGGREGATION_FUNCTIONS = {
def aggregate_metrics(
scoring_results: List[ScoringResultRow], metrics: List[AggregationFunctionType]
) -> Dict[str, Any]:
scoring_results: list[ScoringResultRow], metrics: list[AggregationFunctionType]
) -> dict[str, Any]:
agg_results = {}
for metric in metrics:
if metric not in AGGREGATION_FUNCTIONS:

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any
from llama_stack.apis.scoring import ScoringFnParams, ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFn
@ -28,28 +28,28 @@ class BaseScoringFn(ABC):
@abstractmethod
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
input_row: dict[str, Any],
scoring_fn_identifier: str | None = None,
scoring_params: ScoringFnParams | None = None,
) -> ScoringResultRow:
raise NotImplementedError()
@abstractmethod
async def aggregate(
self,
scoring_results: List[ScoringResultRow],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> Dict[str, Any]:
scoring_results: list[ScoringResultRow],
scoring_fn_identifier: str | None = None,
scoring_params: ScoringFnParams | None = None,
) -> dict[str, Any]:
raise NotImplementedError()
@abstractmethod
async def score(
self,
input_rows: List[Dict[str, Any]],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> List[ScoringResultRow]:
input_rows: list[dict[str, Any]],
scoring_fn_identifier: str | None = None,
scoring_params: ScoringFnParams | None = None,
) -> list[ScoringResultRow]:
raise NotImplementedError()
@ -65,7 +65,7 @@ class RegisteredBaseScoringFn(BaseScoringFn):
def __str__(self) -> str:
return self.__class__.__name__
def get_supported_scoring_fn_defs(self) -> List[ScoringFn]:
def get_supported_scoring_fn_defs(self) -> list[ScoringFn]:
return list(self.supported_fn_defs_registry.values())
def register_scoring_fn_def(self, scoring_fn: ScoringFn) -> None:
@ -81,18 +81,18 @@ class RegisteredBaseScoringFn(BaseScoringFn):
@abstractmethod
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
input_row: dict[str, Any],
scoring_fn_identifier: str | None = None,
scoring_params: ScoringFnParams | None = None,
) -> ScoringResultRow:
raise NotImplementedError()
async def aggregate(
self,
scoring_results: List[ScoringResultRow],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> Dict[str, Any]:
scoring_results: list[ScoringResultRow],
scoring_fn_identifier: str | None = None,
scoring_params: ScoringFnParams | None = None,
) -> dict[str, Any]:
params = self.supported_fn_defs_registry[scoring_fn_identifier].params
if scoring_params is not None:
if params is None:
@ -107,8 +107,8 @@ class RegisteredBaseScoringFn(BaseScoringFn):
async def score(
self,
input_rows: List[Dict[str, Any]],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> List[ScoringResultRow]:
input_rows: list[dict[str, Any]],
scoring_fn_identifier: str | None = None,
scoring_params: ScoringFnParams | None = None,
) -> list[ScoringResultRow]:
return [await self.score_row(input_row, scoring_fn_identifier, scoring_params) for input_row in input_rows]

View file

@ -5,8 +5,8 @@
# the root directory of this source tree.
import contextlib
import signal
from collections.abc import Iterator
from types import FrameType
from typing import Iterator, Optional
class TimeoutError(Exception):
@ -15,7 +15,7 @@ class TimeoutError(Exception):
@contextlib.contextmanager
def time_limit(seconds: float) -> Iterator[None]:
def signal_handler(signum: int, frame: Optional[FrameType]) -> None:
def signal_handler(signum: int, frame: FrameType | None) -> None:
raise TimeoutError("Timed out!")
signal.setitimer(signal.ITIMER_REAL, seconds)

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.telemetry import QueryCondition, QuerySpansResponse, Span
@ -17,10 +16,10 @@ class TelemetryDatasetMixin:
async def save_spans_to_dataset(
self,
attribute_filters: List[QueryCondition],
attributes_to_save: List[str],
attribute_filters: list[QueryCondition],
attributes_to_save: list[str],
dataset_id: str,
max_depth: Optional[int] = None,
max_depth: int | None = None,
) -> None:
if self.datasetio_api is None:
raise RuntimeError("DatasetIO API not available")
@ -48,9 +47,9 @@ class TelemetryDatasetMixin:
async def query_spans(
self,
attribute_filters: List[QueryCondition],
attributes_to_return: List[str],
max_depth: Optional[int] = None,
attribute_filters: list[QueryCondition],
attributes_to_return: list[str],
max_depth: int | None = None,
) -> QuerySpansResponse:
traces = await self.query_traces(attribute_filters=attribute_filters)
spans = []

View file

@ -6,7 +6,7 @@
import json
from datetime import datetime
from typing import Dict, List, Optional, Protocol
from typing import Protocol
import aiosqlite
@ -16,18 +16,18 @@ from llama_stack.apis.telemetry import QueryCondition, Span, SpanWithStatus, Tra
class TraceStore(Protocol):
async def query_traces(
self,
attribute_filters: Optional[List[QueryCondition]] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> List[Trace]: ...
attribute_filters: list[QueryCondition] | None = None,
limit: int | None = 100,
offset: int | None = 0,
order_by: list[str] | None = None,
) -> list[Trace]: ...
async def get_span_tree(
self,
span_id: str,
attributes_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
) -> Dict[str, SpanWithStatus]: ...
attributes_to_return: list[str] | None = None,
max_depth: int | None = None,
) -> dict[str, SpanWithStatus]: ...
class SQLiteTraceStore(TraceStore):
@ -36,11 +36,11 @@ class SQLiteTraceStore(TraceStore):
async def query_traces(
self,
attribute_filters: Optional[List[QueryCondition]] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> List[Trace]:
attribute_filters: list[QueryCondition] | None = None,
limit: int | None = 100,
offset: int | None = 0,
order_by: list[str] | None = None,
) -> list[Trace]:
def build_where_clause() -> tuple[str, list]:
if not attribute_filters:
return "", []
@ -112,9 +112,9 @@ class SQLiteTraceStore(TraceStore):
async def get_span_tree(
self,
span_id: str,
attributes_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
) -> Dict[str, SpanWithStatus]:
attributes_to_return: list[str] | None = None,
max_depth: int | None = None,
) -> dict[str, SpanWithStatus]:
# Build the attributes selection
attributes_select = "s.attributes"
if attributes_to_return:

View file

@ -7,8 +7,9 @@
import asyncio
import inspect
import json
from collections.abc import AsyncGenerator, Callable
from functools import wraps
from typing import Any, AsyncGenerator, Callable, Type, TypeVar
from typing import Any, TypeVar
from pydantic import BaseModel
@ -25,13 +26,13 @@ def _prepare_for_json(value: Any) -> str:
"""Serialize a single value into JSON-compatible format."""
if value is None:
return ""
elif isinstance(value, (str, int, float, bool)):
elif isinstance(value, str | int | float | bool):
return value
elif hasattr(value, "_name_"):
return value._name_
elif isinstance(value, BaseModel):
return json.loads(value.model_dump_json())
elif isinstance(value, (list, tuple, set)):
elif isinstance(value, list | tuple | set):
return [_prepare_for_json(item) for item in value]
elif isinstance(value, dict):
return {str(k): _prepare_for_json(v) for k, v in value.items()}
@ -43,7 +44,7 @@ def _prepare_for_json(value: Any) -> str:
return str(value)
def trace_protocol(cls: Type[T]) -> Type[T]:
def trace_protocol(cls: type[T]) -> type[T]:
"""
A class decorator that automatically traces all methods in a protocol/base class
and its inheriting classes.

View file

@ -10,9 +10,10 @@ import logging
import queue
import random
import threading
from collections.abc import Callable
from datetime import datetime, timezone
from functools import wraps
from typing import Any, Callable, Dict, List, Optional
from typing import Any
from llama_stack.apis.telemetry import (
LogSeverity,
@ -106,13 +107,13 @@ class BackgroundLogger:
class TraceContext:
spans: List[Span] = []
spans: list[Span] = []
def __init__(self, logger: BackgroundLogger, trace_id: str):
self.logger = logger
self.trace_id = trace_id
def push_span(self, name: str, attributes: Dict[str, Any] = None) -> Span:
def push_span(self, name: str, attributes: dict[str, Any] = None) -> Span:
current_span = self.get_current_span()
span = Span(
span_id=generate_span_id(),
@ -168,7 +169,7 @@ def setup_logger(api: Telemetry, level: int = logging.INFO):
root_logger.addHandler(TelemetryHandler())
async def start_trace(name: str, attributes: Dict[str, Any] = None) -> TraceContext:
async def start_trace(name: str, attributes: dict[str, Any] = None) -> TraceContext:
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
if BACKGROUND_LOGGER is None:
@ -246,7 +247,7 @@ class TelemetryHandler(logging.Handler):
class SpanContextManager:
def __init__(self, name: str, attributes: Dict[str, Any] = None):
def __init__(self, name: str, attributes: dict[str, Any] = None):
self.name = name
self.attributes = attributes
self.span = None
@ -316,11 +317,11 @@ class SpanContextManager:
return wrapper
def span(name: str, attributes: Dict[str, Any] = None):
def span(name: str, attributes: dict[str, Any] = None):
return SpanContextManager(name, attributes)
def get_current_span() -> Optional[Span]:
def get_current_span() -> Span | None:
global CURRENT_TRACE_CONTEXT
if CURRENT_TRACE_CONTEXT is None:
logger.debug("No trace context to get current span")