mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
fix(mypy): resolve type issues in MongoDB, batches, and auth providers (#3933)
Fixes mypy type errors in provider utilities: - MongoDB: Fix AsyncMongoClient parameters, use async iteration for cursor - Batches: Handle memoryview|bytes union for file decoding - Auth: Add missing imports, validate JWKS URI, conditionally pass parameters Fixes 11 type errors. No functional changes.
This commit is contained in:
parent
4a2ea278c5
commit
6ce59b5df8
3 changed files with 52 additions and 29 deletions
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
import ssl
|
import ssl
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any
|
||||||
from urllib.parse import parse_qs, urljoin, urlparse
|
from urllib.parse import parse_qs, urljoin, urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
@ -143,14 +144,21 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
if self.config.jwks and self.config.jwks.token:
|
if self.config.jwks and self.config.jwks.token:
|
||||||
headers["Authorization"] = f"Bearer {self.config.jwks.token}"
|
headers["Authorization"] = f"Bearer {self.config.jwks.token}"
|
||||||
|
|
||||||
self._jwks_client = jwt.PyJWKClient(
|
# Ensure uri is not None for PyJWKClient
|
||||||
self.config.jwks.uri if self.config.jwks else None,
|
if not self.config.jwks or not self.config.jwks.uri:
|
||||||
cache_keys=True,
|
raise ValueError("JWKS configuration requires a valid URI")
|
||||||
max_cached_keys=10,
|
|
||||||
lifespan=self.config.jwks.key_recheck_period if self.config.jwks else None,
|
# Build kwargs conditionally to avoid passing None values
|
||||||
headers=headers,
|
jwks_kwargs: dict[str, Any] = {
|
||||||
ssl_context=ssl_context,
|
"cache_keys": True,
|
||||||
)
|
"max_cached_keys": 10,
|
||||||
|
"headers": headers,
|
||||||
|
"ssl_context": ssl_context,
|
||||||
|
}
|
||||||
|
if self.config.jwks.key_recheck_period is not None:
|
||||||
|
jwks_kwargs["lifespan"] = self.config.jwks.key_recheck_period
|
||||||
|
|
||||||
|
self._jwks_client = jwt.PyJWKClient(self.config.jwks.uri, **jwks_kwargs)
|
||||||
return self._jwks_client
|
return self._jwks_client
|
||||||
|
|
||||||
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User:
|
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User:
|
||||||
|
|
@ -197,23 +205,31 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
if self.config.introspection is None:
|
if self.config.introspection is None:
|
||||||
raise ValueError("Introspection is not configured")
|
raise ValueError("Introspection is not configured")
|
||||||
|
|
||||||
|
# ssl_ctxt can be None, bool, str, or SSLContext - httpx accepts all
|
||||||
|
ssl_ctxt: ssl.SSLContext | bool = False # Default to no verification if no cafile
|
||||||
|
if self.config.tls_cafile:
|
||||||
|
ssl_ctxt = ssl.create_default_context(cafile=self.config.tls_cafile.as_posix())
|
||||||
|
|
||||||
|
# Build post kwargs conditionally based on auth method
|
||||||
|
post_kwargs: dict[str, Any] = {
|
||||||
|
"url": self.config.introspection.url,
|
||||||
|
"data": form,
|
||||||
|
"timeout": 10.0,
|
||||||
|
}
|
||||||
|
|
||||||
if self.config.introspection.send_secret_in_body:
|
if self.config.introspection.send_secret_in_body:
|
||||||
form["client_id"] = self.config.introspection.client_id
|
form["client_id"] = self.config.introspection.client_id
|
||||||
form["client_secret"] = self.config.introspection.client_secret
|
form["client_secret"] = self.config.introspection.client_secret
|
||||||
auth = None
|
|
||||||
else:
|
else:
|
||||||
auth = (self.config.introspection.client_id, self.config.introspection.client_secret)
|
# httpx auth parameter expects tuple[str | bytes, str | bytes]
|
||||||
ssl_ctxt = None
|
post_kwargs["auth"] = (
|
||||||
if self.config.tls_cafile:
|
self.config.introspection.client_id,
|
||||||
ssl_ctxt = ssl.create_default_context(cafile=self.config.tls_cafile.as_posix())
|
self.config.introspection.client_secret,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(verify=ssl_ctxt) as client:
|
async with httpx.AsyncClient(verify=ssl_ctxt) as client:
|
||||||
response = await client.post(
|
response = await client.post(**post_kwargs)
|
||||||
self.config.introspection.url,
|
|
||||||
data=form,
|
|
||||||
auth=auth,
|
|
||||||
timeout=10.0, # Add a reasonable timeout
|
|
||||||
)
|
|
||||||
if response.status_code != httpx.codes.OK:
|
if response.status_code != httpx.codes.OK:
|
||||||
logger.warning(f"Token introspection failed with status code: {response.status_code}")
|
logger.warning(f"Token introspection failed with status code: {response.status_code}")
|
||||||
raise ValueError(f"Token introspection failed: {response.status_code}")
|
raise ValueError(f"Token introspection failed: {response.status_code}")
|
||||||
|
|
|
||||||
|
|
@ -358,7 +358,11 @@ class ReferenceBatchesImpl(Batches):
|
||||||
|
|
||||||
# TODO(SECURITY): do something about large files
|
# TODO(SECURITY): do something about large files
|
||||||
file_content_response = await self.files_api.openai_retrieve_file_content(batch.input_file_id)
|
file_content_response = await self.files_api.openai_retrieve_file_content(batch.input_file_id)
|
||||||
file_content = file_content_response.body.decode("utf-8")
|
# Handle both bytes and memoryview types
|
||||||
|
body = file_content_response.body
|
||||||
|
if isinstance(body, memoryview):
|
||||||
|
body = bytes(body)
|
||||||
|
file_content = body.decode("utf-8")
|
||||||
for line_num, line in enumerate(file_content.strip().split("\n"), 1):
|
for line_num, line in enumerate(file_content.strip().split("\n"), 1):
|
||||||
if line.strip(): # skip empty lines
|
if line.strip(): # skip empty lines
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -30,14 +30,13 @@ class MongoDBKVStoreImpl(KVStore):
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
try:
|
try:
|
||||||
conn_creds = {
|
# Pass parameters explicitly to satisfy mypy - AsyncMongoClient doesn't accept **dict
|
||||||
"host": self.config.host,
|
self.conn = AsyncMongoClient(
|
||||||
"port": self.config.port,
|
host=self.config.host if self.config.host is not None else None,
|
||||||
"username": self.config.user,
|
port=self.config.port if self.config.port is not None else None,
|
||||||
"password": self.config.password,
|
username=self.config.user if self.config.user is not None else None,
|
||||||
}
|
password=self.config.password if self.config.password is not None else None,
|
||||||
conn_creds = {k: v for k, v in conn_creds.items() if v is not None}
|
)
|
||||||
self.conn = AsyncMongoClient(**conn_creds)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception("Could not connect to MongoDB database server")
|
log.exception("Could not connect to MongoDB database server")
|
||||||
raise RuntimeError("Could not connect to MongoDB database server") from e
|
raise RuntimeError("Could not connect to MongoDB database server") from e
|
||||||
|
|
@ -79,4 +78,8 @@ class MongoDBKVStoreImpl(KVStore):
|
||||||
end_key = self._namespaced_key(end_key)
|
end_key = self._namespaced_key(end_key)
|
||||||
query = {"key": {"$gte": start_key, "$lt": end_key}}
|
query = {"key": {"$gte": start_key, "$lt": end_key}}
|
||||||
cursor = self.collection.find(query, {"key": 1, "_id": 0}).sort("key", 1)
|
cursor = self.collection.find(query, {"key": 1, "_id": 0}).sort("key", 1)
|
||||||
return [doc["key"] for doc in cursor]
|
# AsyncCursor requires async iteration
|
||||||
|
result = []
|
||||||
|
async for doc in cursor:
|
||||||
|
result.append(doc["key"])
|
||||||
|
return result
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue