From 6ce59b5df8dc81827276a81b5ef78de4c736977c Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 28 Oct 2025 10:23:39 -0700 Subject: [PATCH] fix(mypy): resolve type issues in MongoDB, batches, and auth providers (#3933) Fixes mypy type errors in provider utilities: - MongoDB: Fix AsyncMongoClient parameters, use async iteration for cursor - Batches: Handle memoryview|bytes union for file decoding - Auth: Add missing imports, validate JWKS URI, conditionally pass parameters Fixes 11 type errors. No functional changes. --- src/llama_stack/core/server/auth_providers.py | 54 ++++++++++++------- .../inline/batches/reference/batches.py | 6 ++- .../utils/kvstore/mongodb/mongodb.py | 21 ++++---- 3 files changed, 52 insertions(+), 29 deletions(-) diff --git a/src/llama_stack/core/server/auth_providers.py b/src/llama_stack/core/server/auth_providers.py index 0fe5f1558..da398bf99 100644 --- a/src/llama_stack/core/server/auth_providers.py +++ b/src/llama_stack/core/server/auth_providers.py @@ -6,6 +6,7 @@ import ssl from abc import ABC, abstractmethod +from typing import Any from urllib.parse import parse_qs, urljoin, urlparse import httpx @@ -143,14 +144,21 @@ class OAuth2TokenAuthProvider(AuthProvider): if self.config.jwks and self.config.jwks.token: headers["Authorization"] = f"Bearer {self.config.jwks.token}" - self._jwks_client = jwt.PyJWKClient( - self.config.jwks.uri if self.config.jwks else None, - cache_keys=True, - max_cached_keys=10, - lifespan=self.config.jwks.key_recheck_period if self.config.jwks else None, - headers=headers, - ssl_context=ssl_context, - ) + # Ensure uri is not None for PyJWKClient + if not self.config.jwks or not self.config.jwks.uri: + raise ValueError("JWKS configuration requires a valid URI") + + # Build kwargs conditionally to avoid passing None values + jwks_kwargs: dict[str, Any] = { + "cache_keys": True, + "max_cached_keys": 10, + "headers": headers, + "ssl_context": ssl_context, + } + if self.config.jwks.key_recheck_period is not None: + jwks_kwargs["lifespan"] = self.config.jwks.key_recheck_period + + self._jwks_client = jwt.PyJWKClient(self.config.jwks.uri, **jwks_kwargs) return self._jwks_client async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User: @@ -197,23 +205,31 @@ class OAuth2TokenAuthProvider(AuthProvider): if self.config.introspection is None: raise ValueError("Introspection is not configured") + # ssl_ctxt can be None, bool, str, or SSLContext - httpx accepts all + ssl_ctxt: ssl.SSLContext | bool = False # Default to no verification if no cafile + if self.config.tls_cafile: + ssl_ctxt = ssl.create_default_context(cafile=self.config.tls_cafile.as_posix()) + + # Build post kwargs conditionally based on auth method + post_kwargs: dict[str, Any] = { + "url": self.config.introspection.url, + "data": form, + "timeout": 10.0, + } + if self.config.introspection.send_secret_in_body: form["client_id"] = self.config.introspection.client_id form["client_secret"] = self.config.introspection.client_secret - auth = None else: - auth = (self.config.introspection.client_id, self.config.introspection.client_secret) - ssl_ctxt = None - if self.config.tls_cafile: - ssl_ctxt = ssl.create_default_context(cafile=self.config.tls_cafile.as_posix()) + # httpx auth parameter expects tuple[str | bytes, str | bytes] + post_kwargs["auth"] = ( + self.config.introspection.client_id, + self.config.introspection.client_secret, + ) + try: async with httpx.AsyncClient(verify=ssl_ctxt) as client: - response = await client.post( - self.config.introspection.url, - data=form, - auth=auth, - timeout=10.0, # Add a reasonable timeout - ) + response = await client.post(**post_kwargs) if response.status_code != httpx.codes.OK: logger.warning(f"Token introspection failed with status code: {response.status_code}") raise ValueError(f"Token introspection failed: {response.status_code}") diff --git a/src/llama_stack/providers/inline/batches/reference/batches.py b/src/llama_stack/providers/inline/batches/reference/batches.py index fa581ae1f..79dc9c84c 100644 --- a/src/llama_stack/providers/inline/batches/reference/batches.py +++ b/src/llama_stack/providers/inline/batches/reference/batches.py @@ -358,7 +358,11 @@ class ReferenceBatchesImpl(Batches): # TODO(SECURITY): do something about large files file_content_response = await self.files_api.openai_retrieve_file_content(batch.input_file_id) - file_content = file_content_response.body.decode("utf-8") + # Handle both bytes and memoryview types + body = file_content_response.body + if isinstance(body, memoryview): + body = bytes(body) + file_content = body.decode("utf-8") for line_num, line in enumerate(file_content.strip().split("\n"), 1): if line.strip(): # skip empty lines try: diff --git a/src/llama_stack/providers/utils/kvstore/mongodb/mongodb.py b/src/llama_stack/providers/utils/kvstore/mongodb/mongodb.py index 4d60949c1..964c45090 100644 --- a/src/llama_stack/providers/utils/kvstore/mongodb/mongodb.py +++ b/src/llama_stack/providers/utils/kvstore/mongodb/mongodb.py @@ -30,14 +30,13 @@ class MongoDBKVStoreImpl(KVStore): async def initialize(self) -> None: try: - conn_creds = { - "host": self.config.host, - "port": self.config.port, - "username": self.config.user, - "password": self.config.password, - } - conn_creds = {k: v for k, v in conn_creds.items() if v is not None} - self.conn = AsyncMongoClient(**conn_creds) + # Pass parameters explicitly to satisfy mypy - AsyncMongoClient doesn't accept **dict + self.conn = AsyncMongoClient( + host=self.config.host if self.config.host is not None else None, + port=self.config.port if self.config.port is not None else None, + username=self.config.user if self.config.user is not None else None, + password=self.config.password if self.config.password is not None else None, + ) except Exception as e: log.exception("Could not connect to MongoDB database server") raise RuntimeError("Could not connect to MongoDB database server") from e @@ -79,4 +78,8 @@ class MongoDBKVStoreImpl(KVStore): end_key = self._namespaced_key(end_key) query = {"key": {"$gte": start_key, "$lt": end_key}} cursor = self.collection.find(query, {"key": 1, "_id": 0}).sort("key", 1) - return [doc["key"] for doc in cursor] + # AsyncCursor requires async iteration + result = [] + async for doc in cursor: + result.append(doc["key"]) + return result