chore: enable pyupgrade fixes (#1806)

# What does this PR do?

The goal of this PR is code base modernization.

Schema reflection code needed a minor adjustment to handle UnionTypes
and collections.abc.AsyncIterator. (Both are preferred for latest Python
releases.)

Note to reviewers: almost all changes here are automatically generated
by pyupgrade. Some additional unused imports were cleaned up. The only
change worth of note can be found under `docs/openapi_generator` and
`llama_stack/strong_typing/schema.py` where reflection code was updated
to deal with "newer" types.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-05-01 17:23:50 -04:00 committed by GitHub
parent ffe3d0b2cd
commit 9e6561a1ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
319 changed files with 2843 additions and 3033 deletions

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from llama_stack.distribution.datatypes import Api
from .config import MetaReferenceAgentsImplConfig
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, Any]):
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any]):
from .agents import MetaReferenceAgentsImpl
impl = MetaReferenceAgentsImpl(

View file

@ -10,8 +10,8 @@ import re
import secrets
import string
import uuid
from collections.abc import AsyncGenerator
from datetime import datetime, timezone
from typing import AsyncGenerator, List, Optional, Union
import httpx
@ -112,7 +112,7 @@ class ChatAgent(ShieldRunnerMixin):
output_shields=agent_config.output_shields,
)
def turn_to_messages(self, turn: Turn) -> List[Message]:
def turn_to_messages(self, turn: Turn) -> list[Message]:
messages = []
# NOTE: if a toolcall response is in a step, we do not add it when processing the input messages
@ -161,7 +161,7 @@ class ChatAgent(ShieldRunnerMixin):
async def create_session(self, name: str) -> str:
return await self.storage.create_session(name)
async def get_messages_from_turns(self, turns: List[Turn]) -> List[Message]:
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
messages = []
if self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions))
@ -201,8 +201,8 @@ class ChatAgent(ShieldRunnerMixin):
async def _run_turn(
self,
request: Union[AgentTurnCreateRequest, AgentTurnResumeRequest],
turn_id: Optional[str] = None,
request: AgentTurnCreateRequest | AgentTurnResumeRequest,
turn_id: str | None = None,
) -> AsyncGenerator:
assert request.stream is True, "Non-streaming not supported"
@ -321,10 +321,10 @@ class ChatAgent(ShieldRunnerMixin):
self,
session_id: str,
turn_id: str,
input_messages: List[Message],
input_messages: list[Message],
sampling_params: SamplingParams,
stream: bool = False,
documents: Optional[List[Document]] = None,
documents: list[Document] | None = None,
) -> AsyncGenerator:
# Doing async generators makes downstream code much simpler and everything amenable to
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
@ -374,8 +374,8 @@ class ChatAgent(ShieldRunnerMixin):
async def run_multiple_shields_wrapper(
self,
turn_id: str,
messages: List[Message],
shields: List[str],
messages: list[Message],
shields: list[str],
touchpoint: str,
) -> AsyncGenerator:
async with tracing.span("run_shields") as span:
@ -443,10 +443,10 @@ class ChatAgent(ShieldRunnerMixin):
self,
session_id: str,
turn_id: str,
input_messages: List[Message],
input_messages: list[Message],
sampling_params: SamplingParams,
stream: bool = False,
documents: Optional[List[Document]] = None,
documents: list[Document] | None = None,
) -> AsyncGenerator:
# if document is passed in a turn, we parse the raw text of the document
# and sent it as a user message
@ -760,7 +760,7 @@ class ChatAgent(ShieldRunnerMixin):
async def _initialize_tools(
self,
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
toolgroups_for_turn: list[AgentToolGroup] | None = None,
) -> None:
toolgroup_to_args = {}
for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []):
@ -847,7 +847,7 @@ class ChatAgent(ShieldRunnerMixin):
tool_name_to_args,
)
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]:
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, str | None]:
"""Parse a toolgroup name into its components.
Args:
@ -921,7 +921,7 @@ async def get_raw_document_text(document: Document) -> str:
def _interpret_content_as_attachment(
content: str,
) -> Optional[Attachment]:
) -> Attachment | None:
match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content)
if match:
snippet = match.group(1)

View file

@ -8,7 +8,7 @@ import json
import logging
import shutil
import uuid
from typing import AsyncGenerator, List, Optional, Union
from collections.abc import AsyncGenerator
from llama_stack.apis.agents import (
Agent,
@ -142,16 +142,11 @@ class MetaReferenceAgentsImpl(Agents):
self,
agent_id: str,
session_id: str,
messages: List[
Union[
UserMessage,
ToolResponseMessage,
]
],
toolgroups: Optional[List[AgentToolGroup]] = None,
documents: Optional[List[Document]] = None,
stream: Optional[bool] = False,
tool_config: Optional[ToolConfig] = None,
messages: list[UserMessage | ToolResponseMessage],
toolgroups: list[AgentToolGroup] | None = None,
documents: list[Document] | None = None,
stream: bool | None = False,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
request = AgentTurnCreateRequest(
agent_id=agent_id,
@ -180,8 +175,8 @@ class MetaReferenceAgentsImpl(Agents):
agent_id: str,
session_id: str,
turn_id: str,
tool_responses: List[ToolResponse],
stream: Optional[bool] = False,
tool_responses: list[ToolResponse],
stream: bool | None = False,
) -> AsyncGenerator:
request = AgentTurnResumeRequest(
agent_id=agent_id,
@ -219,7 +214,7 @@ class MetaReferenceAgentsImpl(Agents):
self,
agent_id: str,
session_id: str,
turn_ids: Optional[List[str]] = None,
turn_ids: list[str] | None = None,
) -> Session:
agent = await self._get_agent_impl(agent_id)
session_info = await agent.storage.get_session_info(session_id)
@ -265,13 +260,13 @@ class MetaReferenceAgentsImpl(Agents):
async def create_openai_response(
self,
input: Union[str, List[OpenAIResponseInputMessage]],
input: str | list[OpenAIResponseInputMessage],
model: str,
previous_response_id: Optional[str] = None,
store: Optional[bool] = True,
stream: Optional[bool] = False,
temperature: Optional[float] = None,
tools: Optional[List[OpenAIResponseInputTool]] = None,
previous_response_id: str | None = None,
store: bool | None = True,
stream: bool | None = False,
temperature: float | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
) -> OpenAIResponseObject:
return await self.openai_responses_impl.create_openai_response(
input, model, previous_response_id, store, stream, temperature, tools

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from pydantic import BaseModel
@ -16,7 +16,7 @@ class MetaReferenceAgentsImplConfig(BaseModel):
persistence_store: KVStoreConfig
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
return {
"persistence_store": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,

View file

@ -6,7 +6,8 @@
import json
import uuid
from typing import AsyncIterator, List, Optional, Union, cast
from collections.abc import AsyncIterator
from typing import cast
from openai.types.chat import ChatCompletionToolParam
@ -49,15 +50,15 @@ logger = get_logger(name=__name__, category="openai_responses")
OPENAI_RESPONSES_PREFIX = "openai_responses:"
async def _previous_response_to_messages(previous_response: OpenAIResponseObject) -> List[OpenAIMessageParam]:
messages: List[OpenAIMessageParam] = []
async def _previous_response_to_messages(previous_response: OpenAIResponseObject) -> list[OpenAIMessageParam]:
messages: list[OpenAIMessageParam] = []
for output_message in previous_response.output:
if isinstance(output_message, OpenAIResponseOutputMessage):
messages.append(OpenAIAssistantMessageParam(content=output_message.content[0].text))
return messages
async def _openai_choices_to_output_messages(choices: List[OpenAIChoice]) -> List[OpenAIResponseOutputMessage]:
async def _openai_choices_to_output_messages(choices: list[OpenAIChoice]) -> list[OpenAIResponseOutputMessage]:
output_messages = []
for choice in choices:
output_content = ""
@ -101,22 +102,22 @@ class OpenAIResponsesImpl:
async def create_openai_response(
self,
input: Union[str, List[OpenAIResponseInputMessage]],
input: str | list[OpenAIResponseInputMessage],
model: str,
previous_response_id: Optional[str] = None,
store: Optional[bool] = True,
stream: Optional[bool] = False,
temperature: Optional[float] = None,
tools: Optional[List[OpenAIResponseInputTool]] = None,
previous_response_id: str | None = None,
store: bool | None = True,
stream: bool | None = False,
temperature: float | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
):
stream = False if stream is None else stream
messages: List[OpenAIMessageParam] = []
messages: list[OpenAIMessageParam] = []
if previous_response_id:
previous_response = await self.get_openai_response(previous_response_id)
messages.extend(await _previous_response_to_messages(previous_response))
# TODO: refactor this user_content parsing out into a separate method
user_content: Union[str, List[OpenAIChatCompletionContentPartParam]] = ""
user_content: str | list[OpenAIChatCompletionContentPartParam] = ""
if isinstance(input, list):
user_content = []
for user_input in input:
@ -179,7 +180,7 @@ class OpenAIResponsesImpl:
# dump and reload to map to our pydantic types
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
output_messages: List[OpenAIResponseOutput] = []
output_messages: list[OpenAIResponseOutput] = []
if chat_response.choices[0].message.tool_calls:
output_messages.extend(
await self._execute_tool_and_return_final_output(model, stream, chat_response, messages, temperature)
@ -215,9 +216,9 @@ class OpenAIResponsesImpl:
return response
async def _convert_response_tools_to_chat_tools(
self, tools: List[OpenAIResponseInputTool]
) -> List[ChatCompletionToolParam]:
chat_tools: List[ChatCompletionToolParam] = []
self, tools: list[OpenAIResponseInputTool]
) -> list[ChatCompletionToolParam]:
chat_tools: list[ChatCompletionToolParam] = []
for input_tool in tools:
# TODO: Handle other tool types
if input_tool.type == "web_search":
@ -247,10 +248,10 @@ class OpenAIResponsesImpl:
model_id: str,
stream: bool,
chat_response: OpenAIChatCompletion,
messages: List[OpenAIMessageParam],
messages: list[OpenAIMessageParam],
temperature: float,
) -> List[OpenAIResponseOutput]:
output_messages: List[OpenAIResponseOutput] = []
) -> list[OpenAIResponseOutput]:
output_messages: list[OpenAIResponseOutput] = []
choice = chat_response.choices[0]
# If the choice is not an assistant message, we don't need to execute any tools
@ -314,7 +315,7 @@ class OpenAIResponsesImpl:
async def _execute_tool_call(
self,
function: OpenAIChatCompletionToolCallFunction,
) -> Optional[ToolInvocationResult]:
) -> ToolInvocationResult | None:
if not function.name:
return None
function_args = json.loads(function.arguments) if function.arguments else {}

View file

@ -8,7 +8,6 @@ import json
import logging
import uuid
from datetime import datetime, timezone
from typing import List, Optional
from pydantic import BaseModel
@ -25,9 +24,9 @@ class AgentSessionInfo(BaseModel):
session_id: str
session_name: str
# TODO: is this used anywhere?
vector_db_id: Optional[str] = None
vector_db_id: str | None = None
started_at: datetime
access_attributes: Optional[AccessAttributes] = None
access_attributes: AccessAttributes | None = None
class AgentPersistence:
@ -55,7 +54,7 @@ class AgentPersistence:
)
return session_id
async def get_session_info(self, session_id: str) -> Optional[AgentSessionInfo]:
async def get_session_info(self, session_id: str) -> AgentSessionInfo | None:
value = await self.kvstore.get(
key=f"session:{self.agent_id}:{session_id}",
)
@ -78,7 +77,7 @@ class AgentPersistence:
return check_access(session_info.session_id, session_info.access_attributes, get_auth_attributes())
async def get_session_if_accessible(self, session_id: str) -> Optional[AgentSessionInfo]:
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
"""Get session info if the user has access to it. For internal use by sub-session methods."""
session_info = await self.get_session_info(session_id)
if not session_info:
@ -106,7 +105,7 @@ class AgentPersistence:
value=turn.model_dump_json(),
)
async def get_session_turns(self, session_id: str) -> List[Turn]:
async def get_session_turns(self, session_id: str) -> list[Turn]:
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
@ -125,7 +124,7 @@ class AgentPersistence:
turns.sort(key=lambda x: (x.completed_at or datetime.min))
return turns
async def get_session_turn(self, session_id: str, turn_id: str) -> Optional[Turn]:
async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None:
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
@ -145,7 +144,7 @@ class AgentPersistence:
value=step.model_dump_json(),
)
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]:
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> ToolExecutionStep | None:
if not await self.get_session_if_accessible(session_id):
return None
@ -163,7 +162,7 @@ class AgentPersistence:
value=str(num_infer_iters),
)
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]:
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> int | None:
if not await self.get_session_if_accessible(session_id):
return None

View file

@ -6,7 +6,6 @@
import asyncio
import logging
from typing import List
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
@ -25,14 +24,14 @@ class ShieldRunnerMixin:
def __init__(
self,
safety_api: Safety,
input_shields: List[str] = None,
output_shields: List[str] = None,
input_shields: list[str] = None,
output_shields: list[str] = None,
):
self.safety_api = safety_api
self.input_shields = input_shields
self.output_shields = output_shields
async def run_multiple_shields(self, messages: List[Message], identifiers: List[str]) -> None:
async def run_multiple_shields(self, messages: list[Message], identifiers: list[str]) -> None:
async def run_shield_with_span(identifier: str):
async with tracing.span(f"run_shield_{identifier}"):
return await self.safety_api.run_shield(

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from .config import LocalFSDatasetIOConfig
async def get_provider_impl(
config: LocalFSDatasetIOConfig,
_deps: Dict[str, Any],
_deps: dict[str, Any],
):
from .datasetio import LocalFSDatasetIOImpl

View file

@ -3,7 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from pydantic import BaseModel
@ -17,7 +17,7 @@ class LocalFSDatasetIOConfig(BaseModel):
kvstore: KVStoreConfig
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,

View file

@ -3,7 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional
from typing import Any
import pandas
@ -92,8 +92,8 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
async def iterrows(
self,
dataset_id: str,
start_index: Optional[int] = None,
limit: Optional[int] = None,
start_index: int | None = None,
limit: int | None = None,
) -> PaginatedResponse:
dataset_def = self.dataset_infos[dataset_id]
dataset_impl = PandasDataframeDataset(dataset_def)
@ -102,7 +102,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
records = dataset_impl.df.to_dict("records")
return paginate_records(records, start_index, limit)
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
dataset_def = self.dataset_infos[dataset_id]
dataset_impl = PandasDataframeDataset(dataset_def)
await dataset_impl.load()

View file

@ -3,7 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from llama_stack.distribution.datatypes import Api
@ -12,7 +12,7 @@ from .config import MetaReferenceEvalConfig
async def get_provider_impl(
config: MetaReferenceEvalConfig,
deps: Dict[Api, Any],
deps: dict[Api, Any],
):
from .eval import MetaReferenceEvalImpl

View file

@ -3,7 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from pydantic import BaseModel
@ -17,7 +17,7 @@ class MetaReferenceEvalConfig(BaseModel):
kvstore: KVStoreConfig
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Any, Dict, List
from typing import Any
from tqdm import tqdm
@ -105,8 +105,8 @@ class MetaReferenceEvalImpl(
return Job(job_id=job_id, status=JobStatus.completed)
async def _run_agent_generation(
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
) -> List[Dict[str, Any]]:
self, input_rows: list[dict[str, Any]], benchmark_config: BenchmarkConfig
) -> list[dict[str, Any]]:
candidate = benchmark_config.eval_candidate
create_response = await self.agents_api.create_agent(candidate.config)
agent_id = create_response.agent_id
@ -148,8 +148,8 @@ class MetaReferenceEvalImpl(
return generations
async def _run_model_generation(
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
) -> List[Dict[str, Any]]:
self, input_rows: list[dict[str, Any]], benchmark_config: BenchmarkConfig
) -> list[dict[str, Any]]:
candidate = benchmark_config.eval_candidate
assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided"
@ -185,8 +185,8 @@ class MetaReferenceEvalImpl(
async def evaluate_rows(
self,
benchmark_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
input_rows: list[dict[str, Any]],
scoring_functions: list[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
candidate = benchmark_config.eval_candidate

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from .config import MetaReferenceInferenceConfig
async def get_provider_impl(
config: MetaReferenceInferenceConfig,
_deps: Dict[str, Any],
_deps: dict[str, Any],
):
from .inference import MetaReferenceInferenceImpl

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel, field_validator
@ -17,11 +17,11 @@ class MetaReferenceInferenceConfig(BaseModel):
# the actual inference model id is dtermined by the moddel id in the request
# Note: you need to register the model before using it for inference
# models in the resouce list in the run.yaml config will be registered automatically
model: Optional[str] = None
torch_seed: Optional[int] = None
model: str | None = None
torch_seed: int | None = None
max_seq_len: int = 4096
max_batch_size: int = 1
model_parallel_size: Optional[int] = None
model_parallel_size: int | None = None
# when this is False, we assume that the distributed process group is setup by someone
# outside of this code (e.g., when run inside `torchrun`). that is useful for clients
@ -30,9 +30,9 @@ class MetaReferenceInferenceConfig(BaseModel):
# By default, the implementation will look at ~/.llama/checkpoints/<model> but you
# can override by specifying the directory explicitly
checkpoint_dir: Optional[str] = None
checkpoint_dir: str | None = None
quantization: Optional[QuantizationConfig] = None
quantization: QuantizationConfig | None = None
@field_validator("model")
@classmethod
@ -55,7 +55,7 @@ class MetaReferenceInferenceConfig(BaseModel):
max_batch_size: str = "${env.MAX_BATCH_SIZE:1}",
max_seq_len: str = "${env.MAX_SEQ_LEN:4096}",
**kwargs,
) -> Dict[str, Any]:
) -> dict[str, Any]:
return {
"model": model,
"checkpoint_dir": checkpoint_dir,

View file

@ -5,7 +5,8 @@
# the root directory of this source tree.
import math
from typing import Generator, List, Optional, Tuple
from collections.abc import Generator
from typing import Optional
import torch
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
@ -39,7 +40,7 @@ Tokenizer = Llama4Tokenizer | Llama3Tokenizer
class LogitsProcessor:
def __init__(self, token_enforcer: TokenEnforcer):
self.token_enforcer = token_enforcer
self.mask: Optional[torch.Tensor] = None
self.mask: torch.Tensor | None = None
def __call__(self, tokens: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
token_sequence = tokens[0, :].tolist()
@ -58,7 +59,7 @@ class LogitsProcessor:
def get_logits_processor(
tokenizer: Tokenizer,
vocab_size: int,
response_format: Optional[ResponseFormat],
response_format: ResponseFormat | None,
) -> Optional["LogitsProcessor"]:
if response_format is None:
return None
@ -76,7 +77,7 @@ def get_logits_processor(
return LogitsProcessor(token_enforcer)
def _build_regular_tokens_list(tokenizer: Tokenizer, vocab_size: int) -> List[Tuple[int, str, bool]]:
def _build_regular_tokens_list(tokenizer: Tokenizer, vocab_size: int) -> list[tuple[int, str, bool]]:
token_0 = tokenizer.encode("0", bos=False, eos=False)[-1]
regular_tokens = []
@ -158,7 +159,7 @@ class LlamaGenerator:
def completion(
self,
request_batch: List[CompletionRequestWithRawContent],
request_batch: list[CompletionRequestWithRawContent],
) -> Generator:
first_request = request_batch[0]
sampling_params = first_request.sampling_params or SamplingParams()
@ -167,7 +168,7 @@ class LlamaGenerator:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
for result in self.inner_generator.generate(
yield from self.inner_generator.generate(
llm_inputs=[self.formatter.encode_content(request.content) for request in request_batch],
max_gen_len=max_gen_len,
temperature=temperature,
@ -179,12 +180,11 @@ class LlamaGenerator:
self.args.vocab_size,
first_request.response_format,
),
):
yield result
)
def chat_completion(
self,
request_batch: List[ChatCompletionRequestWithRawContent],
request_batch: list[ChatCompletionRequestWithRawContent],
) -> Generator:
first_request = request_batch[0]
sampling_params = first_request.sampling_params or SamplingParams()
@ -193,7 +193,7 @@ class LlamaGenerator:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
for result in self.inner_generator.generate(
yield from self.inner_generator.generate(
llm_inputs=[
self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))
for request in request_batch
@ -208,5 +208,4 @@ class LlamaGenerator:
self.args.vocab_size,
first_request.response_format,
),
):
yield result
)

View file

@ -6,7 +6,7 @@
import asyncio
import os
from typing import AsyncGenerator, List, Optional, Union
from collections.abc import AsyncGenerator
from pydantic import BaseModel
from termcolor import cprint
@ -184,11 +184,11 @@ class MetaReferenceInferenceImpl(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> CompletionResponse | CompletionResponseStreamChunk:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
@ -215,11 +215,11 @@ class MetaReferenceInferenceImpl(
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> BatchCompletionResponse:
if sampling_params is None:
sampling_params = SamplingParams()
@ -291,14 +291,14 @@ class MetaReferenceInferenceImpl(
for x in impl():
yield x
async def _nonstream_completion(self, request_batch: List[CompletionRequest]) -> List[CompletionResponse]:
async def _nonstream_completion(self, request_batch: list[CompletionRequest]) -> list[CompletionResponse]:
tokenizer = self.generator.formatter.tokenizer
first_request = request_batch[0]
class ItemState(BaseModel):
tokens: List[int] = []
logprobs: List[TokenLogProbs] = []
tokens: list[int] = []
logprobs: list[TokenLogProbs] = []
stop_reason: StopReason | None = None
finished: bool = False
@ -349,15 +349,15 @@ class MetaReferenceInferenceImpl(
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
@ -395,13 +395,13 @@ class MetaReferenceInferenceImpl(
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> BatchChatCompletionResponse:
if sampling_params is None:
sampling_params = SamplingParams()
@ -436,15 +436,15 @@ class MetaReferenceInferenceImpl(
return BatchChatCompletionResponse(batch=results)
async def _nonstream_chat_completion(
self, request_batch: List[ChatCompletionRequest]
) -> List[ChatCompletionResponse]:
self, request_batch: list[ChatCompletionRequest]
) -> list[ChatCompletionResponse]:
tokenizer = self.generator.formatter.tokenizer
first_request = request_batch[0]
class ItemState(BaseModel):
tokens: List[int] = []
logprobs: List[TokenLogProbs] = []
tokens: list[int] = []
logprobs: list[TokenLogProbs] = []
stop_reason: StopReason | None = None
finished: bool = False

View file

@ -4,9 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Callable, Generator
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Generator, List
from typing import Any
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
@ -82,7 +83,7 @@ class LlamaModelParallelGenerator:
def completion(
self,
request_batch: List[CompletionRequestWithRawContent],
request_batch: list[CompletionRequestWithRawContent],
) -> Generator:
req_obj = deepcopy(request_batch)
gen = self.group.run_inference(("completion", req_obj))
@ -90,7 +91,7 @@ class LlamaModelParallelGenerator:
def chat_completion(
self,
request_batch: List[ChatCompletionRequestWithRawContent],
request_batch: list[ChatCompletionRequestWithRawContent],
) -> Generator:
req_obj = deepcopy(request_batch)
gen = self.group.run_inference(("chat_completion", req_obj))

View file

@ -18,8 +18,9 @@ import os
import tempfile
import time
import uuid
from collections.abc import Callable, Generator
from enum import Enum
from typing import Callable, Generator, List, Literal, Optional, Tuple, Union
from typing import Annotated, Literal
import torch
import zmq
@ -30,7 +31,6 @@ from fairscale.nn.model_parallel.initialize import (
)
from pydantic import BaseModel, Field
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
from typing_extensions import Annotated
from llama_stack.models.llama.datatypes import GenerationResult
from llama_stack.providers.utils.inference.prompt_adapter import (
@ -69,15 +69,15 @@ class CancelSentinel(BaseModel):
class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
task: Tuple[
task: tuple[
str,
List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent],
list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
]
class TaskResponse(BaseModel):
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
result: List[GenerationResult]
result: list[GenerationResult]
class ExceptionResponse(BaseModel):
@ -85,15 +85,9 @@ class ExceptionResponse(BaseModel):
error: str
ProcessingMessage = Union[
ReadyRequest,
ReadyResponse,
EndSentinel,
CancelSentinel,
TaskRequest,
TaskResponse,
ExceptionResponse,
]
ProcessingMessage = (
ReadyRequest | ReadyResponse | EndSentinel | CancelSentinel | TaskRequest | TaskResponse | ExceptionResponse
)
class ProcessingMessageWrapper(BaseModel):
@ -203,7 +197,7 @@ def maybe_get_work(sock: zmq.Socket):
return client_id, message
def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage]:
def maybe_parse_message(maybe_json: str | None) -> ProcessingMessage | None:
if maybe_json is None:
return None
try:
@ -334,9 +328,9 @@ class ModelParallelProcessGroup:
def run_inference(
self,
req: Tuple[
req: tuple[
str,
List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent],
list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
],
) -> Generator:
assert not self.running, "inference already running"

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from llama_stack.providers.inline.inference.sentence_transformers.config import (
SentenceTransformersInferenceConfig,
@ -13,7 +13,7 @@ from llama_stack.providers.inline.inference.sentence_transformers.config import
async def get_provider_impl(
config: SentenceTransformersInferenceConfig,
_deps: Dict[str, Any],
_deps: dict[str, Any],
):
from .sentence_transformers import SentenceTransformersInferenceImpl

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from pydantic import BaseModel
class SentenceTransformersInferenceConfig(BaseModel):
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {}

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import logging
from typing import AsyncGenerator, List, Optional, Union
from collections.abc import AsyncGenerator
from llama_stack.apis.inference import (
CompletionResponse,
@ -60,46 +60,46 @@ class SentenceTransformersInferenceImpl(
self,
model_id: str,
content: str,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncGenerator]:
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> CompletionResponse | AsyncGenerator:
raise ValueError("Sentence transformers don't support completion")
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
raise ValueError("Sentence transformers don't support chat completion")
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch completion is not supported for Sentence Transformers")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_config: ToolConfig | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch chat completion is not supported for Sentence Transformers")

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from .config import VLLMConfig
async def get_provider_impl(config: VLLMConfig, _deps: Dict[str, Any]):
async def get_provider_impl(config: VLLMConfig, _deps: dict[str, Any]):
from .vllm import VLLMInferenceImpl
impl = VLLMInferenceImpl(config)

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from pydantic import BaseModel, Field
@ -42,7 +42,7 @@ class VLLMConfig(BaseModel):
)
@classmethod
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
return {
"tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:1}",
"max_tokens": "${env.MAX_TOKENS:4096}",

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional
import vllm
@ -55,8 +54,8 @@ def _merge_context_into_content(message: Message) -> Message: # type: ignore
def _llama_stack_tools_to_openai_tools(
tools: Optional[List[ToolDefinition]] = None,
) -> List[vllm.entrypoints.openai.protocol.ChatCompletionToolsParam]:
tools: list[ToolDefinition] | None = None,
) -> list[vllm.entrypoints.openai.protocol.ChatCompletionToolsParam]:
"""
Convert the list of available tools from Llama Stack's format to vLLM's
version of OpenAI's format.

