mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
test: add tests for github copilot
This commit is contained in:
parent
e394d45513
commit
00d2d90535
7 changed files with 282 additions and 62 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -78,4 +78,5 @@ litellm/proxy/_experimental/out/model_hub.html
|
||||||
.mypy_cache/*
|
.mypy_cache/*
|
||||||
litellm/proxy/application.log
|
litellm/proxy/application.log
|
||||||
|
|
||||||
**/__pycache__
|
**/__pycache__/**
|
||||||
|
**/*.pyc
|
||||||
|
|
|
@ -578,7 +578,7 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
|
||||||
custom_llm_provider,
|
custom_llm_provider,
|
||||||
) = litellm.GithubCopilotConfig()._get_openai_compatible_provider_info(
|
) = litellm.GithubCopilotConfig()._get_openai_compatible_provider_info(
|
||||||
model, api_base, api_key, custom_llm_provider
|
model, api_base, api_key, custom_llm_provider
|
||||||
)
|
)
|
||||||
if api_base is not None and not isinstance(api_base, str):
|
if api_base is not None and not isinstance(api_base, str):
|
||||||
raise Exception("api base needs to be a string. api_base={}".format(api_base))
|
raise Exception("api base needs to be a string. api_base={}".format(api_base))
|
||||||
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):
|
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):
|
||||||
|
|
|
@ -1,4 +1,11 @@
|
||||||
from .constants import GITHUB_COPILOT_API_BASE, CHAT_COMPLETION_ENDPOINT, GITHUB_COPILOT_MODEL, GetAccessTokenError, GetAPIKeyError, RefreshAPIKeyError
|
from .constants import (
|
||||||
|
GITHUB_COPILOT_API_BASE,
|
||||||
|
CHAT_COMPLETION_ENDPOINT,
|
||||||
|
GITHUB_COPILOT_MODEL,
|
||||||
|
GetAccessTokenError,
|
||||||
|
GetAPIKeyError,
|
||||||
|
RefreshAPIKeyError,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"GITHUB_COPILOT_API_BASE",
|
"GITHUB_COPILOT_API_BASE",
|
||||||
|
@ -6,5 +13,5 @@ __all__ = [
|
||||||
"GITHUB_COPILOT_MODEL",
|
"GITHUB_COPILOT_MODEL",
|
||||||
"GetAccessTokenError",
|
"GetAccessTokenError",
|
||||||
"GetAPIKeyError",
|
"GetAPIKeyError",
|
||||||
"RefreshAPIKeyError"
|
"RefreshAPIKeyError",
|
||||||
]
|
]
|
||||||
|
|
|
@ -23,16 +23,23 @@ from .constants import (
|
||||||
APIKeyExpiredError,
|
APIKeyExpiredError,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Authenticator:
|
class Authenticator:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
# Token storage paths
|
# Token storage paths
|
||||||
self.token_dir = os.getenv("GITHUB_COPILOT_TOKEN_DIR", os.path.expanduser("~/.config/litellm/github_copilot"))
|
self.token_dir = os.getenv(
|
||||||
self.access_token_file = os.path.join(self.token_dir, os.getenv("GITHUB_COPILOT_ACCESS_TOKEN_FILE", "access-token"))
|
"GITHUB_COPILOT_TOKEN_DIR",
|
||||||
self.api_key_file = os.path.join(self.token_dir, os.getenv("GITHUB_COPILOT_API_KEY_FILE", "api-key.json"))
|
os.path.expanduser("~/.config/litellm/github_copilot"),
|
||||||
|
)
|
||||||
|
self.access_token_file = os.path.join(
|
||||||
|
self.token_dir,
|
||||||
|
os.getenv("GITHUB_COPILOT_ACCESS_TOKEN_FILE", "access-token"),
|
||||||
|
)
|
||||||
|
self.api_key_file = os.path.join(
|
||||||
|
self.token_dir, os.getenv("GITHUB_COPILOT_API_KEY_FILE", "api-key.json")
|
||||||
|
)
|
||||||
self._ensure_token_dir()
|
self._ensure_token_dir()
|
||||||
|
|
||||||
self.get_access_token()
|
|
||||||
|
|
||||||
def get_access_token(self) -> str:
|
def get_access_token(self) -> str:
|
||||||
"""
|
"""
|
||||||
Login to Copilot with retry 3 times
|
Login to Copilot with retry 3 times
|
||||||
|
@ -42,12 +49,12 @@ class Authenticator:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
with open(self.access_token_file, 'r') as f:
|
with open(self.access_token_file, "r") as f:
|
||||||
access_token = f.read().strip()
|
access_token = f.read().strip()
|
||||||
return access_token
|
return access_token
|
||||||
except IOError:
|
except IOError:
|
||||||
verbose_logger.warning("Error loading access token from file")
|
verbose_logger.warning("Error loading access token from file")
|
||||||
|
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
try:
|
try:
|
||||||
access_token = self._login()
|
access_token = self._login()
|
||||||
|
@ -55,7 +62,7 @@ class Authenticator:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
with open(self.access_token_file, 'w') as f:
|
with open(self.access_token_file, "w") as f:
|
||||||
f.write(access_token)
|
f.write(access_token)
|
||||||
except IOError:
|
except IOError:
|
||||||
verbose_logger.error("Error saving access token to file")
|
verbose_logger.error("Error saving access token to file")
|
||||||
|
@ -66,10 +73,10 @@ class Authenticator:
|
||||||
def get_api_key(self) -> str:
|
def get_api_key(self) -> str:
|
||||||
"""Get the API key"""
|
"""Get the API key"""
|
||||||
try:
|
try:
|
||||||
with open(self.api_key_file, 'r') as f:
|
with open(self.api_key_file, "r") as f:
|
||||||
api_key_info = json.load(f)
|
api_key_info = json.load(f)
|
||||||
if api_key_info.get('expires_at') > datetime.now().timestamp():
|
if api_key_info.get("expires_at") > datetime.now().timestamp():
|
||||||
return api_key_info.get('token')
|
return api_key_info.get("token")
|
||||||
else:
|
else:
|
||||||
raise APIKeyExpiredError("API key expired")
|
raise APIKeyExpiredError("API key expired")
|
||||||
except IOError:
|
except IOError:
|
||||||
|
@ -78,17 +85,17 @@ class Authenticator:
|
||||||
verbose_logger.warning("Error reading API key from file")
|
verbose_logger.warning("Error reading API key from file")
|
||||||
except APIKeyExpiredError:
|
except APIKeyExpiredError:
|
||||||
verbose_logger.warning("API key expired")
|
verbose_logger.warning("API key expired")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
api_key_info = self._refresh_api_key()
|
api_key_info = self._refresh_api_key()
|
||||||
with open(self.api_key_file, 'w') as f:
|
with open(self.api_key_file, "w") as f:
|
||||||
json.dump(api_key_info, f)
|
json.dump(api_key_info, f)
|
||||||
except IOError:
|
except IOError:
|
||||||
verbose_logger.error("Error saving API key to file")
|
verbose_logger.error("Error saving API key to file")
|
||||||
except RefreshAPIKeyError:
|
except RefreshAPIKeyError:
|
||||||
raise GetAPIKeyError("Failed to refresh API key")
|
raise GetAPIKeyError("Failed to refresh API key")
|
||||||
|
|
||||||
return api_key_info.get('token')
|
return api_key_info.get("token")
|
||||||
|
|
||||||
def _refresh_api_key(self) -> dict:
|
def _refresh_api_key(self) -> dict:
|
||||||
"""
|
"""
|
||||||
|
@ -100,10 +107,10 @@ class Authenticator:
|
||||||
|
|
||||||
access_token = self.get_access_token()
|
access_token = self.get_access_token()
|
||||||
headers = {
|
headers = {
|
||||||
'authorization': f'token {access_token}',
|
"authorization": f"token {access_token}",
|
||||||
'editor-version': 'vscode/1.85.1',
|
"editor-version": "vscode/1.85.1",
|
||||||
'editor-plugin-version': 'copilot/1.155.0',
|
"editor-plugin-version": "copilot/1.155.0",
|
||||||
'user-agent': 'GithubCopilot/1.155.0'
|
"user-agent": "GithubCopilot/1.155.0",
|
||||||
}
|
}
|
||||||
|
|
||||||
max_retries = 3
|
max_retries = 3
|
||||||
|
@ -111,20 +118,18 @@ class Authenticator:
|
||||||
try:
|
try:
|
||||||
sync_client = _get_httpx_client()
|
sync_client = _get_httpx_client()
|
||||||
response = sync_client.get(
|
response = sync_client.get(
|
||||||
'https://api.github.com/copilot_internal/v2/token',
|
"https://api.github.com/copilot_internal/v2/token", headers=headers
|
||||||
headers=headers
|
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
|
|
||||||
if 'token' in response_json:
|
if "token" in response_json:
|
||||||
return response_json
|
return response_json
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
verbose_logger.error(f"Error refreshing API key: {str(e)}")
|
verbose_logger.error(f"Error refreshing API key: {str(e)}")
|
||||||
|
|
||||||
raise RefreshAPIKeyError("Failed to refresh API key")
|
raise RefreshAPIKeyError("Failed to refresh API key")
|
||||||
|
|
||||||
|
|
||||||
def _ensure_token_dir(self) -> None:
|
def _ensure_token_dir(self) -> None:
|
||||||
"""Ensure the token directory exists"""
|
"""Ensure the token directory exists"""
|
||||||
|
@ -143,23 +148,23 @@ class Authenticator:
|
||||||
sync_client = _get_httpx_client()
|
sync_client = _get_httpx_client()
|
||||||
# Get device code
|
# Get device code
|
||||||
resp = sync_client.post(
|
resp = sync_client.post(
|
||||||
'https://github.com/login/device/code',
|
"https://github.com/login/device/code",
|
||||||
headers={
|
headers={
|
||||||
'accept': 'application/json',
|
"accept": "application/json",
|
||||||
'editor-version': 'vscode/1.85.1',
|
"editor-version": "vscode/1.85.1",
|
||||||
'editor-plugin-version': 'copilot/1.155.0',
|
"editor-plugin-version": "copilot/1.155.0",
|
||||||
'content-type': 'application/json',
|
"content-type": "application/json",
|
||||||
'user-agent': 'GithubCopilot/1.155.0',
|
"user-agent": "GithubCopilot/1.155.0",
|
||||||
'accept-encoding': 'gzip,deflate,br'
|
"accept-encoding": "gzip,deflate,br",
|
||||||
},
|
},
|
||||||
json={"client_id": "Iv1.b507a08c87ecfe98", "scope": "read:user"}
|
json={"client_id": "Iv1.b507a08c87ecfe98", "scope": "read:user"},
|
||||||
)
|
)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
resp_json = resp.json()
|
resp_json = resp.json()
|
||||||
|
|
||||||
device_code = resp_json.get('device_code')
|
device_code = resp_json.get("device_code")
|
||||||
user_code = resp_json.get('user_code')
|
user_code = resp_json.get("user_code")
|
||||||
verification_uri = resp_json.get('verification_uri')
|
verification_uri = resp_json.get("verification_uri")
|
||||||
|
|
||||||
if not all([device_code, user_code, verification_uri]):
|
if not all([device_code, user_code, verification_uri]):
|
||||||
verbose_logger.error("Response missing required fields")
|
verbose_logger.error("Response missing required fields")
|
||||||
|
@ -173,8 +178,10 @@ class Authenticator:
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
verbose_logger.error(f"Error getting device code: {str(e)}")
|
verbose_logger.error(f"Error getting device code: {str(e)}")
|
||||||
raise GetDeviceCodeError("Failed to get device code")
|
raise GetDeviceCodeError("Failed to get device code")
|
||||||
|
|
||||||
print(f'Please visit {verification_uri} and enter code {user_code} to authenticate.')
|
print(
|
||||||
|
f"Please visit {verification_uri} and enter code {user_code} to authenticate."
|
||||||
|
)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
|
@ -182,27 +189,27 @@ class Authenticator:
|
||||||
# Get access token
|
# Get access token
|
||||||
try:
|
try:
|
||||||
resp = sync_client.post(
|
resp = sync_client.post(
|
||||||
'https://github.com/login/oauth/access_token',
|
"https://github.com/login/oauth/access_token",
|
||||||
headers={
|
headers={
|
||||||
'accept': 'application/json',
|
"accept": "application/json",
|
||||||
'editor-version': 'vscode/1.85.1',
|
"editor-version": "vscode/1.85.1",
|
||||||
'editor-plugin-version': 'copilot/1.155.0',
|
"editor-plugin-version": "copilot/1.155.0",
|
||||||
'content-type': 'application/json',
|
"content-type": "application/json",
|
||||||
'user-agent': 'GithubCopilot/1.155.0',
|
"user-agent": "GithubCopilot/1.155.0",
|
||||||
'accept-encoding': 'gzip,deflate,br'
|
"accept-encoding": "gzip,deflate,br",
|
||||||
},
|
},
|
||||||
json={
|
json={
|
||||||
"client_id": "Iv1.b507a08c87ecfe98",
|
"client_id": "Iv1.b507a08c87ecfe98",
|
||||||
"device_code": device_code,
|
"device_code": device_code,
|
||||||
"grant_type": "urn:ietf:params:oauth:grant-type:device_code"
|
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
resp_json = resp.json()
|
resp_json = resp.json()
|
||||||
|
|
||||||
if "access_token" in resp_json:
|
if "access_token" in resp_json:
|
||||||
verbose_logger.info("Authentication success!")
|
verbose_logger.info("Authentication success!")
|
||||||
return resp_json['access_token']
|
return resp_json["access_token"]
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
|
@ -211,7 +218,3 @@ class Authenticator:
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
verbose_logger.error(f"Error decoding JSON response: {str(e)}")
|
verbose_logger.error(f"Error decoding JSON response: {str(e)}")
|
||||||
raise GetAccessTokenError("Failed to get access token")
|
raise GetAccessTokenError("Failed to get access token")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,9 @@ from typing import Optional, Tuple
|
||||||
from litellm.llms.openai.openai import OpenAIConfig
|
from litellm.llms.openai.openai import OpenAIConfig
|
||||||
|
|
||||||
from ..authenticator import Authenticator
|
from ..authenticator import Authenticator
|
||||||
|
from ..constants import GetAPIKeyError
|
||||||
|
from litellm.exceptions import AuthenticationError
|
||||||
|
|
||||||
|
|
||||||
class GithubCopilotConfig(OpenAIConfig):
|
class GithubCopilotConfig(OpenAIConfig):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -13,7 +16,6 @@ class GithubCopilotConfig(OpenAIConfig):
|
||||||
custom_llm_provider: str = "openai",
|
custom_llm_provider: str = "openai",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.authenticator = Authenticator()
|
self.authenticator = Authenticator()
|
||||||
|
|
||||||
def _get_openai_compatible_provider_info(
|
def _get_openai_compatible_provider_info(
|
||||||
|
@ -24,6 +26,12 @@ class GithubCopilotConfig(OpenAIConfig):
|
||||||
custom_llm_provider: str,
|
custom_llm_provider: str,
|
||||||
) -> Tuple[Optional[str], Optional[str], str]:
|
) -> Tuple[Optional[str], Optional[str], str]:
|
||||||
api_base = "https://api.githubcopilot.com"
|
api_base = "https://api.githubcopilot.com"
|
||||||
dynamic_api_key = self.authenticator.get_api_key()
|
try:
|
||||||
|
dynamic_api_key = self.authenticator.get_api_key()
|
||||||
|
except GetAPIKeyError as e:
|
||||||
|
raise AuthenticationError(
|
||||||
|
model=model,
|
||||||
|
llm_provider=custom_llm_provider,
|
||||||
|
message=str(e),
|
||||||
|
)
|
||||||
return api_base, dynamic_api_key, custom_llm_provider
|
return api_base, dynamic_api_key, custom_llm_provider
|
||||||
|
|
|
@ -17,9 +17,22 @@ DEFAULT_HEADERS = {
|
||||||
"User-Agent": "litellm",
|
"User-Agent": "litellm",
|
||||||
}
|
}
|
||||||
|
|
||||||
class GetDeviceCodeError(Exception): pass
|
|
||||||
class GetAccessTokenError(Exception): pass
|
|
||||||
class APIKeyExpiredError(Exception): pass
|
|
||||||
class RefreshAPIKeyError(Exception): pass
|
|
||||||
class GetAPIKeyError(Exception): pass
|
|
||||||
|
|
||||||
|
class GetDeviceCodeError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GetAccessTokenError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class APIKeyExpiredError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshAPIKeyError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GetAPIKeyError(Exception):
|
||||||
|
pass
|
||||||
|
|
188
tests/llm_translation/test_github_copilot.py
Normal file
188
tests/llm_translation/test_github_copilot.py
Normal file
|
@ -0,0 +1,188 @@
|
||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
from unittest.mock import AsyncMock, patch, mock_open, MagicMock
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.abspath("../.."))
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
from respx import MockRouter
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import Choices, Message, ModelResponse, ModelResponse, Usage
|
||||||
|
from litellm import completion, acompletion
|
||||||
|
from litellm.llms.github_copilot.chat.transformation import GithubCopilotConfig
|
||||||
|
from litellm.llms.github_copilot.authenticator import Authenticator
|
||||||
|
from litellm.llms.github_copilot.constants import (
|
||||||
|
GetAccessTokenError,
|
||||||
|
GetDeviceCodeError,
|
||||||
|
RefreshAPIKeyError,
|
||||||
|
GetAPIKeyError,
|
||||||
|
APIKeyExpiredError,
|
||||||
|
)
|
||||||
|
from litellm.exceptions import AuthenticationError
|
||||||
|
|
||||||
|
# Import at the top to make the patch work correctly
|
||||||
|
import litellm.llms.github_copilot.chat.transformation
|
||||||
|
|
||||||
|
|
||||||
|
def test_github_copilot_config_get_openai_compatible_provider_info():
|
||||||
|
"""Test the GitHub Copilot configuration provider info retrieval."""
|
||||||
|
|
||||||
|
config = GithubCopilotConfig()
|
||||||
|
|
||||||
|
# Mock the authenticator to avoid actual API calls
|
||||||
|
mock_api_key = "gh.test-key-123456789"
|
||||||
|
config.authenticator = MagicMock()
|
||||||
|
config.authenticator.get_api_key.return_value = mock_api_key
|
||||||
|
|
||||||
|
# Test with default values
|
||||||
|
model = "github_copilot/gpt-4"
|
||||||
|
(
|
||||||
|
api_base,
|
||||||
|
dynamic_api_key,
|
||||||
|
custom_llm_provider,
|
||||||
|
) = config._get_openai_compatible_provider_info(
|
||||||
|
model=model,
|
||||||
|
api_base=None,
|
||||||
|
api_key=None,
|
||||||
|
custom_llm_provider="github_copilot",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert api_base == "https://api.githubcopilot.com"
|
||||||
|
assert dynamic_api_key == mock_api_key
|
||||||
|
assert custom_llm_provider == "github_copilot"
|
||||||
|
|
||||||
|
# Test with authentication failure
|
||||||
|
config.authenticator.get_api_key.side_effect = GetAPIKeyError(
|
||||||
|
"Failed to get API key"
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(AuthenticationError) as excinfo:
|
||||||
|
config._get_openai_compatible_provider_info(
|
||||||
|
model=model,
|
||||||
|
api_base=None,
|
||||||
|
api_key=None,
|
||||||
|
custom_llm_provider="github_copilot",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "Failed to get API key" in str(excinfo.value)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("litellm.litellm_core_utils.get_llm_provider_logic.get_llm_provider")
|
||||||
|
@patch("litellm.llms.openai.openai.OpenAIChatCompletion.completion")
|
||||||
|
def test_completion_github_copilot(mock_completion, mock_get_provider):
|
||||||
|
"""Test the completion function with GitHub Copilot provider."""
|
||||||
|
|
||||||
|
# Mock completion response
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.choices = [MagicMock()]
|
||||||
|
mock_response.choices[0].message.content = "Hello, I'm GitHub Copilot!"
|
||||||
|
mock_completion.return_value = mock_response
|
||||||
|
|
||||||
|
# Test non-streaming completion
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You're GitHub Copilot, an AI assistant."},
|
||||||
|
{"role": "user", "content": "Hello, who are you?"},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create a properly formatted headers dictionary
|
||||||
|
headers = {
|
||||||
|
"editor-version": "Neovim/0.9.0",
|
||||||
|
"Copilot-Integration-Id": "vscode-chat",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Patch the get_llm_provider function instead of the config method
|
||||||
|
# Make it return the expected tuple directly
|
||||||
|
mock_get_provider.return_value = (
|
||||||
|
"gpt-4",
|
||||||
|
"github_copilot",
|
||||||
|
"gh.test-key-123456789",
|
||||||
|
"https://api.githubcopilot.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="github_copilot/gpt-4",
|
||||||
|
messages=messages,
|
||||||
|
extra_headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response is not None
|
||||||
|
|
||||||
|
# Verify the completion call was made with the expected params
|
||||||
|
mock_completion.assert_called_once()
|
||||||
|
args, kwargs = mock_completion.call_args
|
||||||
|
|
||||||
|
# Check that the proper authorization header is set
|
||||||
|
assert "headers" in kwargs
|
||||||
|
# Check that the model name is correctly formatted
|
||||||
|
assert (
|
||||||
|
kwargs.get("model") == "gpt-4"
|
||||||
|
) # Model name should be without provider prefix
|
||||||
|
assert kwargs.get("messages") == messages
|
||||||
|
|
||||||
|
|
||||||
|
@patch("litellm.llms.github_copilot.authenticator.Authenticator.get_api_key")
|
||||||
|
def test_authenticator_get_api_key(mock_get_api_key):
|
||||||
|
"""Test the Authenticator's get_api_key method."""
|
||||||
|
from litellm.llms.github_copilot.authenticator import Authenticator
|
||||||
|
|
||||||
|
# Test successful API key retrieval
|
||||||
|
mock_get_api_key.return_value = "gh.test-key-123456789"
|
||||||
|
authenticator = Authenticator()
|
||||||
|
api_key = authenticator.get_api_key()
|
||||||
|
|
||||||
|
assert api_key == "gh.test-key-123456789"
|
||||||
|
mock_get_api_key.assert_called_once()
|
||||||
|
|
||||||
|
# Test API key retrieval failure
|
||||||
|
mock_get_api_key.reset_mock()
|
||||||
|
mock_get_api_key.side_effect = GetAPIKeyError("Failed to get API key")
|
||||||
|
authenticator = Authenticator()
|
||||||
|
|
||||||
|
with pytest.raises(GetAPIKeyError) as excinfo:
|
||||||
|
authenticator.get_api_key()
|
||||||
|
|
||||||
|
assert "Failed to get API key" in str(excinfo.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion_github_copilot(stream=False):
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You are an AI programming assistant."},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Write a Python function to calculate fibonacci numbers",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
extra_headers = {
|
||||||
|
"editor-version": "Neovim/0.9.0",
|
||||||
|
"Copilot-Integration-Id": "vscode-chat",
|
||||||
|
}
|
||||||
|
response = completion(
|
||||||
|
model="github_copilot/gpt-4",
|
||||||
|
messages=messages,
|
||||||
|
stream=stream,
|
||||||
|
extra_headers=extra_headers,
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
if stream is True:
|
||||||
|
for chunk in response:
|
||||||
|
print(chunk)
|
||||||
|
assert chunk is not None
|
||||||
|
assert isinstance(chunk, litellm.ModelResponseStream)
|
||||||
|
assert isinstance(chunk.choices[0], litellm.utils.StreamingChoices)
|
||||||
|
|
||||||
|
else:
|
||||||
|
assert response is not None
|
||||||
|
assert isinstance(response, litellm.ModelResponse)
|
||||||
|
assert response.choices[0].message.content is not None
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
Loading…
Add table
Add a link
Reference in a new issue