This commit is contained in:
Krish Dholakia 2025-04-24 09:53:52 +02:00 committed by GitHub
commit db4043d762
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 904 additions and 1 deletions

View file

@ -1023,6 +1023,7 @@ from .llms.azure.chat.o_series_transformation import AzureOpenAIO1Config
from .llms.watsonx.completion.transformation import IBMWatsonXAIConfig
from .llms.watsonx.chat.transformation import IBMWatsonXChatConfig
from .llms.watsonx.embed.transformation import IBMWatsonXEmbeddingConfig
from .llms.github_copilot.chat.transformation import GithubCopilotConfig
from .main import * # type: ignore
from .integrations import *
from .exceptions import (

View file

@ -157,6 +157,7 @@ LITELLM_CHAT_PROVIDERS = [
"hosted_vllm",
"lm_studio",
"galadriel",
"github_copilot", # GitHub Copilot Chat API
]
@ -216,7 +217,7 @@ openai_compatible_endpoints: List = [
"https://api.friendli.ai/serverless/v1",
"api.sambanova.ai/v1",
"api.x.ai/v1",
"api.galadriel.ai/v1",
"api.galadriel.ai/v1"
]
@ -246,6 +247,7 @@ openai_compatible_providers: List = [
"hosted_vllm",
"lm_studio",
"galadriel",
"github_copilot", # GitHub Copilot Chat API
]
openai_text_completion_compatible_providers: List = (
[ # providers that support `/v1/completions`

View file

@ -573,6 +573,14 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
or "https://api.galadriel.com/v1"
) # type: ignore
dynamic_api_key = api_key or get_secret_str("GALADRIEL_API_KEY")
elif custom_llm_provider == "github_copilot":
(
api_base,
dynamic_api_key,
custom_llm_provider,
) = litellm.GithubCopilotConfig()._get_openai_compatible_provider_info(
model, api_base, api_key, custom_llm_provider
)
elif custom_llm_provider == "snowflake":
api_base = (
api_base

View file

@ -0,0 +1,17 @@
from .constants import (
GITHUB_COPILOT_API_BASE,
CHAT_COMPLETION_ENDPOINT,
GITHUB_COPILOT_MODEL,
GetAccessTokenError,
GetAPIKeyError,
RefreshAPIKeyError,
)
__all__ = [
"GITHUB_COPILOT_API_BASE",
"CHAT_COMPLETION_ENDPOINT",
"GITHUB_COPILOT_MODEL",
"GetAccessTokenError",
"GetAPIKeyError",
"RefreshAPIKeyError",
]

View file

@ -0,0 +1,305 @@
import json
import os
import time
from datetime import datetime
from typing import Any, Dict, Optional
import httpx
from litellm._logging import verbose_logger
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
from .constants import (
APIKeyExpiredError,
GetAccessTokenError,
GetAPIKeyError,
GetDeviceCodeError,
RefreshAPIKeyError,
)
# Constants
GITHUB_CLIENT_ID = "Iv1.b507a08c87ecfe98"
GITHUB_DEVICE_CODE_URL = "https://github.com/login/device/code"
GITHUB_ACCESS_TOKEN_URL = "https://github.com/login/oauth/access_token"
GITHUB_API_KEY_URL = "https://api.github.com/copilot_internal/v2/token"
class Authenticator:
def __init__(self) -> None:
"""Initialize the GitHub Copilot authenticator with configurable token paths."""
# Token storage paths
self.token_dir = os.getenv(
"GITHUB_COPILOT_TOKEN_DIR",
os.path.expanduser("~/.config/litellm/github_copilot"),
)
self.access_token_file = os.path.join(
self.token_dir,
os.getenv("GITHUB_COPILOT_ACCESS_TOKEN_FILE", "access-token"),
)
self.api_key_file = os.path.join(
self.token_dir, os.getenv("GITHUB_COPILOT_API_KEY_FILE", "api-key.json")
)
self._ensure_token_dir()
def get_access_token(self) -> str:
"""
Login to Copilot with retry 3 times.
Returns:
str: The GitHub access token.
Raises:
GetAccessTokenError: If unable to obtain an access token after retries.
"""
try:
with open(self.access_token_file, "r") as f:
access_token = f.read().strip()
if access_token:
return access_token
except IOError:
verbose_logger.warning(
"No existing access token found or error reading file"
)
for attempt in range(3):
verbose_logger.debug(f"Access token acquisition attempt {attempt + 1}/3")
try:
access_token = self._login()
try:
with open(self.access_token_file, "w") as f:
f.write(access_token)
except IOError:
verbose_logger.error("Error saving access token to file")
return access_token
except (GetDeviceCodeError, GetAccessTokenError, RefreshAPIKeyError) as e:
verbose_logger.warning(f"Failed attempt {attempt + 1}: {str(e)}")
continue
raise GetAccessTokenError("Failed to get access token after 3 attempts")
def get_api_key(self) -> str:
"""
Get the API key, refreshing if necessary.
Returns:
str: The GitHub Copilot API key.
Raises:
GetAPIKeyError: If unable to obtain an API key.
"""
try:
with open(self.api_key_file, "r") as f:
api_key_info = json.load(f)
if api_key_info.get("expires_at", 0) > datetime.now().timestamp():
return api_key_info.get("token")
else:
verbose_logger.warning("API key expired, refreshing")
raise APIKeyExpiredError("API key expired")
except IOError:
verbose_logger.warning("No API key file found or error opening file")
except (json.JSONDecodeError, KeyError) as e:
verbose_logger.warning(f"Error reading API key from file: {str(e)}")
except APIKeyExpiredError:
pass # Already logged in the try block
try:
api_key_info = self._refresh_api_key()
with open(self.api_key_file, "w") as f:
json.dump(api_key_info, f)
token = api_key_info.get("token")
if token:
return token
else:
raise GetAPIKeyError("API key response missing token")
except IOError as e:
verbose_logger.error(f"Error saving API key to file: {str(e)}")
raise GetAPIKeyError(f"Failed to save API key: {str(e)}")
except RefreshAPIKeyError as e:
raise GetAPIKeyError(f"Failed to refresh API key: {str(e)}")
def _refresh_api_key(self) -> Dict[str, Any]:
"""
Refresh the API key using the access token.
Returns:
Dict[str, Any]: The API key information including token and expiration.
Raises:
RefreshAPIKeyError: If unable to refresh the API key.
"""
access_token = self.get_access_token()
headers = self._get_github_headers(access_token)
max_retries = 3
for attempt in range(max_retries):
try:
sync_client = _get_httpx_client()
response = sync_client.get(GITHUB_API_KEY_URL, headers=headers)
response.raise_for_status()
response_json = response.json()
if "token" in response_json:
return response_json
else:
verbose_logger.warning(
f"API key response missing token: {response_json}"
)
except httpx.HTTPStatusError as e:
verbose_logger.error(
f"HTTP error refreshing API key (attempt {attempt+1}/{max_retries}): {str(e)}"
)
except Exception as e:
verbose_logger.error(f"Unexpected error refreshing API key: {str(e)}")
raise RefreshAPIKeyError("Failed to refresh API key after maximum retries")
def _ensure_token_dir(self) -> None:
"""Ensure the token directory exists."""
if not os.path.exists(self.token_dir):
os.makedirs(self.token_dir, exist_ok=True)
def _get_github_headers(self, access_token: Optional[str] = None) -> Dict[str, str]:
"""
Generate standard GitHub headers for API requests.
Args:
access_token: Optional access token to include in the headers.
Returns:
Dict[str, str]: Headers for GitHub API requests.
"""
headers = {
"accept": "application/json",
"editor-version": "vscode/1.85.1",
"editor-plugin-version": "copilot/1.155.0",
"user-agent": "GithubCopilot/1.155.0",
"accept-encoding": "gzip,deflate,br",
}
if access_token:
headers["authorization"] = f"token {access_token}"
if "content-type" not in headers:
headers["content-type"] = "application/json"
return headers
def _get_device_code(self) -> Dict[str, str]:
"""
Get a device code for GitHub authentication.
Returns:
Dict[str, str]: Device code information.
Raises:
GetDeviceCodeError: If unable to get a device code.
"""
try:
sync_client = _get_httpx_client()
resp = sync_client.post(
GITHUB_DEVICE_CODE_URL,
headers=self._get_github_headers(),
json={"client_id": GITHUB_CLIENT_ID, "scope": "read:user"},
)
resp.raise_for_status()
resp_json = resp.json()
required_fields = ["device_code", "user_code", "verification_uri"]
if not all(field in resp_json for field in required_fields):
verbose_logger.error(f"Response missing required fields: {resp_json}")
raise GetDeviceCodeError("Response missing required fields")
return resp_json
except httpx.HTTPStatusError as e:
verbose_logger.error(f"HTTP error getting device code: {str(e)}")
raise GetDeviceCodeError(f"Failed to get device code: {str(e)}")
except json.JSONDecodeError as e:
verbose_logger.error(f"Error decoding JSON response: {str(e)}")
raise GetDeviceCodeError(f"Failed to decode device code response: {str(e)}")
except Exception as e:
verbose_logger.error(f"Unexpected error getting device code: {str(e)}")
raise GetDeviceCodeError(f"Failed to get device code: {str(e)}")
def _poll_for_access_token(self, device_code: str) -> str:
"""
Poll for an access token after user authentication.
Args:
device_code: The device code to use for polling.
Returns:
str: The access token.
Raises:
GetAccessTokenError: If unable to get an access token.
"""
sync_client = _get_httpx_client()
max_attempts = 12 # 1 minute (12 * 5 seconds)
for attempt in range(max_attempts):
try:
resp = sync_client.post(
GITHUB_ACCESS_TOKEN_URL,
headers=self._get_github_headers(),
json={
"client_id": GITHUB_CLIENT_ID,
"device_code": device_code,
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
},
)
resp.raise_for_status()
resp_json = resp.json()
if "access_token" in resp_json:
verbose_logger.info("Authentication successful!")
return resp_json["access_token"]
elif (
"error" in resp_json
and resp_json.get("error") == "authorization_pending"
):
verbose_logger.debug(
f"Authorization pending (attempt {attempt+1}/{max_attempts})"
)
else:
verbose_logger.warning(f"Unexpected response: {resp_json}")
except httpx.HTTPStatusError as e:
verbose_logger.error(f"HTTP error polling for access token: {str(e)}")
raise GetAccessTokenError(f"Failed to get access token: {str(e)}")
except json.JSONDecodeError as e:
verbose_logger.error(f"Error decoding JSON response: {str(e)}")
raise GetAccessTokenError(
f"Failed to decode access token response: {str(e)}"
)
except Exception as e:
verbose_logger.error(
f"Unexpected error polling for access token: {str(e)}"
)
raise GetAccessTokenError(f"Failed to get access token: {str(e)}")
time.sleep(5)
raise GetAccessTokenError("Timed out waiting for user to authorize the device")
def _login(self) -> str:
"""
Login to GitHub Copilot using device code flow.
Returns:
str: The GitHub access token.
Raises:
GetDeviceCodeError: If unable to get a device code.
GetAccessTokenError: If unable to get an access token.
"""
device_code_info = self._get_device_code()
device_code = device_code_info["device_code"]
user_code = device_code_info["user_code"]
verification_uri = device_code_info["verification_uri"]
print(
f"Please visit {verification_uri} and enter code {user_code} to authenticate."
)
return self._poll_for_access_token(device_code)

View file

@ -0,0 +1,37 @@
from typing import Optional, Tuple
from litellm.llms.openai.openai import OpenAIConfig
from ..authenticator import Authenticator
from ..constants import GetAPIKeyError
from litellm.exceptions import AuthenticationError
class GithubCopilotConfig(OpenAIConfig):
def __init__(
self,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
custom_llm_provider: str = "openai",
) -> None:
super().__init__()
self.authenticator = Authenticator()
def _get_openai_compatible_provider_info(
self,
model: str,
api_base: Optional[str],
api_key: Optional[str],
custom_llm_provider: str,
) -> Tuple[Optional[str], Optional[str], str]:
api_base = "https://api.githubcopilot.com"
try:
dynamic_api_key = self.authenticator.get_api_key()
except GetAPIKeyError as e:
raise AuthenticationError(
model=model,
llm_provider=custom_llm_provider,
message=str(e),
)
return api_base, dynamic_api_key, custom_llm_provider

View file

@ -0,0 +1,38 @@
"""
Constants for Copilot integration
"""
import os
# Copilot API endpoints
GITHUB_COPILOT_API_BASE = "https://api.github.com/copilot/v1"
CHAT_COMPLETION_ENDPOINT = "/chat/completions"
# Model names
GITHUB_COPILOT_MODEL = "gpt-4o" # The model identifier for Copilot
# Request headers
DEFAULT_HEADERS = {
"Content-Type": "application/json",
"Accept": "application/json",
"User-Agent": "litellm",
}
class GetDeviceCodeError(Exception):
pass
class GetAccessTokenError(Exception):
pass
class APIKeyExpiredError(Exception):
pass
class RefreshAPIKeyError(Exception):
pass
class GetAPIKeyError(Exception):
pass

View file

@ -11544,6 +11544,89 @@
"litellm_provider": "jina_ai",
"mode": "rerank"
},
"github_copilot/gpt-4o": {
"max_tokens": 128000,
"max_input_tokens": 64000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "github_copilot",
"mode": "chat",
"supports_vision": true,
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_system_messages": true
},
"github_copilot/o1": {
"max_tokens": 200000,
"max_input_tokens": 20000,
"max_output_tokens": 20000,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "github_copilot",
"mode": "chat",
"supports_function_calling": true,
"supports_response_schema": true,
"supports_system_messages": true
},
"github_copilot/o3-mini": {
"max_tokens": 200000,
"max_input_tokens": 64000,
"max_output_tokens": 100000,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "github_copilot",
"mode": "chat",
"supports_function_calling": true,
"supports_response_schema": true,
"supports_system_messages": true
},
"github_copilot/claude-3.5-sonnet": {
"max_tokens": 90000,
"max_input_tokens": 90000,
"max_output_tokens": 8192,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "github_copilot",
"mode": "chat",
"supports_vision": true,
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_system_messages": true
},
"github_copilot/claude-3.7-sonnet": {
"max_tokens": 200000,
"max_input_tokens": 90000,
"max_output_tokens": 8192,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "github_copilot",
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_system_messages": true
},
"github_copilot/claude-3.7-sonnet-thought": {
"max_tokens": 200000,
"max_input_tokens": 90000,
"max_output_tokens": 8192,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "github_copilot",
"mode": "chat",
"supports_system_messages": true
},
"github_copilot/gemini-2.0-flash-001": {
"max_tokens": 1000000,
"max_input_tokens": 128000,
"max_output_tokens": 8192,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "github_copilot",
"mode": "chat",
"supports_vision": true,
"supports_system_messages": true
"snowflake/deepseek-r1": {
"max_tokens": 32768,
"max_input_tokens": 32768,

View file

@ -2101,6 +2101,7 @@ class LlmProviders(str, Enum):
HUMANLOOP = "humanloop"
TOPAZ = "topaz"
ASSEMBLYAI = "assemblyai"
GITHUB_COPILOT = "github_copilot"
SNOWFLAKE = "snowflake"

View file

@ -0,0 +1,232 @@
import pytest
import os
import sys
import json
import asyncio
from datetime import datetime, timedelta
from typing import AsyncGenerator
from unittest.mock import AsyncMock, patch, mock_open, MagicMock
sys.path.insert(0, os.path.abspath("../.."))
import httpx
import pytest
from respx import MockRouter
import litellm
from litellm import Choices, Message, ModelResponse, ModelResponse, Usage
from litellm import completion, acompletion
from litellm.llms.github_copilot.chat.transformation import GithubCopilotConfig
from litellm.llms.github_copilot.authenticator import Authenticator
from litellm.llms.github_copilot.constants import (
GetAccessTokenError,
GetDeviceCodeError,
RefreshAPIKeyError,
GetAPIKeyError,
APIKeyExpiredError,
)
from litellm.exceptions import AuthenticationError
# Import at the top to make the patch work correctly
import litellm.llms.github_copilot.chat.transformation
def test_github_copilot_config_get_openai_compatible_provider_info():
"""Test the GitHub Copilot configuration provider info retrieval."""
config = GithubCopilotConfig()
# Mock the authenticator to avoid actual API calls
mock_api_key = "gh.test-key-123456789"
config.authenticator = MagicMock()
config.authenticator.get_api_key.return_value = mock_api_key
# Test with default values
model = "github_copilot/gpt-4"
(
api_base,
dynamic_api_key,
custom_llm_provider,
) = config._get_openai_compatible_provider_info(
model=model,
api_base=None,
api_key=None,
custom_llm_provider="github_copilot",
)
assert api_base == "https://api.githubcopilot.com"
assert dynamic_api_key == mock_api_key
assert custom_llm_provider == "github_copilot"
# Test with authentication failure
config.authenticator.get_api_key.side_effect = GetAPIKeyError(
"Failed to get API key"
)
with pytest.raises(AuthenticationError) as excinfo:
config._get_openai_compatible_provider_info(
model=model,
api_base=None,
api_key=None,
custom_llm_provider="github_copilot",
)
assert "Failed to get API key" in str(excinfo.value)
@patch("litellm.main.get_llm_provider")
@patch("litellm.llms.openai.openai.OpenAIChatCompletion.completion")
def test_completion_github_copilot_mock_response(mock_completion, mock_get_llm_provider):
"""Test the completion function with GitHub Copilot provider."""
# Mock completion response
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Hello, I'm GitHub Copilot!"
mock_completion.return_value = mock_response
# Test non-streaming completion
messages = [
{"role": "system", "content": "You're GitHub Copilot, an AI assistant."},
{"role": "user", "content": "Hello, who are you?"},
]
# Create a properly formatted headers dictionary
headers = {
"editor-version": "Neovim/0.9.0",
"Copilot-Integration-Id": "vscode-chat",
}
# Patch the get_llm_provider function instead of the config method
# Make it return the expected tuple directly
mock_get_llm_provider.return_value = (
"gpt-4",
"github_copilot",
"gh.test-key-123456789",
"https://api.githubcopilot.com",
)
response = completion(
model="github_copilot/gpt-4",
messages=messages,
extra_headers=headers,
)
assert response is not None
# Verify the get_llm_provider call was made with the expected params
mock_get_llm_provider.assert_called_once()
args, kwargs = mock_get_llm_provider.call_args
assert kwargs.get("model") is "github_copilot/gpt-4"
assert kwargs.get("custom_llm_provider") is None
assert kwargs.get("api_key") is None
assert kwargs.get("api_base") is None
# Verify the completion call was made with the expected params
mock_completion.assert_called_once()
args, kwargs = mock_completion.call_args
# Check that the proper authorization header is set
assert "headers" in kwargs
# Check that the model name is correctly formatted
assert (
kwargs.get("model") == "gpt-4"
) # Model name should be without provider prefix
assert kwargs.get("messages") == messages
@patch("litellm.llms.github_copilot.authenticator.Authenticator.get_api_key")
def test_authenticator_get_api_key(mock_get_api_key):
"""Test the Authenticator's get_api_key method."""
from litellm.llms.github_copilot.authenticator import Authenticator
# Test successful API key retrieval
mock_get_api_key.return_value = "gh.test-key-123456789"
authenticator = Authenticator()
api_key = authenticator.get_api_key()
assert api_key == "gh.test-key-123456789"
mock_get_api_key.assert_called_once()
# Test API key retrieval failure
mock_get_api_key.reset_mock()
mock_get_api_key.side_effect = GetAPIKeyError("Failed to get API key")
authenticator = Authenticator()
with pytest.raises(GetAPIKeyError) as excinfo:
authenticator.get_api_key()
assert "Failed to get API key" in str(excinfo.value)
# def test_completion_github_copilot(stream=False):
# try:
# litellm.set_verbose = True
# messages = [
# {"role": "system", "content": "You are an AI programming assistant."},
# {
# "role": "user",
# "content": "Write a Python function to calculate fibonacci numbers",
# },
# ]
# extra_headers = {
# "editor-version": "Neovim/0.9.0",
# "Copilot-Integration-Id": "vscode-chat",
# }
# response = completion(
# model="github_copilot/gpt-4",
# messages=messages,
# stream=stream,
# extra_headers=extra_headers,
# )
# print(response)
# if stream is True:
# for chunk in response:
# print(chunk)
# assert chunk is not None
# assert isinstance(chunk, litellm.ModelResponseStream)
# assert isinstance(chunk.choices[0], litellm.utils.StreamingChoices)
# else:
# assert response is not None
# assert isinstance(response, litellm.ModelResponse)
# assert response.choices[0].message.content is not None
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# def test_completion_github_copilot_sonnet_3_7_thought(stream=False):
# try:
# litellm.set_verbose = True
# messages = [
# {"role": "system", "content": "You are an AI programming assistant."},
# {
# "role": "user",
# "content": "Write a Python function to calculate fibonacci numbers",
# },
# ]
# extra_headers = {
# "editor-version": "Neovim/0.9.0",
# "Copilot-Integration-Id": "vscode-chat",
# }
# response = completion(
# model="github_copilot/claude-3.7-sonnet-thought",
# messages=messages,
# stream=stream,
# extra_headers=extra_headers,
# )
# print(response)
# if stream is True:
# for chunk in response:
# print(chunk)
# assert chunk is not None
# assert isinstance(chunk, litellm.ModelResponseStream)
# assert isinstance(chunk.choices[0], litellm.utils.StreamingChoices)
# else:
# assert response is not None
# assert isinstance(response, litellm.ModelResponse)
# assert response.choices[0].message.content is not None
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")

View file

@ -0,0 +1,179 @@
import os
import json
import time
from datetime import datetime, timedelta
import pytest
from unittest.mock import patch, mock_open, MagicMock
from litellm.llms.github_copilot.authenticator import Authenticator
from litellm.llms.github_copilot.constants import (
GetAccessTokenError,
GetDeviceCodeError,
RefreshAPIKeyError,
GetAPIKeyError,
APIKeyExpiredError,
)
class TestGitHubCopilotAuthenticator:
@pytest.fixture
def authenticator(self):
with patch("os.path.exists", return_value=False), patch("os.makedirs") as mock_makedirs:
auth = Authenticator()
mock_makedirs.assert_called_once()
return auth
@pytest.fixture
def mock_http_client(self):
mock_client = MagicMock()
mock_response = MagicMock()
mock_client.get.return_value = mock_response
mock_client.post.return_value = mock_response
mock_response.raise_for_status.return_value = None
return mock_client, mock_response
def test_init(self):
"""Test the initialization of the authenticator."""
with patch("os.path.exists", return_value=False), patch("os.makedirs") as mock_makedirs:
auth = Authenticator()
assert auth.token_dir.endswith("/github_copilot")
assert auth.access_token_file.endswith("/access-token")
assert auth.api_key_file.endswith("/api-key.json")
mock_makedirs.assert_called_once()
def test_ensure_token_dir(self):
"""Test that the token directory is created if it doesn't exist."""
with patch("os.path.exists", return_value=False), patch("os.makedirs") as mock_makedirs:
auth = Authenticator()
mock_makedirs.assert_called_once_with(auth.token_dir, exist_ok=True)
def test_get_github_headers(self, authenticator):
"""Test that GitHub headers are correctly generated."""
headers = authenticator._get_github_headers()
assert "accept" in headers
assert "editor-version" in headers
assert "user-agent" in headers
assert "content-type" in headers
headers_with_token = authenticator._get_github_headers("test-token")
assert headers_with_token["authorization"] == "token test-token"
def test_get_access_token_from_file(self, authenticator):
"""Test retrieving an access token from a file."""
mock_token = "mock-access-token"
with patch("builtins.open", mock_open(read_data=mock_token)):
token = authenticator.get_access_token()
assert token == mock_token
def test_get_access_token_login(self, authenticator):
"""Test logging in to get an access token."""
mock_token = "mock-access-token"
with patch.object(authenticator, "_login", return_value=mock_token), \
patch("builtins.open", mock_open()), \
patch("builtins.open", side_effect=IOError) as mock_read:
token = authenticator.get_access_token()
assert token == mock_token
authenticator._login.assert_called_once()
def test_get_access_token_failure(self, authenticator):
"""Test that an exception is raised after multiple login failures."""
with patch.object(authenticator, "_login", side_effect=GetDeviceCodeError("Test error")), \
patch("builtins.open", side_effect=IOError):
with pytest.raises(GetAccessTokenError):
authenticator.get_access_token()
assert authenticator._login.call_count == 3
def test_get_api_key_from_file(self, authenticator):
"""Test retrieving an API key from a file."""
future_time = (datetime.now() + timedelta(hours=1)).timestamp()
mock_api_key_data = json.dumps({"token": "mock-api-key", "expires_at": future_time})
with patch("builtins.open", mock_open(read_data=mock_api_key_data)):
api_key = authenticator.get_api_key()
assert api_key == "mock-api-key"
def test_get_api_key_expired(self, authenticator):
"""Test refreshing an expired API key."""
past_time = (datetime.now() - timedelta(hours=1)).timestamp()
mock_expired_data = json.dumps({"token": "expired-api-key", "expires_at": past_time})
mock_new_data = {"token": "new-api-key", "expires_at": (datetime.now() + timedelta(hours=1)).timestamp()}
with patch("builtins.open", mock_open(read_data=mock_expired_data)), \
patch.object(authenticator, "_refresh_api_key", return_value=mock_new_data), \
patch("json.dump") as mock_json_dump:
api_key = authenticator.get_api_key()
assert api_key == "new-api-key"
authenticator._refresh_api_key.assert_called_once()
def test_refresh_api_key(self, authenticator, mock_http_client):
"""Test refreshing an API key."""
mock_client, mock_response = mock_http_client
mock_token = "mock-access-token"
mock_api_key_data = {"token": "new-api-key", "expires_at": 12345}
with patch.object(authenticator, "get_access_token", return_value=mock_token), \
patch("litellm.llms.github_copilot.authenticator._get_httpx_client", return_value=mock_client), \
patch.object(mock_response, "json", return_value=mock_api_key_data):
result = authenticator._refresh_api_key()
assert result == mock_api_key_data
mock_client.get.assert_called_once()
authenticator.get_access_token.assert_called_once()
def test_refresh_api_key_failure(self, authenticator, mock_http_client):
"""Test failure to refresh an API key."""
mock_client, mock_response = mock_http_client
mock_token = "mock-access-token"
with patch.object(authenticator, "get_access_token", return_value=mock_token), \
patch("litellm.llms.github_copilot.authenticator._get_httpx_client", return_value=mock_client), \
patch.object(mock_response, "json", return_value={}):
with pytest.raises(RefreshAPIKeyError):
authenticator._refresh_api_key()
assert mock_client.get.call_count == 3
def test_get_device_code(self, authenticator, mock_http_client):
"""Test getting a device code."""
mock_client, mock_response = mock_http_client
mock_device_code_data = {
"device_code": "mock-device-code",
"user_code": "ABCD-EFGH",
"verification_uri": "https://github.com/login/device"
}
with patch("litellm.llms.github_copilot.authenticator._get_httpx_client", return_value=mock_client), \
patch.object(mock_response, "json", return_value=mock_device_code_data):
result = authenticator._get_device_code()
assert result == mock_device_code_data
mock_client.post.assert_called_once()
def test_poll_for_access_token(self, authenticator, mock_http_client):
"""Test polling for an access token."""
mock_client, mock_response = mock_http_client
mock_token_data = {"access_token": "mock-access-token"}
with patch("litellm.llms.github_copilot.authenticator._get_httpx_client", return_value=mock_client), \
patch.object(mock_response, "json", return_value=mock_token_data), \
patch("time.sleep"):
result = authenticator._poll_for_access_token("mock-device-code")
assert result == "mock-access-token"
mock_client.post.assert_called_once()
def test_login(self, authenticator):
"""Test the login process."""
mock_device_code_data = {
"device_code": "mock-device-code",
"user_code": "ABCD-EFGH",
"verification_uri": "https://github.com/login/device"
}
mock_token = "mock-access-token"
with patch.object(authenticator, "_get_device_code", return_value=mock_device_code_data), \
patch.object(authenticator, "_poll_for_access_token", return_value=mock_token), \
patch("builtins.print") as mock_print:
result = authenticator._login()
assert result == mock_token
authenticator._get_device_code.assert_called_once()
authenticator._poll_for_access_token.assert_called_once_with("mock-device-code")
mock_print.assert_called_once()