View file

@ -7,7 +7,7 @@
import json
import re
import uuid
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from collections.abc import AsyncGenerator, AsyncIterator
# These vLLM modules contain names that overlap with Llama Stack names, so we import
# fully-qualified names
@ -100,7 +100,7 @@ def _random_uuid_str() -> str:
def _response_format_to_guided_decoding_params(
response_format: Optional[ResponseFormat], # type: ignore
response_format: ResponseFormat | None, # type: ignore
) -> vllm.sampling_params.GuidedDecodingParams:
"""
Translate constrained decoding parameters from Llama Stack's format to vLLM's format.
@ -131,9 +131,9 @@ def _response_format_to_guided_decoding_params(
def _convert_sampling_params(
sampling_params: Optional[SamplingParams],
response_format: Optional[ResponseFormat], # type: ignore
log_prob_config: Optional[LogProbConfig],
sampling_params: SamplingParams | None,
response_format: ResponseFormat | None, # type: ignore
log_prob_config: LogProbConfig | None,
) -> vllm.SamplingParams:
"""Convert sampling and constrained decoding configuration from Llama Stack's format to vLLM's
format."""
@ -370,11 +370,11 @@ class VLLMInferenceImpl(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
if model_id not in self.model_ids:
raise ValueError(
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
@ -403,25 +403,25 @@ class VLLMInferenceImpl(
async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
raise NotImplementedError()
async def chat_completion(
self,
model_id: str,
messages: List[Message], # type: ignore
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, # type: ignore
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
messages: list[Message], # type: ignore
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None, # type: ignore
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
sampling_params = sampling_params or SamplingParams()
if model_id not in self.model_ids:
@ -605,7 +605,7 @@ class VLLMInferenceImpl(
async def _chat_completion_for_meta_llama(
self, request: ChatCompletionRequest
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
"""
Subroutine that routes chat completions for Meta Llama models through Llama Stack's
chat template instead of using vLLM's version of that template. The Llama Stack version
@ -701,7 +701,7 @@ class VLLMInferenceImpl(
# Tool calls come in pieces, but Llama Stack expects them in bigger chunks. We build up
# those chunks and output them at the end.
# This data structure holds the current set of partial tool calls.
index_to_tool_call: Dict[int, Dict] = dict()
index_to_tool_call: dict[int, dict] = dict()
# The Llama Stack event stream must always start with a start event. Use an empty one to
# simplify logic below

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from llama_stack.distribution.datatypes import Api
@ -15,7 +15,7 @@ from .config import TorchtunePostTrainingConfig
async def get_provider_impl(
config: TorchtunePostTrainingConfig,
deps: Dict[Api, Any],
deps: dict[Api, Any],
):
from .post_training import TorchtunePostTrainingImpl

View file

@ -8,7 +8,7 @@ import json
import os
import shutil
from pathlib import Path
from typing import Any, Dict, List
from typing import Any
import torch
from safetensors.torch import save_file
@ -34,7 +34,7 @@ class TorchtuneCheckpointer:
model_id: str,
training_algorithm: str,
checkpoint_dir: str,
checkpoint_files: List[str],
checkpoint_files: list[str],
output_dir: str,
model_type: str,
):
@ -54,11 +54,11 @@ class TorchtuneCheckpointer:
# get ckpt paths
self._checkpoint_path = Path.joinpath(self._checkpoint_dir, self._checkpoint_file)
def load_checkpoint(self) -> Dict[str, Any]:
def load_checkpoint(self) -> dict[str, Any]:
"""
Load Meta checkpoint from file. Currently only loading from a single file is supported.
"""
state_dict: Dict[str, Any] = {}
state_dict: dict[str, Any] = {}
model_state_dict = safe_torch_load(self._checkpoint_path)
if self._model_type == ModelType.LLAMA3_VISION:
from torchtune.models.llama3_2_vision._convert_weights import (
@ -82,7 +82,7 @@ class TorchtuneCheckpointer:
def save_checkpoint(
self,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
epoch: int,
adapter_only: bool = False,
checkpoint_format: str | None = None,
@ -100,7 +100,7 @@ class TorchtuneCheckpointer:
def _save_meta_format_checkpoint(
self,
model_file_path: Path,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
adapter_only: bool = False,
) -> None:
model_file_path.mkdir(parents=True, exist_ok=True)
@ -168,7 +168,7 @@ class TorchtuneCheckpointer:
def _save_hf_format_checkpoint(
self,
model_file_path: Path,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
) -> None:
# the config.json file contains model params needed for state dict conversion
config = json.loads(Path.joinpath(self._checkpoint_dir.parent, "config.json").read_text())
@ -179,7 +179,7 @@ class TorchtuneCheckpointer:
repo_id_path = Path.joinpath(self._checkpoint_dir.parent, REPO_ID_FNAME).with_suffix(".json")
self.repo_id = None
if repo_id_path.exists():
with open(repo_id_path, "r") as json_file:
with open(repo_id_path) as json_file:
data = json.load(json_file)
self.repo_id = data.get("repo_id")

View file

@ -10,7 +10,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Callable, Dict
from collections.abc import Callable
import torch
from pydantic import BaseModel
@ -35,7 +35,7 @@ class ModelConfig(BaseModel):
checkpoint_type: str
MODEL_CONFIGS: Dict[str, ModelConfig] = {
MODEL_CONFIGS: dict[str, ModelConfig] = {
"Llama3.2-3B-Instruct": ModelConfig(
model_definition=lora_llama3_2_3b,
tokenizer_type=llama3_tokenizer,
@ -48,7 +48,7 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = {
),
}
DATA_FORMATS: Dict[str, Transform] = {
DATA_FORMATS: dict[str, Transform] = {
"instruct": InputOutputToMessages,
"dialog": ShareGPTToMessages,
}

View file

@ -4,17 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Literal, Optional
from typing import Any, Literal
from pydantic import BaseModel
class TorchtunePostTrainingConfig(BaseModel):
torch_seed: Optional[int] = None
checkpoint_format: Optional[Literal["meta", "huggingface"]] = "meta"
torch_seed: int | None = None
checkpoint_format: Literal["meta", "huggingface"] | None = "meta"
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {
"checkpoint_format": "meta",
}

View file

@ -11,7 +11,8 @@
# LICENSE file in the root directory of this source tree.
import json
from typing import Any, Mapping
from collections.abc import Mapping
from typing import Any
from llama_stack.providers.utils.common.data_schema_validator import ColumnName

View file

@ -10,7 +10,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, List, Mapping
from collections.abc import Mapping
from typing import Any
import numpy as np
from torch.utils.data import Dataset
@ -27,7 +28,7 @@ from llama_stack.providers.inline.post_training.torchtune.datasets.format_adapte
class SFTDataset(Dataset):
def __init__(
self,
rows: List[Dict[str, Any]],
rows: list[dict[str, Any]],
message_transform: Transform,
model_transform: Transform,
dataset_type: str,
@ -40,11 +41,11 @@ class SFTDataset(Dataset):
def __len__(self):
return len(self._rows)
def __getitem__(self, index: int) -> Dict[str, Any]:
def __getitem__(self, index: int) -> dict[str, Any]:
sample = self._rows[index]
return self._prepare_sample(sample)
def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]:
def _prepare_sample(self, sample: Mapping[str, Any]) -> dict[str, Any]:
if self._dataset_type == "instruct":
sample = llama_stack_instruct_to_torchtune_instruct(sample)
elif self._dataset_type == "dialog":

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, Optional
from typing import Any
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
@ -64,7 +64,7 @@ class TorchtunePostTrainingImpl:
)
@staticmethod
def _resources_stats_to_artifact(resources_stats: Dict[str, Any]) -> JobArtifact:
def _resources_stats_to_artifact(resources_stats: dict[str, Any]) -> JobArtifact:
return JobArtifact(
type=TrainingArtifactType.RESOURCES_STATS.value,
name=TrainingArtifactType.RESOURCES_STATS.value,
@ -75,11 +75,11 @@ class TorchtunePostTrainingImpl:
self,
job_uuid: str,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
model: str,
checkpoint_dir: Optional[str],
algorithm_config: Optional[AlgorithmConfig],
checkpoint_dir: str | None,
algorithm_config: AlgorithmConfig | None,
) -> PostTrainingJob:
if isinstance(algorithm_config, LoraFinetuningConfig):
@ -121,8 +121,8 @@ class TorchtunePostTrainingImpl:
finetuned_model: str,
algorithm_config: DPOAlignmentConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
) -> PostTrainingJob: ...
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
@ -144,7 +144,7 @@ class TorchtunePostTrainingImpl:
return data[0] if data else None
@webmethod(route="/post-training/job/status")
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None:
job = self._scheduler.get_job(job_uuid)
match job.status:
@ -175,6 +175,6 @@ class TorchtunePostTrainingImpl:
self._scheduler.cancel(job_uuid)
@webmethod(route="/post-training/job/artifacts")
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None:
job = self._scheduler.get_job(job_uuid)
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))

View file

@ -11,7 +11,7 @@ import time
from datetime import datetime, timezone
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from typing import Any
import torch
from torch import nn
@ -80,10 +80,10 @@ class LoraFinetuningSingleDevice:
config: TorchtunePostTrainingConfig,
job_uuid: str,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
model: str,
checkpoint_dir: Optional[str],
checkpoint_dir: str | None,
algorithm_config: LoraFinetuningConfig | QATFinetuningConfig | None,
datasetio_api: DatasetIO,
datasets_api: Datasets,
@ -156,7 +156,7 @@ class LoraFinetuningSingleDevice:
self.datasets_api = datasets_api
async def load_checkpoint(self):
def get_checkpoint_files(checkpoint_dir: str) -> List[str]:
def get_checkpoint_files(checkpoint_dir: str) -> list[str]:
try:
# List all files in the given directory
files = os.listdir(checkpoint_dir)
@ -250,8 +250,8 @@ class LoraFinetuningSingleDevice:
self,
enable_activation_checkpointing: bool,
enable_activation_offloading: bool,
base_model_state_dict: Dict[str, Any],
lora_weights_state_dict: Optional[Dict[str, Any]] = None,
base_model_state_dict: dict[str, Any],
lora_weights_state_dict: dict[str, Any] | None = None,
) -> nn.Module:
self._lora_rank = self.algorithm_config.rank
self._lora_alpha = self.algorithm_config.alpha
@ -335,7 +335,7 @@ class LoraFinetuningSingleDevice:
tokenizer: Llama3Tokenizer,
shuffle: bool,
batch_size: int,
) -> Tuple[DistributedSampler, DataLoader]:
) -> tuple[DistributedSampler, DataLoader]:
async def fetch_rows(dataset_id: str):
return await self.datasetio_api.iterrows(
dataset_id=dataset_id,
@ -430,7 +430,7 @@ class LoraFinetuningSingleDevice:
checkpoint_format=self._checkpoint_format,
)
async def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
async def _loss_step(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
# Shape [b, s], needed for the loss not the model
labels = batch.pop("labels")
# run model
@ -452,7 +452,7 @@ class LoraFinetuningSingleDevice:
return loss
async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]:
async def train(self) -> tuple[dict[str, Any], list[Checkpoint]]:
"""
The core training loop.
"""
@ -464,7 +464,7 @@ class LoraFinetuningSingleDevice:
# training artifacts
checkpoints = []
memory_stats: Dict[str, Any] = {}
memory_stats: dict[str, Any] = {}
# self.epochs_run should be non-zero when we're resuming from a checkpoint
for curr_epoch in range(self.epochs_run, self.total_epochs):
@ -565,7 +565,7 @@ class LoraFinetuningSingleDevice:
return (memory_stats, checkpoints)
async def validation(self) -> Tuple[float, float]:
async def validation(self) -> tuple[float, float]:
total_loss = 0.0
total_tokens = 0
log.info("Starting validation...")

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from .config import CodeScannerConfig
async def get_provider_impl(config: CodeScannerConfig, deps: Dict[str, Any]):
async def get_provider_impl(config: CodeScannerConfig, deps: dict[str, Any]):
from .code_scanner import MetaReferenceCodeScannerSafetyImpl
impl = MetaReferenceCodeScannerSafetyImpl(config, deps)

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import logging
from typing import Any, Dict, List
from typing import Any
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import (
@ -48,8 +48,8 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
async def run_shield(
self,
shield_id: str,
messages: List[Message],
params: Dict[str, Any] = None,
messages: list[Message],
params: dict[str, Any] = None,
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
if not shield:

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from pydantic import BaseModel
class CodeScannerConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {}

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from .config import LlamaGuardConfig
async def get_provider_impl(config: LlamaGuardConfig, deps: Dict[str, Any]):
async def get_provider_impl(config: LlamaGuardConfig, deps: dict[str, Any]):
from .llama_guard import LlamaGuardSafetyImpl
assert isinstance(config, LlamaGuardConfig), f"Unexpected config type: {type(config)}"

View file

@ -4,16 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
from typing import Any
from pydantic import BaseModel
class LlamaGuardConfig(BaseModel):
excluded_categories: List[str] = []
excluded_categories: list[str] = []
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {
"excluded_categories": [],
}

View file

@ -6,7 +6,7 @@
import re
from string import Template
from typing import Any, Dict, List, Optional
from typing import Any
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
from llama_stack.apis.inference import (
@ -149,8 +149,8 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
async def run_shield(
self,
shield_id: str,
messages: List[Message],
params: Dict[str, Any] = None,
messages: list[Message],
params: dict[str, Any] = None,
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
if not shield:
@ -177,7 +177,7 @@ class LlamaGuardShield:
self,
model: str,
inference_api: Inference,
excluded_categories: Optional[List[str]] = None,
excluded_categories: list[str] | None = None,
):
if excluded_categories is None:
excluded_categories = []
@ -193,7 +193,7 @@ class LlamaGuardShield:
self.inference_api = inference_api
self.excluded_categories = excluded_categories
def check_unsafe_response(self, response: str) -> Optional[str]:
def check_unsafe_response(self, response: str) -> str | None:
match = re.match(r"^unsafe\n(.*)$", response)
if match:
# extracts the unsafe code
@ -202,7 +202,7 @@ class LlamaGuardShield:
return None
def get_safety_categories(self) -> List[str]:
def get_safety_categories(self) -> list[str]:
excluded_categories = self.excluded_categories
if set(excluded_categories) == set(SAFETY_CATEGORIES_TO_CODE_MAP.values()):
excluded_categories = []
@ -218,7 +218,7 @@ class LlamaGuardShield:
return final_categories
def validate_messages(self, messages: List[Message]) -> None:
def validate_messages(self, messages: list[Message]) -> None:
if len(messages) == 0:
raise ValueError("Messages must not be empty")
if messages[0].role != Role.user.value:
@ -229,7 +229,7 @@ class LlamaGuardShield:
return messages
async def run(self, messages: List[Message]) -> RunShieldResponse:
async def run(self, messages: list[Message]) -> RunShieldResponse:
messages = self.validate_messages(messages)
if self.model == CoreModelId.llama_guard_3_11b_vision.value:
@ -247,10 +247,10 @@ class LlamaGuardShield:
content = content.strip()
return self.get_shield_response(content)
def build_text_shield_input(self, messages: List[Message]) -> UserMessage:
def build_text_shield_input(self, messages: list[Message]) -> UserMessage:
return UserMessage(content=self.build_prompt(messages))
def build_vision_shield_input(self, messages: List[Message]) -> UserMessage:
def build_vision_shield_input(self, messages: list[Message]) -> UserMessage:
conversation = []
most_recent_img = None
@ -284,7 +284,7 @@ class LlamaGuardShield:
return UserMessage(content=prompt)
def build_prompt(self, messages: List[Message]) -> str:
def build_prompt(self, messages: list[Message]) -> str:
categories = self.get_safety_categories()
categories_str = "\n".join(categories)
conversations_str = "\n\n".join(

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from .config import PromptGuardConfig # noqa: F401
from .config import PromptGuardConfig
async def get_provider_impl(config: PromptGuardConfig, deps: Dict[str, Any]):
async def get_provider_impl(config: PromptGuardConfig, deps: dict[str, Any]):
from .prompt_guard import PromptGuardSafetyImpl
impl = PromptGuardSafetyImpl(config, deps)

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict
from typing import Any
from pydantic import BaseModel, field_validator
@ -26,7 +26,7 @@ class PromptGuardConfig(BaseModel):
return v
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {
"guard_type": "injection",
}

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import logging
from typing import Any, Dict, List
from typing import Any
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
@ -49,8 +49,8 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
async def run_shield(
self,
shield_id: str,
messages: List[Message],
params: Dict[str, Any] = None,
messages: list[Message],
params: dict[str, Any] = None,
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
if not shield:
@ -81,7 +81,7 @@ class PromptGuardShield:
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForSequenceClassification.from_pretrained(model_dir, device_map=self.device)
async def run(self, messages: List[Message]) -> RunShieldResponse:
async def run(self, messages: list[Message]) -> RunShieldResponse:
message = messages[-1]
text = interleaved_content_as_str(message.content)

View file

@ -3,7 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from llama_stack.distribution.datatypes import Api
@ -12,7 +12,7 @@ from .config import BasicScoringConfig
async def get_provider_impl(
config: BasicScoringConfig,
deps: Dict[Api, Any],
deps: dict[Api, Any],
):
from .scoring import BasicScoringImpl

View file

@ -3,12 +3,12 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from pydantic import BaseModel
class BasicScoringConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {}

View file

@ -3,7 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional
from typing import Any
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
@ -66,7 +66,7 @@ class BasicScoringImpl(
async def shutdown(self) -> None: ...
async def list_scoring_functions(self) -> List[ScoringFn]:
async def list_scoring_functions(self) -> list[ScoringFn]:
scoring_fn_defs_list = [
fn_def for impl in self.scoring_fn_id_impls.values() for fn_def in impl.get_supported_scoring_fn_defs()
]
@ -82,7 +82,7 @@ class BasicScoringImpl(
async def score_batch(
self,
dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
scoring_functions: dict[str, ScoringFnParams | None] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
@ -107,8 +107,8 @@ class BasicScoringImpl(
async def score(
self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
input_rows: list[dict[str, Any]],
scoring_functions: dict[str, ScoringFnParams | None] = None,
) -> ScoreResponse:
res = {}
for scoring_fn_id in scoring_functions.keys():

View file

@ -6,7 +6,7 @@
import json
import re
from typing import Any, Dict, Optional
from typing import Any
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
@ -17,7 +17,7 @@ from ..utils.bfcl.checker import ast_checker, is_empty_output
from .fn_defs.bfcl import bfcl
def postprocess(x: Dict[str, Any], test_category: str) -> Dict[str, Any]:
def postprocess(x: dict[str, Any], test_category: str) -> dict[str, Any]:
contain_func_call = False
error = None
error_type = None
@ -52,11 +52,11 @@ def postprocess(x: Dict[str, Any], test_category: str) -> Dict[str, Any]:
}
def gen_valid(x: Dict[str, Any]) -> Dict[str, float]:
def gen_valid(x: dict[str, Any]) -> dict[str, float]:
return {"valid": x["valid"]}
def gen_relevance_acc(x: Dict[str, Any]) -> Dict[str, float]:
def gen_relevance_acc(x: dict[str, Any]) -> dict[str, float]:
# This function serves for both relevance and irrelevance tests, which share the exact opposite logic.
# If `test_category` is "irrelevance", the model is expected to output no function call.
# No function call means either the AST decoding fails (a error message is generated) or the decoded AST does not contain any function call (such as a empty list, `[]`).
@ -78,9 +78,9 @@ class BFCLScoringFn(RegisteredBaseScoringFn):
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = "bfcl",
scoring_params: Optional[ScoringFnParams] = None,
input_row: dict[str, Any],
scoring_fn_identifier: str | None = "bfcl",
scoring_params: ScoringFnParams | None = None,
) -> ScoringResultRow:
test_category = re.sub(r"_[0-9_-]+$", "", input_row["id"])
score_result = postprocess(input_row, test_category)

View file

@ -6,7 +6,7 @@
import json
import re
from typing import Any, Dict, Optional
from typing import Any
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
@ -228,9 +228,9 @@ class DocVQAScoringFn(RegisteredBaseScoringFn):
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = "docvqa",
scoring_params: Optional[ScoringFnParams] = None,
input_row: dict[str, Any],
scoring_fn_identifier: str | None = "docvqa",
scoring_params: ScoringFnParams | None = None,
) -> ScoringResultRow:
expected_answers = json.loads(input_row["expected_answer"])
generated_answer = input_row["generated_answer"]

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from typing import Any
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
@ -26,9 +26,9 @@ class EqualityScoringFn(RegisteredBaseScoringFn):
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = "equality",
scoring_params: Optional[ScoringFnParams] = None,
input_row: dict[str, Any],
scoring_fn_identifier: str | None = "equality",
scoring_params: ScoringFnParams | None = None,
) -> ScoringResultRow:
assert "expected_answer" in input_row, "Expected answer not found in input row."
assert "generated_answer" in input_row, "Generated answer not found in input row."

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from typing import Any
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
@ -28,9 +28,9 @@ class IfEvalScoringFn(RegisteredBaseScoringFn):
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
input_row: dict[str, Any],
scoring_fn_identifier: str | None = None,
scoring_params: ScoringFnParams | None = None,
) -> ScoringResultRow:
from ..utils.ifeval_utils import INSTRUCTION_DICT, INSTRUCTION_LIST

View file

@ -3,7 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from typing import Any
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
@ -28,9 +28,9 @@ class RegexParserMathResponseScoringFn(RegisteredBaseScoringFn):
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
input_row: dict[str, Any],
scoring_fn_identifier: str | None = None,
scoring_params: ScoringFnParams | None = None,
) -> ScoringResultRow:
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
from typing import Any, Dict, Optional
from typing import Any
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
@ -28,9 +28,9 @@ class RegexParserScoringFn(RegisteredBaseScoringFn):
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
input_row: dict[str, Any],
scoring_fn_identifier: str | None = None,
scoring_params: ScoringFnParams | None = None,
) -> ScoringResultRow:
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from typing import Any
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
@ -26,9 +26,9 @@ class SubsetOfScoringFn(RegisteredBaseScoringFn):
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = "subset_of",
scoring_params: Optional[ScoringFnParams] = None,
input_row: dict[str, Any],
scoring_fn_identifier: str | None = "subset_of",
scoring_params: ScoringFnParams | None = None,
) -> ScoringResultRow:
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]

