mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
feat: add support for copilot provider (#8577)
* feat: add support for copilot provider * test: add tests for github copilot * chore: clean up github copilot authenticator * test: add test for github copilot authenticator * test: add test for github copilot for sonnet 3.7 thought model --------- Co-authored-by: Krish Dholakia <krrishdholakia@gmail.com>
This commit is contained in:
parent
312286c588
commit
7ccccff39a
11 changed files with 801 additions and 4 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -77,5 +77,3 @@ litellm/proxy/_experimental/out/404.html
|
||||||
litellm/proxy/_experimental/out/model_hub.html
|
litellm/proxy/_experimental/out/model_hub.html
|
||||||
.mypy_cache/*
|
.mypy_cache/*
|
||||||
litellm/proxy/application.log
|
litellm/proxy/application.log
|
||||||
tests/llm_translation/vertex_test_account.json
|
|
||||||
tests/llm_translation/test_vertex_key.json
|
|
||||||
|
|
|
@ -980,6 +980,7 @@ from .llms.azure.chat.o_series_transformation import AzureOpenAIO1Config
|
||||||
from .llms.watsonx.completion.transformation import IBMWatsonXAIConfig
|
from .llms.watsonx.completion.transformation import IBMWatsonXAIConfig
|
||||||
from .llms.watsonx.chat.transformation import IBMWatsonXChatConfig
|
from .llms.watsonx.chat.transformation import IBMWatsonXChatConfig
|
||||||
from .llms.watsonx.embed.transformation import IBMWatsonXEmbeddingConfig
|
from .llms.watsonx.embed.transformation import IBMWatsonXEmbeddingConfig
|
||||||
|
from .llms.github_copilot.chat.transformation import GithubCopilotConfig
|
||||||
from .main import * # type: ignore
|
from .main import * # type: ignore
|
||||||
from .integrations import *
|
from .integrations import *
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
|
|
|
@ -79,6 +79,7 @@ LITELLM_CHAT_PROVIDERS = [
|
||||||
"hosted_vllm",
|
"hosted_vllm",
|
||||||
"lm_studio",
|
"lm_studio",
|
||||||
"galadriel",
|
"galadriel",
|
||||||
|
"github_copilot", # GitHub Copilot Chat API
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -138,7 +139,7 @@ openai_compatible_endpoints: List = [
|
||||||
"https://api.friendli.ai/serverless/v1",
|
"https://api.friendli.ai/serverless/v1",
|
||||||
"api.sambanova.ai/v1",
|
"api.sambanova.ai/v1",
|
||||||
"api.x.ai/v1",
|
"api.x.ai/v1",
|
||||||
"api.galadriel.ai/v1",
|
"api.galadriel.ai/v1"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -168,6 +169,7 @@ openai_compatible_providers: List = [
|
||||||
"hosted_vllm",
|
"hosted_vllm",
|
||||||
"lm_studio",
|
"lm_studio",
|
||||||
"galadriel",
|
"galadriel",
|
||||||
|
"github_copilot", # GitHub Copilot Chat API
|
||||||
]
|
]
|
||||||
openai_text_completion_compatible_providers: List = (
|
openai_text_completion_compatible_providers: List = (
|
||||||
[ # providers that support `/v1/completions`
|
[ # providers that support `/v1/completions`
|
||||||
|
|
|
@ -571,6 +571,14 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
|
||||||
or "https://api.galadriel.com/v1"
|
or "https://api.galadriel.com/v1"
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
dynamic_api_key = api_key or get_secret_str("GALADRIEL_API_KEY")
|
dynamic_api_key = api_key or get_secret_str("GALADRIEL_API_KEY")
|
||||||
|
elif custom_llm_provider == "github_copilot":
|
||||||
|
(
|
||||||
|
api_base,
|
||||||
|
dynamic_api_key,
|
||||||
|
custom_llm_provider,
|
||||||
|
) = litellm.GithubCopilotConfig()._get_openai_compatible_provider_info(
|
||||||
|
model, api_base, api_key, custom_llm_provider
|
||||||
|
)
|
||||||
if api_base is not None and not isinstance(api_base, str):
|
if api_base is not None and not isinstance(api_base, str):
|
||||||
raise Exception("api base needs to be a string. api_base={}".format(api_base))
|
raise Exception("api base needs to be a string. api_base={}".format(api_base))
|
||||||
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):
|
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):
|
||||||
|
|
17
litellm/llms/github_copilot/__init__.py
Normal file
17
litellm/llms/github_copilot/__init__.py
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
from .constants import (
|
||||||
|
GITHUB_COPILOT_API_BASE,
|
||||||
|
CHAT_COMPLETION_ENDPOINT,
|
||||||
|
GITHUB_COPILOT_MODEL,
|
||||||
|
GetAccessTokenError,
|
||||||
|
GetAPIKeyError,
|
||||||
|
RefreshAPIKeyError,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"GITHUB_COPILOT_API_BASE",
|
||||||
|
"CHAT_COMPLETION_ENDPOINT",
|
||||||
|
"GITHUB_COPILOT_MODEL",
|
||||||
|
"GetAccessTokenError",
|
||||||
|
"GetAPIKeyError",
|
||||||
|
"RefreshAPIKeyError",
|
||||||
|
]
|
292
litellm/llms/github_copilot/authenticator.py
Normal file
292
litellm/llms/github_copilot/authenticator.py
Normal file
|
@ -0,0 +1,292 @@
|
||||||
|
import os
|
||||||
|
import httpx
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
from litellm.caching import InMemoryCache
|
||||||
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
_get_httpx_client,
|
||||||
|
get_async_httpx_client,
|
||||||
|
httpxSpecialProvider,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .constants import (
|
||||||
|
GetAccessTokenError,
|
||||||
|
GetDeviceCodeError,
|
||||||
|
RefreshAPIKeyError,
|
||||||
|
GetAPIKeyError,
|
||||||
|
APIKeyExpiredError,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
GITHUB_CLIENT_ID = "Iv1.b507a08c87ecfe98"
|
||||||
|
GITHUB_DEVICE_CODE_URL = "https://github.com/login/device/code"
|
||||||
|
GITHUB_ACCESS_TOKEN_URL = "https://github.com/login/oauth/access_token"
|
||||||
|
GITHUB_API_KEY_URL = "https://api.github.com/copilot_internal/v2/token"
|
||||||
|
|
||||||
|
|
||||||
|
class Authenticator:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize the GitHub Copilot authenticator with configurable token paths."""
|
||||||
|
# Token storage paths
|
||||||
|
self.token_dir = os.getenv(
|
||||||
|
"GITHUB_COPILOT_TOKEN_DIR",
|
||||||
|
os.path.expanduser("~/.config/litellm/github_copilot"),
|
||||||
|
)
|
||||||
|
self.access_token_file = os.path.join(
|
||||||
|
self.token_dir,
|
||||||
|
os.getenv("GITHUB_COPILOT_ACCESS_TOKEN_FILE", "access-token"),
|
||||||
|
)
|
||||||
|
self.api_key_file = os.path.join(
|
||||||
|
self.token_dir, os.getenv("GITHUB_COPILOT_API_KEY_FILE", "api-key.json")
|
||||||
|
)
|
||||||
|
self._ensure_token_dir()
|
||||||
|
|
||||||
|
def get_access_token(self) -> str:
|
||||||
|
"""
|
||||||
|
Login to Copilot with retry 3 times.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The GitHub access token.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
GetAccessTokenError: If unable to obtain an access token after retries.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with open(self.access_token_file, "r") as f:
|
||||||
|
access_token = f.read().strip()
|
||||||
|
if access_token:
|
||||||
|
return access_token
|
||||||
|
except IOError:
|
||||||
|
verbose_logger.warning("No existing access token found or error reading file")
|
||||||
|
|
||||||
|
for attempt in range(3):
|
||||||
|
verbose_logger.debug(f"Access token acquisition attempt {attempt + 1}/3")
|
||||||
|
try:
|
||||||
|
access_token = self._login()
|
||||||
|
try:
|
||||||
|
with open(self.access_token_file, "w") as f:
|
||||||
|
f.write(access_token)
|
||||||
|
except IOError:
|
||||||
|
verbose_logger.error("Error saving access token to file")
|
||||||
|
return access_token
|
||||||
|
except (GetDeviceCodeError, GetAccessTokenError, RefreshAPIKeyError) as e:
|
||||||
|
verbose_logger.warning(f"Failed attempt {attempt + 1}: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
raise GetAccessTokenError("Failed to get access token after 3 attempts")
|
||||||
|
|
||||||
|
def get_api_key(self) -> str:
|
||||||
|
"""
|
||||||
|
Get the API key, refreshing if necessary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The GitHub Copilot API key.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
GetAPIKeyError: If unable to obtain an API key.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with open(self.api_key_file, "r") as f:
|
||||||
|
api_key_info = json.load(f)
|
||||||
|
if api_key_info.get("expires_at", 0) > datetime.now().timestamp():
|
||||||
|
return api_key_info.get("token")
|
||||||
|
else:
|
||||||
|
verbose_logger.warning("API key expired, refreshing")
|
||||||
|
raise APIKeyExpiredError("API key expired")
|
||||||
|
except IOError:
|
||||||
|
verbose_logger.warning("No API key file found or error opening file")
|
||||||
|
except (json.JSONDecodeError, KeyError) as e:
|
||||||
|
verbose_logger.warning(f"Error reading API key from file: {str(e)}")
|
||||||
|
except APIKeyExpiredError:
|
||||||
|
pass # Already logged in the try block
|
||||||
|
|
||||||
|
try:
|
||||||
|
api_key_info = self._refresh_api_key()
|
||||||
|
with open(self.api_key_file, "w") as f:
|
||||||
|
json.dump(api_key_info, f)
|
||||||
|
return api_key_info.get("token")
|
||||||
|
except IOError as e:
|
||||||
|
verbose_logger.error(f"Error saving API key to file: {str(e)}")
|
||||||
|
raise GetAPIKeyError(f"Failed to save API key: {str(e)}")
|
||||||
|
except RefreshAPIKeyError as e:
|
||||||
|
raise GetAPIKeyError(f"Failed to refresh API key: {str(e)}")
|
||||||
|
|
||||||
|
def _refresh_api_key(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Refresh the API key using the access token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: The API key information including token and expiration.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RefreshAPIKeyError: If unable to refresh the API key.
|
||||||
|
"""
|
||||||
|
access_token = self.get_access_token()
|
||||||
|
headers = self._get_github_headers(access_token)
|
||||||
|
|
||||||
|
max_retries = 3
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
sync_client = _get_httpx_client()
|
||||||
|
response = sync_client.get(
|
||||||
|
GITHUB_API_KEY_URL, headers=headers
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
response_json = response.json()
|
||||||
|
|
||||||
|
if "token" in response_json:
|
||||||
|
return response_json
|
||||||
|
else:
|
||||||
|
verbose_logger.warning(f"API key response missing token: {response_json}")
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
verbose_logger.error(f"HTTP error refreshing API key (attempt {attempt+1}/{max_retries}): {str(e)}")
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.error(f"Unexpected error refreshing API key: {str(e)}")
|
||||||
|
|
||||||
|
raise RefreshAPIKeyError("Failed to refresh API key after maximum retries")
|
||||||
|
|
||||||
|
def _ensure_token_dir(self) -> None:
|
||||||
|
"""Ensure the token directory exists."""
|
||||||
|
if not os.path.exists(self.token_dir):
|
||||||
|
os.makedirs(self.token_dir, exist_ok=True)
|
||||||
|
|
||||||
|
def _get_github_headers(self, access_token: Optional[str] = None) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
Generate standard GitHub headers for API requests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
access_token: Optional access token to include in the headers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, str]: Headers for GitHub API requests.
|
||||||
|
"""
|
||||||
|
headers = {
|
||||||
|
"accept": "application/json",
|
||||||
|
"editor-version": "vscode/1.85.1",
|
||||||
|
"editor-plugin-version": "copilot/1.155.0",
|
||||||
|
"user-agent": "GithubCopilot/1.155.0",
|
||||||
|
"accept-encoding": "gzip,deflate,br",
|
||||||
|
}
|
||||||
|
|
||||||
|
if access_token:
|
||||||
|
headers["authorization"] = f"token {access_token}"
|
||||||
|
|
||||||
|
if "content-type" not in headers:
|
||||||
|
headers["content-type"] = "application/json"
|
||||||
|
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def _get_device_code(self) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
Get a device code for GitHub authentication.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, str]: Device code information.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
GetDeviceCodeError: If unable to get a device code.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
sync_client = _get_httpx_client()
|
||||||
|
resp = sync_client.post(
|
||||||
|
GITHUB_DEVICE_CODE_URL,
|
||||||
|
headers=self._get_github_headers(),
|
||||||
|
json={"client_id": GITHUB_CLIENT_ID, "scope": "read:user"},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
resp_json = resp.json()
|
||||||
|
|
||||||
|
required_fields = ["device_code", "user_code", "verification_uri"]
|
||||||
|
if not all(field in resp_json for field in required_fields):
|
||||||
|
verbose_logger.error(f"Response missing required fields: {resp_json}")
|
||||||
|
raise GetDeviceCodeError("Response missing required fields")
|
||||||
|
|
||||||
|
return resp_json
|
||||||
|
except httpx.HttpError as e:
|
||||||
|
verbose_logger.error(f"HTTP error getting device code: {str(e)}")
|
||||||
|
raise GetDeviceCodeError(f"Failed to get device code: {str(e)}")
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
verbose_logger.error(f"Error decoding JSON response: {str(e)}")
|
||||||
|
raise GetDeviceCodeError(f"Failed to decode device code response: {str(e)}")
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.error(f"Unexpected error getting device code: {str(e)}")
|
||||||
|
raise GetDeviceCodeError(f"Failed to get device code: {str(e)}")
|
||||||
|
|
||||||
|
def _poll_for_access_token(self, device_code: str) -> str:
|
||||||
|
"""
|
||||||
|
Poll for an access token after user authentication.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device_code: The device code to use for polling.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The access token.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
GetAccessTokenError: If unable to get an access token.
|
||||||
|
"""
|
||||||
|
sync_client = _get_httpx_client()
|
||||||
|
max_attempts = 12 # 1 minute (12 * 5 seconds)
|
||||||
|
|
||||||
|
for attempt in range(max_attempts):
|
||||||
|
try:
|
||||||
|
resp = sync_client.post(
|
||||||
|
GITHUB_ACCESS_TOKEN_URL,
|
||||||
|
headers=self._get_github_headers(),
|
||||||
|
json={
|
||||||
|
"client_id": GITHUB_CLIENT_ID,
|
||||||
|
"device_code": device_code,
|
||||||
|
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
resp_json = resp.json()
|
||||||
|
|
||||||
|
if "access_token" in resp_json:
|
||||||
|
verbose_logger.info("Authentication successful!")
|
||||||
|
return resp_json["access_token"]
|
||||||
|
elif "error" in resp_json and resp_json.get("error") == "authorization_pending":
|
||||||
|
verbose_logger.debug(f"Authorization pending (attempt {attempt+1}/{max_attempts})")
|
||||||
|
else:
|
||||||
|
verbose_logger.warning(f"Unexpected response: {resp_json}")
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
verbose_logger.error(f"HTTP error polling for access token: {str(e)}")
|
||||||
|
raise GetAccessTokenError(f"Failed to get access token: {str(e)}")
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
verbose_logger.error(f"Error decoding JSON response: {str(e)}")
|
||||||
|
raise GetAccessTokenError(f"Failed to decode access token response: {str(e)}")
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.error(f"Unexpected error polling for access token: {str(e)}")
|
||||||
|
raise GetAccessTokenError(f"Failed to get access token: {str(e)}")
|
||||||
|
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
raise GetAccessTokenError("Timed out waiting for user to authorize the device")
|
||||||
|
|
||||||
|
def _login(self) -> str:
|
||||||
|
"""
|
||||||
|
Login to GitHub Copilot using device code flow.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The GitHub access token.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
GetDeviceCodeError: If unable to get a device code.
|
||||||
|
GetAccessTokenError: If unable to get an access token.
|
||||||
|
"""
|
||||||
|
device_code_info = self._get_device_code()
|
||||||
|
|
||||||
|
device_code = device_code_info["device_code"]
|
||||||
|
user_code = device_code_info["user_code"]
|
||||||
|
verification_uri = device_code_info["verification_uri"]
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Please visit {verification_uri} and enter code {user_code} to authenticate."
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._poll_for_access_token(device_code)
|
37
litellm/llms/github_copilot/chat/transformation.py
Normal file
37
litellm/llms/github_copilot/chat/transformation.py
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
from litellm.llms.openai.openai import OpenAIConfig
|
||||||
|
|
||||||
|
from ..authenticator import Authenticator
|
||||||
|
from ..constants import GetAPIKeyError
|
||||||
|
from litellm.exceptions import AuthenticationError
|
||||||
|
|
||||||
|
|
||||||
|
class GithubCopilotConfig(OpenAIConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
custom_llm_provider: str = "openai",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.authenticator = Authenticator()
|
||||||
|
|
||||||
|
def _get_openai_compatible_provider_info(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_key: Optional[str],
|
||||||
|
custom_llm_provider: str,
|
||||||
|
) -> Tuple[Optional[str], Optional[str], str]:
|
||||||
|
api_base = "https://api.githubcopilot.com"
|
||||||
|
try:
|
||||||
|
dynamic_api_key = self.authenticator.get_api_key()
|
||||||
|
except GetAPIKeyError as e:
|
||||||
|
raise AuthenticationError(
|
||||||
|
model=model,
|
||||||
|
llm_provider=custom_llm_provider,
|
||||||
|
message=str(e),
|
||||||
|
)
|
||||||
|
return api_base, dynamic_api_key, custom_llm_provider
|
38
litellm/llms/github_copilot/constants.py
Normal file
38
litellm/llms/github_copilot/constants.py
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
"""
|
||||||
|
Constants for Copilot integration
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Copilot API endpoints
|
||||||
|
GITHUB_COPILOT_API_BASE = "https://api.github.com/copilot/v1"
|
||||||
|
CHAT_COMPLETION_ENDPOINT = "/chat/completions"
|
||||||
|
|
||||||
|
# Model names
|
||||||
|
GITHUB_COPILOT_MODEL = "gpt-4o" # The model identifier for Copilot
|
||||||
|
|
||||||
|
# Request headers
|
||||||
|
DEFAULT_HEADERS = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "application/json",
|
||||||
|
"User-Agent": "litellm",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class GetDeviceCodeError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GetAccessTokenError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class APIKeyExpiredError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshAPIKeyError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GetAPIKeyError(Exception):
|
||||||
|
pass
|
|
@ -1914,6 +1914,7 @@ class LlmProviders(str, Enum):
|
||||||
HUMANLOOP = "humanloop"
|
HUMANLOOP = "humanloop"
|
||||||
TOPAZ = "topaz"
|
TOPAZ = "topaz"
|
||||||
ASSEMBLYAI = "assemblyai"
|
ASSEMBLYAI = "assemblyai"
|
||||||
|
GITHUB_COPILOT = "github_copilot"
|
||||||
|
|
||||||
|
|
||||||
# Create a set of all provider values for quick lookup
|
# Create a set of all provider values for quick lookup
|
||||||
|
|
224
tests/llm_translation/test_github_copilot.py
Normal file
224
tests/llm_translation/test_github_copilot.py
Normal file
|
@ -0,0 +1,224 @@
|
||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
from unittest.mock import AsyncMock, patch, mock_open, MagicMock
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.abspath("../.."))
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
from respx import MockRouter
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import Choices, Message, ModelResponse, ModelResponse, Usage
|
||||||
|
from litellm import completion, acompletion
|
||||||
|
from litellm.llms.github_copilot.chat.transformation import GithubCopilotConfig
|
||||||
|
from litellm.llms.github_copilot.authenticator import Authenticator
|
||||||
|
from litellm.llms.github_copilot.constants import (
|
||||||
|
GetAccessTokenError,
|
||||||
|
GetDeviceCodeError,
|
||||||
|
RefreshAPIKeyError,
|
||||||
|
GetAPIKeyError,
|
||||||
|
APIKeyExpiredError,
|
||||||
|
)
|
||||||
|
from litellm.exceptions import AuthenticationError
|
||||||
|
|
||||||
|
# Import at the top to make the patch work correctly
|
||||||
|
import litellm.llms.github_copilot.chat.transformation
|
||||||
|
|
||||||
|
|
||||||
|
def test_github_copilot_config_get_openai_compatible_provider_info():
|
||||||
|
"""Test the GitHub Copilot configuration provider info retrieval."""
|
||||||
|
|
||||||
|
config = GithubCopilotConfig()
|
||||||
|
|
||||||
|
# Mock the authenticator to avoid actual API calls
|
||||||
|
mock_api_key = "gh.test-key-123456789"
|
||||||
|
config.authenticator = MagicMock()
|
||||||
|
config.authenticator.get_api_key.return_value = mock_api_key
|
||||||
|
|
||||||
|
# Test with default values
|
||||||
|
model = "github_copilot/gpt-4"
|
||||||
|
(
|
||||||
|
api_base,
|
||||||
|
dynamic_api_key,
|
||||||
|
custom_llm_provider,
|
||||||
|
) = config._get_openai_compatible_provider_info(
|
||||||
|
model=model,
|
||||||
|
api_base=None,
|
||||||
|
api_key=None,
|
||||||
|
custom_llm_provider="github_copilot",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert api_base == "https://api.githubcopilot.com"
|
||||||
|
assert dynamic_api_key == mock_api_key
|
||||||
|
assert custom_llm_provider == "github_copilot"
|
||||||
|
|
||||||
|
# Test with authentication failure
|
||||||
|
config.authenticator.get_api_key.side_effect = GetAPIKeyError(
|
||||||
|
"Failed to get API key"
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(AuthenticationError) as excinfo:
|
||||||
|
config._get_openai_compatible_provider_info(
|
||||||
|
model=model,
|
||||||
|
api_base=None,
|
||||||
|
api_key=None,
|
||||||
|
custom_llm_provider="github_copilot",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "Failed to get API key" in str(excinfo.value)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("litellm.litellm_core_utils.get_llm_provider_logic.get_llm_provider")
|
||||||
|
@patch("litellm.llms.openai.openai.OpenAIChatCompletion.completion")
|
||||||
|
def test_completion_github_copilot_mock_response(mock_completion, mock_get_provider):
|
||||||
|
"""Test the completion function with GitHub Copilot provider."""
|
||||||
|
|
||||||
|
# Mock completion response
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.choices = [MagicMock()]
|
||||||
|
mock_response.choices[0].message.content = "Hello, I'm GitHub Copilot!"
|
||||||
|
mock_completion.return_value = mock_response
|
||||||
|
|
||||||
|
# Test non-streaming completion
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You're GitHub Copilot, an AI assistant."},
|
||||||
|
{"role": "user", "content": "Hello, who are you?"},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create a properly formatted headers dictionary
|
||||||
|
headers = {
|
||||||
|
"editor-version": "Neovim/0.9.0",
|
||||||
|
"Copilot-Integration-Id": "vscode-chat",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Patch the get_llm_provider function instead of the config method
|
||||||
|
# Make it return the expected tuple directly
|
||||||
|
mock_get_provider.return_value = (
|
||||||
|
"gpt-4",
|
||||||
|
"github_copilot",
|
||||||
|
"gh.test-key-123456789",
|
||||||
|
"https://api.githubcopilot.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="github_copilot/gpt-4",
|
||||||
|
messages=messages,
|
||||||
|
extra_headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response is not None
|
||||||
|
|
||||||
|
# Verify the completion call was made with the expected params
|
||||||
|
mock_completion.assert_called_once()
|
||||||
|
args, kwargs = mock_completion.call_args
|
||||||
|
|
||||||
|
# Check that the proper authorization header is set
|
||||||
|
assert "headers" in kwargs
|
||||||
|
# Check that the model name is correctly formatted
|
||||||
|
assert (
|
||||||
|
kwargs.get("model") == "gpt-4"
|
||||||
|
) # Model name should be without provider prefix
|
||||||
|
assert kwargs.get("messages") == messages
|
||||||
|
|
||||||
|
|
||||||
|
@patch("litellm.llms.github_copilot.authenticator.Authenticator.get_api_key")
|
||||||
|
def test_authenticator_get_api_key(mock_get_api_key):
|
||||||
|
"""Test the Authenticator's get_api_key method."""
|
||||||
|
from litellm.llms.github_copilot.authenticator import Authenticator
|
||||||
|
|
||||||
|
# Test successful API key retrieval
|
||||||
|
mock_get_api_key.return_value = "gh.test-key-123456789"
|
||||||
|
authenticator = Authenticator()
|
||||||
|
api_key = authenticator.get_api_key()
|
||||||
|
|
||||||
|
assert api_key == "gh.test-key-123456789"
|
||||||
|
mock_get_api_key.assert_called_once()
|
||||||
|
|
||||||
|
# Test API key retrieval failure
|
||||||
|
mock_get_api_key.reset_mock()
|
||||||
|
mock_get_api_key.side_effect = GetAPIKeyError("Failed to get API key")
|
||||||
|
authenticator = Authenticator()
|
||||||
|
|
||||||
|
with pytest.raises(GetAPIKeyError) as excinfo:
|
||||||
|
authenticator.get_api_key()
|
||||||
|
|
||||||
|
assert "Failed to get API key" in str(excinfo.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion_github_copilot(stream=False):
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You are an AI programming assistant."},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Write a Python function to calculate fibonacci numbers",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
extra_headers = {
|
||||||
|
"editor-version": "Neovim/0.9.0",
|
||||||
|
"Copilot-Integration-Id": "vscode-chat",
|
||||||
|
}
|
||||||
|
response = completion(
|
||||||
|
model="github_copilot/gpt-4",
|
||||||
|
messages=messages,
|
||||||
|
stream=stream,
|
||||||
|
extra_headers=extra_headers,
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
if stream is True:
|
||||||
|
for chunk in response:
|
||||||
|
print(chunk)
|
||||||
|
assert chunk is not None
|
||||||
|
assert isinstance(chunk, litellm.ModelResponseStream)
|
||||||
|
assert isinstance(chunk.choices[0], litellm.utils.StreamingChoices)
|
||||||
|
|
||||||
|
else:
|
||||||
|
assert response is not None
|
||||||
|
assert isinstance(response, litellm.ModelResponse)
|
||||||
|
assert response.choices[0].message.content is not None
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
def test_completion_github_copilot_sonnet_3_7_thought(stream=False):
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You are an AI programming assistant."},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Write a Python function to calculate fibonacci numbers",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
extra_headers = {
|
||||||
|
"editor-version": "Neovim/0.9.0",
|
||||||
|
"Copilot-Integration-Id": "vscode-chat",
|
||||||
|
}
|
||||||
|
response = completion(
|
||||||
|
model="github_copilot/claude-3.7-sonnet-thought",
|
||||||
|
messages=messages,
|
||||||
|
stream=stream,
|
||||||
|
extra_headers=extra_headers,
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
if stream is True:
|
||||||
|
for chunk in response:
|
||||||
|
print(chunk)
|
||||||
|
assert chunk is not None
|
||||||
|
assert isinstance(chunk, litellm.ModelResponseStream)
|
||||||
|
assert isinstance(chunk.choices[0], litellm.utils.StreamingChoices)
|
||||||
|
|
||||||
|
else:
|
||||||
|
assert response is not None
|
||||||
|
assert isinstance(response, litellm.ModelResponse)
|
||||||
|
assert response.choices[0].message.content is not None
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
179
tests/llm_translation/test_github_copilot_authenticator.py
Normal file
179
tests/llm_translation/test_github_copilot_authenticator.py
Normal file
|
@ -0,0 +1,179 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, mock_open, MagicMock
|
||||||
|
|
||||||
|
from litellm.llms.github_copilot.authenticator import Authenticator
|
||||||
|
from litellm.llms.github_copilot.constants import (
|
||||||
|
GetAccessTokenError,
|
||||||
|
GetDeviceCodeError,
|
||||||
|
RefreshAPIKeyError,
|
||||||
|
GetAPIKeyError,
|
||||||
|
APIKeyExpiredError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGitHubCopilotAuthenticator:
|
||||||
|
@pytest.fixture
|
||||||
|
def authenticator(self):
|
||||||
|
with patch("os.path.exists", return_value=False), patch("os.makedirs") as mock_makedirs:
|
||||||
|
auth = Authenticator()
|
||||||
|
mock_makedirs.assert_called_once()
|
||||||
|
return auth
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_http_client(self):
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_client.get.return_value = mock_response
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
return mock_client, mock_response
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
"""Test the initialization of the authenticator."""
|
||||||
|
with patch("os.path.exists", return_value=False), patch("os.makedirs") as mock_makedirs:
|
||||||
|
auth = Authenticator()
|
||||||
|
assert auth.token_dir.endswith("/github_copilot")
|
||||||
|
assert auth.access_token_file.endswith("/access-token")
|
||||||
|
assert auth.api_key_file.endswith("/api-key.json")
|
||||||
|
mock_makedirs.assert_called_once()
|
||||||
|
|
||||||
|
def test_ensure_token_dir(self):
|
||||||
|
"""Test that the token directory is created if it doesn't exist."""
|
||||||
|
with patch("os.path.exists", return_value=False), patch("os.makedirs") as mock_makedirs:
|
||||||
|
auth = Authenticator()
|
||||||
|
mock_makedirs.assert_called_once_with(auth.token_dir, exist_ok=True)
|
||||||
|
|
||||||
|
def test_get_github_headers(self, authenticator):
|
||||||
|
"""Test that GitHub headers are correctly generated."""
|
||||||
|
headers = authenticator._get_github_headers()
|
||||||
|
assert "accept" in headers
|
||||||
|
assert "editor-version" in headers
|
||||||
|
assert "user-agent" in headers
|
||||||
|
assert "content-type" in headers
|
||||||
|
|
||||||
|
headers_with_token = authenticator._get_github_headers("test-token")
|
||||||
|
assert headers_with_token["authorization"] == "token test-token"
|
||||||
|
|
||||||
|
def test_get_access_token_from_file(self, authenticator):
|
||||||
|
"""Test retrieving an access token from a file."""
|
||||||
|
mock_token = "mock-access-token"
|
||||||
|
|
||||||
|
with patch("builtins.open", mock_open(read_data=mock_token)):
|
||||||
|
token = authenticator.get_access_token()
|
||||||
|
assert token == mock_token
|
||||||
|
|
||||||
|
def test_get_access_token_login(self, authenticator):
|
||||||
|
"""Test logging in to get an access token."""
|
||||||
|
mock_token = "mock-access-token"
|
||||||
|
|
||||||
|
with patch.object(authenticator, "_login", return_value=mock_token), \
|
||||||
|
patch("builtins.open", mock_open()), \
|
||||||
|
patch("builtins.open", side_effect=IOError) as mock_read:
|
||||||
|
token = authenticator.get_access_token()
|
||||||
|
assert token == mock_token
|
||||||
|
authenticator._login.assert_called_once()
|
||||||
|
|
||||||
|
def test_get_access_token_failure(self, authenticator):
|
||||||
|
"""Test that an exception is raised after multiple login failures."""
|
||||||
|
with patch.object(authenticator, "_login", side_effect=GetDeviceCodeError("Test error")), \
|
||||||
|
patch("builtins.open", side_effect=IOError):
|
||||||
|
with pytest.raises(GetAccessTokenError):
|
||||||
|
authenticator.get_access_token()
|
||||||
|
assert authenticator._login.call_count == 3
|
||||||
|
|
||||||
|
def test_get_api_key_from_file(self, authenticator):
|
||||||
|
"""Test retrieving an API key from a file."""
|
||||||
|
future_time = (datetime.now() + timedelta(hours=1)).timestamp()
|
||||||
|
mock_api_key_data = json.dumps({"token": "mock-api-key", "expires_at": future_time})
|
||||||
|
|
||||||
|
with patch("builtins.open", mock_open(read_data=mock_api_key_data)):
|
||||||
|
api_key = authenticator.get_api_key()
|
||||||
|
assert api_key == "mock-api-key"
|
||||||
|
|
||||||
|
def test_get_api_key_expired(self, authenticator):
|
||||||
|
"""Test refreshing an expired API key."""
|
||||||
|
past_time = (datetime.now() - timedelta(hours=1)).timestamp()
|
||||||
|
mock_expired_data = json.dumps({"token": "expired-api-key", "expires_at": past_time})
|
||||||
|
mock_new_data = {"token": "new-api-key", "expires_at": (datetime.now() + timedelta(hours=1)).timestamp()}
|
||||||
|
|
||||||
|
with patch("builtins.open", mock_open(read_data=mock_expired_data)), \
|
||||||
|
patch.object(authenticator, "_refresh_api_key", return_value=mock_new_data), \
|
||||||
|
patch("json.dump") as mock_json_dump:
|
||||||
|
api_key = authenticator.get_api_key()
|
||||||
|
assert api_key == "new-api-key"
|
||||||
|
authenticator._refresh_api_key.assert_called_once()
|
||||||
|
|
||||||
|
def test_refresh_api_key(self, authenticator, mock_http_client):
|
||||||
|
"""Test refreshing an API key."""
|
||||||
|
mock_client, mock_response = mock_http_client
|
||||||
|
mock_token = "mock-access-token"
|
||||||
|
mock_api_key_data = {"token": "new-api-key", "expires_at": 12345}
|
||||||
|
|
||||||
|
with patch.object(authenticator, "get_access_token", return_value=mock_token), \
|
||||||
|
patch("litellm.llms.github_copilot.authenticator._get_httpx_client", return_value=mock_client), \
|
||||||
|
patch.object(mock_response, "json", return_value=mock_api_key_data):
|
||||||
|
result = authenticator._refresh_api_key()
|
||||||
|
assert result == mock_api_key_data
|
||||||
|
mock_client.get.assert_called_once()
|
||||||
|
authenticator.get_access_token.assert_called_once()
|
||||||
|
|
||||||
|
def test_refresh_api_key_failure(self, authenticator, mock_http_client):
|
||||||
|
"""Test failure to refresh an API key."""
|
||||||
|
mock_client, mock_response = mock_http_client
|
||||||
|
mock_token = "mock-access-token"
|
||||||
|
|
||||||
|
with patch.object(authenticator, "get_access_token", return_value=mock_token), \
|
||||||
|
patch("litellm.llms.github_copilot.authenticator._get_httpx_client", return_value=mock_client), \
|
||||||
|
patch.object(mock_response, "json", return_value={}):
|
||||||
|
with pytest.raises(RefreshAPIKeyError):
|
||||||
|
authenticator._refresh_api_key()
|
||||||
|
assert mock_client.get.call_count == 3
|
||||||
|
|
||||||
|
def test_get_device_code(self, authenticator, mock_http_client):
|
||||||
|
"""Test getting a device code."""
|
||||||
|
mock_client, mock_response = mock_http_client
|
||||||
|
mock_device_code_data = {
|
||||||
|
"device_code": "mock-device-code",
|
||||||
|
"user_code": "ABCD-EFGH",
|
||||||
|
"verification_uri": "https://github.com/login/device"
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("litellm.llms.github_copilot.authenticator._get_httpx_client", return_value=mock_client), \
|
||||||
|
patch.object(mock_response, "json", return_value=mock_device_code_data):
|
||||||
|
result = authenticator._get_device_code()
|
||||||
|
assert result == mock_device_code_data
|
||||||
|
mock_client.post.assert_called_once()
|
||||||
|
|
||||||
|
def test_poll_for_access_token(self, authenticator, mock_http_client):
|
||||||
|
"""Test polling for an access token."""
|
||||||
|
mock_client, mock_response = mock_http_client
|
||||||
|
mock_token_data = {"access_token": "mock-access-token"}
|
||||||
|
|
||||||
|
with patch("litellm.llms.github_copilot.authenticator._get_httpx_client", return_value=mock_client), \
|
||||||
|
patch.object(mock_response, "json", return_value=mock_token_data), \
|
||||||
|
patch("time.sleep"):
|
||||||
|
result = authenticator._poll_for_access_token("mock-device-code")
|
||||||
|
assert result == "mock-access-token"
|
||||||
|
mock_client.post.assert_called_once()
|
||||||
|
|
||||||
|
def test_login(self, authenticator):
|
||||||
|
"""Test the login process."""
|
||||||
|
mock_device_code_data = {
|
||||||
|
"device_code": "mock-device-code",
|
||||||
|
"user_code": "ABCD-EFGH",
|
||||||
|
"verification_uri": "https://github.com/login/device"
|
||||||
|
}
|
||||||
|
mock_token = "mock-access-token"
|
||||||
|
|
||||||
|
with patch.object(authenticator, "_get_device_code", return_value=mock_device_code_data), \
|
||||||
|
patch.object(authenticator, "_poll_for_access_token", return_value=mock_token), \
|
||||||
|
patch("builtins.print") as mock_print:
|
||||||
|
result = authenticator._login()
|
||||||
|
assert result == mock_token
|
||||||
|
authenticator._get_device_code.assert_called_once()
|
||||||
|
authenticator._poll_for_access_token.assert_called_once_with("mock-device-code")
|
||||||
|
mock_print.assert_called_once()
|
Loading…
Add table
Add a link
Reference in a new issue