mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
fix(mypy): resolve type issues in MongoDB, batches, and auth providers (#3933)
Fixes mypy type errors in provider utilities: - MongoDB: Fix AsyncMongoClient parameters, use async iteration for cursor - Batches: Handle memoryview|bytes union for file decoding - Auth: Add missing imports, validate JWKS URI, conditionally pass parameters Fixes 11 type errors. No functional changes.
This commit is contained in:
parent
4a2ea278c5
commit
6ce59b5df8
3 changed files with 52 additions and 29 deletions
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
import ssl
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs, urljoin, urlparse
|
||||
|
||||
import httpx
|
||||
|
|
@ -143,14 +144,21 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
if self.config.jwks and self.config.jwks.token:
|
||||
headers["Authorization"] = f"Bearer {self.config.jwks.token}"
|
||||
|
||||
self._jwks_client = jwt.PyJWKClient(
|
||||
self.config.jwks.uri if self.config.jwks else None,
|
||||
cache_keys=True,
|
||||
max_cached_keys=10,
|
||||
lifespan=self.config.jwks.key_recheck_period if self.config.jwks else None,
|
||||
headers=headers,
|
||||
ssl_context=ssl_context,
|
||||
)
|
||||
# Ensure uri is not None for PyJWKClient
|
||||
if not self.config.jwks or not self.config.jwks.uri:
|
||||
raise ValueError("JWKS configuration requires a valid URI")
|
||||
|
||||
# Build kwargs conditionally to avoid passing None values
|
||||
jwks_kwargs: dict[str, Any] = {
|
||||
"cache_keys": True,
|
||||
"max_cached_keys": 10,
|
||||
"headers": headers,
|
||||
"ssl_context": ssl_context,
|
||||
}
|
||||
if self.config.jwks.key_recheck_period is not None:
|
||||
jwks_kwargs["lifespan"] = self.config.jwks.key_recheck_period
|
||||
|
||||
self._jwks_client = jwt.PyJWKClient(self.config.jwks.uri, **jwks_kwargs)
|
||||
return self._jwks_client
|
||||
|
||||
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User:
|
||||
|
|
@ -197,23 +205,31 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
if self.config.introspection is None:
|
||||
raise ValueError("Introspection is not configured")
|
||||
|
||||
# ssl_ctxt can be None, bool, str, or SSLContext - httpx accepts all
|
||||
ssl_ctxt: ssl.SSLContext | bool = False # Default to no verification if no cafile
|
||||
if self.config.tls_cafile:
|
||||
ssl_ctxt = ssl.create_default_context(cafile=self.config.tls_cafile.as_posix())
|
||||
|
||||
# Build post kwargs conditionally based on auth method
|
||||
post_kwargs: dict[str, Any] = {
|
||||
"url": self.config.introspection.url,
|
||||
"data": form,
|
||||
"timeout": 10.0,
|
||||
}
|
||||
|
||||
if self.config.introspection.send_secret_in_body:
|
||||
form["client_id"] = self.config.introspection.client_id
|
||||
form["client_secret"] = self.config.introspection.client_secret
|
||||
auth = None
|
||||
else:
|
||||
auth = (self.config.introspection.client_id, self.config.introspection.client_secret)
|
||||
ssl_ctxt = None
|
||||
if self.config.tls_cafile:
|
||||
ssl_ctxt = ssl.create_default_context(cafile=self.config.tls_cafile.as_posix())
|
||||
# httpx auth parameter expects tuple[str | bytes, str | bytes]
|
||||
post_kwargs["auth"] = (
|
||||
self.config.introspection.client_id,
|
||||
self.config.introspection.client_secret,
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(verify=ssl_ctxt) as client:
|
||||
response = await client.post(
|
||||
self.config.introspection.url,
|
||||
data=form,
|
||||
auth=auth,
|
||||
timeout=10.0, # Add a reasonable timeout
|
||||
)
|
||||
response = await client.post(**post_kwargs)
|
||||
if response.status_code != httpx.codes.OK:
|
||||
logger.warning(f"Token introspection failed with status code: {response.status_code}")
|
||||
raise ValueError(f"Token introspection failed: {response.status_code}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue