test: add tests for github copilot

This commit is contained in:
Son H. Nguyen 2025-02-26 23:28:16 +07:00
parent e394d45513
commit 00d2d90535
7 changed files with 282 additions and 62 deletions

3
.gitignore vendored
View file

@ -78,4 +78,5 @@ litellm/proxy/_experimental/out/model_hub.html
.mypy_cache/* .mypy_cache/*
litellm/proxy/application.log litellm/proxy/application.log
**/__pycache__ **/__pycache__/**
**/*.pyc

View file

@ -1,4 +1,11 @@
from .constants import GITHUB_COPILOT_API_BASE, CHAT_COMPLETION_ENDPOINT, GITHUB_COPILOT_MODEL, GetAccessTokenError, GetAPIKeyError, RefreshAPIKeyError from .constants import (
GITHUB_COPILOT_API_BASE,
CHAT_COMPLETION_ENDPOINT,
GITHUB_COPILOT_MODEL,
GetAccessTokenError,
GetAPIKeyError,
RefreshAPIKeyError,
)
__all__ = [ __all__ = [
"GITHUB_COPILOT_API_BASE", "GITHUB_COPILOT_API_BASE",
@ -6,5 +13,5 @@ __all__ = [
"GITHUB_COPILOT_MODEL", "GITHUB_COPILOT_MODEL",
"GetAccessTokenError", "GetAccessTokenError",
"GetAPIKeyError", "GetAPIKeyError",
"RefreshAPIKeyError" "RefreshAPIKeyError",
] ]

View file

@ -23,16 +23,23 @@ from .constants import (
APIKeyExpiredError, APIKeyExpiredError,
) )
class Authenticator: class Authenticator:
def __init__(self) -> None: def __init__(self) -> None:
# Token storage paths # Token storage paths
self.token_dir = os.getenv("GITHUB_COPILOT_TOKEN_DIR", os.path.expanduser("~/.config/litellm/github_copilot")) self.token_dir = os.getenv(
self.access_token_file = os.path.join(self.token_dir, os.getenv("GITHUB_COPILOT_ACCESS_TOKEN_FILE", "access-token")) "GITHUB_COPILOT_TOKEN_DIR",
self.api_key_file = os.path.join(self.token_dir, os.getenv("GITHUB_COPILOT_API_KEY_FILE", "api-key.json")) os.path.expanduser("~/.config/litellm/github_copilot"),
)
self.access_token_file = os.path.join(
self.token_dir,
os.getenv("GITHUB_COPILOT_ACCESS_TOKEN_FILE", "access-token"),
)
self.api_key_file = os.path.join(
self.token_dir, os.getenv("GITHUB_COPILOT_API_KEY_FILE", "api-key.json")
)
self._ensure_token_dir() self._ensure_token_dir()
self.get_access_token()
def get_access_token(self) -> str: def get_access_token(self) -> str:
""" """
Login to Copilot with retry 3 times Login to Copilot with retry 3 times
@ -42,7 +49,7 @@ class Authenticator:
""" """
try: try:
with open(self.access_token_file, 'r') as f: with open(self.access_token_file, "r") as f:
access_token = f.read().strip() access_token = f.read().strip()
return access_token return access_token
except IOError: except IOError:
@ -55,7 +62,7 @@ class Authenticator:
continue continue
else: else:
try: try:
with open(self.access_token_file, 'w') as f: with open(self.access_token_file, "w") as f:
f.write(access_token) f.write(access_token)
except IOError: except IOError:
verbose_logger.error("Error saving access token to file") verbose_logger.error("Error saving access token to file")
@ -66,10 +73,10 @@ class Authenticator:
def get_api_key(self) -> str: def get_api_key(self) -> str:
"""Get the API key""" """Get the API key"""
try: try:
with open(self.api_key_file, 'r') as f: with open(self.api_key_file, "r") as f:
api_key_info = json.load(f) api_key_info = json.load(f)
if api_key_info.get('expires_at') > datetime.now().timestamp(): if api_key_info.get("expires_at") > datetime.now().timestamp():
return api_key_info.get('token') return api_key_info.get("token")
else: else:
raise APIKeyExpiredError("API key expired") raise APIKeyExpiredError("API key expired")
except IOError: except IOError:
@ -81,14 +88,14 @@ class Authenticator:
try: try:
api_key_info = self._refresh_api_key() api_key_info = self._refresh_api_key()
with open(self.api_key_file, 'w') as f: with open(self.api_key_file, "w") as f:
json.dump(api_key_info, f) json.dump(api_key_info, f)
except IOError: except IOError:
verbose_logger.error("Error saving API key to file") verbose_logger.error("Error saving API key to file")
except RefreshAPIKeyError: except RefreshAPIKeyError:
raise GetAPIKeyError("Failed to refresh API key") raise GetAPIKeyError("Failed to refresh API key")
return api_key_info.get('token') return api_key_info.get("token")
def _refresh_api_key(self) -> dict: def _refresh_api_key(self) -> dict:
""" """
@ -100,10 +107,10 @@ class Authenticator:
access_token = self.get_access_token() access_token = self.get_access_token()
headers = { headers = {
'authorization': f'token {access_token}', "authorization": f"token {access_token}",
'editor-version': 'vscode/1.85.1', "editor-version": "vscode/1.85.1",
'editor-plugin-version': 'copilot/1.155.0', "editor-plugin-version": "copilot/1.155.0",
'user-agent': 'GithubCopilot/1.155.0' "user-agent": "GithubCopilot/1.155.0",
} }
max_retries = 3 max_retries = 3
@ -111,21 +118,19 @@ class Authenticator:
try: try:
sync_client = _get_httpx_client() sync_client = _get_httpx_client()
response = sync_client.get( response = sync_client.get(
'https://api.github.com/copilot_internal/v2/token', "https://api.github.com/copilot_internal/v2/token", headers=headers
headers=headers
) )
response.raise_for_status() response.raise_for_status()
response_json = response.json() response_json = response.json()
if 'token' in response_json: if "token" in response_json:
return response_json return response_json
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
verbose_logger.error(f"Error refreshing API key: {str(e)}") verbose_logger.error(f"Error refreshing API key: {str(e)}")
raise RefreshAPIKeyError("Failed to refresh API key") raise RefreshAPIKeyError("Failed to refresh API key")
def _ensure_token_dir(self) -> None: def _ensure_token_dir(self) -> None:
"""Ensure the token directory exists""" """Ensure the token directory exists"""
if not os.path.exists(self.token_dir): if not os.path.exists(self.token_dir):
@ -143,23 +148,23 @@ class Authenticator:
sync_client = _get_httpx_client() sync_client = _get_httpx_client()
# Get device code # Get device code
resp = sync_client.post( resp = sync_client.post(
'https://github.com/login/device/code', "https://github.com/login/device/code",
headers={ headers={
'accept': 'application/json', "accept": "application/json",
'editor-version': 'vscode/1.85.1', "editor-version": "vscode/1.85.1",
'editor-plugin-version': 'copilot/1.155.0', "editor-plugin-version": "copilot/1.155.0",
'content-type': 'application/json', "content-type": "application/json",
'user-agent': 'GithubCopilot/1.155.0', "user-agent": "GithubCopilot/1.155.0",
'accept-encoding': 'gzip,deflate,br' "accept-encoding": "gzip,deflate,br",
}, },
json={"client_id": "Iv1.b507a08c87ecfe98", "scope": "read:user"} json={"client_id": "Iv1.b507a08c87ecfe98", "scope": "read:user"},
) )
resp.raise_for_status() resp.raise_for_status()
resp_json = resp.json() resp_json = resp.json()
device_code = resp_json.get('device_code') device_code = resp_json.get("device_code")
user_code = resp_json.get('user_code') user_code = resp_json.get("user_code")
verification_uri = resp_json.get('verification_uri') verification_uri = resp_json.get("verification_uri")
if not all([device_code, user_code, verification_uri]): if not all([device_code, user_code, verification_uri]):
verbose_logger.error("Response missing required fields") verbose_logger.error("Response missing required fields")
@ -174,7 +179,9 @@ class Authenticator:
verbose_logger.error(f"Error getting device code: {str(e)}") verbose_logger.error(f"Error getting device code: {str(e)}")
raise GetDeviceCodeError("Failed to get device code") raise GetDeviceCodeError("Failed to get device code")
print(f'Please visit {verification_uri} and enter code {user_code} to authenticate.') print(
f"Please visit {verification_uri} and enter code {user_code} to authenticate."
)
while True: while True:
time.sleep(5) time.sleep(5)
@ -182,27 +189,27 @@ class Authenticator:
# Get access token # Get access token
try: try:
resp = sync_client.post( resp = sync_client.post(
'https://github.com/login/oauth/access_token', "https://github.com/login/oauth/access_token",
headers={ headers={
'accept': 'application/json', "accept": "application/json",
'editor-version': 'vscode/1.85.1', "editor-version": "vscode/1.85.1",
'editor-plugin-version': 'copilot/1.155.0', "editor-plugin-version": "copilot/1.155.0",
'content-type': 'application/json', "content-type": "application/json",
'user-agent': 'GithubCopilot/1.155.0', "user-agent": "GithubCopilot/1.155.0",
'accept-encoding': 'gzip,deflate,br' "accept-encoding": "gzip,deflate,br",
}, },
json={ json={
"client_id": "Iv1.b507a08c87ecfe98", "client_id": "Iv1.b507a08c87ecfe98",
"device_code": device_code, "device_code": device_code,
"grant_type": "urn:ietf:params:oauth:grant-type:device_code" "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
} },
) )
resp.raise_for_status() resp.raise_for_status()
resp_json = resp.json() resp_json = resp.json()
if "access_token" in resp_json: if "access_token" in resp_json:
verbose_logger.info("Authentication success!") verbose_logger.info("Authentication success!")
return resp_json['access_token'] return resp_json["access_token"]
else: else:
continue continue
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
@ -211,7 +218,3 @@ class Authenticator:
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
verbose_logger.error(f"Error decoding JSON response: {str(e)}") verbose_logger.error(f"Error decoding JSON response: {str(e)}")
raise GetAccessTokenError("Failed to get access token") raise GetAccessTokenError("Failed to get access token")

