mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
chore: enable pyupgrade fixes (#1806)
# What does this PR do? The goal of this PR is code base modernization. Schema reflection code needed a minor adjustment to handle UnionTypes and collections.abc.AsyncIterator. (Both are preferred for latest Python releases.) Note to reviewers: almost all changes here are automatically generated by pyupgrade. Some additional unused imports were cleaned up. The only change worth of note can be found under `docs/openapi_generator` and `llama_stack/strong_typing/schema.py` where reflection code was updated to deal with "newer" types. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
ffe3d0b2cd
commit
9e6561a1ec
319 changed files with 2843 additions and 3033 deletions
|
@ -7,7 +7,6 @@
|
|||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
import httpx
|
||||
|
@ -22,7 +21,7 @@ logger = get_logger(name=__name__, category="auth")
|
|||
class AuthResponse(BaseModel):
|
||||
"""The format of the authentication response from the auth endpoint."""
|
||||
|
||||
access_attributes: Optional[AccessAttributes] = Field(
|
||||
access_attributes: AccessAttributes | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Structured user attributes for attribute-based access control.
|
||||
|
@ -44,7 +43,7 @@ class AuthResponse(BaseModel):
|
|||
""",
|
||||
)
|
||||
|
||||
message: Optional[str] = Field(
|
||||
message: str | None = Field(
|
||||
default=None, description="Optional message providing additional context about the authentication result."
|
||||
)
|
||||
|
||||
|
@ -52,9 +51,9 @@ class AuthResponse(BaseModel):
|
|||
class AuthRequestContext(BaseModel):
|
||||
path: str = Field(description="The path of the request being authenticated")
|
||||
|
||||
headers: Dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
|
||||
headers: dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
|
||||
|
||||
params: Dict[str, List[str]] = Field(
|
||||
params: dict[str, list[str]] = Field(
|
||||
description="Query parameters from the original request, parsed as dictionary of lists"
|
||||
)
|
||||
|
||||
|
@ -76,14 +75,14 @@ class AuthProviderConfig(BaseModel):
|
|||
"""Base configuration for authentication providers."""
|
||||
|
||||
provider_type: AuthProviderType = Field(..., description="Type of authentication provider")
|
||||
config: Dict[str, str] = Field(..., description="Provider-specific configuration")
|
||||
config: dict[str, str] = Field(..., description="Provider-specific configuration")
|
||||
|
||||
|
||||
class AuthProvider(ABC):
|
||||
"""Abstract base class for authentication providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def validate_token(self, token: str, scope: Optional[Dict] = None) -> Optional[AccessAttributes]:
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
|
||||
"""Validate a token and return access attributes."""
|
||||
pass
|
||||
|
||||
|
@ -96,7 +95,7 @@ class AuthProvider(ABC):
|
|||
class KubernetesAuthProvider(AuthProvider):
|
||||
"""Kubernetes authentication provider that validates tokens against the Kubernetes API server."""
|
||||
|
||||
def __init__(self, config: Dict[str, str]):
|
||||
def __init__(self, config: dict[str, str]):
|
||||
self.api_server_url = config["api_server_url"]
|
||||
self.ca_cert_path = config.get("ca_cert_path")
|
||||
self._client = None
|
||||
|
@ -120,7 +119,7 @@ class KubernetesAuthProvider(AuthProvider):
|
|||
self._client = ApiClient(configuration)
|
||||
return self._client
|
||||
|
||||
async def validate_token(self, token: str, scope: Optional[Dict] = None) -> Optional[AccessAttributes]:
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
|
||||
"""Validate a Kubernetes token and return access attributes."""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
|
@ -166,11 +165,11 @@ class KubernetesAuthProvider(AuthProvider):
|
|||
class CustomAuthProvider(AuthProvider):
|
||||
"""Custom authentication provider that uses an external endpoint."""
|
||||
|
||||
def __init__(self, config: Dict[str, str]):
|
||||
def __init__(self, config: dict[str, str]):
|
||||
self.endpoint = config["endpoint"]
|
||||
self._client = None
|
||||
|
||||
async def validate_token(self, token: str, scope: Optional[Dict] = None) -> Optional[AccessAttributes]:
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
|
||||
"""Validate a token using the custom authentication endpoint."""
|
||||
if not self.endpoint:
|
||||
raise ValueError("Authentication endpoint not configured")
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
import inspect
|
||||
import re
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -29,7 +28,7 @@ def toolgroup_protocol_map():
|
|||
}
|
||||
|
||||
|
||||
def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||
def get_all_api_endpoints() -> dict[Api, list[ApiEndpoint]]:
|
||||
apis = {}
|
||||
|
||||
protocols = api_protocol_map()
|
||||
|
|
|
@ -15,7 +15,7 @@ import warnings
|
|||
from contextlib import asynccontextmanager
|
||||
from importlib.metadata import version as parse_version
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional, Union
|
||||
from typing import Annotated, Any
|
||||
|
||||
import yaml
|
||||
from fastapi import Body, FastAPI, HTTPException, Request
|
||||
|
@ -24,7 +24,6 @@ from fastapi.exceptions import RequestValidationError
|
|||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from openai import BadRequestError
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.distribution.datatypes import LoggingConfig, StackRunConfig
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
|
@ -91,7 +90,7 @@ async def global_exception_handler(request: Request, exc: Exception):
|
|||
return JSONResponse(status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}})
|
||||
|
||||
|
||||
def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidationError]:
|
||||
def translate_exception(exc: Exception) -> HTTPException | RequestValidationError:
|
||||
if isinstance(exc, ValidationError):
|
||||
exc = RequestValidationError(exc.errors())
|
||||
|
||||
|
@ -315,7 +314,7 @@ class ClientVersionMiddleware:
|
|||
return await self.app(scope, receive, send)
|
||||
|
||||
|
||||
def main(args: Optional[argparse.Namespace] = None):
|
||||
def main(args: argparse.Namespace | None = None):
|
||||
"""Start the LlamaStack server."""
|
||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||
parser.add_argument(
|
||||
|
@ -385,7 +384,7 @@ def main(args: Optional[argparse.Namespace] = None):
|
|||
raise ValueError("Either --yaml-config or --template must be provided")
|
||||
|
||||
logger_config = None
|
||||
with open(config_file, "r") as fp:
|
||||
with open(config_file) as fp:
|
||||
config_contents = yaml.safe_load(fp)
|
||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||
logger_config = LoggingConfig(**cfg)
|
||||
|
@ -517,7 +516,7 @@ def main(args: Optional[argparse.Namespace] = None):
|
|||
uvicorn.run(**uvicorn_config)
|
||||
|
||||
|
||||
def extract_path_params(route: str) -> List[str]:
|
||||
def extract_path_params(route: str) -> list[str]:
|
||||
segments = route.split("/")
|
||||
params = [seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")]
|
||||
# to handle path params like {param:path}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue