fix: fix linting errors

This commit is contained in:
Krrish Dholakia 2025-03-09 18:59:03 -07:00
parent 7ccccff39a
commit 3df4e28ae6

View file

@ -1,24 +1,20 @@
import os
import httpx
import json
import os
import time
from datetime import datetime
from typing import Optional, Dict, Any
from typing import Any, Dict, Optional
import httpx
from litellm._logging import verbose_logger
from litellm.caching import InMemoryCache
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
from .constants import (
APIKeyExpiredError,
GetAccessTokenError,
GetAPIKeyError,
GetDeviceCodeError,
RefreshAPIKeyError,
GetAPIKeyError,
APIKeyExpiredError,
)
# Constants
@ -61,7 +57,9 @@ class Authenticator:
if access_token:
return access_token
except IOError:
verbose_logger.warning("No existing access token found or error reading file")
verbose_logger.warning(
"No existing access token found or error reading file"
)
for attempt in range(3):
verbose_logger.debug(f"Access token acquisition attempt {attempt + 1}/3")
@ -108,7 +106,11 @@ class Authenticator:
api_key_info = self._refresh_api_key()
with open(self.api_key_file, "w") as f:
json.dump(api_key_info, f)
return api_key_info.get("token")
token = api_key_info.get("token")
if token:
return token
else:
raise GetAPIKeyError("API key response missing token")
except IOError as e:
verbose_logger.error(f"Error saving API key to file: {str(e)}")
raise GetAPIKeyError(f"Failed to save API key: {str(e)}")
@ -132,9 +134,7 @@ class Authenticator:
for attempt in range(max_retries):
try:
sync_client = _get_httpx_client()
response = sync_client.get(
GITHUB_API_KEY_URL, headers=headers
)
response = sync_client.get(GITHUB_API_KEY_URL, headers=headers)
response.raise_for_status()
response_json = response.json()
@ -142,9 +142,13 @@ class Authenticator:
if "token" in response_json:
return response_json
else:
verbose_logger.warning(f"API key response missing token: {response_json}")
verbose_logger.warning(
f"API key response missing token: {response_json}"
)
except httpx.HTTPStatusError as e:
verbose_logger.error(f"HTTP error refreshing API key (attempt {attempt+1}/{max_retries}): {str(e)}")
verbose_logger.error(
f"HTTP error refreshing API key (attempt {attempt+1}/{max_retries}): {str(e)}"
)
except Exception as e:
verbose_logger.error(f"Unexpected error refreshing API key: {str(e)}")
@ -158,36 +162,36 @@ class Authenticator:
def _get_github_headers(self, access_token: Optional[str] = None) -> Dict[str, str]:
"""
Generate standard GitHub headers for API requests.
Args:
access_token: Optional access token to include in the headers.
Returns:
Dict[str, str]: Headers for GitHub API requests.
"""
headers = {
"accept": "application/json",
"editor-version": "vscode/1.85.1",
"editor-plugin-version": "copilot/1.155.0",
"editor-plugin-version": "copilot/1.155.0",
"user-agent": "GithubCopilot/1.155.0",
"accept-encoding": "gzip,deflate,br",
}
if access_token:
headers["authorization"] = f"token {access_token}"
if "content-type" not in headers:
headers["content-type"] = "application/json"
return headers
def _get_device_code(self) -> Dict[str, str]:
"""
Get a device code for GitHub authentication.
Returns:
Dict[str, str]: Device code information.
Raises:
GetDeviceCodeError: If unable to get a device code.
"""
@ -205,9 +209,9 @@ class Authenticator:
if not all(field in resp_json for field in required_fields):
verbose_logger.error(f"Response missing required fields: {resp_json}")
raise GetDeviceCodeError("Response missing required fields")
return resp_json
except httpx.HttpError as e:
except httpx.HTTPStatusError as e:
verbose_logger.error(f"HTTP error getting device code: {str(e)}")
raise GetDeviceCodeError(f"Failed to get device code: {str(e)}")
except json.JSONDecodeError as e:
@ -220,19 +224,19 @@ class Authenticator:
def _poll_for_access_token(self, device_code: str) -> str:
"""
Poll for an access token after user authentication.
Args:
device_code: The device code to use for polling.
Returns:
str: The access token.
Raises:
GetAccessTokenError: If unable to get an access token.
"""
sync_client = _get_httpx_client()
max_attempts = 12 # 1 minute (12 * 5 seconds)
for attempt in range(max_attempts):
try:
resp = sync_client.post(
@ -250,8 +254,13 @@ class Authenticator:
if "access_token" in resp_json:
verbose_logger.info("Authentication successful!")
return resp_json["access_token"]
elif "error" in resp_json and resp_json.get("error") == "authorization_pending":
verbose_logger.debug(f"Authorization pending (attempt {attempt+1}/{max_attempts})")
elif (
"error" in resp_json
and resp_json.get("error") == "authorization_pending"
):
verbose_logger.debug(
f"Authorization pending (attempt {attempt+1}/{max_attempts})"
)
else:
verbose_logger.warning(f"Unexpected response: {resp_json}")
except httpx.HTTPStatusError as e:
@ -259,13 +268,17 @@ class Authenticator:
raise GetAccessTokenError(f"Failed to get access token: {str(e)}")
except json.JSONDecodeError as e:
verbose_logger.error(f"Error decoding JSON response: {str(e)}")
raise GetAccessTokenError(f"Failed to decode access token response: {str(e)}")
raise GetAccessTokenError(
f"Failed to decode access token response: {str(e)}"
)
except Exception as e:
verbose_logger.error(f"Unexpected error polling for access token: {str(e)}")
verbose_logger.error(
f"Unexpected error polling for access token: {str(e)}"
)
raise GetAccessTokenError(f"Failed to get access token: {str(e)}")
time.sleep(5)
raise GetAccessTokenError("Timed out waiting for user to authorize the device")
def _login(self) -> str:
@ -274,19 +287,19 @@ class Authenticator:
Returns:
str: The GitHub access token.
Raises:
GetDeviceCodeError: If unable to get a device code.
GetAccessTokenError: If unable to get an access token.
"""
device_code_info = self._get_device_code()
device_code = device_code_info["device_code"]
user_code = device_code_info["user_code"]
verification_uri = device_code_info["verification_uri"]
print(
verbose_logger.info(
f"Please visit {verification_uri} and enter code {user_code} to authenticate."
)
return self._poll_for_access_token(device_code)