chore: enable pyupgrade fixes (#1806)

# What does this PR do?

The goal of this PR is code base modernization.

Schema reflection code needed a minor adjustment to handle UnionTypes
and collections.abc.AsyncIterator. (Both are preferred for latest Python
releases.)

Note to reviewers: almost all changes here are automatically generated
by pyupgrade. Some additional unused imports were cleaned up. The only
change worth of note can be found under `docs/openapi_generator` and
`llama_stack/strong_typing/schema.py` where reflection code was updated
to deal with "newer" types.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-05-01 17:23:50 -04:00 committed by GitHub
parent ffe3d0b2cd
commit 9e6561a1ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
319 changed files with 2843 additions and 3033 deletions

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from llama_stack.distribution.datatypes import RoutedProtocol
from llama_stack.distribution.store import DistributionRegistry
@ -23,7 +23,7 @@ from .routing_tables import (
async def get_routing_table_impl(
api: Api,
impls_by_provider_id: Dict[str, RoutedProtocol],
impls_by_provider_id: dict[str, RoutedProtocol],
_deps,
dist_registry: DistributionRegistry,
) -> Any:
@ -45,7 +45,7 @@ async def get_routing_table_impl(
return impl
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any:
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: dict[str, Any]) -> Any:
from .routers import (
DatasetIORouter,
EvalRouter,

View file

@ -6,12 +6,12 @@
import asyncio
import time
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Annotated, Any
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
from pydantic import Field, TypeAdapter
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import (
URL,
@ -100,9 +100,9 @@ class VectorIORouter(VectorIO):
self,
vector_db_id: str,
embedding_model: str,
embedding_dimension: Optional[int] = 384,
provider_id: Optional[str] = None,
provider_vector_db_id: Optional[str] = None,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> None:
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
await self.routing_table.register_vector_db(
@ -116,8 +116,8 @@ class VectorIORouter(VectorIO):
async def insert_chunks(
self,
vector_db_id: str,
chunks: List[Chunk],
ttl_seconds: Optional[int] = None,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
logger.debug(
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
@ -128,7 +128,7 @@ class VectorIORouter(VectorIO):
self,
vector_db_id: str,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
@ -140,7 +140,7 @@ class InferenceRouter(Inference):
def __init__(
self,
routing_table: RoutingTable,
telemetry: Optional[Telemetry] = None,
telemetry: Telemetry | None = None,
) -> None:
logger.debug("Initializing InferenceRouter")
self.routing_table = routing_table
@ -160,10 +160,10 @@ class InferenceRouter(Inference):
async def register_model(
self,
model_id: str,
provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
provider_model_id: str | None = None,
provider_id: str | None = None,
metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None,
) -> None:
logger.debug(
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
@ -176,7 +176,7 @@ class InferenceRouter(Inference):
completion_tokens: int,
total_tokens: int,
model: Model,
) -> List[MetricEvent]:
) -> list[MetricEvent]:
"""Constructs a list of MetricEvent objects containing token usage metrics.
Args:
@ -221,7 +221,7 @@ class InferenceRouter(Inference):
completion_tokens: int,
total_tokens: int,
model: Model,
) -> List[MetricInResponse]:
) -> list[MetricInResponse]:
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
if self.telemetry:
for metric in metrics:
@ -230,9 +230,9 @@ class InferenceRouter(Inference):
async def _count_tokens(
self,
messages: List[Message] | InterleavedContent,
tool_prompt_format: Optional[ToolPromptFormat] = None,
) -> Optional[int]:
messages: list[Message] | InterleavedContent,
tool_prompt_format: ToolPromptFormat | None = None,
) -> int | None:
if isinstance(messages, list):
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
else:
@ -242,16 +242,16 @@ class InferenceRouter(Inference):
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = None,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = None,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
logger.debug(
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
)
@ -351,12 +351,12 @@ class InferenceRouter(Inference):
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
messages_batch: list[list[Message]],
tools: list[ToolDefinition] | None = None,
tool_config: ToolConfig | None = None,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> BatchChatCompletionResponse:
logger.debug(
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
@ -376,10 +376,10 @@ class InferenceRouter(Inference):
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
@ -439,10 +439,10 @@ class InferenceRouter(Inference):
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> BatchCompletionResponse:
logger.debug(
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
@ -453,10 +453,10 @@ class InferenceRouter(Inference):
async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
logger.debug(f"InferenceRouter.embeddings: {model_id}")
model = await self.routing_table.get_model(model_id)
@ -475,24 +475,24 @@ class InferenceRouter(Inference):
async def openai_completion(
self,
model: str,
prompt: Union[str, List[str], List[int], List[List[int]]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
) -> OpenAICompletion:
logger.debug(
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
@ -531,29 +531,29 @@ class InferenceRouter(Inference):
async def openai_chat_completion(
self,
model: str,
messages: Annotated[List[OpenAIMessageParam], Field(..., min_length=1)],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
messages: Annotated[list[OpenAIMessageParam], Field(..., min_length=1)],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
logger.debug(
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
)
@ -602,7 +602,7 @@ class InferenceRouter(Inference):
provider = self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.openai_chat_completion(**params)
async def health(self) -> Dict[str, HealthResponse]:
async def health(self) -> dict[str, HealthResponse]:
health_statuses = {}
timeout = 0.5
for provider_id, impl in self.routing_table.impls_by_provider_id.items():
@ -645,9 +645,9 @@ class SafetyRouter(Safety):
async def register_shield(
self,
shield_id: str,
provider_shield_id: Optional[str] = None,
provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
provider_shield_id: str | None = None,
provider_id: str | None = None,
params: dict[str, Any] | None = None,
) -> Shield:
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
@ -655,8 +655,8 @@ class SafetyRouter(Safety):
async def run_shield(
self,
shield_id: str,
messages: List[Message],
params: Dict[str, Any] = None,
messages: list[Message],
params: dict[str, Any] = None,
) -> RunShieldResponse:
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
return await self.routing_table.get_provider_impl(shield_id).run_shield(
@ -686,8 +686,8 @@ class DatasetIORouter(DatasetIO):
self,
purpose: DatasetPurpose,
source: DataSource,
metadata: Optional[Dict[str, Any]] = None,
dataset_id: Optional[str] = None,
metadata: dict[str, Any] | None = None,
dataset_id: str | None = None,
) -> None:
logger.debug(
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
@ -702,8 +702,8 @@ class DatasetIORouter(DatasetIO):
async def iterrows(
self,
dataset_id: str,
start_index: Optional[int] = None,
limit: Optional[int] = None,
start_index: int | None = None,
limit: int | None = None,
) -> PaginatedResponse:
logger.debug(
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
@ -714,7 +714,7 @@ class DatasetIORouter(DatasetIO):
limit=limit,
)
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
dataset_id=dataset_id,
@ -741,7 +741,7 @@ class ScoringRouter(Scoring):
async def score_batch(
self,
dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
scoring_functions: dict[str, ScoringFnParams | None] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
@ -762,8 +762,8 @@ class ScoringRouter(Scoring):
async def score(
self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
input_rows: list[dict[str, Any]],
scoring_functions: dict[str, ScoringFnParams | None] = None,
) -> ScoreResponse:
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
res = {}
@ -808,8 +808,8 @@ class EvalRouter(Eval):
async def evaluate_rows(
self,
benchmark_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
input_rows: list[dict[str, Any]],
scoring_functions: list[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
@ -863,8 +863,8 @@ class ToolRuntimeRouter(ToolRuntime):
async def query(
self,
content: InterleavedContent,
vector_db_ids: List[str],
query_config: Optional[RAGQueryConfig] = None,
vector_db_ids: list[str],
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
return await self.routing_table.get_provider_impl("knowledge_search").query(
@ -873,7 +873,7 @@ class ToolRuntimeRouter(ToolRuntime):
async def insert(
self,
documents: List[RAGDocument],
documents: list[RAGDocument],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
@ -904,7 +904,7 @@ class ToolRuntimeRouter(ToolRuntime):
logger.debug("ToolRuntimeRouter.shutdown")
pass
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
tool_name=tool_name,
@ -912,7 +912,7 @@ class ToolRuntimeRouter(ToolRuntime):
)
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)

View file

@ -7,7 +7,7 @@
import logging
import time
import uuid
from typing import Any, Dict, List, Optional
from typing import Any
from pydantic import TypeAdapter
@ -106,20 +106,20 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
raise ValueError(f"Unregister not supported for {api}")
Registry = Dict[str, List[RoutableObjectWithProvider]]
Registry = dict[str, list[RoutableObjectWithProvider]]
class CommonRoutingTableImpl(RoutingTable):
def __init__(
self,
impls_by_provider_id: Dict[str, RoutedProtocol],
impls_by_provider_id: dict[str, RoutedProtocol],
dist_registry: DistributionRegistry,
) -> None:
self.impls_by_provider_id = impls_by_provider_id
self.dist_registry = dist_registry
async def initialize(self) -> None:
async def add_objects(objs: List[RoutableObjectWithProvider], provider_id: str, cls) -> None:
async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None:
for obj in objs:
if cls is None:
obj.provider_id = provider_id
@ -154,7 +154,7 @@ class CommonRoutingTableImpl(RoutingTable):
for p in self.impls_by_provider_id.values():
await p.shutdown()
def get_provider_impl(self, routing_key: str, provider_id: Optional[str] = None) -> Any:
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
def apiname_object():
if isinstance(self, ModelsRoutingTable):
return ("Inference", "model")
@ -192,7 +192,7 @@ class CommonRoutingTableImpl(RoutingTable):
raise ValueError(f"Provider not found for `{routing_key}`")
async def get_object_by_identifier(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
async def get_object_by_identifier(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
# Get from disk registry
obj = await self.dist_registry.get(type, identifier)
if not obj:
@ -236,7 +236,7 @@ class CommonRoutingTableImpl(RoutingTable):
await self.dist_registry.register(obj)
return obj
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
objs = await self.dist_registry.get_all()
filtered_objs = [obj for obj in objs if obj.type == type]
@ -277,10 +277,10 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def register_model(
self,
model_id: str,
provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
provider_model_id: str | None = None,
provider_id: str | None = None,
metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None,
) -> Model:
if provider_model_id is None:
provider_model_id = model_id
@ -328,9 +328,9 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def register_shield(
self,
shield_id: str,
provider_shield_id: Optional[str] = None,
provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
provider_shield_id: str | None = None,
provider_id: str | None = None,
params: dict[str, Any] | None = None,
) -> Shield:
if provider_shield_id is None:
provider_shield_id = shield_id
@ -368,9 +368,9 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
self,
vector_db_id: str,
embedding_model: str,
embedding_dimension: Optional[int] = 384,
provider_id: Optional[str] = None,
provider_vector_db_id: Optional[str] = None,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorDB:
if provider_vector_db_id is None:
provider_vector_db_id = vector_db_id
@ -423,8 +423,8 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
self,
purpose: DatasetPurpose,
source: DataSource,
metadata: Optional[Dict[str, Any]] = None,
dataset_id: Optional[str] = None,
metadata: dict[str, Any] | None = None,
dataset_id: str | None = None,
) -> Dataset:
if isinstance(source, dict):
if source["type"] == "uri":
@ -489,9 +489,9 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
scoring_fn_id: str,
description: str,
return_type: ParamType,
provider_scoring_fn_id: Optional[str] = None,
provider_id: Optional[str] = None,
params: Optional[ScoringFnParams] = None,
provider_scoring_fn_id: str | None = None,
provider_id: str | None = None,
params: ScoringFnParams | None = None,
) -> None:
if provider_scoring_fn_id is None:
provider_scoring_fn_id = scoring_fn_id
@ -528,10 +528,10 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
self,
benchmark_id: str,
dataset_id: str,
scoring_functions: List[str],
metadata: Optional[Dict[str, Any]] = None,
provider_benchmark_id: Optional[str] = None,
provider_id: Optional[str] = None,
scoring_functions: list[str],
metadata: dict[str, Any] | None = None,
provider_benchmark_id: str | None = None,
provider_id: str | None = None,
) -> None:
if metadata is None:
metadata = {}
@ -556,7 +556,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
tools = await self.get_all_with_type("tool")
if toolgroup_id:
tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id]
@ -578,8 +578,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
self,
toolgroup_id: str,
provider_id: str,
mcp_endpoint: Optional[URL] = None,
args: Optional[Dict[str, Any]] = None,
mcp_endpoint: URL | None = None,
args: dict[str, Any] | None = None,
) -> None:
tools = []
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)