View file

@ -11,8 +11,8 @@ import logging
import random
import re
import string
from collections.abc import Iterable, Sequence
from types import MappingProxyType
from typing import Dict, Iterable, List, Optional, Sequence, Union
import emoji
import langdetect
@ -1673,12 +1673,11 @@ def split_chinese_japanese_hindi(lines: str) -> Iterable[str]:
The separator for hindi is ''
"""
for line in lines.splitlines():
for sent in re.findall(
yield from re.findall(
r"[^!?。\.\!\?\\\\n।]+[!?。\.\!\?\\\\n।]?",
line.strip(),
flags=re.U,
):
yield sent
)
def count_words_cjk(text: str) -> int:
@ -1707,7 +1706,7 @@ def count_words_cjk(text: str) -> int:
return non_asian_words_cnt + asian_chars_cnt + emoji_cnt
@functools.lru_cache(maxsize=None)
@functools.cache
def _get_sentence_tokenizer():
return nltk.data.load("nltk:tokenizers/punkt/english.pickle")
@ -1719,8 +1718,8 @@ def count_sentences(text):
return len(tokenized_sentences)
def get_langid(text: str, lid_path: Optional[str] = None) -> str:
line_langs: List[str] = []
def get_langid(text: str, lid_path: str | None = None) -> str:
line_langs: list[str] = []
lines = [line.strip() for line in text.split("\n") if len(line.strip()) >= 4]
for line in lines:
@ -1741,7 +1740,7 @@ def generate_keywords(num_keywords):
"""Library of instructions"""
_InstructionArgsDtype = Optional[Dict[str, Union[int, str, Sequence[str]]]]
_InstructionArgsDtype = dict[str, int | str | Sequence[str]] | None
_LANGUAGES = LANGUAGE_CODES

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import re
from typing import Sequence
from collections.abc import Sequence
from llama_stack.providers.utils.scoring.basic_scoring_utils import time_limit
@ -323,7 +323,7 @@ def _fix_a_slash_b(string: str) -> str:
try:
ia = int(a)
ib = int(b)
assert string == "{}/{}".format(ia, ib)
assert string == f"{ia}/{ib}"
new_string = "\\frac{" + str(ia) + "}{" + str(ib) + "}"
return new_string
except (ValueError, AssertionError):

View file

@ -3,7 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from pydantic import BaseModel
@ -18,7 +18,7 @@ class BraintrustProviderDataValidator(BaseModel):
async def get_provider_impl(
config: BraintrustScoringConfig,
deps: Dict[Api, Any],
deps: dict[Api, Any],
):
from .braintrust import BraintrustScoringImpl

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any, Dict, List, Optional
from typing import Any
from autoevals.llm import Factuality
from autoevals.ragas import (
@ -132,7 +132,7 @@ class BraintrustScoringImpl(
async def shutdown(self) -> None: ...
async def list_scoring_functions(self) -> List[ScoringFn]:
async def list_scoring_functions(self) -> list[ScoringFn]:
scoring_fn_defs_list = list(self.supported_fn_defs_registry.values())
for f in scoring_fn_defs_list:
assert f.identifier.startswith("braintrust"), (
@ -159,7 +159,7 @@ class BraintrustScoringImpl(
async def score_batch(
self,
dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]],
scoring_functions: dict[str, ScoringFnParams | None],
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
await self.set_api_key()
@ -181,9 +181,7 @@ class BraintrustScoringImpl(
results=res.results,
)
async def score_row(
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
) -> ScoringResultRow:
async def score_row(self, input_row: dict[str, Any], scoring_fn_identifier: str | None = None) -> ScoringResultRow:
validate_row_schema(input_row, get_valid_schemas(Api.scoring.value))
await self.set_api_key()
assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None"
@ -203,8 +201,8 @@ class BraintrustScoringImpl(
async def score(
self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]],
input_rows: list[dict[str, Any]],
scoring_functions: dict[str, ScoringFnParams | None],
) -> ScoreResponse:
await self.set_api_key()
res = {}

View file

@ -3,19 +3,19 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel, Field
class BraintrustScoringConfig(BaseModel):
openai_api_key: Optional[str] = Field(
openai_api_key: str | None = Field(
default=None,
description="The OpenAI API Key",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"openai_api_key": "${env.OPENAI_API_KEY:}",
}

View file

@ -3,7 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from llama_stack.distribution.datatypes import Api
@ -12,7 +12,7 @@ from .config import LlmAsJudgeScoringConfig
async def get_provider_impl(
config: LlmAsJudgeScoringConfig,
deps: Dict[Api, Any],
deps: dict[Api, Any],
):
from .scoring import LlmAsJudgeScoringImpl

View file

@ -3,12 +3,12 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from pydantic import BaseModel
class LlmAsJudgeScoringConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {}

View file

@ -3,7 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional
from typing import Any
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
@ -50,7 +50,7 @@ class LlmAsJudgeScoringImpl(
async def shutdown(self) -> None: ...
async def list_scoring_functions(self) -> List[ScoringFn]:
async def list_scoring_functions(self) -> list[ScoringFn]:
scoring_fn_defs_list = self.llm_as_judge_fn.get_supported_scoring_fn_defs()
for f in self.llm_as_judge_fn.get_supported_scoring_fn_defs():
@ -66,7 +66,7 @@ class LlmAsJudgeScoringImpl(
async def score_batch(
self,
dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
scoring_functions: dict[str, ScoringFnParams | None] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
@ -91,8 +91,8 @@ class LlmAsJudgeScoringImpl(
async def score(
self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
input_rows: list[dict[str, Any]],
scoring_functions: dict[str, ScoringFnParams | None] = None,
) -> ScoreResponse:
res = {}
for scoring_fn_id in scoring_functions.keys():

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
from typing import Any, Dict, Optional
from typing import Any
from llama_stack.apis.inference.inference import Inference, UserMessage
from llama_stack.apis.scoring import ScoringResultRow
@ -30,9 +30,9 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
input_row: dict[str, Any],
scoring_fn_identifier: str | None = None,
scoring_params: ScoringFnParams | None = None,
) -> ScoringResultRow:
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from llama_stack.distribution.datatypes import Api
@ -13,7 +13,7 @@ from .config import TelemetryConfig, TelemetrySink
__all__ = ["TelemetryConfig", "TelemetrySink"]
async def get_provider_impl(config: TelemetryConfig, deps: Dict[Api, Any]):
async def get_provider_impl(config: TelemetryConfig, deps: dict[Api, Any]):
from .telemetry import TelemetryAdapter
impl = TelemetryAdapter(config, deps)

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List
from typing import Any
from pydantic import BaseModel, Field, field_validator
@ -33,7 +33,7 @@ class TelemetryConfig(BaseModel):
default="",
description="The service name to use for telemetry",
)
sinks: List[TelemetrySink] = Field(
sinks: list[TelemetrySink] = Field(
default=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE],
description="List of telemetry sinks to enable (possible values: otel, sqlite, console)",
)
@ -50,7 +50,7 @@ class TelemetryConfig(BaseModel):
return v
@classmethod
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> Dict[str, Any]:
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> dict[str, Any]:
return {
"service_name": "${env.OTEL_SERVICE_NAME:}",
"sinks": "${env.TELEMETRY_SINKS:console,sqlite}",

View file

@ -78,7 +78,7 @@ class ConsoleSpanProcessor(SpanProcessor):
severity = event.attributes.get("severity", "info")
message = event.attributes.get("message", event.name)
if isinstance(message, (dict, list)):
if isinstance(message, dict | list):
message = json.dumps(message, indent=2)
severity_colors = {

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import threading
from typing import Any, Dict, List, Optional
from typing import Any
from opentelemetry import metrics, trace
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
@ -60,7 +60,7 @@ def is_tracing_enabled(tracer):
class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
def __init__(self, config: TelemetryConfig, deps: Dict[Api, Any]) -> None:
def __init__(self, config: TelemetryConfig, deps: dict[Api, Any]) -> None:
self.config = config
self.datasetio_api = deps.get(Api.datasetio)
self.meter = None
@ -231,10 +231,10 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
async def query_traces(
self,
attribute_filters: Optional[List[QueryCondition]] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
attribute_filters: list[QueryCondition] | None = None,
limit: int | None = 100,
offset: int | None = 0,
order_by: list[str] | None = None,
) -> QueryTracesResponse:
return QueryTracesResponse(
data=await self.trace_store.query_traces(
@ -254,8 +254,8 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
async def get_span_tree(
self,
span_id: str,
attributes_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
attributes_to_return: list[str] | None = None,
max_depth: int | None = None,
) -> QuerySpanTreeResponse:
return QuerySpanTreeResponse(
data=await self.trace_store.get_span_tree(

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from .config import CodeInterpreterToolConfig
__all__ = ["CodeInterpreterToolConfig", "CodeInterpreterToolRuntimeImpl"]
async def get_provider_impl(config: CodeInterpreterToolConfig, _deps: Dict[str, Any]):
async def get_provider_impl(config: CodeInterpreterToolConfig, _deps: dict[str, Any]):
from .code_interpreter import CodeInterpreterToolRuntimeImpl
impl = CodeInterpreterToolRuntimeImpl(config)

View file

@ -18,7 +18,6 @@ from dataclasses import dataclass
from datetime import datetime
from io import BytesIO
from pathlib import Path
from typing import List
from PIL import Image
@ -45,7 +44,7 @@ except:
"""
def generate_bwrap_command(bind_dirs: List[str]) -> str:
def generate_bwrap_command(bind_dirs: list[str]) -> str:
"""
Generate the bwrap command string for binding all
directories in the current directory read-only.
@ -71,7 +70,7 @@ class CodeExecutionContext:
@dataclass
class CodeExecutionRequest:
scripts: List[str]
scripts: list[str]
only_last_cell_stdouterr: bool = True
only_last_cell_fail: bool = True
seed: int = 0

View file

@ -9,7 +9,7 @@ import asyncio
import logging
import os
import tempfile
from typing import Any, Dict, Optional
from typing import Any
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
@ -46,7 +46,7 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
return
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse:
return ListToolDefsResponse(
data=[
@ -64,7 +64,7 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
]
)
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
script = kwargs["code"]
# Use environment variable to control bwrap usage
force_disable_bwrap = os.environ.get("DISABLE_CODE_SANDBOX", "").lower() in ("1", "true", "yes")

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from pydantic import BaseModel
class CodeInterpreterToolConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {}

View file

@ -15,7 +15,7 @@ def get_code_env_prefix() -> str:
global CODE_ENV_PREFIX
if CODE_ENV_PREFIX is None:
with open(CODE_ENV_PREFIX_FILE, "r") as f:
with open(CODE_ENV_PREFIX_FILE) as f:
CODE_ENV_PREFIX = f.read()
return CODE_ENV_PREFIX

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from llama_stack.providers.datatypes import Api
from .config import RagToolRuntimeConfig
async def get_provider_impl(config: RagToolRuntimeConfig, deps: Dict[Api, Any]):
async def get_provider_impl(config: RagToolRuntimeConfig, deps: dict[Api, Any]):
from .memory import MemoryToolRuntimeImpl
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from pydantic import BaseModel
class RagToolRuntimeConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {}

View file

@ -8,7 +8,7 @@ import asyncio
import logging
import secrets
import string
from typing import Any, Dict, List, Optional
from typing import Any
from pydantic import TypeAdapter
@ -74,7 +74,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
async def insert(
self,
documents: List[RAGDocument],
documents: list[RAGDocument],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
@ -101,8 +101,8 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
async def query(
self,
content: InterleavedContent,
vector_db_ids: List[str],
query_config: Optional[RAGQueryConfig] = None,
vector_db_ids: list[str],
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
if not vector_db_ids:
return RAGQueryResult(content=None)
@ -123,7 +123,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
)
for vector_db_id in vector_db_ids
]
results: List[QueryChunksResponse] = await asyncio.gather(*tasks)
results: list[QueryChunksResponse] = await asyncio.gather(*tasks)
chunks = [c for r in results for c in r.chunks]
scores = [s for r in results for s in r.scores]
@ -168,7 +168,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
)
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse:
# Parameters are not listed since these methods are not yet invoked automatically
# by the LLM. The method is only implemented so things like /tools can list without
@ -193,7 +193,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
]
)
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
vector_db_ids = kwargs.get("vector_db_ids", [])
query_config = kwargs.get("query_config")
if query_config:

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from llama_stack.providers.datatypes import Api
from .config import ChromaVectorIOConfig
async def get_provider_impl(config: ChromaVectorIOConfig, deps: Dict[Api, Any]):
async def get_provider_impl(config: ChromaVectorIOConfig, deps: dict[Api, Any]):
from llama_stack.providers.remote.vector_io.chroma.chroma import (
ChromaVectorIOAdapter,
)

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from pydantic import BaseModel
@ -13,5 +13,5 @@ class ChromaVectorIOConfig(BaseModel):
db_path: str
@classmethod
def sample_run_config(cls, db_path: str = "${env.CHROMADB_PATH}", **kwargs: Any) -> Dict[str, Any]:
def sample_run_config(cls, db_path: str = "${env.CHROMADB_PATH}", **kwargs: Any) -> dict[str, Any]:
return {"db_path": db_path}

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from llama_stack.providers.datatypes import Api
from .config import FaissVectorIOConfig
async def get_provider_impl(config: FaissVectorIOConfig, deps: Dict[Api, Any]):
async def get_provider_impl(config: FaissVectorIOConfig, deps: dict[Api, Any]):
from .faiss import FaissVectorIOAdapter
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from pydantic import BaseModel
@ -20,7 +20,7 @@ class FaissVectorIOConfig(BaseModel):
kvstore: KVStoreConfig
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,

View file

@ -9,7 +9,7 @@ import base64
import io
import json
import logging
from typing import Any, Dict, List, Optional
from typing import Any
import faiss
import numpy as np
@ -84,7 +84,7 @@ class FaissIndex(EmbeddingIndex):
await self.kvstore.delete(f"{FAISS_INDEX_PREFIX}{self.bank_id}")
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
# Add dimension check
embedding_dim = embeddings.shape[1] if len(embeddings.shape) > 1 else embeddings.shape[0]
if embedding_dim != self.index.d:
@ -159,7 +159,7 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
inference_api=self.inference_api,
)
async def list_vector_dbs(self) -> List[VectorDB]:
async def list_vector_dbs(self) -> list[VectorDB]:
return [i.vector_db for i in self.cache.values()]
async def unregister_vector_db(self, vector_db_id: str) -> None:
@ -176,8 +176,8 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
async def insert_chunks(
self,
vector_db_id: str,
chunks: List[Chunk],
ttl_seconds: Optional[int] = None,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
index = self.cache.get(vector_db_id)
if index is None:
@ -189,7 +189,7 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
self,
vector_db_id: str,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
index = self.cache.get(vector_db_id)
if index is None:

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from llama_stack.providers.datatypes import Api
from .config import MilvusVectorIOConfig
async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, Any]):
async def get_provider_impl(config: MilvusVectorIOConfig, deps: dict[Api, Any]):
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
impl = MilvusVectorIOAdapter(config, deps[Api.inference])

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from pydantic import BaseModel
@ -16,5 +16,5 @@ class MilvusVectorIOConfig(BaseModel):
db_path: str
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {"db_path": "${env.MILVUS_DB_PATH}"}

View file

@ -4,14 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import QdrantVectorIOConfig
async def get_adapter_impl(config: QdrantVectorIOConfig, deps: Dict[Api, ProviderSpec]):
async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]):
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
impl = QdrantVectorIOAdapter(config, deps[Api.inference])

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from pydantic import BaseModel
@ -17,7 +17,7 @@ class QdrantVectorIOConfig(BaseModel):
path: str
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
return {
"path": "${env.QDRANT_PATH:~/.llama/" + __distro_dir__ + "}/" + "qdrant.db",
}

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from llama_stack.providers.datatypes import Api
from .config import SQLiteVectorIOConfig
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: Dict[Api, Any]):
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: dict[Api, Any]):
from .sqlite_vec import SQLiteVecVectorIOAdapter
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from pydantic import BaseModel
@ -13,7 +13,7 @@ class SQLiteVectorIOConfig(BaseModel):
db_path: str
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
return {
"db_path": "${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + "sqlite_vec.db",
}

View file

@ -10,7 +10,7 @@ import logging
import sqlite3
import struct
import uuid
from typing import Any, Dict, List, Optional
from typing import Any
import numpy as np
import sqlite_vec
@ -25,7 +25,7 @@ from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, Vect
logger = logging.getLogger(__name__)
def serialize_vector(vector: List[float]) -> bytes:
def serialize_vector(vector: list[float]) -> bytes:
"""Serialize a list of floats into a compact binary representation."""
return struct.pack(f"{len(vector)}f", *vector)
@ -98,7 +98,7 @@ class SQLiteVecIndex(EmbeddingIndex):
await asyncio.to_thread(_drop_tables)
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray, batch_size: int = 500):
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray, batch_size: int = 500):
"""
Add new chunks along with their embeddings using batch inserts.
For each chunk, we insert its JSON into the metadata table and then insert its
@ -209,7 +209,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config, inference_api: Inference) -> None:
self.config = config
self.inference_api = inference_api
self.cache: Dict[str, VectorDBWithIndex] = {}
self.cache: dict[str, VectorDBWithIndex] = {}
async def initialize(self) -> None:
def _setup_connection():
@ -264,7 +264,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.config.db_path, vector_db.identifier)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
async def list_vector_dbs(self) -> List[VectorDB]:
async def list_vector_dbs(self) -> list[VectorDB]:
return [v.vector_db for v in self.cache.values()]
async def unregister_vector_db(self, vector_db_id: str) -> None:
@ -286,7 +286,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
await asyncio.to_thread(_delete_vector_db_from_registry)
async def insert_chunks(self, vector_db_id: str, chunks: List[Chunk], ttl_seconds: Optional[int] = None) -> None:
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
if vector_db_id not in self.cache:
raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}")
# The VectorDBWithIndex helper is expected to compute embeddings via the inference_api
@ -294,7 +294,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
await self.cache[vector_db_id].insert_chunks(chunks)
async def query_chunks(
self, vector_db_id: str, query: Any, params: Optional[Dict[str, Any]] = None
self, vector_db_id: str, query: Any, params: dict[str, Any] | None = None
) -> QueryChunksResponse:
if vector_db_id not in self.cache:
raise ValueError(f"Vector DB {vector_db_id} not found")
@ -303,5 +303,5 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
hash_input = f"{document_id}:{chunk_text}".encode("utf-8")
hash_input = f"{document_id}:{chunk_text}".encode()
return str(uuid.UUID(hashlib.md5(hash_input).hexdigest()))