diff --git a/litellm/__init__.py b/litellm/__init__.py index 59c8c78eb9..641c070b54 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -1023,6 +1023,7 @@ from .llms.azure.chat.o_series_transformation import AzureOpenAIO1Config from .llms.watsonx.completion.transformation import IBMWatsonXAIConfig from .llms.watsonx.chat.transformation import IBMWatsonXChatConfig from .llms.watsonx.embed.transformation import IBMWatsonXEmbeddingConfig +from .llms.github_copilot.chat.transformation import GithubCopilotConfig from .main import * # type: ignore from .integrations import * from .exceptions import ( diff --git a/litellm/constants.py b/litellm/constants.py index f48ce97afe..68c28520bf 100644 --- a/litellm/constants.py +++ b/litellm/constants.py @@ -157,6 +157,7 @@ LITELLM_CHAT_PROVIDERS = [ "hosted_vllm", "lm_studio", "galadriel", + "github_copilot", # GitHub Copilot Chat API ] @@ -216,7 +217,7 @@ openai_compatible_endpoints: List = [ "https://api.friendli.ai/serverless/v1", "api.sambanova.ai/v1", "api.x.ai/v1", - "api.galadriel.ai/v1", + "api.galadriel.ai/v1" ] @@ -246,6 +247,7 @@ openai_compatible_providers: List = [ "hosted_vllm", "lm_studio", "galadriel", + "github_copilot", # GitHub Copilot Chat API ] openai_text_completion_compatible_providers: List = ( [ # providers that support `/v1/completions` diff --git a/litellm/litellm_core_utils/get_llm_provider_logic.py b/litellm/litellm_core_utils/get_llm_provider_logic.py index 13103c85a0..9a83084736 100644 --- a/litellm/litellm_core_utils/get_llm_provider_logic.py +++ b/litellm/litellm_core_utils/get_llm_provider_logic.py @@ -573,6 +573,14 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915 or "https://api.galadriel.com/v1" ) # type: ignore dynamic_api_key = api_key or get_secret_str("GALADRIEL_API_KEY") + elif custom_llm_provider == "github_copilot": + ( + api_base, + dynamic_api_key, + custom_llm_provider, + ) = litellm.GithubCopilotConfig()._get_openai_compatible_provider_info( + model, api_base, api_key, custom_llm_provider + ) elif custom_llm_provider == "snowflake": api_base = ( api_base diff --git a/litellm/llms/github_copilot/__init__.py b/litellm/llms/github_copilot/__init__.py new file mode 100644 index 0000000000..c642eff211 --- /dev/null +++ b/litellm/llms/github_copilot/__init__.py @@ -0,0 +1,17 @@ +from .constants import ( + GITHUB_COPILOT_API_BASE, + CHAT_COMPLETION_ENDPOINT, + GITHUB_COPILOT_MODEL, + GetAccessTokenError, + GetAPIKeyError, + RefreshAPIKeyError, +) + +__all__ = [ + "GITHUB_COPILOT_API_BASE", + "CHAT_COMPLETION_ENDPOINT", + "GITHUB_COPILOT_MODEL", + "GetAccessTokenError", + "GetAPIKeyError", + "RefreshAPIKeyError", +] diff --git a/litellm/llms/github_copilot/authenticator.py b/litellm/llms/github_copilot/authenticator.py new file mode 100644 index 0000000000..7b76d521c1 --- /dev/null +++ b/litellm/llms/github_copilot/authenticator.py @@ -0,0 +1,305 @@ +import json +import os +import time +from datetime import datetime +from typing import Any, Dict, Optional + +import httpx + +from litellm._logging import verbose_logger +from litellm.llms.custom_httpx.http_handler import _get_httpx_client + +from .constants import ( + APIKeyExpiredError, + GetAccessTokenError, + GetAPIKeyError, + GetDeviceCodeError, + RefreshAPIKeyError, +) + +# Constants +GITHUB_CLIENT_ID = "Iv1.b507a08c87ecfe98" +GITHUB_DEVICE_CODE_URL = "https://github.com/login/device/code" +GITHUB_ACCESS_TOKEN_URL = "https://github.com/login/oauth/access_token" +GITHUB_API_KEY_URL = "https://api.github.com/copilot_internal/v2/token" + + +class Authenticator: + def __init__(self) -> None: + """Initialize the GitHub Copilot authenticator with configurable token paths.""" + # Token storage paths + self.token_dir = os.getenv( + "GITHUB_COPILOT_TOKEN_DIR", + os.path.expanduser("~/.config/litellm/github_copilot"), + ) + self.access_token_file = os.path.join( + self.token_dir, + os.getenv("GITHUB_COPILOT_ACCESS_TOKEN_FILE", "access-token"), + ) + self.api_key_file = os.path.join( + self.token_dir, os.getenv("GITHUB_COPILOT_API_KEY_FILE", "api-key.json") + ) + self._ensure_token_dir() + + def get_access_token(self) -> str: + """ + Login to Copilot with retry 3 times. + + Returns: + str: The GitHub access token. + + Raises: + GetAccessTokenError: If unable to obtain an access token after retries. + """ + try: + with open(self.access_token_file, "r") as f: + access_token = f.read().strip() + if access_token: + return access_token + except IOError: + verbose_logger.warning( + "No existing access token found or error reading file" + ) + + for attempt in range(3): + verbose_logger.debug(f"Access token acquisition attempt {attempt + 1}/3") + try: + access_token = self._login() + try: + with open(self.access_token_file, "w") as f: + f.write(access_token) + except IOError: + verbose_logger.error("Error saving access token to file") + return access_token + except (GetDeviceCodeError, GetAccessTokenError, RefreshAPIKeyError) as e: + verbose_logger.warning(f"Failed attempt {attempt + 1}: {str(e)}") + continue + + raise GetAccessTokenError("Failed to get access token after 3 attempts") + + def get_api_key(self) -> str: + """ + Get the API key, refreshing if necessary. + + Returns: + str: The GitHub Copilot API key. + + Raises: + GetAPIKeyError: If unable to obtain an API key. + """ + try: + with open(self.api_key_file, "r") as f: + api_key_info = json.load(f) + if api_key_info.get("expires_at", 0) > datetime.now().timestamp(): + return api_key_info.get("token") + else: + verbose_logger.warning("API key expired, refreshing") + raise APIKeyExpiredError("API key expired") + except IOError: + verbose_logger.warning("No API key file found or error opening file") + except (json.JSONDecodeError, KeyError) as e: + verbose_logger.warning(f"Error reading API key from file: {str(e)}") + except APIKeyExpiredError: + pass # Already logged in the try block + + try: + api_key_info = self._refresh_api_key() + with open(self.api_key_file, "w") as f: + json.dump(api_key_info, f) + token = api_key_info.get("token") + if token: + return token + else: + raise GetAPIKeyError("API key response missing token") + except IOError as e: + verbose_logger.error(f"Error saving API key to file: {str(e)}") + raise GetAPIKeyError(f"Failed to save API key: {str(e)}") + except RefreshAPIKeyError as e: + raise GetAPIKeyError(f"Failed to refresh API key: {str(e)}") + + def _refresh_api_key(self) -> Dict[str, Any]: + """ + Refresh the API key using the access token. + + Returns: + Dict[str, Any]: The API key information including token and expiration. + + Raises: + RefreshAPIKeyError: If unable to refresh the API key. + """ + access_token = self.get_access_token() + headers = self._get_github_headers(access_token) + + max_retries = 3 + for attempt in range(max_retries): + try: + sync_client = _get_httpx_client() + response = sync_client.get(GITHUB_API_KEY_URL, headers=headers) + response.raise_for_status() + + response_json = response.json() + + if "token" in response_json: + return response_json + else: + verbose_logger.warning( + f"API key response missing token: {response_json}" + ) + except httpx.HTTPStatusError as e: + verbose_logger.error( + f"HTTP error refreshing API key (attempt {attempt+1}/{max_retries}): {str(e)}" + ) + except Exception as e: + verbose_logger.error(f"Unexpected error refreshing API key: {str(e)}") + + raise RefreshAPIKeyError("Failed to refresh API key after maximum retries") + + def _ensure_token_dir(self) -> None: + """Ensure the token directory exists.""" + if not os.path.exists(self.token_dir): + os.makedirs(self.token_dir, exist_ok=True) + + def _get_github_headers(self, access_token: Optional[str] = None) -> Dict[str, str]: + """ + Generate standard GitHub headers for API requests. + + Args: + access_token: Optional access token to include in the headers. + + Returns: + Dict[str, str]: Headers for GitHub API requests. + """ + headers = { + "accept": "application/json", + "editor-version": "vscode/1.85.1", + "editor-plugin-version": "copilot/1.155.0", + "user-agent": "GithubCopilot/1.155.0", + "accept-encoding": "gzip,deflate,br", + } + + if access_token: + headers["authorization"] = f"token {access_token}" + + if "content-type" not in headers: + headers["content-type"] = "application/json" + + return headers + + def _get_device_code(self) -> Dict[str, str]: + """ + Get a device code for GitHub authentication. + + Returns: + Dict[str, str]: Device code information. + + Raises: + GetDeviceCodeError: If unable to get a device code. + """ + try: + sync_client = _get_httpx_client() + resp = sync_client.post( + GITHUB_DEVICE_CODE_URL, + headers=self._get_github_headers(), + json={"client_id": GITHUB_CLIENT_ID, "scope": "read:user"}, + ) + resp.raise_for_status() + resp_json = resp.json() + + required_fields = ["device_code", "user_code", "verification_uri"] + if not all(field in resp_json for field in required_fields): + verbose_logger.error(f"Response missing required fields: {resp_json}") + raise GetDeviceCodeError("Response missing required fields") + + return resp_json + except httpx.HTTPStatusError as e: + verbose_logger.error(f"HTTP error getting device code: {str(e)}") + raise GetDeviceCodeError(f"Failed to get device code: {str(e)}") + except json.JSONDecodeError as e: + verbose_logger.error(f"Error decoding JSON response: {str(e)}") + raise GetDeviceCodeError(f"Failed to decode device code response: {str(e)}") + except Exception as e: + verbose_logger.error(f"Unexpected error getting device code: {str(e)}") + raise GetDeviceCodeError(f"Failed to get device code: {str(e)}") + + def _poll_for_access_token(self, device_code: str) -> str: + """ + Poll for an access token after user authentication. + + Args: + device_code: The device code to use for polling. + + Returns: + str: The access token. + + Raises: + GetAccessTokenError: If unable to get an access token. + """ + sync_client = _get_httpx_client() + max_attempts = 12 # 1 minute (12 * 5 seconds) + + for attempt in range(max_attempts): + try: + resp = sync_client.post( + GITHUB_ACCESS_TOKEN_URL, + headers=self._get_github_headers(), + json={ + "client_id": GITHUB_CLIENT_ID, + "device_code": device_code, + "grant_type": "urn:ietf:params:oauth:grant-type:device_code", + }, + ) + resp.raise_for_status() + resp_json = resp.json() + + if "access_token" in resp_json: + verbose_logger.info("Authentication successful!") + return resp_json["access_token"] + elif ( + "error" in resp_json + and resp_json.get("error") == "authorization_pending" + ): + verbose_logger.debug( + f"Authorization pending (attempt {attempt+1}/{max_attempts})" + ) + else: + verbose_logger.warning(f"Unexpected response: {resp_json}") + except httpx.HTTPStatusError as e: + verbose_logger.error(f"HTTP error polling for access token: {str(e)}") + raise GetAccessTokenError(f"Failed to get access token: {str(e)}") + except json.JSONDecodeError as e: + verbose_logger.error(f"Error decoding JSON response: {str(e)}") + raise GetAccessTokenError( + f"Failed to decode access token response: {str(e)}" + ) + except Exception as e: + verbose_logger.error( + f"Unexpected error polling for access token: {str(e)}" + ) + raise GetAccessTokenError(f"Failed to get access token: {str(e)}") + + time.sleep(5) + + raise GetAccessTokenError("Timed out waiting for user to authorize the device") + + def _login(self) -> str: + """ + Login to GitHub Copilot using device code flow. + + Returns: + str: The GitHub access token. + + Raises: + GetDeviceCodeError: If unable to get a device code. + GetAccessTokenError: If unable to get an access token. + """ + device_code_info = self._get_device_code() + + device_code = device_code_info["device_code"] + user_code = device_code_info["user_code"] + verification_uri = device_code_info["verification_uri"] + + print( + f"Please visit {verification_uri} and enter code {user_code} to authenticate." + ) + + return self._poll_for_access_token(device_code) diff --git a/litellm/llms/github_copilot/chat/transformation.py b/litellm/llms/github_copilot/chat/transformation.py new file mode 100644 index 0000000000..5f074debc5 --- /dev/null +++ b/litellm/llms/github_copilot/chat/transformation.py @@ -0,0 +1,37 @@ +from typing import Optional, Tuple + + +from litellm.llms.openai.openai import OpenAIConfig + +from ..authenticator import Authenticator +from ..constants import GetAPIKeyError +from litellm.exceptions import AuthenticationError + + +class GithubCopilotConfig(OpenAIConfig): + def __init__( + self, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + custom_llm_provider: str = "openai", + ) -> None: + super().__init__() + self.authenticator = Authenticator() + + def _get_openai_compatible_provider_info( + self, + model: str, + api_base: Optional[str], + api_key: Optional[str], + custom_llm_provider: str, + ) -> Tuple[Optional[str], Optional[str], str]: + api_base = "https://api.githubcopilot.com" + try: + dynamic_api_key = self.authenticator.get_api_key() + except GetAPIKeyError as e: + raise AuthenticationError( + model=model, + llm_provider=custom_llm_provider, + message=str(e), + ) + return api_base, dynamic_api_key, custom_llm_provider diff --git a/litellm/llms/github_copilot/constants.py b/litellm/llms/github_copilot/constants.py new file mode 100644 index 0000000000..210511bea3 --- /dev/null +++ b/litellm/llms/github_copilot/constants.py @@ -0,0 +1,38 @@ +""" +Constants for Copilot integration +""" +import os + +# Copilot API endpoints +GITHUB_COPILOT_API_BASE = "https://api.github.com/copilot/v1" +CHAT_COMPLETION_ENDPOINT = "/chat/completions" + +# Model names +GITHUB_COPILOT_MODEL = "gpt-4o" # The model identifier for Copilot + +# Request headers +DEFAULT_HEADERS = { + "Content-Type": "application/json", + "Accept": "application/json", + "User-Agent": "litellm", +} + + +class GetDeviceCodeError(Exception): + pass + + +class GetAccessTokenError(Exception): + pass + + +class APIKeyExpiredError(Exception): + pass + + +class RefreshAPIKeyError(Exception): + pass + + +class GetAPIKeyError(Exception): + pass diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 55052761c7..ae7c594742 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -11544,6 +11544,89 @@ "litellm_provider": "jina_ai", "mode": "rerank" }, + "github_copilot/gpt-4o": { + "max_tokens": 128000, + "max_input_tokens": 64000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.0, + "output_cost_per_token": 0.0, + "litellm_provider": "github_copilot", + "mode": "chat", + "supports_vision": true, + "supports_function_calling": true, + "supports_parallel_function_calling": true, + "supports_system_messages": true + }, + "github_copilot/o1": { + "max_tokens": 200000, + "max_input_tokens": 20000, + "max_output_tokens": 20000, + "input_cost_per_token": 0.0, + "output_cost_per_token": 0.0, + "litellm_provider": "github_copilot", + "mode": "chat", + "supports_function_calling": true, + "supports_response_schema": true, + "supports_system_messages": true + }, + "github_copilot/o3-mini": { + "max_tokens": 200000, + "max_input_tokens": 64000, + "max_output_tokens": 100000, + "input_cost_per_token": 0.0, + "output_cost_per_token": 0.0, + "litellm_provider": "github_copilot", + "mode": "chat", + "supports_function_calling": true, + "supports_response_schema": true, + "supports_system_messages": true + }, + "github_copilot/claude-3.5-sonnet": { + "max_tokens": 90000, + "max_input_tokens": 90000, + "max_output_tokens": 8192, + "input_cost_per_token": 0.0, + "output_cost_per_token": 0.0, + "litellm_provider": "github_copilot", + "mode": "chat", + "supports_vision": true, + "supports_function_calling": true, + "supports_parallel_function_calling": true, + "supports_system_messages": true + }, + "github_copilot/claude-3.7-sonnet": { + "max_tokens": 200000, + "max_input_tokens": 90000, + "max_output_tokens": 8192, + "input_cost_per_token": 0.0, + "output_cost_per_token": 0.0, + "litellm_provider": "github_copilot", + "mode": "chat", + "supports_function_calling": true, + "supports_parallel_function_calling": true, + "supports_system_messages": true + }, + "github_copilot/claude-3.7-sonnet-thought": { + "max_tokens": 200000, + "max_input_tokens": 90000, + "max_output_tokens": 8192, + "input_cost_per_token": 0.0, + "output_cost_per_token": 0.0, + "litellm_provider": "github_copilot", + "mode": "chat", + "supports_system_messages": true + }, + "github_copilot/gemini-2.0-flash-001": { + "max_tokens": 1000000, + "max_input_tokens": 128000, + "max_output_tokens": 8192, + "input_cost_per_token": 0.0, + "output_cost_per_token": 0.0, + "litellm_provider": "github_copilot", + "mode": "chat", + "supports_vision": true, + "supports_system_messages": true + "snowflake/deepseek-r1": { "max_tokens": 32768, "max_input_tokens": 32768, diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 532162e60f..a09e902275 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -2101,6 +2101,7 @@ class LlmProviders(str, Enum): HUMANLOOP = "humanloop" TOPAZ = "topaz" ASSEMBLYAI = "assemblyai" + GITHUB_COPILOT = "github_copilot" SNOWFLAKE = "snowflake" diff --git a/tests/llm_translation/test_github_copilot.py b/tests/llm_translation/test_github_copilot.py new file mode 100644 index 0000000000..e53045b36d --- /dev/null +++ b/tests/llm_translation/test_github_copilot.py @@ -0,0 +1,232 @@ +import pytest +import os +import sys +import json +import asyncio +from datetime import datetime, timedelta +from typing import AsyncGenerator +from unittest.mock import AsyncMock, patch, mock_open, MagicMock + +sys.path.insert(0, os.path.abspath("../..")) + +import httpx +import pytest +from respx import MockRouter + +import litellm +from litellm import Choices, Message, ModelResponse, ModelResponse, Usage +from litellm import completion, acompletion +from litellm.llms.github_copilot.chat.transformation import GithubCopilotConfig +from litellm.llms.github_copilot.authenticator import Authenticator +from litellm.llms.github_copilot.constants import ( + GetAccessTokenError, + GetDeviceCodeError, + RefreshAPIKeyError, + GetAPIKeyError, + APIKeyExpiredError, +) +from litellm.exceptions import AuthenticationError + +# Import at the top to make the patch work correctly +import litellm.llms.github_copilot.chat.transformation + + +def test_github_copilot_config_get_openai_compatible_provider_info(): + """Test the GitHub Copilot configuration provider info retrieval.""" + + config = GithubCopilotConfig() + + # Mock the authenticator to avoid actual API calls + mock_api_key = "gh.test-key-123456789" + config.authenticator = MagicMock() + config.authenticator.get_api_key.return_value = mock_api_key + + # Test with default values + model = "github_copilot/gpt-4" + ( + api_base, + dynamic_api_key, + custom_llm_provider, + ) = config._get_openai_compatible_provider_info( + model=model, + api_base=None, + api_key=None, + custom_llm_provider="github_copilot", + ) + + assert api_base == "https://api.githubcopilot.com" + assert dynamic_api_key == mock_api_key + assert custom_llm_provider == "github_copilot" + + # Test with authentication failure + config.authenticator.get_api_key.side_effect = GetAPIKeyError( + "Failed to get API key" + ) + + with pytest.raises(AuthenticationError) as excinfo: + config._get_openai_compatible_provider_info( + model=model, + api_base=None, + api_key=None, + custom_llm_provider="github_copilot", + ) + + assert "Failed to get API key" in str(excinfo.value) + + +@patch("litellm.main.get_llm_provider") +@patch("litellm.llms.openai.openai.OpenAIChatCompletion.completion") +def test_completion_github_copilot_mock_response(mock_completion, mock_get_llm_provider): + """Test the completion function with GitHub Copilot provider.""" + + # Mock completion response + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Hello, I'm GitHub Copilot!" + mock_completion.return_value = mock_response + + # Test non-streaming completion + messages = [ + {"role": "system", "content": "You're GitHub Copilot, an AI assistant."}, + {"role": "user", "content": "Hello, who are you?"}, + ] + + # Create a properly formatted headers dictionary + headers = { + "editor-version": "Neovim/0.9.0", + "Copilot-Integration-Id": "vscode-chat", + } + + # Patch the get_llm_provider function instead of the config method + # Make it return the expected tuple directly + mock_get_llm_provider.return_value = ( + "gpt-4", + "github_copilot", + "gh.test-key-123456789", + "https://api.githubcopilot.com", + ) + + response = completion( + model="github_copilot/gpt-4", + messages=messages, + extra_headers=headers, + ) + + assert response is not None + + # Verify the get_llm_provider call was made with the expected params + mock_get_llm_provider.assert_called_once() + args, kwargs = mock_get_llm_provider.call_args + assert kwargs.get("model") is "github_copilot/gpt-4" + assert kwargs.get("custom_llm_provider") is None + assert kwargs.get("api_key") is None + assert kwargs.get("api_base") is None + + # Verify the completion call was made with the expected params + mock_completion.assert_called_once() + args, kwargs = mock_completion.call_args + + # Check that the proper authorization header is set + assert "headers" in kwargs + # Check that the model name is correctly formatted + assert ( + kwargs.get("model") == "gpt-4" + ) # Model name should be without provider prefix + assert kwargs.get("messages") == messages + + +@patch("litellm.llms.github_copilot.authenticator.Authenticator.get_api_key") +def test_authenticator_get_api_key(mock_get_api_key): + """Test the Authenticator's get_api_key method.""" + from litellm.llms.github_copilot.authenticator import Authenticator + + # Test successful API key retrieval + mock_get_api_key.return_value = "gh.test-key-123456789" + authenticator = Authenticator() + api_key = authenticator.get_api_key() + + assert api_key == "gh.test-key-123456789" + mock_get_api_key.assert_called_once() + + # Test API key retrieval failure + mock_get_api_key.reset_mock() + mock_get_api_key.side_effect = GetAPIKeyError("Failed to get API key") + authenticator = Authenticator() + + with pytest.raises(GetAPIKeyError) as excinfo: + authenticator.get_api_key() + + assert "Failed to get API key" in str(excinfo.value) + + +# def test_completion_github_copilot(stream=False): +# try: +# litellm.set_verbose = True +# messages = [ +# {"role": "system", "content": "You are an AI programming assistant."}, +# { +# "role": "user", +# "content": "Write a Python function to calculate fibonacci numbers", +# }, +# ] +# extra_headers = { +# "editor-version": "Neovim/0.9.0", +# "Copilot-Integration-Id": "vscode-chat", +# } +# response = completion( +# model="github_copilot/gpt-4", +# messages=messages, +# stream=stream, +# extra_headers=extra_headers, +# ) +# print(response) + +# if stream is True: +# for chunk in response: +# print(chunk) +# assert chunk is not None +# assert isinstance(chunk, litellm.ModelResponseStream) +# assert isinstance(chunk.choices[0], litellm.utils.StreamingChoices) + +# else: +# assert response is not None +# assert isinstance(response, litellm.ModelResponse) +# assert response.choices[0].message.content is not None +# except Exception as e: +# pytest.fail(f"Error occurred: {e}") + +# def test_completion_github_copilot_sonnet_3_7_thought(stream=False): +# try: +# litellm.set_verbose = True +# messages = [ +# {"role": "system", "content": "You are an AI programming assistant."}, +# { +# "role": "user", +# "content": "Write a Python function to calculate fibonacci numbers", +# }, +# ] +# extra_headers = { +# "editor-version": "Neovim/0.9.0", +# "Copilot-Integration-Id": "vscode-chat", +# } +# response = completion( +# model="github_copilot/claude-3.7-sonnet-thought", +# messages=messages, +# stream=stream, +# extra_headers=extra_headers, +# ) +# print(response) + +# if stream is True: +# for chunk in response: +# print(chunk) +# assert chunk is not None +# assert isinstance(chunk, litellm.ModelResponseStream) +# assert isinstance(chunk.choices[0], litellm.utils.StreamingChoices) + +# else: +# assert response is not None +# assert isinstance(response, litellm.ModelResponse) +# assert response.choices[0].message.content is not None +# except Exception as e: +# pytest.fail(f"Error occurred: {e}") diff --git a/tests/llm_translation/test_github_copilot_authenticator.py b/tests/llm_translation/test_github_copilot_authenticator.py new file mode 100644 index 0000000000..247b7bf678 --- /dev/null +++ b/tests/llm_translation/test_github_copilot_authenticator.py @@ -0,0 +1,179 @@ +import os +import json +import time +from datetime import datetime, timedelta +import pytest +from unittest.mock import patch, mock_open, MagicMock + +from litellm.llms.github_copilot.authenticator import Authenticator +from litellm.llms.github_copilot.constants import ( + GetAccessTokenError, + GetDeviceCodeError, + RefreshAPIKeyError, + GetAPIKeyError, + APIKeyExpiredError, +) + + +class TestGitHubCopilotAuthenticator: + @pytest.fixture + def authenticator(self): + with patch("os.path.exists", return_value=False), patch("os.makedirs") as mock_makedirs: + auth = Authenticator() + mock_makedirs.assert_called_once() + return auth + + @pytest.fixture + def mock_http_client(self): + mock_client = MagicMock() + mock_response = MagicMock() + mock_client.get.return_value = mock_response + mock_client.post.return_value = mock_response + mock_response.raise_for_status.return_value = None + return mock_client, mock_response + + def test_init(self): + """Test the initialization of the authenticator.""" + with patch("os.path.exists", return_value=False), patch("os.makedirs") as mock_makedirs: + auth = Authenticator() + assert auth.token_dir.endswith("/github_copilot") + assert auth.access_token_file.endswith("/access-token") + assert auth.api_key_file.endswith("/api-key.json") + mock_makedirs.assert_called_once() + + def test_ensure_token_dir(self): + """Test that the token directory is created if it doesn't exist.""" + with patch("os.path.exists", return_value=False), patch("os.makedirs") as mock_makedirs: + auth = Authenticator() + mock_makedirs.assert_called_once_with(auth.token_dir, exist_ok=True) + + def test_get_github_headers(self, authenticator): + """Test that GitHub headers are correctly generated.""" + headers = authenticator._get_github_headers() + assert "accept" in headers + assert "editor-version" in headers + assert "user-agent" in headers + assert "content-type" in headers + + headers_with_token = authenticator._get_github_headers("test-token") + assert headers_with_token["authorization"] == "token test-token" + + def test_get_access_token_from_file(self, authenticator): + """Test retrieving an access token from a file.""" + mock_token = "mock-access-token" + + with patch("builtins.open", mock_open(read_data=mock_token)): + token = authenticator.get_access_token() + assert token == mock_token + + def test_get_access_token_login(self, authenticator): + """Test logging in to get an access token.""" + mock_token = "mock-access-token" + + with patch.object(authenticator, "_login", return_value=mock_token), \ + patch("builtins.open", mock_open()), \ + patch("builtins.open", side_effect=IOError) as mock_read: + token = authenticator.get_access_token() + assert token == mock_token + authenticator._login.assert_called_once() + + def test_get_access_token_failure(self, authenticator): + """Test that an exception is raised after multiple login failures.""" + with patch.object(authenticator, "_login", side_effect=GetDeviceCodeError("Test error")), \ + patch("builtins.open", side_effect=IOError): + with pytest.raises(GetAccessTokenError): + authenticator.get_access_token() + assert authenticator._login.call_count == 3 + + def test_get_api_key_from_file(self, authenticator): + """Test retrieving an API key from a file.""" + future_time = (datetime.now() + timedelta(hours=1)).timestamp() + mock_api_key_data = json.dumps({"token": "mock-api-key", "expires_at": future_time}) + + with patch("builtins.open", mock_open(read_data=mock_api_key_data)): + api_key = authenticator.get_api_key() + assert api_key == "mock-api-key" + + def test_get_api_key_expired(self, authenticator): + """Test refreshing an expired API key.""" + past_time = (datetime.now() - timedelta(hours=1)).timestamp() + mock_expired_data = json.dumps({"token": "expired-api-key", "expires_at": past_time}) + mock_new_data = {"token": "new-api-key", "expires_at": (datetime.now() + timedelta(hours=1)).timestamp()} + + with patch("builtins.open", mock_open(read_data=mock_expired_data)), \ + patch.object(authenticator, "_refresh_api_key", return_value=mock_new_data), \ + patch("json.dump") as mock_json_dump: + api_key = authenticator.get_api_key() + assert api_key == "new-api-key" + authenticator._refresh_api_key.assert_called_once() + + def test_refresh_api_key(self, authenticator, mock_http_client): + """Test refreshing an API key.""" + mock_client, mock_response = mock_http_client + mock_token = "mock-access-token" + mock_api_key_data = {"token": "new-api-key", "expires_at": 12345} + + with patch.object(authenticator, "get_access_token", return_value=mock_token), \ + patch("litellm.llms.github_copilot.authenticator._get_httpx_client", return_value=mock_client), \ + patch.object(mock_response, "json", return_value=mock_api_key_data): + result = authenticator._refresh_api_key() + assert result == mock_api_key_data + mock_client.get.assert_called_once() + authenticator.get_access_token.assert_called_once() + + def test_refresh_api_key_failure(self, authenticator, mock_http_client): + """Test failure to refresh an API key.""" + mock_client, mock_response = mock_http_client + mock_token = "mock-access-token" + + with patch.object(authenticator, "get_access_token", return_value=mock_token), \ + patch("litellm.llms.github_copilot.authenticator._get_httpx_client", return_value=mock_client), \ + patch.object(mock_response, "json", return_value={}): + with pytest.raises(RefreshAPIKeyError): + authenticator._refresh_api_key() + assert mock_client.get.call_count == 3 + + def test_get_device_code(self, authenticator, mock_http_client): + """Test getting a device code.""" + mock_client, mock_response = mock_http_client + mock_device_code_data = { + "device_code": "mock-device-code", + "user_code": "ABCD-EFGH", + "verification_uri": "https://github.com/login/device" + } + + with patch("litellm.llms.github_copilot.authenticator._get_httpx_client", return_value=mock_client), \ + patch.object(mock_response, "json", return_value=mock_device_code_data): + result = authenticator._get_device_code() + assert result == mock_device_code_data + mock_client.post.assert_called_once() + + def test_poll_for_access_token(self, authenticator, mock_http_client): + """Test polling for an access token.""" + mock_client, mock_response = mock_http_client + mock_token_data = {"access_token": "mock-access-token"} + + with patch("litellm.llms.github_copilot.authenticator._get_httpx_client", return_value=mock_client), \ + patch.object(mock_response, "json", return_value=mock_token_data), \ + patch("time.sleep"): + result = authenticator._poll_for_access_token("mock-device-code") + assert result == "mock-access-token" + mock_client.post.assert_called_once() + + def test_login(self, authenticator): + """Test the login process.""" + mock_device_code_data = { + "device_code": "mock-device-code", + "user_code": "ABCD-EFGH", + "verification_uri": "https://github.com/login/device" + } + mock_token = "mock-access-token" + + with patch.object(authenticator, "_get_device_code", return_value=mock_device_code_data), \ + patch.object(authenticator, "_poll_for_access_token", return_value=mock_token), \ + patch("builtins.print") as mock_print: + result = authenticator._login() + assert result == mock_token + authenticator._get_device_code.assert_called_once() + authenticator._poll_for_access_token.assert_called_once_with("mock-device-code") + mock_print.assert_called_once()