mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
chore: enable pyupgrade fixes (#1806)
# What does this PR do? The goal of this PR is code base modernization. Schema reflection code needed a minor adjustment to handle UnionTypes and collections.abc.AsyncIterator. (Both are preferred for latest Python releases.) Note to reviewers: almost all changes here are automatically generated by pyupgrade. Some additional unused imports were cleaned up. The only change worth of note can be found under `docs/openapi_generator` and `llama_stack/strong_typing/schema.py` where reflection code was updated to deal with "newer" types. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
ffe3d0b2cd
commit
9e6561a1ec
319 changed files with 2843 additions and 3033 deletions
|
@ -7,7 +7,7 @@
|
|||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
|
@ -106,20 +106,20 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
|||
raise ValueError(f"Unregister not supported for {api}")
|
||||
|
||||
|
||||
Registry = Dict[str, List[RoutableObjectWithProvider]]
|
||||
Registry = dict[str, list[RoutableObjectWithProvider]]
|
||||
|
||||
|
||||
class CommonRoutingTableImpl(RoutingTable):
|
||||
def __init__(
|
||||
self,
|
||||
impls_by_provider_id: Dict[str, RoutedProtocol],
|
||||
impls_by_provider_id: dict[str, RoutedProtocol],
|
||||
dist_registry: DistributionRegistry,
|
||||
) -> None:
|
||||
self.impls_by_provider_id = impls_by_provider_id
|
||||
self.dist_registry = dist_registry
|
||||
|
||||
async def initialize(self) -> None:
|
||||
async def add_objects(objs: List[RoutableObjectWithProvider], provider_id: str, cls) -> None:
|
||||
async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None:
|
||||
for obj in objs:
|
||||
if cls is None:
|
||||
obj.provider_id = provider_id
|
||||
|
@ -154,7 +154,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
for p in self.impls_by_provider_id.values():
|
||||
await p.shutdown()
|
||||
|
||||
def get_provider_impl(self, routing_key: str, provider_id: Optional[str] = None) -> Any:
|
||||
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||
def apiname_object():
|
||||
if isinstance(self, ModelsRoutingTable):
|
||||
return ("Inference", "model")
|
||||
|
@ -192,7 +192,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
raise ValueError(f"Provider not found for `{routing_key}`")
|
||||
|
||||
async def get_object_by_identifier(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
|
||||
async def get_object_by_identifier(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
|
||||
# Get from disk registry
|
||||
obj = await self.dist_registry.get(type, identifier)
|
||||
if not obj:
|
||||
|
@ -236,7 +236,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
await self.dist_registry.register(obj)
|
||||
return obj
|
||||
|
||||
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
|
||||
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
|
||||
objs = await self.dist_registry.get_all()
|
||||
filtered_objs = [obj for obj in objs if obj.type == type]
|
||||
|
||||
|
@ -277,10 +277,10 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
async def register_model(
|
||||
self,
|
||||
model_id: str,
|
||||
provider_model_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
provider_model_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
model_type: ModelType | None = None,
|
||||
) -> Model:
|
||||
if provider_model_id is None:
|
||||
provider_model_id = model_id
|
||||
|
@ -328,9 +328,9 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|||
async def register_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
provider_shield_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
provider_shield_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> Shield:
|
||||
if provider_shield_id is None:
|
||||
provider_shield_id = shield_id
|
||||
|
@ -368,9 +368,9 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
self,
|
||||
vector_db_id: str,
|
||||
embedding_model: str,
|
||||
embedding_dimension: Optional[int] = 384,
|
||||
provider_id: Optional[str] = None,
|
||||
provider_vector_db_id: Optional[str] = None,
|
||||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
) -> VectorDB:
|
||||
if provider_vector_db_id is None:
|
||||
provider_vector_db_id = vector_db_id
|
||||
|
@ -423,8 +423,8 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
self,
|
||||
purpose: DatasetPurpose,
|
||||
source: DataSource,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
dataset_id: Optional[str] = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
dataset_id: str | None = None,
|
||||
) -> Dataset:
|
||||
if isinstance(source, dict):
|
||||
if source["type"] == "uri":
|
||||
|
@ -489,9 +489,9 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
|||
scoring_fn_id: str,
|
||||
description: str,
|
||||
return_type: ParamType,
|
||||
provider_scoring_fn_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
params: Optional[ScoringFnParams] = None,
|
||||
provider_scoring_fn_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
params: ScoringFnParams | None = None,
|
||||
) -> None:
|
||||
if provider_scoring_fn_id is None:
|
||||
provider_scoring_fn_id = scoring_fn_id
|
||||
|
@ -528,10 +528,10 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
|||
self,
|
||||
benchmark_id: str,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
provider_benchmark_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
scoring_functions: list[str],
|
||||
metadata: dict[str, Any] | None = None,
|
||||
provider_benchmark_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
) -> None:
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
@ -556,7 +556,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
|||
|
||||
|
||||
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||
tools = await self.get_all_with_type("tool")
|
||||
if toolgroup_id:
|
||||
tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id]
|
||||
|
@ -578,8 +578,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
self,
|
||||
toolgroup_id: str,
|
||||
provider_id: str,
|
||||
mcp_endpoint: Optional[URL] = None,
|
||||
args: Optional[Dict[str, Any]] = None,
|
||||
mcp_endpoint: URL | None = None,
|
||||
args: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
tools = []
|
||||
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue