diff --git a/tests/llm_translation/test_github_copilot_authenticator.py b/tests/llm_translation/test_github_copilot_authenticator.py new file mode 100644 index 0000000000..95e3982a2a --- /dev/null +++ b/tests/llm_translation/test_github_copilot_authenticator.py @@ -0,0 +1,179 @@ +import os +import json +import time +from datetime import datetime, timedelta +import pytest +from unittest.mock import patch, mock_open, MagicMock + +from litellm.llms.github_copilot.authenticator import Authenticator +from litellm.llms.github_copilot.constants import ( + GetAccessTokenError, + GetDeviceCodeError, + RefreshAPIKeyError, + GetAPIKeyError, + APIKeyExpiredError, +) + + +class TestGitHubCopilotAuthenticator: + @pytest.fixture + def authenticator(self): + with patch("os.path.exists", return_value=False), patch("os.makedirs") as mock_makedirs: + auth = Authenticator() + mock_makedirs.assert_called_once() + return auth + + @pytest.fixture + def mock_http_client(self): + mock_client = MagicMock() + mock_response = MagicMock() + mock_client.get.return_value = mock_response + mock_client.post.return_value = mock_response + mock_response.raise_for_status.return_value = None + return mock_client, mock_response + + def test_init(self): + """Test the initialization of the authenticator.""" + with patch("os.path.exists", return_value=False), patch("os.makedirs") as mock_makedirs: + auth = Authenticator() + assert auth.token_dir.endswith("/github_copilot") + assert auth.access_token_file.endswith("/access-token") + assert auth.api_key_file.endswith("/api-key.json") + mock_makedirs.assert_called_once() + + def test_ensure_token_dir(self): + """Test that the token directory is created if it doesn't exist.""" + with patch("os.path.exists", return_value=False), patch("os.makedirs") as mock_makedirs: + auth = Authenticator() + mock_makedirs.assert_called_once_with(auth.token_dir, exist_ok=True) + + def test_get_github_headers(self, authenticator): + """Test that GitHub headers are correctly generated.""" + headers = authenticator._get_github_headers() + assert "accept" in headers + assert "editor-version" in headers + assert "user-agent" in headers + assert "content-type" in headers + + headers_with_token = authenticator._get_github_headers("test-token") + assert headers_with_token["authorization"] == "token test-token" + + def test_get_access_token_from_file(self, authenticator): + """Test retrieving an access token from a file.""" + mock_token = "mock-access-token" + + with patch("builtins.open", mock_open(read_data=mock_token)): + token = authenticator.get_access_token() + assert token == mock_token + + def test_get_access_token_login(self, authenticator): + """Test logging in to get an access token.""" + mock_token = "mock-access-token" + + with patch.object(authenticator, "_login", return_value=mock_token), \ + patch("builtins.open", mock_open()), \ + patch("builtins.open", side_effect=IOError) as mock_read: + token = authenticator.get_access_token() + assert token == mock_token + authenticator._login.assert_called_once() + + def test_get_access_token_failure(self, authenticator): + """Test that an exception is raised after multiple login failures.""" + with patch.object(authenticator, "_login", side_effect=GetDeviceCodeError("Test error")), \ + patch("builtins.open", side_effect=IOError): + with pytest.raises(GetAccessTokenError): + authenticator.get_access_token() + assert authenticator._login.call_count == 3 + + def test_get_api_key_from_file(self, authenticator): + """Test retrieving an API key from a file.""" + future_time = (datetime.now() + timedelta(hours=1)).timestamp() + mock_api_key_data = json.dumps({"token": "mock-api-key", "expires_at": future_time}) + + with patch("builtins.open", mock_open(read_data=mock_api_key_data)): + api_key = authenticator.get_api_key() + assert api_key == "mock-api-key" + + def test_get_api_key_expired(self, authenticator): + """Test refreshing an expired API key.""" + past_time = (datetime.now() - timedelta(hours=1)).timestamp() + mock_expired_data = json.dumps({"token": "expired-api-key", "expires_at": past_time}) + mock_new_data = {"token": "new-api-key", "expires_at": (datetime.now() + timedelta(hours=1)).timestamp()} + + with patch("builtins.open", mock_open(read_data=mock_expired_data)), \ + patch.object(authenticator, "_refresh_api_key", return_value=mock_new_data), \ + patch("json.dump") as mock_json_dump: + api_key = authenticator.get_api_key() + assert api_key == "new-api-key" + authenticator._refresh_api_key.assert_called_once() + + def test_refresh_api_key(self, authenticator, mock_http_client): + """Test refreshing an API key.""" + mock_client, mock_response = mock_http_client + mock_token = "mock-access-token" + mock_api_key_data = {"token": "new-api-key", "expires_at": 12345} + + with patch.object(authenticator, "get_access_token", return_value=mock_token), \ + patch("litellm.llms.github_copilot.authenticator._get_httpx_client", return_value=mock_client), \ + patch.object(mock_response, "json", return_value=mock_api_key_data): + result = authenticator._refresh_api_key() + assert result == mock_api_key_data + mock_client.get.assert_called_once() + authenticator.get_access_token.assert_called_once() + + def test_refresh_api_key_failure(self, authenticator, mock_http_client): + """Test failure to refresh an API key.""" + mock_client, mock_response = mock_http_client + mock_token = "mock-access-token" + + with patch.object(authenticator, "get_access_token", return_value=mock_token), \ + patch("litellm.llms.github_copilot.authenticator._get_httpx_client", return_value=mock_client), \ + patch.object(mock_response, "json", return_value={}): + with pytest.raises(RefreshAPIKeyError): + authenticator._refresh_api_key() + assert mock_client.get.call_count == 3 + + def test_get_device_code(self, authenticator, mock_http_client): + """Test getting a device code.""" + mock_client, mock_response = mock_http_client + mock_device_code_data = { + "device_code": "mock-device-code", + "user_code": "ABCD-EFGH", + "verification_uri": "https://github.com/login/device" + } + + with patch("litellm.llms.github_copilot.authenticator._get_httpx_client", return_value=mock_client), \ + patch.object(mock_response, "json", return_value=mock_device_code_data): + result = authenticator._get_device_code() + assert result == mock_device_code_data + mock_client.post.assert_called_once() + + def test_poll_for_access_token(self, authenticator, mock_http_client): + """Test polling for an access token.""" + mock_client, mock_response = mock_http_client + mock_token_data = {"access_token": "mock-access-token"} + + with patch("litellm.llms.github_copilot.authenticator._get_httpx_client", return_value=mock_client), \ + patch.object(mock_response, "json", return_value=mock_token_data), \ + patch("time.sleep"): + result = authenticator._poll_for_access_token("mock-device-code") + assert result == "mock-access-token" + mock_client.post.assert_called_once() + + def test_login(self, authenticator): + """Test the login process.""" + mock_device_code_data = { + "device_code": "mock-device-code", + "user_code": "ABCD-EFGH", + "verification_uri": "https://github.com/login/device" + } + mock_token = "mock-access-token" + + with patch.object(authenticator, "_get_device_code", return_value=mock_device_code_data), \ + patch.object(authenticator, "_poll_for_access_token", return_value=mock_token), \ + patch("builtins.print") as mock_print: + result = authenticator._login() + assert result == mock_token + authenticator._get_device_code.assert_called_once() + authenticator._poll_for_access_token.assert_called_once_with("mock-device-code") + mock_print.assert_called_once() \ No newline at end of file