diff --git a/.gitignore b/.gitignore index 1b740501f3..51103df76f 100644 --- a/.gitignore +++ b/.gitignore @@ -78,4 +78,5 @@ litellm/proxy/_experimental/out/model_hub.html .mypy_cache/* litellm/proxy/application.log -**/__pycache__ \ No newline at end of file +**/__pycache__/** +**/*.pyc diff --git a/litellm/litellm_core_utils/get_llm_provider_logic.py b/litellm/litellm_core_utils/get_llm_provider_logic.py index 45fdbfaa35..c8c7f30803 100644 --- a/litellm/litellm_core_utils/get_llm_provider_logic.py +++ b/litellm/litellm_core_utils/get_llm_provider_logic.py @@ -578,7 +578,7 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915 custom_llm_provider, ) = litellm.GithubCopilotConfig()._get_openai_compatible_provider_info( model, api_base, api_key, custom_llm_provider - ) + ) if api_base is not None and not isinstance(api_base, str): raise Exception("api base needs to be a string. api_base={}".format(api_base)) if dynamic_api_key is not None and not isinstance(dynamic_api_key, str): diff --git a/litellm/llms/github_copilot/__init__.py b/litellm/llms/github_copilot/__init__.py index 3d6c43e093..c642eff211 100644 --- a/litellm/llms/github_copilot/__init__.py +++ b/litellm/llms/github_copilot/__init__.py @@ -1,4 +1,11 @@ -from .constants import GITHUB_COPILOT_API_BASE, CHAT_COMPLETION_ENDPOINT, GITHUB_COPILOT_MODEL, GetAccessTokenError, GetAPIKeyError, RefreshAPIKeyError +from .constants import ( + GITHUB_COPILOT_API_BASE, + CHAT_COMPLETION_ENDPOINT, + GITHUB_COPILOT_MODEL, + GetAccessTokenError, + GetAPIKeyError, + RefreshAPIKeyError, +) __all__ = [ "GITHUB_COPILOT_API_BASE", @@ -6,5 +13,5 @@ __all__ = [ "GITHUB_COPILOT_MODEL", "GetAccessTokenError", "GetAPIKeyError", - "RefreshAPIKeyError" + "RefreshAPIKeyError", ] diff --git a/litellm/llms/github_copilot/authenticator.py b/litellm/llms/github_copilot/authenticator.py index e712ac1d9a..46a85a7efc 100644 --- a/litellm/llms/github_copilot/authenticator.py +++ b/litellm/llms/github_copilot/authenticator.py @@ -23,16 +23,23 @@ from .constants import ( APIKeyExpiredError, ) + class Authenticator: def __init__(self) -> None: # Token storage paths - self.token_dir = os.getenv("GITHUB_COPILOT_TOKEN_DIR", os.path.expanduser("~/.config/litellm/github_copilot")) - self.access_token_file = os.path.join(self.token_dir, os.getenv("GITHUB_COPILOT_ACCESS_TOKEN_FILE", "access-token")) - self.api_key_file = os.path.join(self.token_dir, os.getenv("GITHUB_COPILOT_API_KEY_FILE", "api-key.json")) + self.token_dir = os.getenv( + "GITHUB_COPILOT_TOKEN_DIR", + os.path.expanduser("~/.config/litellm/github_copilot"), + ) + self.access_token_file = os.path.join( + self.token_dir, + os.getenv("GITHUB_COPILOT_ACCESS_TOKEN_FILE", "access-token"), + ) + self.api_key_file = os.path.join( + self.token_dir, os.getenv("GITHUB_COPILOT_API_KEY_FILE", "api-key.json") + ) self._ensure_token_dir() - self.get_access_token() - def get_access_token(self) -> str: """ Login to Copilot with retry 3 times @@ -42,12 +49,12 @@ class Authenticator: """ try: - with open(self.access_token_file, 'r') as f: + with open(self.access_token_file, "r") as f: access_token = f.read().strip() return access_token except IOError: verbose_logger.warning("Error loading access token from file") - + for _ in range(3): try: access_token = self._login() @@ -55,7 +62,7 @@ class Authenticator: continue else: try: - with open(self.access_token_file, 'w') as f: + with open(self.access_token_file, "w") as f: f.write(access_token) except IOError: verbose_logger.error("Error saving access token to file") @@ -66,10 +73,10 @@ class Authenticator: def get_api_key(self) -> str: """Get the API key""" try: - with open(self.api_key_file, 'r') as f: + with open(self.api_key_file, "r") as f: api_key_info = json.load(f) - if api_key_info.get('expires_at') > datetime.now().timestamp(): - return api_key_info.get('token') + if api_key_info.get("expires_at") > datetime.now().timestamp(): + return api_key_info.get("token") else: raise APIKeyExpiredError("API key expired") except IOError: @@ -78,17 +85,17 @@ class Authenticator: verbose_logger.warning("Error reading API key from file") except APIKeyExpiredError: verbose_logger.warning("API key expired") - + try: api_key_info = self._refresh_api_key() - with open(self.api_key_file, 'w') as f: + with open(self.api_key_file, "w") as f: json.dump(api_key_info, f) except IOError: verbose_logger.error("Error saving API key to file") except RefreshAPIKeyError: raise GetAPIKeyError("Failed to refresh API key") - - return api_key_info.get('token') + + return api_key_info.get("token") def _refresh_api_key(self) -> dict: """ @@ -100,10 +107,10 @@ class Authenticator: access_token = self.get_access_token() headers = { - 'authorization': f'token {access_token}', - 'editor-version': 'vscode/1.85.1', - 'editor-plugin-version': 'copilot/1.155.0', - 'user-agent': 'GithubCopilot/1.155.0' + "authorization": f"token {access_token}", + "editor-version": "vscode/1.85.1", + "editor-plugin-version": "copilot/1.155.0", + "user-agent": "GithubCopilot/1.155.0", } max_retries = 3 @@ -111,20 +118,18 @@ class Authenticator: try: sync_client = _get_httpx_client() response = sync_client.get( - 'https://api.github.com/copilot_internal/v2/token', - headers=headers + "https://api.github.com/copilot_internal/v2/token", headers=headers ) response.raise_for_status() response_json = response.json() - if 'token' in response_json: + if "token" in response_json: return response_json except httpx.HTTPStatusError as e: verbose_logger.error(f"Error refreshing API key: {str(e)}") raise RefreshAPIKeyError("Failed to refresh API key") - def _ensure_token_dir(self) -> None: """Ensure the token directory exists""" @@ -143,23 +148,23 @@ class Authenticator: sync_client = _get_httpx_client() # Get device code resp = sync_client.post( - 'https://github.com/login/device/code', + "https://github.com/login/device/code", headers={ - 'accept': 'application/json', - 'editor-version': 'vscode/1.85.1', - 'editor-plugin-version': 'copilot/1.155.0', - 'content-type': 'application/json', - 'user-agent': 'GithubCopilot/1.155.0', - 'accept-encoding': 'gzip,deflate,br' + "accept": "application/json", + "editor-version": "vscode/1.85.1", + "editor-plugin-version": "copilot/1.155.0", + "content-type": "application/json", + "user-agent": "GithubCopilot/1.155.0", + "accept-encoding": "gzip,deflate,br", }, - json={"client_id": "Iv1.b507a08c87ecfe98", "scope": "read:user"} + json={"client_id": "Iv1.b507a08c87ecfe98", "scope": "read:user"}, ) resp.raise_for_status() resp_json = resp.json() - device_code = resp_json.get('device_code') - user_code = resp_json.get('user_code') - verification_uri = resp_json.get('verification_uri') + device_code = resp_json.get("device_code") + user_code = resp_json.get("user_code") + verification_uri = resp_json.get("verification_uri") if not all([device_code, user_code, verification_uri]): verbose_logger.error("Response missing required fields") @@ -173,8 +178,10 @@ class Authenticator: except RuntimeError as e: verbose_logger.error(f"Error getting device code: {str(e)}") raise GetDeviceCodeError("Failed to get device code") - - print(f'Please visit {verification_uri} and enter code {user_code} to authenticate.') + + print( + f"Please visit {verification_uri} and enter code {user_code} to authenticate." + ) while True: time.sleep(5) @@ -182,27 +189,27 @@ class Authenticator: # Get access token try: resp = sync_client.post( - 'https://github.com/login/oauth/access_token', + "https://github.com/login/oauth/access_token", headers={ - 'accept': 'application/json', - 'editor-version': 'vscode/1.85.1', - 'editor-plugin-version': 'copilot/1.155.0', - 'content-type': 'application/json', - 'user-agent': 'GithubCopilot/1.155.0', - 'accept-encoding': 'gzip,deflate,br' + "accept": "application/json", + "editor-version": "vscode/1.85.1", + "editor-plugin-version": "copilot/1.155.0", + "content-type": "application/json", + "user-agent": "GithubCopilot/1.155.0", + "accept-encoding": "gzip,deflate,br", }, json={ "client_id": "Iv1.b507a08c87ecfe98", "device_code": device_code, - "grant_type": "urn:ietf:params:oauth:grant-type:device_code" - } + "grant_type": "urn:ietf:params:oauth:grant-type:device_code", + }, ) resp.raise_for_status() resp_json = resp.json() if "access_token" in resp_json: verbose_logger.info("Authentication success!") - return resp_json['access_token'] + return resp_json["access_token"] else: continue except httpx.HTTPStatusError as e: @@ -211,7 +218,3 @@ class Authenticator: except json.JSONDecodeError as e: verbose_logger.error(f"Error decoding JSON response: {str(e)}") raise GetAccessTokenError("Failed to get access token") - - - - diff --git a/litellm/llms/github_copilot/chat/transformation.py b/litellm/llms/github_copilot/chat/transformation.py index 3ba9bb8250..5f074debc5 100644 --- a/litellm/llms/github_copilot/chat/transformation.py +++ b/litellm/llms/github_copilot/chat/transformation.py @@ -4,6 +4,9 @@ from typing import Optional, Tuple from litellm.llms.openai.openai import OpenAIConfig from ..authenticator import Authenticator +from ..constants import GetAPIKeyError +from litellm.exceptions import AuthenticationError + class GithubCopilotConfig(OpenAIConfig): def __init__( @@ -13,7 +16,6 @@ class GithubCopilotConfig(OpenAIConfig): custom_llm_provider: str = "openai", ) -> None: super().__init__() - self.authenticator = Authenticator() def _get_openai_compatible_provider_info( @@ -24,6 +26,12 @@ class GithubCopilotConfig(OpenAIConfig): custom_llm_provider: str, ) -> Tuple[Optional[str], Optional[str], str]: api_base = "https://api.githubcopilot.com" - dynamic_api_key = self.authenticator.get_api_key() - + try: + dynamic_api_key = self.authenticator.get_api_key() + except GetAPIKeyError as e: + raise AuthenticationError( + model=model, + llm_provider=custom_llm_provider, + message=str(e), + ) return api_base, dynamic_api_key, custom_llm_provider diff --git a/litellm/llms/github_copilot/constants.py b/litellm/llms/github_copilot/constants.py index b1e1b769a3..210511bea3 100644 --- a/litellm/llms/github_copilot/constants.py +++ b/litellm/llms/github_copilot/constants.py @@ -17,9 +17,22 @@ DEFAULT_HEADERS = { "User-Agent": "litellm", } -class GetDeviceCodeError(Exception): pass -class GetAccessTokenError(Exception): pass -class APIKeyExpiredError(Exception): pass -class RefreshAPIKeyError(Exception): pass -class GetAPIKeyError(Exception): pass +class GetDeviceCodeError(Exception): + pass + + +class GetAccessTokenError(Exception): + pass + + +class APIKeyExpiredError(Exception): + pass + + +class RefreshAPIKeyError(Exception): + pass + + +class GetAPIKeyError(Exception): + pass diff --git a/tests/llm_translation/test_github_copilot.py b/tests/llm_translation/test_github_copilot.py new file mode 100644 index 0000000000..0836da5725 --- /dev/null +++ b/tests/llm_translation/test_github_copilot.py @@ -0,0 +1,188 @@ +import pytest +import os +import sys +import json +import asyncio +from datetime import datetime, timedelta +from typing import AsyncGenerator +from unittest.mock import AsyncMock, patch, mock_open, MagicMock + +sys.path.insert(0, os.path.abspath("../..")) + +import httpx +import pytest +from respx import MockRouter + +import litellm +from litellm import Choices, Message, ModelResponse, ModelResponse, Usage +from litellm import completion, acompletion +from litellm.llms.github_copilot.chat.transformation import GithubCopilotConfig +from litellm.llms.github_copilot.authenticator import Authenticator +from litellm.llms.github_copilot.constants import ( + GetAccessTokenError, + GetDeviceCodeError, + RefreshAPIKeyError, + GetAPIKeyError, + APIKeyExpiredError, +) +from litellm.exceptions import AuthenticationError + +# Import at the top to make the patch work correctly +import litellm.llms.github_copilot.chat.transformation + + +def test_github_copilot_config_get_openai_compatible_provider_info(): + """Test the GitHub Copilot configuration provider info retrieval.""" + + config = GithubCopilotConfig() + + # Mock the authenticator to avoid actual API calls + mock_api_key = "gh.test-key-123456789" + config.authenticator = MagicMock() + config.authenticator.get_api_key.return_value = mock_api_key + + # Test with default values + model = "github_copilot/gpt-4" + ( + api_base, + dynamic_api_key, + custom_llm_provider, + ) = config._get_openai_compatible_provider_info( + model=model, + api_base=None, + api_key=None, + custom_llm_provider="github_copilot", + ) + + assert api_base == "https://api.githubcopilot.com" + assert dynamic_api_key == mock_api_key + assert custom_llm_provider == "github_copilot" + + # Test with authentication failure + config.authenticator.get_api_key.side_effect = GetAPIKeyError( + "Failed to get API key" + ) + + with pytest.raises(AuthenticationError) as excinfo: + config._get_openai_compatible_provider_info( + model=model, + api_base=None, + api_key=None, + custom_llm_provider="github_copilot", + ) + + assert "Failed to get API key" in str(excinfo.value) + + +@patch("litellm.litellm_core_utils.get_llm_provider_logic.get_llm_provider") +@patch("litellm.llms.openai.openai.OpenAIChatCompletion.completion") +def test_completion_github_copilot(mock_completion, mock_get_provider): + """Test the completion function with GitHub Copilot provider.""" + + # Mock completion response + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Hello, I'm GitHub Copilot!" + mock_completion.return_value = mock_response + + # Test non-streaming completion + messages = [ + {"role": "system", "content": "You're GitHub Copilot, an AI assistant."}, + {"role": "user", "content": "Hello, who are you?"}, + ] + + # Create a properly formatted headers dictionary + headers = { + "editor-version": "Neovim/0.9.0", + "Copilot-Integration-Id": "vscode-chat", + } + + # Patch the get_llm_provider function instead of the config method + # Make it return the expected tuple directly + mock_get_provider.return_value = ( + "gpt-4", + "github_copilot", + "gh.test-key-123456789", + "https://api.githubcopilot.com", + ) + + response = completion( + model="github_copilot/gpt-4", + messages=messages, + extra_headers=headers, + ) + + assert response is not None + + # Verify the completion call was made with the expected params + mock_completion.assert_called_once() + args, kwargs = mock_completion.call_args + + # Check that the proper authorization header is set + assert "headers" in kwargs + # Check that the model name is correctly formatted + assert ( + kwargs.get("model") == "gpt-4" + ) # Model name should be without provider prefix + assert kwargs.get("messages") == messages + + +@patch("litellm.llms.github_copilot.authenticator.Authenticator.get_api_key") +def test_authenticator_get_api_key(mock_get_api_key): + """Test the Authenticator's get_api_key method.""" + from litellm.llms.github_copilot.authenticator import Authenticator + + # Test successful API key retrieval + mock_get_api_key.return_value = "gh.test-key-123456789" + authenticator = Authenticator() + api_key = authenticator.get_api_key() + + assert api_key == "gh.test-key-123456789" + mock_get_api_key.assert_called_once() + + # Test API key retrieval failure + mock_get_api_key.reset_mock() + mock_get_api_key.side_effect = GetAPIKeyError("Failed to get API key") + authenticator = Authenticator() + + with pytest.raises(GetAPIKeyError) as excinfo: + authenticator.get_api_key() + + assert "Failed to get API key" in str(excinfo.value) + + +def test_completion_github_copilot(stream=False): + try: + litellm.set_verbose = True + messages = [ + {"role": "system", "content": "You are an AI programming assistant."}, + { + "role": "user", + "content": "Write a Python function to calculate fibonacci numbers", + }, + ] + extra_headers = { + "editor-version": "Neovim/0.9.0", + "Copilot-Integration-Id": "vscode-chat", + } + response = completion( + model="github_copilot/gpt-4", + messages=messages, + stream=stream, + extra_headers=extra_headers, + ) + print(response) + + if stream is True: + for chunk in response: + print(chunk) + assert chunk is not None + assert isinstance(chunk, litellm.ModelResponseStream) + assert isinstance(chunk.choices[0], litellm.utils.StreamingChoices) + + else: + assert response is not None + assert isinstance(response, litellm.ModelResponse) + assert response.choices[0].message.content is not None + except Exception as e: + pytest.fail(f"Error occurred: {e}")