Merge branch 'main' into litellm_no_store_cache_control

This commit is contained in:
Krish Dholakia 2024-01-30 21:44:57 -08:00 committed by GitHub
commit ce415a243d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 266 additions and 31 deletions

View file

@ -31,6 +31,18 @@ general_settings:
## 2. Setup SSO/Auth for UI ## 2. Setup SSO/Auth for UI
<Tabs> <Tabs>
<TabItem value="username" label="Quick Start - Username, Password">
Set the following in your .env on the Proxy
```shell
UI_USERNAME=ishaan-litellm
UI_PASSWORD=langchain
```
On accessing the LiteLLM UI, you will be prompted to enter your username, password
</TabItem>
<TabItem value="google" label="Google SSO"> <TabItem value="google" label="Google SSO">
@ -73,6 +85,7 @@ MICROSOFT_TENANT="5a39737
``` ```
</TabItem> </TabItem>
</Tabs> </Tabs>
## 4. Use UI ## 4. Use UI

View file

@ -76,6 +76,7 @@ from litellm.proxy.utils import (
get_logging_payload, get_logging_payload,
reset_budget, reset_budget,
hash_token, hash_token,
html_form,
) )
from litellm.proxy.secret_managers.google_kms import load_google_kms from litellm.proxy.secret_managers.google_kms import load_google_kms
import pydantic import pydantic
@ -94,6 +95,7 @@ from fastapi import (
BackgroundTasks, BackgroundTasks,
Header, Header,
Response, Response,
Form,
) )
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
@ -1268,7 +1270,7 @@ async def generate_key_helper_fn(
key_alias: Optional[str] = None, key_alias: Optional[str] = None,
allowed_cache_controls: Optional[list] = [], allowed_cache_controls: Optional[list] = [],
): ):
global prisma_client, custom_db_client global prisma_client, custom_db_client, user_api_key_cache
if prisma_client is None and custom_db_client is None: if prisma_client is None and custom_db_client is None:
raise Exception( raise Exception(
@ -1361,6 +1363,18 @@ async def generate_key_helper_fn(
} }
if general_settings.get("allow_user_auth", False) == True: if general_settings.get("allow_user_auth", False) == True:
key_data["key_name"] = f"sk-...{token[-4:]}" key_data["key_name"] = f"sk-...{token[-4:]}"
saved_token = copy.deepcopy(key_data)
if isinstance(saved_token["aliases"], str):
saved_token["aliases"] = json.loads(saved_token["aliases"])
if isinstance(saved_token["config"], str):
saved_token["config"] = json.loads(saved_token["config"])
if isinstance(saved_token["metadata"], str):
saved_token["metadata"] = json.loads(saved_token["metadata"])
user_api_key_cache.set_cache(
key=key_data["token"],
value=LiteLLM_VerificationToken(**saved_token), # type: ignore
ttl=60,
)
if prisma_client is not None: if prisma_client is not None:
## CREATE USER (If necessary) ## CREATE USER (If necessary)
verbose_proxy_logger.debug(f"prisma_client: Creating User={user_data}") verbose_proxy_logger.debug(f"prisma_client: Creating User={user_data}")
@ -1675,7 +1689,8 @@ async def startup_event():
if prisma_client is not None and master_key is not None: if prisma_client is not None and master_key is not None:
# add master key to db # add master key to db
await generate_key_helper_fn( asyncio.create_task(
generate_key_helper_fn(
duration=None, duration=None,
models=[], models=[],
aliases={}, aliases={},
@ -1684,6 +1699,7 @@ async def startup_event():
token=master_key, token=master_key,
user_id="default_user_id", user_id="default_user_id",
) )
)
if prisma_client is not None and litellm.max_budget > 0: if prisma_client is not None and litellm.max_budget > 0:
if litellm.budget_duration is None: if litellm.budget_duration is None:
@ -1692,7 +1708,8 @@ async def startup_event():
) )
# add proxy budget to db in the user table # add proxy budget to db in the user table
await generate_key_helper_fn( asyncio.create_task(
generate_key_helper_fn(
user_id=litellm_proxy_budget_name, user_id=litellm_proxy_budget_name,
duration=None, duration=None,
models=[], models=[],
@ -1707,6 +1724,7 @@ async def startup_event():
"budget_duration": litellm.budget_duration, "budget_duration": litellm.budget_duration,
}, },
) )
)
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"custom_db_client client {custom_db_client}. Master_key: {master_key}" f"custom_db_client client {custom_db_client}. Master_key: {master_key}"
@ -2962,6 +2980,60 @@ async def google_login(request: Request):
) )
with microsoft_sso: with microsoft_sso:
return await microsoft_sso.get_login_redirect() return await microsoft_sso.get_login_redirect()
else:
# No Google, Microsoft SSO
# Use UI Credentials set in .env
from fastapi.responses import HTMLResponse
return HTMLResponse(content=html_form, status_code=200)
@router.post(
"/login", include_in_schema=False
) # hidden since this is a helper for UI sso login
async def login(request: Request):
try:
import multipart
except ImportError:
subprocess.run(["pip", "install", "python-multipart"])
form = await request.form()
username = str(form.get("username"))
password = form.get("password")
ui_username = os.getenv("UI_USERNAME")
ui_password = os.getenv("UI_PASSWORD")
if username == ui_username and password == ui_password:
user_id = username
response = await generate_key_helper_fn(
**{"duration": "24hr", "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_id, "team_id": "litellm-dashboard"} # type: ignore
)
key = response["token"] # type: ignore
user_id = response["user_id"] # type: ignore
litellm_dashboard_ui = "https://litellm-dashboard.vercel.app/"
# if user set LITELLM_UI_LINK in .env, use that
litellm_ui_link_in_env = os.getenv("LITELLM_UI_LINK", None)
if litellm_ui_link_in_env is not None:
litellm_dashboard_ui = litellm_ui_link_in_env
litellm_dashboard_ui += (
"?userID="
+ user_id
+ "&accessToken="
+ key
+ "&proxyBaseUrl="
+ os.getenv("PROXY_BASE_URL")
)
return RedirectResponse(url=litellm_dashboard_ui)
else:
raise ProxyException(
message=f"Invalid credentials used to access UI. Passed in username: {username}, passed in password: {password}.\nCheck 'UI_USERNAME', 'UI_PASSWORD' in .env file",
type="auth_error",
param="invalid_credentials",
code=status.HTTP_401_UNAUTHORIZED,
)
@app.get("/sso/callback", tags=["experimental"]) @app.get("/sso/callback", tags=["experimental"])

View file

@ -1211,3 +1211,67 @@ async def reset_budget(prisma_client: PrismaClient):
await prisma_client.update_data( await prisma_client.update_data(
query_type="update_many", data_list=users_to_reset, table_name="user" query_type="update_many", data_list=users_to_reset, table_name="user"
) )
# LiteLLM Admin UI - Non SSO Login
html_form = """
<!DOCTYPE html>
<html>
<head>
<title>LiteLLM Login</title>
<style>
body {
font-family: Arial, sans-serif;
background-color: #f4f4f4;
margin: 0;
padding: 0;
display: flex;
justify-content: center;
align-items: center;
height: 100vh;
}
form {
background-color: #fff;
padding: 20px;
border-radius: 8px;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
}
label {
display: block;
margin-bottom: 8px;
}
input {
width: 100%;
padding: 8px;
margin-bottom: 16px;
box-sizing: border-box;
border: 1px solid #ccc;
border-radius: 4px;
}
input[type="submit"] {
background-color: #4caf50;
color: #fff;
cursor: pointer;
}
input[type="submit"]:hover {
background-color: #45a049;
}
</style>
</head>
<body>
<form action="/login" method="post">
<h2>LiteLLM Login</h2>
<label for="username">Username:</label>
<input type="text" id="username" name="username" required>
<label for="password">Password:</label>
<input type="password" id="password" name="password" required>
<input type="submit" value="Submit">
</form>
</body>
</html>
"""

View file

@ -289,11 +289,7 @@ class Router:
timeout = kwargs.get("request_timeout", self.timeout) timeout = kwargs.get("request_timeout", self.timeout)
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.setdefault("metadata", {}).update({"model_group": model}) kwargs.setdefault("metadata", {}).update({"model_group": model})
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: response = self.function_with_fallbacks(**kwargs)
# Submit the function to the executor with a timeout
future = executor.submit(self.function_with_fallbacks, **kwargs)
response = future.result(timeout=timeout) # type: ignore
return response return response
except Exception as e: except Exception as e:
raise e raise e

View file

@ -0,0 +1,87 @@
#### What this tests ####
# This tests if the router timeout error handling during fallbacks
import sys, os, time
import traceback, asyncio
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import os
import litellm
from litellm import Router
from dotenv import load_dotenv
load_dotenv()
def test_router_timeouts():
# Model list for OpenAI and Anthropic models
model_list = [
{
"model_name": "openai-gpt-4",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": "os.environ/AZURE_API_KEY",
"api_base": "os.environ/AZURE_API_BASE",
"api_version": "os.environ/AZURE_API_VERSION",
},
"tpm": 80000,
},
{
"model_name": "anthropic-claude-instant-1.2",
"litellm_params": {
"model": "claude-instant-1",
"api_key": "os.environ/ANTHROPIC_API_KEY",
},
"tpm": 20000,
},
]
fallbacks_list = [
{"openai-gpt-4": ["anthropic-claude-instant-1.2"]},
]
# Configure router
router = Router(
model_list=model_list,
fallbacks=fallbacks_list,
routing_strategy="usage-based-routing",
debug_level="INFO",
set_verbose=True,
redis_host=os.getenv("REDIS_HOST"),
redis_password=os.getenv("REDIS_PASSWORD"),
redis_port=int(os.getenv("REDIS_PORT")),
timeout=10,
)
print("***** TPM SETTINGS *****")
for model_object in model_list:
print(f"{model_object['model_name']}: {model_object['tpm']} TPM")
# Sample list of questions
questions_list = [
{"content": "Tell me a very long joke.", "modality": "voice"},
]
total_tokens_used = 0
# Process each question
for question in questions_list:
messages = [{"content": question["content"], "role": "user"}]
prompt_tokens = litellm.token_counter(text=question["content"], model="gpt-4")
print("prompt_tokens = ", prompt_tokens)
response = router.completion(
model="openai-gpt-4", messages=messages, timeout=5, num_retries=0
)
total_tokens_used += response.usage.total_tokens
print("Response:", response)
print("********** TOKENS USED SO FAR = ", total_tokens_used)

View file

@ -7490,7 +7490,10 @@ class CustomStreamWrapper:
logprobs = None logprobs = None
original_chunk = None # this is used for function/tool calling original_chunk = None # this is used for function/tool calling
if len(str_line.choices) > 0: if len(str_line.choices) > 0:
if str_line.choices[0].delta.content is not None: if (
str_line.choices[0].delta is not None
and str_line.choices[0].delta.content is not None
):
text = str_line.choices[0].delta.content text = str_line.choices[0].delta.content
else: # function/tool calling chunk - when content is None. in this case we just return the original chunk from openai else: # function/tool calling chunk - when content is None. in this case we just return the original chunk from openai
original_chunk = str_line original_chunk = str_line

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "1.20.6" version = "1.20.7"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT" license = "MIT"
@ -63,7 +63,7 @@ requires = ["poetry-core", "wheel"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.commitizen] [tool.commitizen]
version = "1.20.6" version = "1.20.7"
version_files = [ version_files = [
"pyproject.toml:^version" "pyproject.toml:^version"
] ]