View file

@ -4,6 +4,9 @@ from typing import Optional, Tuple
from litellm.llms.openai.openai import OpenAIConfig from litellm.llms.openai.openai import OpenAIConfig
from ..authenticator import Authenticator from ..authenticator import Authenticator
from ..constants import GetAPIKeyError
from litellm.exceptions import AuthenticationError
class GithubCopilotConfig(OpenAIConfig): class GithubCopilotConfig(OpenAIConfig):
def __init__( def __init__(
@ -13,7 +16,6 @@ class GithubCopilotConfig(OpenAIConfig):
custom_llm_provider: str = "openai", custom_llm_provider: str = "openai",
) -> None: ) -> None:
super().__init__() super().__init__()
self.authenticator = Authenticator() self.authenticator = Authenticator()
def _get_openai_compatible_provider_info( def _get_openai_compatible_provider_info(
@ -24,6 +26,12 @@ class GithubCopilotConfig(OpenAIConfig):
custom_llm_provider: str, custom_llm_provider: str,
) -> Tuple[Optional[str], Optional[str], str]: ) -> Tuple[Optional[str], Optional[str], str]:
api_base = "https://api.githubcopilot.com" api_base = "https://api.githubcopilot.com"
try:
dynamic_api_key = self.authenticator.get_api_key() dynamic_api_key = self.authenticator.get_api_key()
except GetAPIKeyError as e:
raise AuthenticationError(
model=model,
llm_provider=custom_llm_provider,
message=str(e),
)
return api_base, dynamic_api_key, custom_llm_provider return api_base, dynamic_api_key, custom_llm_provider

