test: add tests for github copilot

This commit is contained in:
Son H. Nguyen 2025-02-26 23:28:16 +07:00
parent e394d45513
commit 00d2d90535
7 changed files with 282 additions and 62 deletions

View file

@ -23,16 +23,23 @@ from .constants import (
APIKeyExpiredError,
)
class Authenticator:
def __init__(self) -> None:
# Token storage paths
self.token_dir = os.getenv("GITHUB_COPILOT_TOKEN_DIR", os.path.expanduser("~/.config/litellm/github_copilot"))
self.access_token_file = os.path.join(self.token_dir, os.getenv("GITHUB_COPILOT_ACCESS_TOKEN_FILE", "access-token"))
self.api_key_file = os.path.join(self.token_dir, os.getenv("GITHUB_COPILOT_API_KEY_FILE", "api-key.json"))
self.token_dir = os.getenv(
"GITHUB_COPILOT_TOKEN_DIR",
os.path.expanduser("~/.config/litellm/github_copilot"),
)
self.access_token_file = os.path.join(
self.token_dir,
os.getenv("GITHUB_COPILOT_ACCESS_TOKEN_FILE", "access-token"),
)
self.api_key_file = os.path.join(
self.token_dir, os.getenv("GITHUB_COPILOT_API_KEY_FILE", "api-key.json")
)
self._ensure_token_dir()
self.get_access_token()
def get_access_token(self) -> str:
"""
Login to Copilot with retry 3 times
@ -42,12 +49,12 @@ class Authenticator:
"""
try:
with open(self.access_token_file, 'r') as f:
with open(self.access_token_file, "r") as f:
access_token = f.read().strip()
return access_token
except IOError:
verbose_logger.warning("Error loading access token from file")
for _ in range(3):
try:
access_token = self._login()
@ -55,7 +62,7 @@ class Authenticator:
continue
else:
try:
with open(self.access_token_file, 'w') as f:
with open(self.access_token_file, "w") as f:
f.write(access_token)
except IOError:
verbose_logger.error("Error saving access token to file")
@ -66,10 +73,10 @@ class Authenticator:
def get_api_key(self) -> str:
"""Get the API key"""
try:
with open(self.api_key_file, 'r') as f:
with open(self.api_key_file, "r") as f:
api_key_info = json.load(f)
if api_key_info.get('expires_at') > datetime.now().timestamp():
return api_key_info.get('token')
if api_key_info.get("expires_at") > datetime.now().timestamp():
return api_key_info.get("token")
else:
raise APIKeyExpiredError("API key expired")
except IOError:
@ -78,17 +85,17 @@ class Authenticator:
verbose_logger.warning("Error reading API key from file")
except APIKeyExpiredError:
verbose_logger.warning("API key expired")
try:
api_key_info = self._refresh_api_key()
with open(self.api_key_file, 'w') as f:
with open(self.api_key_file, "w") as f:
json.dump(api_key_info, f)
except IOError:
verbose_logger.error("Error saving API key to file")
except RefreshAPIKeyError:
raise GetAPIKeyError("Failed to refresh API key")
return api_key_info.get('token')
return api_key_info.get("token")
def _refresh_api_key(self) -> dict:
"""
@ -100,10 +107,10 @@ class Authenticator:
access_token = self.get_access_token()
headers = {
'authorization': f'token {access_token}',
'editor-version': 'vscode/1.85.1',
'editor-plugin-version': 'copilot/1.155.0',
'user-agent': 'GithubCopilot/1.155.0'
"authorization": f"token {access_token}",
"editor-version": "vscode/1.85.1",
"editor-plugin-version": "copilot/1.155.0",
"user-agent": "GithubCopilot/1.155.0",
}
max_retries = 3
@ -111,20 +118,18 @@ class Authenticator:
try:
sync_client = _get_httpx_client()
response = sync_client.get(
'https://api.github.com/copilot_internal/v2/token',
headers=headers
"https://api.github.com/copilot_internal/v2/token", headers=headers
)
response.raise_for_status()
response_json = response.json()
if 'token' in response_json:
if "token" in response_json:
return response_json
except httpx.HTTPStatusError as e:
verbose_logger.error(f"Error refreshing API key: {str(e)}")
raise RefreshAPIKeyError("Failed to refresh API key")
def _ensure_token_dir(self) -> None:
"""Ensure the token directory exists"""
@ -143,23 +148,23 @@ class Authenticator:
sync_client = _get_httpx_client()
# Get device code
resp = sync_client.post(
'https://github.com/login/device/code',
"https://github.com/login/device/code",
headers={
'accept': 'application/json',
'editor-version': 'vscode/1.85.1',
'editor-plugin-version': 'copilot/1.155.0',
'content-type': 'application/json',
'user-agent': 'GithubCopilot/1.155.0',
'accept-encoding': 'gzip,deflate,br'
"accept": "application/json",
"editor-version": "vscode/1.85.1",
"editor-plugin-version": "copilot/1.155.0",
"content-type": "application/json",
"user-agent": "GithubCopilot/1.155.0",
"accept-encoding": "gzip,deflate,br",
},
json={"client_id": "Iv1.b507a08c87ecfe98", "scope": "read:user"}
json={"client_id": "Iv1.b507a08c87ecfe98", "scope": "read:user"},
)
resp.raise_for_status()
resp_json = resp.json()
device_code = resp_json.get('device_code')
user_code = resp_json.get('user_code')
verification_uri = resp_json.get('verification_uri')
device_code = resp_json.get("device_code")
user_code = resp_json.get("user_code")
verification_uri = resp_json.get("verification_uri")
if not all([device_code, user_code, verification_uri]):
verbose_logger.error("Response missing required fields")
@ -173,8 +178,10 @@ class Authenticator:
except RuntimeError as e:
verbose_logger.error(f"Error getting device code: {str(e)}")
raise GetDeviceCodeError("Failed to get device code")
print(f'Please visit {verification_uri} and enter code {user_code} to authenticate.')
print(
f"Please visit {verification_uri} and enter code {user_code} to authenticate."
)
while True:
time.sleep(5)
@ -182,27 +189,27 @@ class Authenticator:
# Get access token
try:
resp = sync_client.post(
'https://github.com/login/oauth/access_token',
"https://github.com/login/oauth/access_token",
headers={
'accept': 'application/json',
'editor-version': 'vscode/1.85.1',
'editor-plugin-version': 'copilot/1.155.0',
'content-type': 'application/json',
'user-agent': 'GithubCopilot/1.155.0',
'accept-encoding': 'gzip,deflate,br'
"accept": "application/json",
"editor-version": "vscode/1.85.1",
"editor-plugin-version": "copilot/1.155.0",
"content-type": "application/json",
"user-agent": "GithubCopilot/1.155.0",
"accept-encoding": "gzip,deflate,br",
},
json={
"client_id": "Iv1.b507a08c87ecfe98",
"device_code": device_code,
"grant_type": "urn:ietf:params:oauth:grant-type:device_code"
}
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
},
)
resp.raise_for_status()
resp_json = resp.json()
if "access_token" in resp_json:
verbose_logger.info("Authentication success!")
return resp_json['access_token']
return resp_json["access_token"]
else:
continue
except httpx.HTTPStatusError as e:
@ -211,7 +218,3 @@ class Authenticator:
except json.JSONDecodeError as e:
verbose_logger.error(f"Error decoding JSON response: {str(e)}")
raise GetAccessTokenError("Failed to get access token")