LiteLLM Minor Fixes & Improvements (09/20/2024) (#5807)

* fix(vertex_llm_base.py): Handle api_base = ""

Fixes https://github.com/BerriAI/litellm/issues/5798

* fix(o1_transformation.py): handle stream_options not being supported

https://github.com/BerriAI/litellm/issues/5803

* docs(routing.md): fix docs

Closes https://github.com/BerriAI/litellm/issues/5808

* perf(internal_user_endpoints.py): reduce db calls for getting team_alias for a key

Use the list gotten earlier in `/user/info` endpoint

 Reduces ui keys tab load time to 800ms (prev. 28s+)

* feat(proxy_server.py): support CONFIG_FILE_PATH as env var

Closes https://github.com/BerriAI/litellm/issues/5744

* feat(get_llm_provider_logic.py): add `litellm_proxy/` as a known openai-compatible route

simplifies calling litellm proxy

Reduces confusion when calling models on litellm proxy from litellm sdk

* docs(litellm_proxy.md): cleanup docs

* fix(internal_user_endpoints.py): fix pydantic obj

* test(test_key_generate_prisma.py): fix test
This commit is contained in:
Krish Dholakia 2024-09-20 20:21:32 -07:00 committed by GitHub
parent c9ceab0f1e
commit 7ed6938a3f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 204 additions and 84 deletions

View file

@ -9,13 +9,13 @@ import TabItem from '@theme/TabItem';
:::
**[LiteLLM Proxy](../simple_proxy) is OpenAI compatible**, you just need the `openai/` prefix before the model
**[LiteLLM Proxy](../simple_proxy) is OpenAI compatible**, you just need the `litellm_proxy/` prefix before the model
## Required Variables
```python
os.environ["OPENAI_API_KEY"] = "" # "sk-1234" your litellm proxy api key
os.environ["OPENAI_API_BASE"] = "" # "http://localhost:4000" your litellm proxy api base
os.environ["LITELLM_PROXY_API_KEY"] = "" # "sk-1234" your litellm proxy api key
os.environ["LITELLM_PROXY_API_BASE"] = "" # "http://localhost:4000" your litellm proxy api base
```
@ -25,18 +25,18 @@ import os
import litellm
from litellm import completion
os.environ["OPENAI_API_KEY"] = ""
os.environ["LITELLM_PROXY_API_KEY"] = ""
# set custom api base to your proxy
# either set .env or litellm.api_base
# os.environ["OPENAI_API_BASE"] = ""
# os.environ["LITELLM_PROXY_API_BASE"] = ""
litellm.api_base = "your-openai-proxy-url"
messages = [{ "content": "Hello, how are you?","role": "user"}]
# openai call
response = completion(model="openai/your-model-name", messages)
# litellm proxy call
response = completion(model="litellm_proxy/your-model-name", messages)
```
## Usage - passing `api_base`, `api_key` per request
@ -48,13 +48,13 @@ import os
import litellm
from litellm import completion
os.environ["OPENAI_API_KEY"] = ""
os.environ["LITELLM_PROXY_API_KEY"] = ""
messages = [{ "content": "Hello, how are you?","role": "user"}]
# openai call
# litellm proxy call
response = completion(
model="openai/your-model-name",
model="litellm_proxy/your-model-name",
messages,
api_base = "your-litellm-proxy-url",
api_key = "your-litellm-proxy-api-key"
@ -67,13 +67,13 @@ import os
import litellm
from litellm import completion
os.environ["OPENAI_API_KEY"] = ""
os.environ["LITELLM_PROXY_API_KEY"] = ""
messages = [{ "content": "Hello, how are you?","role": "user"}]
# openai call
response = completion(
model="openai/your-model-name",
model="litellm_proxy/your-model-name",
messages,
api_base = "your-litellm-proxy-url",
stream=True

View file

@ -183,9 +183,9 @@ model_list = [{ # list of model deployments
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
"tpm": 100000,
"rpm": 10000,
},
"tpm": 100000,
"rpm": 10000,
}, {
"model_name": "gpt-3.5-turbo",
"litellm_params": { # params for litellm completion/embedding call
@ -193,24 +193,24 @@ model_list = [{ # list of model deployments
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
"tpm": 100000,
"rpm": 1000,
},
"tpm": 100000,
"rpm": 1000,
}, {
"model_name": "gpt-3.5-turbo",
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
"tpm": 100000,
"rpm": 1000,
},
"tpm": 100000,
"rpm": 1000,
}]
router = Router(model_list=model_list,
redis_host=os.environ["REDIS_HOST"],
redis_password=os.environ["REDIS_PASSWORD"],
redis_port=os.environ["REDIS_PORT"],
routing_strategy="usage-based-routing-v2" # 👈 KEY CHANGE
enable_pre_call_check=True, # enables router rate limits for concurrent calls
enable_pre_call_checks=True, # enables router rate limits for concurrent calls
)
response = await router.acompletion(model="gpt-3.5-turbo",

View file

@ -495,6 +495,7 @@ openai_compatible_providers: List = [
"friendliai",
"azure_ai",
"github",
"litellm_proxy",
]
openai_text_completion_compatible_providers: List = (
[ # providers that support `/v1/completions`
@ -748,6 +749,7 @@ class LlmProviders(str, Enum):
EMPOWER = "empower"
GITHUB = "github"
CUSTOM = "custom"
LITELLM_PROXY = "litellm_proxy"
provider_list: List[Union[LlmProviders, str]] = list(LlmProviders)

View file

@ -243,6 +243,9 @@ def get_llm_provider(
elif custom_llm_provider == "github":
api_base = api_base or get_secret("GITHUB_API_BASE") or "https://models.inference.ai.azure.com" # type: ignore
dynamic_api_key = api_key or get_secret("GITHUB_API_KEY")
elif custom_llm_provider == "litellm_proxy":
api_base = api_base or get_secret("LITELLM_PROXY_API_BASE") or "https://models.inference.ai.azure.com" # type: ignore
dynamic_api_key = api_key or get_secret("LITELLM_PROXY_API_KEY")
elif custom_llm_provider == "mistral":
# mistral is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.mistral.ai

View file

@ -65,6 +65,7 @@ class OpenAIO1Config(OpenAIGPTConfig):
"top_logprobs",
"response_format",
"stop",
"stream_options",
]
return [

View file

@ -16,6 +16,7 @@ from litellm.types.llms.vertex_ai import (
from litellm.utils import ModelResponse
from ..common_utils import VertexAIError
from ..vertex_llm_base import VertexBase
from .transformation import (
separate_cached_messages,
transform_openai_messages_to_gemini_context_caching,
@ -24,7 +25,7 @@ from .transformation import (
local_cache_obj = Cache(type="local") # only used for calling 'get_cache_key' function
class ContextCachingEndpoints:
class ContextCachingEndpoints(VertexBase):
"""
Covers context caching endpoints for Vertex AI + Google AI Studio
@ -34,7 +35,7 @@ class ContextCachingEndpoints:
def __init__(self) -> None:
pass
def _get_token_and_url(
def _get_token_and_url_context_caching(
self,
gemini_api_key: Optional[str],
custom_llm_provider: Literal["gemini"],
@ -57,18 +58,16 @@ class ContextCachingEndpoints:
else:
raise NotImplementedError
if (
api_base is not None
): # for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317
if custom_llm_provider == "gemini":
url = "{}/{}".format(api_base, endpoint)
auth_header = (
gemini_api_key # cloudflare expects api key as bearer token
)
else:
url = "{}:{}".format(api_base, endpoint)
return auth_header, url
return self._check_custom_proxy(
api_base=api_base,
custom_llm_provider=custom_llm_provider,
gemini_api_key=gemini_api_key,
endpoint=endpoint,
stream=None,
auth_header=auth_header,
url=url,
)
def check_cache(
self,
@ -90,7 +89,7 @@ class ContextCachingEndpoints:
- None
"""
_, url = self._get_token_and_url(
_, url = self._get_token_and_url_context_caching(
gemini_api_key=api_key,
custom_llm_provider="gemini",
api_base=api_base,
@ -125,11 +124,12 @@ class ContextCachingEndpoints:
all_cached_items = CachedContentListAllResponseBody(**raw_response)
if "cachedContents" not in all_cached_items:
return None
for cached_item in all_cached_items["cachedContents"]:
if (
cached_item.get("displayName") is not None
and cached_item["displayName"] == cache_key
):
display_name = cached_item.get("displayName")
if display_name is not None and display_name == cache_key:
return cached_item.get("name")
return None
@ -154,7 +154,7 @@ class ContextCachingEndpoints:
- None
"""
_, url = self._get_token_and_url(
_, url = self._get_token_and_url_context_caching(
gemini_api_key=api_key,
custom_llm_provider="gemini",
api_base=api_base,
@ -189,11 +189,12 @@ class ContextCachingEndpoints:
all_cached_items = CachedContentListAllResponseBody(**raw_response)
if "cachedContents" not in all_cached_items:
return None
for cached_item in all_cached_items["cachedContents"]:
if (
cached_item.get("displayName") is not None
and cached_item["displayName"] == cache_key
):
display_name = cached_item.get("displayName")
if display_name is not None and display_name == cache_key:
return cached_item.get("name")
return None
@ -224,7 +225,7 @@ class ContextCachingEndpoints:
return messages, cached_content
## AUTHORIZATION ##
token, url = self._get_token_and_url(
token, url = self._get_token_and_url_context_caching(
gemini_api_key=api_key,
custom_llm_provider="gemini",
api_base=api_base,
@ -329,7 +330,7 @@ class ContextCachingEndpoints:
return messages, cached_content
## AUTHORIZATION ##
token, url = self._get_token_and_url(
token, url = self._get_token_and_url_context_caching(
gemini_api_key=api_key,
custom_llm_provider="gemini",
api_base=api_base,

View file

@ -86,7 +86,7 @@ class VertexBase(BaseLLM):
)
if project_id is None:
project_id = creds.project_id
project_id = getattr(creds, "project_id", None)
else:
creds, creds_project_id = google_auth.default(
quota_project_id=project_id,
@ -95,7 +95,7 @@ class VertexBase(BaseLLM):
if project_id is None:
project_id = creds_project_id
creds.refresh(Request())
creds.refresh(Request()) # type: ignore
if not project_id:
raise ValueError("Could not resolve project_id")
@ -169,6 +169,39 @@ class VertexBase(BaseLLM):
return True
return False
def _check_custom_proxy(
self,
api_base: Optional[str],
custom_llm_provider: str,
gemini_api_key: Optional[str],
endpoint: str,
stream: Optional[bool],
auth_header: Optional[str],
url: str,
) -> Tuple[Optional[str], str]:
"""
for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317
## Returns
- (auth_header, url) - Tuple[Optional[str], str]
"""
if api_base:
if custom_llm_provider == "gemini":
url = "{}:{}".format(api_base, endpoint)
if gemini_api_key is None:
raise ValueError(
"Missing gemini_api_key, please set `GEMINI_API_KEY`"
)
auth_header = (
gemini_api_key # cloudflare expects api key as bearer token
)
else:
url = "{}:{}".format(api_base, endpoint)
if stream is True:
url = url + "?alt=sse"
return auth_header, url
def _get_token_and_url(
self,
model: str,
@ -215,25 +248,15 @@ class VertexBase(BaseLLM):
vertex_api_version=version,
)
if (
api_base is not None
): # for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317
if custom_llm_provider == "gemini":
url = "{}:{}".format(api_base, endpoint)
if gemini_api_key is None:
raise ValueError(
"Missing gemini_api_key, please set `GEMINI_API_KEY`"
)
auth_header = (
gemini_api_key # cloudflare expects api key as bearer token
)
else:
url = "{}:{}".format(api_base, endpoint)
if stream is True:
url = url + "?alt=sse"
return auth_header, url
return self._check_custom_proxy(
api_base=api_base,
auth_header=auth_header,
custom_llm_provider=custom_llm_provider,
gemini_api_key=gemini_api_key,
endpoint=endpoint,
stream=stream,
url=url,
)
async def _ensure_access_token_async(
self,

View file

@ -1422,6 +1422,13 @@ class UserAPIKeyAuth(
arbitrary_types_allowed = True
class UserInfoResponse(LiteLLMBase):
user_id: Optional[str]
user_info: Optional[Union[dict, BaseModel]]
keys: List
teams: List
class LiteLLM_Config(LiteLLMBase):
param_name: str
param_value: Dict

View file

@ -15,6 +15,7 @@ import copy
import json
import re
import secrets
import time
import traceback
import uuid
from datetime import datetime, timedelta, timezone
@ -274,10 +275,23 @@ async def ui_get_available_role(
return _data_to_return
def get_team_from_list(
team_list: Optional[List[LiteLLM_TeamTable]], team_id: str
) -> Optional[LiteLLM_TeamTable]:
if team_list is None:
return None
for team in team_list:
if team.team_id == team_id:
return team
return None
@router.get(
"/user/info",
tags=["Internal User management"],
dependencies=[Depends(user_api_key_auth)],
response_model=UserInfoResponse,
)
@management_endpoint_wrapper
async def user_info(
@ -337,7 +351,7 @@ async def user_info(
## GET ALL TEAMS ##
team_list = []
team_id_list = []
# _DEPRECATED_ check if user in 'member' field
# get all teams user belongs to
teams_1 = await prisma_client.get_data(
user_id=user_id, table_name="team", query_type="find_all"
)
@ -414,13 +428,13 @@ async def user_info(
if (
key.token == litellm_master_key_hash
and general_settings.get("disable_master_key_return", False)
== True ## [IMPORTANT] used by hosted proxy-ui to prevent sharing master key on ui
is True ## [IMPORTANT] used by hosted proxy-ui to prevent sharing master key on ui
):
continue
try:
key = key.model_dump() # noqa
except:
except Exception:
# if using pydantic v1
key = key.dict()
if (
@ -428,29 +442,29 @@ async def user_info(
and key["team_id"] is not None
and key["team_id"] != "litellm-dashboard"
):
team_info = await prisma_client.get_data(
team_id=key["team_id"], table_name="team"
team_info = get_team_from_list(
team_list=teams_1, team_id=key["team_id"]
)
team_alias = getattr(team_info, "team_alias", None)
key["team_alias"] = team_alias
if team_info is not None:
team_alias = getattr(team_info, "team_alias", None)
key["team_alias"] = team_alias
else:
key["team_alias"] = None
else:
key["team_alias"] = "None"
returned_keys.append(key)
response_data = {
"user_id": user_id,
"user_info": user_info,
"keys": returned_keys,
"teams": team_list,
}
response_data = UserInfoResponse(
user_id=user_id, user_info=user_info, keys=returned_keys, teams=team_list
)
return response_data
except Exception as e:
verbose_proxy_logger.error(
verbose_proxy_logger.exception(
"litellm.proxy.proxy_server.user_info(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Authentication Error({str(e)})"),

View file

@ -248,7 +248,7 @@ from litellm.secret_managers.aws_secret_manager import (
load_aws_secret_manager,
)
from litellm.secret_managers.google_kms import load_google_kms
from litellm.secret_managers.main import get_secret, str_to_bool
from litellm.secret_managers.main import get_secret, get_secret_str, str_to_bool
from litellm.types.llms.anthropic import (
AnthropicMessagesRequest,
AnthropicResponse,
@ -2728,9 +2728,21 @@ async def startup_event():
### LOAD CONFIG ###
worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore
env_config_yaml: Optional[str] = get_secret_str("CONFIG_FILE_PATH")
verbose_proxy_logger.debug("worker_config: %s", worker_config)
# check if it's a valid file path
if worker_config is not None:
if env_config_yaml is not None:
if os.path.isfile(env_config_yaml) and proxy_config.is_yaml(
config_file_path=env_config_yaml
):
(
llm_router,
llm_model_list,
general_settings,
) = await proxy_config.load_config(
router=llm_router, config_file_path=env_config_yaml
)
elif worker_config is not None:
if (
isinstance(worker_config, str)
and os.path.isfile(worker_config)

View file

@ -31,6 +31,7 @@ from litellm import (
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
_gemini_convert_messages_with_history,
)
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_llm_base import VertexBase
from litellm.tests.test_streaming import streaming_format_tests
litellm.num_retries = 3
@ -2869,3 +2870,24 @@ def test_gemini_finetuned_endpoint(base_model, metadata):
assert mock_client.call_args.kwargs["url"].endswith(
"endpoints/4965075652664360960:generateContent"
)
@pytest.mark.parametrize("api_base", ["", None, "my-custom-proxy-base"])
def test_custom_api_base(api_base):
stream = None
test_endpoint = "my-fake-endpoint"
vertex_base = VertexBase()
auth_header, url = vertex_base._check_custom_proxy(
api_base=api_base,
custom_llm_provider="gemini",
gemini_api_key="12324",
endpoint="",
stream=stream,
auth_header=None,
url="my-fake-endpoint",
)
if api_base:
assert url == api_base + ":"
else:
assert url == test_endpoint

View file

@ -1933,6 +1933,41 @@ async def test_openai_compatible_custom_api_base(provider):
assert "hello" in mock_call.call_args.kwargs["extra_body"]
@pytest.mark.asyncio
async def test_litellm_gateway_from_sdk():
litellm.set_verbose = True
messages = [
{
"role": "user",
"content": "Hello world",
}
]
from openai import OpenAI
openai_client = OpenAI(api_key="fake-key")
with patch.object(
openai_client.chat.completions, "create", new=MagicMock()
) as mock_call:
try:
completion(
model="litellm_proxy/my-vllm-model",
messages=messages,
response_format={"type": "json_object"},
client=openai_client,
api_base="my-custom-api-base",
hello="world",
)
except Exception as e:
print(e)
mock_call.assert_called_once()
print("Call KWARGS - {}".format(mock_call.call_args.kwargs))
assert "hello" in mock_call.call_args.kwargs["extra_body"]
# ################### Hugging Face Conversational models ########################
# def hf_test_completion_conv():
# try:

View file

@ -256,7 +256,7 @@ def test_generate_and_call_with_valid_key(prisma_client, api_route):
# check /user/info to verify user_role was set correctly
new_user_info = await user_info(user_id=user_id)
new_user_info = new_user_info["user_info"]
new_user_info = new_user_info.user_info
print("new_user_info=", new_user_info)
assert new_user_info.user_role == LitellmUserRoles.INTERNAL_USER
assert new_user_info.user_id == user_id

View file

@ -2154,7 +2154,7 @@ def test_openai_chat_completion_complete_response_call():
# test_openai_chat_completion_complete_response_call()
@pytest.mark.parametrize(
"model",
["gpt-3.5-turbo", "azure/chatgpt-v-2", "claude-3-haiku-20240307"], #
["gpt-3.5-turbo", "azure/chatgpt-v-2", "claude-3-haiku-20240307", "o1-preview"], #
)
@pytest.mark.parametrize(
"sync",