View file

@ -17,9 +17,22 @@ DEFAULT_HEADERS = {
"User-Agent": "litellm", "User-Agent": "litellm",
} }
class GetDeviceCodeError(Exception): pass
class GetAccessTokenError(Exception): pass
class APIKeyExpiredError(Exception): pass
class RefreshAPIKeyError(Exception): pass
class GetAPIKeyError(Exception): pass
class GetDeviceCodeError(Exception):
pass
class GetAccessTokenError(Exception):
pass
class APIKeyExpiredError(Exception):
pass
class RefreshAPIKeyError(Exception):
pass
class GetAPIKeyError(Exception):
pass

View file

@ -0,0 +1,188 @@
import pytest
import os
import sys
import json
import asyncio
from datetime import datetime, timedelta
from typing import AsyncGenerator
from unittest.mock import AsyncMock, patch, mock_open, MagicMock
sys.path.insert(0, os.path.abspath("../.."))
import httpx
import pytest
from respx import MockRouter
import litellm
from litellm import Choices, Message, ModelResponse, ModelResponse, Usage
from litellm import completion, acompletion
from litellm.llms.github_copilot.chat.transformation import GithubCopilotConfig
from litellm.llms.github_copilot.authenticator import Authenticator
from litellm.llms.github_copilot.constants import (
GetAccessTokenError,
GetDeviceCodeError,
RefreshAPIKeyError,
GetAPIKeyError,
APIKeyExpiredError,
)
from litellm.exceptions import AuthenticationError
# Import at the top to make the patch work correctly
import litellm.llms.github_copilot.chat.transformation
def test_github_copilot_config_get_openai_compatible_provider_info():
"""Test the GitHub Copilot configuration provider info retrieval."""
config = GithubCopilotConfig()
# Mock the authenticator to avoid actual API calls
mock_api_key = "gh.test-key-123456789"
config.authenticator = MagicMock()
config.authenticator.get_api_key.return_value = mock_api_key
# Test with default values
model = "github_copilot/gpt-4"
(
api_base,
dynamic_api_key,
custom_llm_provider,
) = config._get_openai_compatible_provider_info(
model=model,
api_base=None,
api_key=None,
custom_llm_provider="github_copilot",
)
assert api_base == "https://api.githubcopilot.com"
assert dynamic_api_key == mock_api_key
assert custom_llm_provider == "github_copilot"
# Test with authentication failure
config.authenticator.get_api_key.side_effect = GetAPIKeyError(
"Failed to get API key"
)
with pytest.raises(AuthenticationError) as excinfo:
config._get_openai_compatible_provider_info(
model=model,
api_base=None,
api_key=None,
custom_llm_provider="github_copilot",
)
assert "Failed to get API key" in str(excinfo.value)
@patch("litellm.litellm_core_utils.get_llm_provider_logic.get_llm_provider")
@patch("litellm.llms.openai.openai.OpenAIChatCompletion.completion")
def test_completion_github_copilot(mock_completion, mock_get_provider):
"""Test the completion function with GitHub Copilot provider."""
# Mock completion response
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Hello, I'm GitHub Copilot!"
mock_completion.return_value = mock_response
# Test non-streaming completion
messages = [
{"role": "system", "content": "You're GitHub Copilot, an AI assistant."},
{"role": "user", "content": "Hello, who are you?"},
]
# Create a properly formatted headers dictionary
headers = {
"editor-version": "Neovim/0.9.0",
"Copilot-Integration-Id": "vscode-chat",
}
# Patch the get_llm_provider function instead of the config method
# Make it return the expected tuple directly
mock_get_provider.return_value = (
"gpt-4",
"github_copilot",
"gh.test-key-123456789",
"https://api.githubcopilot.com",
)
response = completion(
model="github_copilot/gpt-4",
messages=messages,
extra_headers=headers,
)
assert response is not None
# Verify the completion call was made with the expected params
mock_completion.assert_called_once()
args, kwargs = mock_completion.call_args
# Check that the proper authorization header is set
assert "headers" in kwargs
# Check that the model name is correctly formatted
assert (
kwargs.get("model") == "gpt-4"
) # Model name should be without provider prefix
assert kwargs.get("messages") == messages
@patch("litellm.llms.github_copilot.authenticator.Authenticator.get_api_key")
def test_authenticator_get_api_key(mock_get_api_key):
"""Test the Authenticator's get_api_key method."""
from litellm.llms.github_copilot.authenticator import Authenticator
# Test successful API key retrieval
mock_get_api_key.return_value = "gh.test-key-123456789"
authenticator = Authenticator()
api_key = authenticator.get_api_key()
assert api_key == "gh.test-key-123456789"
mock_get_api_key.assert_called_once()
# Test API key retrieval failure
mock_get_api_key.reset_mock()
mock_get_api_key.side_effect = GetAPIKeyError("Failed to get API key")
authenticator = Authenticator()
with pytest.raises(GetAPIKeyError) as excinfo:
authenticator.get_api_key()
assert "Failed to get API key" in str(excinfo.value)
def test_completion_github_copilot(stream=False):
try:
litellm.set_verbose = True
messages = [
{"role": "system", "content": "You are an AI programming assistant."},
{
"role": "user",
"content": "Write a Python function to calculate fibonacci numbers",
},
]
extra_headers = {
"editor-version": "Neovim/0.9.0",
"Copilot-Integration-Id": "vscode-chat",
}
response = completion(
model="github_copilot/gpt-4",
messages=messages,
stream=stream,
extra_headers=extra_headers,
)
print(response)
if stream is True:
for chunk in response:
print(chunk)
assert chunk is not None
assert isinstance(chunk, litellm.ModelResponseStream)
assert isinstance(chunk.choices[0], litellm.utils.StreamingChoices)
else:
assert response is not None
assert isinstance(response, litellm.ModelResponse)
assert response.choices[0].message.content is not None
except Exception as e:
pytest.fail(f"Error occurred: {e}")