fix(mypy): resolve type issues in MongoDB, batches, and auth providers (#3933)

Fixes mypy type errors in provider utilities:
- MongoDB: Fix AsyncMongoClient parameters, use async iteration for
cursor
- Batches: Handle memoryview|bytes union for file decoding
- Auth: Add missing imports, validate JWKS URI, conditionally pass
parameters

Fixes 11 type errors. No functional changes.
This commit is contained in:
Ashwin Bharambe 2025-10-28 10:23:39 -07:00 committed by GitHub
parent 4a2ea278c5
commit 6ce59b5df8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 52 additions and 29 deletions

View file

@ -6,6 +6,7 @@
import ssl import ssl
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any
from urllib.parse import parse_qs, urljoin, urlparse from urllib.parse import parse_qs, urljoin, urlparse
import httpx import httpx
@ -143,14 +144,21 @@ class OAuth2TokenAuthProvider(AuthProvider):
if self.config.jwks and self.config.jwks.token: if self.config.jwks and self.config.jwks.token:
headers["Authorization"] = f"Bearer {self.config.jwks.token}" headers["Authorization"] = f"Bearer {self.config.jwks.token}"
self._jwks_client = jwt.PyJWKClient( # Ensure uri is not None for PyJWKClient
self.config.jwks.uri if self.config.jwks else None, if not self.config.jwks or not self.config.jwks.uri:
cache_keys=True, raise ValueError("JWKS configuration requires a valid URI")
max_cached_keys=10,
lifespan=self.config.jwks.key_recheck_period if self.config.jwks else None, # Build kwargs conditionally to avoid passing None values
headers=headers, jwks_kwargs: dict[str, Any] = {
ssl_context=ssl_context, "cache_keys": True,
) "max_cached_keys": 10,
"headers": headers,
"ssl_context": ssl_context,
}
if self.config.jwks.key_recheck_period is not None:
jwks_kwargs["lifespan"] = self.config.jwks.key_recheck_period
self._jwks_client = jwt.PyJWKClient(self.config.jwks.uri, **jwks_kwargs)
return self._jwks_client return self._jwks_client
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User: async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User:
@ -197,23 +205,31 @@ class OAuth2TokenAuthProvider(AuthProvider):
if self.config.introspection is None: if self.config.introspection is None:
raise ValueError("Introspection is not configured") raise ValueError("Introspection is not configured")
# ssl_ctxt can be None, bool, str, or SSLContext - httpx accepts all
ssl_ctxt: ssl.SSLContext | bool = False # Default to no verification if no cafile
if self.config.tls_cafile:
ssl_ctxt = ssl.create_default_context(cafile=self.config.tls_cafile.as_posix())
# Build post kwargs conditionally based on auth method
post_kwargs: dict[str, Any] = {
"url": self.config.introspection.url,
"data": form,
"timeout": 10.0,
}
if self.config.introspection.send_secret_in_body: if self.config.introspection.send_secret_in_body:
form["client_id"] = self.config.introspection.client_id form["client_id"] = self.config.introspection.client_id
form["client_secret"] = self.config.introspection.client_secret form["client_secret"] = self.config.introspection.client_secret
auth = None
else: else:
auth = (self.config.introspection.client_id, self.config.introspection.client_secret) # httpx auth parameter expects tuple[str | bytes, str | bytes]
ssl_ctxt = None post_kwargs["auth"] = (
if self.config.tls_cafile: self.config.introspection.client_id,
ssl_ctxt = ssl.create_default_context(cafile=self.config.tls_cafile.as_posix()) self.config.introspection.client_secret,
)
try: try:
async with httpx.AsyncClient(verify=ssl_ctxt) as client: async with httpx.AsyncClient(verify=ssl_ctxt) as client:
response = await client.post( response = await client.post(**post_kwargs)
self.config.introspection.url,
data=form,
auth=auth,
timeout=10.0, # Add a reasonable timeout
)
if response.status_code != httpx.codes.OK: if response.status_code != httpx.codes.OK:
logger.warning(f"Token introspection failed with status code: {response.status_code}") logger.warning(f"Token introspection failed with status code: {response.status_code}")
raise ValueError(f"Token introspection failed: {response.status_code}") raise ValueError(f"Token introspection failed: {response.status_code}")

View file

@ -358,7 +358,11 @@ class ReferenceBatchesImpl(Batches):
# TODO(SECURITY): do something about large files # TODO(SECURITY): do something about large files
file_content_response = await self.files_api.openai_retrieve_file_content(batch.input_file_id) file_content_response = await self.files_api.openai_retrieve_file_content(batch.input_file_id)
file_content = file_content_response.body.decode("utf-8") # Handle both bytes and memoryview types
body = file_content_response.body
if isinstance(body, memoryview):
body = bytes(body)
file_content = body.decode("utf-8")
for line_num, line in enumerate(file_content.strip().split("\n"), 1): for line_num, line in enumerate(file_content.strip().split("\n"), 1):
if line.strip(): # skip empty lines if line.strip(): # skip empty lines
try: try:

View file

@ -30,14 +30,13 @@ class MongoDBKVStoreImpl(KVStore):
async def initialize(self) -> None: async def initialize(self) -> None:
try: try:
conn_creds = { # Pass parameters explicitly to satisfy mypy - AsyncMongoClient doesn't accept **dict
"host": self.config.host, self.conn = AsyncMongoClient(
"port": self.config.port, host=self.config.host if self.config.host is not None else None,
"username": self.config.user, port=self.config.port if self.config.port is not None else None,
"password": self.config.password, username=self.config.user if self.config.user is not None else None,
} password=self.config.password if self.config.password is not None else None,
conn_creds = {k: v for k, v in conn_creds.items() if v is not None} )
self.conn = AsyncMongoClient(**conn_creds)
except Exception as e: except Exception as e:
log.exception("Could not connect to MongoDB database server") log.exception("Could not connect to MongoDB database server")
raise RuntimeError("Could not connect to MongoDB database server") from e raise RuntimeError("Could not connect to MongoDB database server") from e
@ -79,4 +78,8 @@ class MongoDBKVStoreImpl(KVStore):
end_key = self._namespaced_key(end_key) end_key = self._namespaced_key(end_key)
query = {"key": {"$gte": start_key, "$lt": end_key}} query = {"key": {"$gte": start_key, "$lt": end_key}}
cursor = self.collection.find(query, {"key": 1, "_id": 0}).sort("key", 1) cursor = self.collection.find(query, {"key": 1, "_id": 0}).sort("key", 1)
return [doc["key"] for doc in cursor] # AsyncCursor requires async iteration
result = []
async for doc in cursor:
result.append(doc["key"])
return result