mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
feat: add support for copilot provider
This commit is contained in:
parent
d918b089c6
commit
e394d45513
9 changed files with 296 additions and 1 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -77,3 +77,5 @@ litellm/proxy/_experimental/out/404.html
|
|||
litellm/proxy/_experimental/out/model_hub.html
|
||||
.mypy_cache/*
|
||||
litellm/proxy/application.log
|
||||
|
||||
**/__pycache__
|
|
@ -969,6 +969,7 @@ from .llms.azure.chat.o_series_transformation import AzureOpenAIO1Config
|
|||
from .llms.watsonx.completion.transformation import IBMWatsonXAIConfig
|
||||
from .llms.watsonx.chat.transformation import IBMWatsonXChatConfig
|
||||
from .llms.watsonx.embed.transformation import IBMWatsonXEmbeddingConfig
|
||||
from .llms.github_copilot.chat.transformation import GithubCopilotConfig
|
||||
from .main import * # type: ignore
|
||||
from .integrations import *
|
||||
from .exceptions import (
|
||||
|
|
|
@ -79,6 +79,7 @@ LITELLM_CHAT_PROVIDERS = [
|
|||
"hosted_vllm",
|
||||
"lm_studio",
|
||||
"galadriel",
|
||||
"github_copilot", # GitHub Copilot Chat API
|
||||
]
|
||||
|
||||
|
||||
|
@ -137,7 +138,7 @@ openai_compatible_endpoints: List = [
|
|||
"https://api.friendli.ai/serverless/v1",
|
||||
"api.sambanova.ai/v1",
|
||||
"api.x.ai/v1",
|
||||
"api.galadriel.ai/v1",
|
||||
"api.galadriel.ai/v1"
|
||||
]
|
||||
|
||||
|
||||
|
@ -167,6 +168,7 @@ openai_compatible_providers: List = [
|
|||
"hosted_vllm",
|
||||
"lm_studio",
|
||||
"galadriel",
|
||||
"github_copilot", # GitHub Copilot Chat API
|
||||
]
|
||||
openai_text_completion_compatible_providers: List = (
|
||||
[ # providers that support `/v1/completions`
|
||||
|
|
|
@ -571,6 +571,14 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
|
|||
or "https://api.galadriel.com/v1"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret_str("GALADRIEL_API_KEY")
|
||||
elif custom_llm_provider == "github_copilot":
|
||||
(
|
||||
api_base,
|
||||
dynamic_api_key,
|
||||
custom_llm_provider,
|
||||
) = litellm.GithubCopilotConfig()._get_openai_compatible_provider_info(
|
||||
model, api_base, api_key, custom_llm_provider
|
||||
)
|
||||
if api_base is not None and not isinstance(api_base, str):
|
||||
raise Exception("api base needs to be a string. api_base={}".format(api_base))
|
||||
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):
|
||||
|
|
10
litellm/llms/github_copilot/__init__.py
Normal file
10
litellm/llms/github_copilot/__init__.py
Normal file
|
@ -0,0 +1,10 @@
|
|||
from .constants import GITHUB_COPILOT_API_BASE, CHAT_COMPLETION_ENDPOINT, GITHUB_COPILOT_MODEL, GetAccessTokenError, GetAPIKeyError, RefreshAPIKeyError
|
||||
|
||||
__all__ = [
|
||||
"GITHUB_COPILOT_API_BASE",
|
||||
"CHAT_COMPLETION_ENDPOINT",
|
||||
"GITHUB_COPILOT_MODEL",
|
||||
"GetAccessTokenError",
|
||||
"GetAPIKeyError",
|
||||
"RefreshAPIKeyError"
|
||||
]
|
217
litellm/llms/github_copilot/authenticator.py
Normal file
217
litellm/llms/github_copilot/authenticator.py
Normal file
|
@ -0,0 +1,217 @@
|
|||
import os
|
||||
|
||||
import httpx
|
||||
import json
|
||||
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.caching import InMemoryCache
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
|
||||
from .constants import (
|
||||
GetAccessTokenError,
|
||||
GetDeviceCodeError,
|
||||
RefreshAPIKeyError,
|
||||
GetAPIKeyError,
|
||||
APIKeyExpiredError,
|
||||
)
|
||||
|
||||
class Authenticator:
|
||||
def __init__(self) -> None:
|
||||
# Token storage paths
|
||||
self.token_dir = os.getenv("GITHUB_COPILOT_TOKEN_DIR", os.path.expanduser("~/.config/litellm/github_copilot"))
|
||||
self.access_token_file = os.path.join(self.token_dir, os.getenv("GITHUB_COPILOT_ACCESS_TOKEN_FILE", "access-token"))
|
||||
self.api_key_file = os.path.join(self.token_dir, os.getenv("GITHUB_COPILOT_API_KEY_FILE", "api-key.json"))
|
||||
self._ensure_token_dir()
|
||||
|
||||
self.get_access_token()
|
||||
|
||||
def get_access_token(self) -> str:
|
||||
"""
|
||||
Login to Copilot with retry 3 times
|
||||
|
||||
Returns:
|
||||
access_token: str
|
||||
|
||||
"""
|
||||
try:
|
||||
with open(self.access_token_file, 'r') as f:
|
||||
access_token = f.read().strip()
|
||||
return access_token
|
||||
except IOError:
|
||||
verbose_logger.warning("Error loading access token from file")
|
||||
|
||||
for _ in range(3):
|
||||
try:
|
||||
access_token = self._login()
|
||||
except GetDeviceCodeError | GetAccessTokenError | RefreshAPIKeyError:
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
with open(self.access_token_file, 'w') as f:
|
||||
f.write(access_token)
|
||||
except IOError:
|
||||
verbose_logger.error("Error saving access token to file")
|
||||
return access_token
|
||||
|
||||
raise GetAccessTokenError("Failed to get access token")
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
"""Get the API key"""
|
||||
try:
|
||||
with open(self.api_key_file, 'r') as f:
|
||||
api_key_info = json.load(f)
|
||||
if api_key_info.get('expires_at') > datetime.now().timestamp():
|
||||
return api_key_info.get('token')
|
||||
else:
|
||||
raise APIKeyExpiredError("API key expired")
|
||||
except IOError:
|
||||
verbose_logger.warning("Error opening API key file")
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
verbose_logger.warning("Error reading API key from file")
|
||||
except APIKeyExpiredError:
|
||||
verbose_logger.warning("API key expired")
|
||||
|
||||
try:
|
||||
api_key_info = self._refresh_api_key()
|
||||
with open(self.api_key_file, 'w') as f:
|
||||
json.dump(api_key_info, f)
|
||||
except IOError:
|
||||
verbose_logger.error("Error saving API key to file")
|
||||
except RefreshAPIKeyError:
|
||||
raise GetAPIKeyError("Failed to refresh API key")
|
||||
|
||||
return api_key_info.get('token')
|
||||
|
||||
def _refresh_api_key(self) -> dict:
|
||||
"""
|
||||
Refresh the API key using the access token
|
||||
|
||||
Returns:
|
||||
api_key_info: dict
|
||||
"""
|
||||
|
||||
access_token = self.get_access_token()
|
||||
headers = {
|
||||
'authorization': f'token {access_token}',
|
||||
'editor-version': 'vscode/1.85.1',
|
||||
'editor-plugin-version': 'copilot/1.155.0',
|
||||
'user-agent': 'GithubCopilot/1.155.0'
|
||||
}
|
||||
|
||||
max_retries = 3
|
||||
for _ in range(max_retries):
|
||||
try:
|
||||
sync_client = _get_httpx_client()
|
||||
response = sync_client.get(
|
||||
'https://api.github.com/copilot_internal/v2/token',
|
||||
headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
response_json = response.json()
|
||||
|
||||
if 'token' in response_json:
|
||||
return response_json
|
||||
except httpx.HTTPStatusError as e:
|
||||
verbose_logger.error(f"Error refreshing API key: {str(e)}")
|
||||
|
||||
raise RefreshAPIKeyError("Failed to refresh API key")
|
||||
|
||||
|
||||
def _ensure_token_dir(self) -> None:
|
||||
"""Ensure the token directory exists"""
|
||||
if not os.path.exists(self.token_dir):
|
||||
os.makedirs(self.token_dir, exist_ok=True)
|
||||
|
||||
def _login(self) -> str:
|
||||
"""
|
||||
Login to GitHub Copilot using device code flow
|
||||
|
||||
Returns:
|
||||
access_token: str
|
||||
"""
|
||||
|
||||
try:
|
||||
sync_client = _get_httpx_client()
|
||||
# Get device code
|
||||
resp = sync_client.post(
|
||||
'https://github.com/login/device/code',
|
||||
headers={
|
||||
'accept': 'application/json',
|
||||
'editor-version': 'vscode/1.85.1',
|
||||
'editor-plugin-version': 'copilot/1.155.0',
|
||||
'content-type': 'application/json',
|
||||
'user-agent': 'GithubCopilot/1.155.0',
|
||||
'accept-encoding': 'gzip,deflate,br'
|
||||
},
|
||||
json={"client_id": "Iv1.b507a08c87ecfe98", "scope": "read:user"}
|
||||
)
|
||||
resp.raise_for_status()
|
||||
resp_json = resp.json()
|
||||
|
||||
device_code = resp_json.get('device_code')
|
||||
user_code = resp_json.get('user_code')
|
||||
verification_uri = resp_json.get('verification_uri')
|
||||
|
||||
if not all([device_code, user_code, verification_uri]):
|
||||
verbose_logger.error("Response missing required fields")
|
||||
return None
|
||||
except httpx.HttpError as e:
|
||||
verbose_logger.error(f"Error getting device code: {str(e)}")
|
||||
raise GetDeviceCodeError("Failed to get device code")
|
||||
except json.JSONDecodeError as e:
|
||||
verbose_logger.error(f"Error decoding JSON response: {str(e)}")
|
||||
raise GetDeviceCodeError("Failed to get device code")
|
||||
except RuntimeError as e:
|
||||
verbose_logger.error(f"Error getting device code: {str(e)}")
|
||||
raise GetDeviceCodeError("Failed to get device code")
|
||||
|
||||
print(f'Please visit {verification_uri} and enter code {user_code} to authenticate.')
|
||||
|
||||
while True:
|
||||
time.sleep(5)
|
||||
|
||||
# Get access token
|
||||
try:
|
||||
resp = sync_client.post(
|
||||
'https://github.com/login/oauth/access_token',
|
||||
headers={
|
||||
'accept': 'application/json',
|
||||
'editor-version': 'vscode/1.85.1',
|
||||
'editor-plugin-version': 'copilot/1.155.0',
|
||||
'content-type': 'application/json',
|
||||
'user-agent': 'GithubCopilot/1.155.0',
|
||||
'accept-encoding': 'gzip,deflate,br'
|
||||
},
|
||||
json={
|
||||
"client_id": "Iv1.b507a08c87ecfe98",
|
||||
"device_code": device_code,
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:device_code"
|
||||
}
|
||||
)
|
||||
resp.raise_for_status()
|
||||
resp_json = resp.json()
|
||||
|
||||
if "access_token" in resp_json:
|
||||
verbose_logger.info("Authentication success!")
|
||||
return resp_json['access_token']
|
||||
else:
|
||||
continue
|
||||
except httpx.HTTPStatusError as e:
|
||||
verbose_logger.error(f"Error getting access token: {str(e)}")
|
||||
raise GetAccessTokenError("Failed to get access token")
|
||||
except json.JSONDecodeError as e:
|
||||
verbose_logger.error(f"Error decoding JSON response: {str(e)}")
|
||||
raise GetAccessTokenError("Failed to get access token")
|
||||
|
||||
|
||||
|
||||
|
29
litellm/llms/github_copilot/chat/transformation.py
Normal file
29
litellm/llms/github_copilot/chat/transformation.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
from litellm.llms.openai.openai import OpenAIConfig
|
||||
|
||||
from ..authenticator import Authenticator
|
||||
|
||||
class GithubCopilotConfig(OpenAIConfig):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
custom_llm_provider: str = "openai",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.authenticator = Authenticator()
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
custom_llm_provider: str,
|
||||
) -> Tuple[Optional[str], Optional[str], str]:
|
||||
api_base = "https://api.githubcopilot.com"
|
||||
dynamic_api_key = self.authenticator.get_api_key()
|
||||
|
||||
return api_base, dynamic_api_key, custom_llm_provider
|
25
litellm/llms/github_copilot/constants.py
Normal file
25
litellm/llms/github_copilot/constants.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
"""
|
||||
Constants for Copilot integration
|
||||
"""
|
||||
import os
|
||||
|
||||
# Copilot API endpoints
|
||||
GITHUB_COPILOT_API_BASE = "https://api.github.com/copilot/v1"
|
||||
CHAT_COMPLETION_ENDPOINT = "/chat/completions"
|
||||
|
||||
# Model names
|
||||
GITHUB_COPILOT_MODEL = "gpt-4o" # The model identifier for Copilot
|
||||
|
||||
# Request headers
|
||||
DEFAULT_HEADERS = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"User-Agent": "litellm",
|
||||
}
|
||||
|
||||
class GetDeviceCodeError(Exception): pass
|
||||
class GetAccessTokenError(Exception): pass
|
||||
class APIKeyExpiredError(Exception): pass
|
||||
class RefreshAPIKeyError(Exception): pass
|
||||
class GetAPIKeyError(Exception): pass
|
||||
|
|
@ -1874,6 +1874,7 @@ class LlmProviders(str, Enum):
|
|||
HUMANLOOP = "humanloop"
|
||||
TOPAZ = "topaz"
|
||||
ASSEMBLYAI = "assemblyai"
|
||||
GITHUB_COPILOT = "github_copilot"
|
||||
|
||||
|
||||
# Create a set of all provider values for quick lookup
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue