Merge branch 'main' into litellm_imp_mem_use

This commit is contained in:
Ishaan Jaff 2024-03-11 19:00:56 -07:00 committed by GitHub
commit 89ef2023e9
10 changed files with 296 additions and 65 deletions

View file

@ -12,7 +12,6 @@ from typing import Any, Literal, Union, BinaryIO
from functools import partial from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy from copy import deepcopy
import httpx import httpx
import litellm import litellm
from ._logging import verbose_logger from ._logging import verbose_logger

View file

@ -2103,12 +2103,14 @@ async def generate_key_helper_fn(
return key_data return key_data
async def delete_verification_token(tokens: List): async def delete_verification_token(tokens: List, user_id: Optional[str] = None):
global prisma_client global prisma_client
try: try:
if prisma_client: if prisma_client:
# Assuming 'db' is your Prisma Client instance # Assuming 'db' is your Prisma Client instance
deleted_tokens = await prisma_client.delete_data(tokens=tokens) deleted_tokens = await prisma_client.delete_data(
tokens=tokens, user_id=user_id
)
else: else:
raise Exception raise Exception
except Exception as e: except Exception as e:
@ -3744,7 +3746,10 @@ async def update_key_fn(request: Request, data: UpdateKeyRequest):
@router.post( @router.post(
"/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)] "/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)]
) )
async def delete_key_fn(data: KeyRequest): async def delete_key_fn(
data: KeyRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
""" """
Delete a key from the key management system. Delete a key from the key management system.
@ -3769,11 +3774,33 @@ async def delete_key_fn(data: KeyRequest):
code=status.HTTP_400_BAD_REQUEST, code=status.HTTP_400_BAD_REQUEST,
) )
result = await delete_verification_token(tokens=keys) ## only allow user to delete keys they own
verbose_proxy_logger.debug("/key/delete - deleted_keys=", result) user_id = user_api_key_dict.user_id
verbose_proxy_logger.debug(
f"user_api_key_dict.user_role: {user_api_key_dict.user_role}"
)
if (
user_api_key_dict.user_role is not None
and user_api_key_dict.user_role == "proxy_admin"
):
user_id = None # unless they're admin
number_deleted_keys = len(result["deleted_keys"]) number_deleted_keys = await delete_verification_token(
assert len(keys) == number_deleted_keys tokens=keys, user_id=user_id
)
verbose_proxy_logger.debug(
f"/key/delete - deleted_keys={number_deleted_keys['deleted_keys']}"
)
try:
assert len(keys) == number_deleted_keys["deleted_keys"]
except Exception as e:
raise HTTPException(
status_code=400,
detail={
"error": "Not all keys passed in were deleted. This probably means you don't have access to delete all the keys passed in."
},
)
for key in keys: for key in keys:
user_api_key_cache.delete_cache(key) user_api_key_cache.delete_cache(key)
@ -6529,8 +6556,6 @@ async def login(request: Request):
algorithm="HS256", algorithm="HS256",
) )
litellm_dashboard_ui += "?userID=" + user_id + "&token=" + jwt_token litellm_dashboard_ui += "?userID=" + user_id + "&token=" + jwt_token
# if a user has logged in they should be allowed to create keys - this ensures that it's set to True
general_settings["allow_user_auth"] = True
return RedirectResponse(url=litellm_dashboard_ui, status_code=303) return RedirectResponse(url=litellm_dashboard_ui, status_code=303)
else: else:
raise ProxyException( raise ProxyException(

View file

@ -1356,9 +1356,12 @@ class PrismaClient:
tokens: Optional[List] = None, tokens: Optional[List] = None,
team_id_list: Optional[List] = None, team_id_list: Optional[List] = None,
table_name: Optional[Literal["user", "key", "config", "spend", "team"]] = None, table_name: Optional[Literal["user", "key", "config", "spend", "team"]] = None,
user_id: Optional[str] = None,
): ):
""" """
Allow user to delete a key(s) Allow user to delete a key(s)
Ensure user owns that key, unless admin.
""" """
try: try:
if tokens is not None and isinstance(tokens, List): if tokens is not None and isinstance(tokens, List):
@ -1369,15 +1372,25 @@ class PrismaClient:
else: else:
hashed_token = token hashed_token = token
hashed_tokens.append(hashed_token) hashed_tokens.append(hashed_token)
await self.db.litellm_verificationtoken.delete_many( filter_query: dict = {}
where={"token": {"in": hashed_tokens}} if user_id is not None:
filter_query = {
"AND": [{"token": {"in": hashed_tokens}}, {"user_id": user_id}]
}
else:
filter_query = {"token": {"in": hashed_tokens}}
deleted_tokens = await self.db.litellm_verificationtoken.delete_many(
where=filter_query # type: ignore
) )
return {"deleted_keys": tokens} verbose_proxy_logger.debug(f"deleted_tokens: {deleted_tokens}")
return {"deleted_keys": deleted_tokens}
elif ( elif (
table_name == "team" table_name == "team"
and team_id_list is not None and team_id_list is not None
and isinstance(team_id_list, List) and isinstance(team_id_list, List)
): ):
# admin only endpoint -> `/team/delete`
await self.db.litellm_teamtable.delete_many( await self.db.litellm_teamtable.delete_many(
where={"team_id": {"in": team_id_list}} where={"team_id": {"in": team_id_list}}
) )
@ -1387,6 +1400,7 @@ class PrismaClient:
and team_id_list is not None and team_id_list is not None
and isinstance(team_id_list, List) and isinstance(team_id_list, List)
): ):
# admin only endpoint -> `/team/delete`
await self.db.litellm_verificationtoken.delete_many( await self.db.litellm_verificationtoken.delete_many(
where={"team_id": {"in": team_id_list}} where={"team_id": {"in": team_id_list}}
) )

View file

@ -967,44 +967,81 @@ class Router:
is_async: Optional[bool] = False, is_async: Optional[bool] = False,
**kwargs, **kwargs,
) -> Union[List[float], None]: ) -> Union[List[float], None]:
# pick the one that is available (lowest TPM/RPM) try:
deployment = self.get_available_deployment( kwargs["model"] = model
model=model, kwargs["input"] = input
input=input, kwargs["original_function"] = self._embedding
specific_deployment=kwargs.pop("specific_deployment", None), kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
) timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("model_info", {}) kwargs.setdefault("metadata", {}).update({"model_group": model})
kwargs.setdefault("metadata", {}).update( response = self.function_with_fallbacks(**kwargs)
{"model_group": model, "deployment": deployment["litellm_params"]["model"]} return response
) # [TODO]: move to using async_function_with_fallbacks except Exception as e:
data = deployment["litellm_params"].copy() raise e
for k, v in self.default_litellm_params.items():
def _embedding(self, input: Union[str, List], model: str, **kwargs):
try:
verbose_router_logger.debug(
f"Inside embedding()- model: {model}; kwargs: {kwargs}"
)
deployment = self.get_available_deployment(
model=model,
input=input,
specific_deployment=kwargs.pop("specific_deployment", None),
)
kwargs.setdefault("metadata", {}).update(
{
"deployment": deployment["litellm_params"]["model"],
"model_info": deployment.get("model_info", {}),
}
)
kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy()
model_name = data["model"]
for k, v in self.default_litellm_params.items():
if (
k not in kwargs
): # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
potential_model_client = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="sync"
)
# check if provided keys == client keys #
dynamic_api_key = kwargs.get("api_key", None)
if ( if (
k not in kwargs dynamic_api_key is not None
): # prioritize model-specific params > default router params and potential_model_client is not None
kwargs[k] = v and dynamic_api_key != potential_model_client.api_key
elif k == "metadata": ):
kwargs[k].update(v) model_client = None
potential_model_client = self._get_client(deployment=deployment, kwargs=kwargs) else:
# check if provided keys == client keys # model_client = potential_model_client
dynamic_api_key = kwargs.get("api_key", None)
if ( self.total_calls[model_name] += 1
dynamic_api_key is not None response = litellm.embedding(
and potential_model_client is not None **{
and dynamic_api_key != potential_model_client.api_key **data,
): "input": input,
model_client = None "caching": self.cache_responses,
else: "client": model_client,
model_client = potential_model_client **kwargs,
return litellm.embedding( }
**{ )
**data, self.success_calls[model_name] += 1
"input": input, verbose_router_logger.info(
"caching": self.cache_responses, f"litellm.embedding(model={model_name})\033[32m 200 OK\033[0m"
"client": model_client, )
**kwargs, return response
} except Exception as e:
) verbose_router_logger.info(
f"litellm.embedding(model={model_name})\033[31m Exception {str(e)}\033[0m"
)
if model_name is not None:
self.fail_calls[model_name] += 1
raise e
async def aembedding( async def aembedding(
self, self,

View file

@ -722,6 +722,7 @@ def test_delete_key(prisma_client):
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm.proxy.proxy_server, "user_custom_auth", None)
try: try:
async def test(): async def test():
@ -737,8 +738,19 @@ def test_delete_key(prisma_client):
delete_key_request = KeyRequest(keys=[generated_key]) delete_key_request = KeyRequest(keys=[generated_key])
bearer_token = "Bearer sk-1234"
request = Request(scope={"type": "http"})
request._url = URL(url="/key/delete")
# use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token)
print(f"result: {result}")
result.user_role = "proxy_admin"
# delete the key # delete the key
result_delete_key = await delete_key_fn(data=delete_key_request) result_delete_key = await delete_key_fn(
data=delete_key_request, user_api_key_dict=result
)
print("result from delete key", result_delete_key) print("result from delete key", result_delete_key)
assert result_delete_key == {"deleted_keys": [generated_key]} assert result_delete_key == {"deleted_keys": [generated_key]}
@ -776,7 +788,19 @@ def test_delete_key_auth(prisma_client):
delete_key_request = KeyRequest(keys=[generated_key]) delete_key_request = KeyRequest(keys=[generated_key])
# delete the key # delete the key
result_delete_key = await delete_key_fn(data=delete_key_request) bearer_token = "Bearer sk-1234"
request = Request(scope={"type": "http"})
request._url = URL(url="/key/delete")
# use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token)
print(f"result: {result}")
result.user_role = "proxy_admin"
result_delete_key = await delete_key_fn(
data=delete_key_request, user_api_key_dict=result
)
print("result from delete key", result_delete_key) print("result from delete key", result_delete_key)
assert result_delete_key == {"deleted_keys": [generated_key]} assert result_delete_key == {"deleted_keys": [generated_key]}
@ -791,6 +815,7 @@ def test_delete_key_auth(prisma_client):
) )
# use generated key to auth in # use generated key to auth in
bearer_token = "Bearer " + generated_key
result = await user_api_key_auth(request=request, api_key=bearer_token) result = await user_api_key_auth(request=request, api_key=bearer_token)
print("got result", result) print("got result", result)
pytest.fail(f"This should have failed!. IT's an invalid key") pytest.fail(f"This should have failed!. IT's an invalid key")
@ -835,9 +860,19 @@ def test_generate_and_call_key_info(prisma_client):
# cleanup - delete key # cleanup - delete key
delete_key_request = KeyRequest(keys=[generated_key]) delete_key_request = KeyRequest(keys=[generated_key])
bearer_token = "Bearer sk-1234"
# delete the key request = Request(scope={"type": "http"})
await delete_key_fn(data=delete_key_request) request._url = URL(url="/key/delete")
# use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token)
print(f"result: {result}")
result.user_role = "proxy_admin"
result_delete_key = await delete_key_fn(
data=delete_key_request, user_api_key_dict=result
)
asyncio.run(test()) asyncio.run(test())
except Exception as e: except Exception as e:
@ -916,7 +951,19 @@ def test_generate_and_update_key(prisma_client):
delete_key_request = KeyRequest(keys=[generated_key]) delete_key_request = KeyRequest(keys=[generated_key])
# delete the key # delete the key
await delete_key_fn(data=delete_key_request) bearer_token = "Bearer sk-1234"
request = Request(scope={"type": "http"})
request._url = URL(url="/key/delete")
# use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token)
print(f"result: {result}")
result.user_role = "proxy_admin"
result_delete_key = await delete_key_fn(
data=delete_key_request, user_api_key_dict=result
)
asyncio.run(test()) asyncio.run(test())
except Exception as e: except Exception as e:

View file

@ -85,7 +85,7 @@
# async def main(): # async def main():
# for i in range(1): # for i in range(1):
# start = time.time() # start = time.time()
# n = 20 # Number of concurrent tasks # n = 15 # Number of concurrent tasks
# tasks = [router_acompletion() for _ in range(n)] # tasks = [router_acompletion() for _ in range(n)]
# chat_completions = await asyncio.gather(*tasks) # chat_completions = await asyncio.gather(*tasks)

View file

@ -227,6 +227,57 @@ async def test_async_fallbacks():
# test_async_fallbacks() # test_async_fallbacks()
def test_sync_fallbacks_embeddings():
litellm.set_verbose = False
model_list = [
{ # list of model deployments
"model_name": "bad-azure-embedding-model", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/azure-embedding-model",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"tpm": 240000,
"rpm": 1800,
},
{ # list of model deployments
"model_name": "good-azure-embedding-model", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/azure-embedding-model",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"tpm": 240000,
"rpm": 1800,
},
]
router = Router(
model_list=model_list,
fallbacks=[{"bad-azure-embedding-model": ["good-azure-embedding-model"]}],
set_verbose=False,
)
customHandler = MyCustomHandler()
litellm.callbacks = [customHandler]
user_message = "Hello, how are you?"
input = [user_message]
try:
kwargs = {"model": "bad-azure-embedding-model", "input": input}
response = router.embedding(**kwargs)
print(f"customHandler.previous_models: {customHandler.previous_models}")
time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread
assert customHandler.previous_models == 1 # 0 retries, 1 fallback
router.reset()
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
finally:
router.reset()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_fallbacks_embeddings(): async def test_async_fallbacks_embeddings():
litellm.set_verbose = False litellm.set_verbose = False

View file

@ -0,0 +1,56 @@
# [LOCAL TEST] - runs against mock openai proxy
# # What this tests?
# ## This tests if fallbacks works for 429 errors
# import sys, os, time
# import traceback, asyncio
# import pytest
# sys.path.insert(
# 0, os.path.abspath("../..")
# ) # Adds the parent directory to the system path
# import litellm
# from litellm import Router
# model_list = [
# { # list of model deployments
# "model_name": "text-embedding-ada-002", # model alias
# "litellm_params": { # params for litellm completion/embedding call
# "model": "text-embedding-ada-002", # actual model name
# "api_key": "sk-fakekey",
# "api_base": "http://0.0.0.0:8080",
# },
# "tpm": 1000,
# "rpm": 6,
# },
# {
# "model_name": "text-embedding-ada-002-fallback",
# "litellm_params": { # params for litellm completion/embedding call
# "model": "openai/text-embedding-ada-002-anything-else", # actual model name
# "api_key": "sk-fakekey2",
# "api_base": "http://0.0.0.0:8080",
# },
# "tpm": 1000,
# "rpm": 6,
# },
# ]
# router = Router(
# model_list=model_list,
# fallbacks=[
# {"text-embedding-ada-002": ["text-embedding-ada-002-fallback"]},
# {"text-embedding-ada-002-fallback": ["text-embedding-ada-002"]},
# ],
# set_verbose=True,
# num_retries=0,
# debug_level="INFO",
# routing_strategy="usage-based-routing",
# )
# def test_embedding_with_fallbacks():
# response = router.embedding(model="text-embedding-ada-002", input=["Hello world"])
# print(f"response: {response}")
# test_embedding_with_fallbacks()

View file

@ -1,18 +1,16 @@
model_list: model_list:
# NOTE: This is the default config users use with Dockerfile.
# DO not expect users to pass os.environ/<> vars here, this will lead to proxy startup failing for them if they don't have the expected env vars
- model_name: gpt-3.5-turbo - model_name: gpt-3.5-turbo
litellm_params: litellm_params:
model: azure/chatgpt-v-2 model: azure/chatgpt-v-2
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
api_version: "2023-05-15" api_version: "2023-05-15"
api_key: sk-defaultKey # use `os.environ/AZURE_API_KEY` for production. The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault
- model_name: gpt-4 - model_name: gpt-4
litellm_params: litellm_params:
model: azure/chatgpt-v-2 model: azure/chatgpt-v-2
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
api_version: "2023-05-15" api_version: "2023-05-15"
api_key: sk-defaultKey # use `os.environ/AZURE_API_KEY` for production. The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault
- model_name: sagemaker-completion-model - model_name: sagemaker-completion-model
litellm_params: litellm_params:
model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4 model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4
@ -20,7 +18,7 @@ model_list:
- model_name: text-embedding-ada-002 - model_name: text-embedding-ada-002
litellm_params: litellm_params:
model: azure/azure-embedding-model model: azure/azure-embedding-model
api_key: sk-defaultKey # use `os.environ/AZURE_API_KEY` for production. The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault api_key: os.environ/AZURE_API_KEY
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
api_version: "2023-05-15" api_version: "2023-05-15"
model_info: model_info:
@ -28,10 +26,13 @@ model_list:
base_model: text-embedding-ada-002 base_model: text-embedding-ada-002
- model_name: dall-e-2 - model_name: dall-e-2
litellm_params: litellm_params:
model: azure/dall-e-2 model: azure/
api_version: 2023-06-01-preview api_version: 2023-06-01-preview
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
api_key: sk-defaultKey # use `os.environ/AZURE_API_KEY` for production. The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault api_key: os.environ/AZURE_API_KEY
- model_name: openai-dall-e-3
litellm_params:
model: dall-e-3
litellm_settings: litellm_settings:
drop_params: True drop_params: True
@ -39,7 +40,7 @@ litellm_settings:
budget_duration: 30d budget_duration: 30d
num_retries: 5 num_retries: 5
request_timeout: 600 request_timeout: 600
general_settings: general_settings:
master_key: sk-1234 # [OPTIONAL] Only use this if you to require all calls to contain this key (Authorization: Bearer sk-1234) master_key: sk-1234 # [OPTIONAL] Only use this if you to require all calls to contain this key (Authorization: Bearer sk-1234)
proxy_budget_rescheduler_min_time: 60 proxy_budget_rescheduler_min_time: 60
proxy_budget_rescheduler_max_time: 64 proxy_budget_rescheduler_max_time: 64

View file

@ -22,7 +22,6 @@ click = "*"
jinja2 = "^3.1.2" jinja2 = "^3.1.2"
aiohttp = "*" aiohttp = "*"
requests = "^2.31.0" requests = "^2.31.0"
argon2-cffi = "^23.1.0"
uvicorn = {version = "^0.22.0", optional = true} uvicorn = {version = "^0.22.0", optional = true}
gunicorn = {version = "^21.2.0", optional = true} gunicorn = {version = "^21.2.0", optional = true}
@ -36,6 +35,7 @@ streamlit = {version = "^1.29.0", optional = true}
fastapi-sso = { version = "^0.10.0", optional = true } fastapi-sso = { version = "^0.10.0", optional = true }
PyJWT = { version = "^2.8.0", optional = true } PyJWT = { version = "^2.8.0", optional = true }
python-multipart = { version = "^0.0.6", optional = true } python-multipart = { version = "^0.0.6", optional = true }
argon2-cffi = { version = "^23.1.0", optional = true }
[tool.poetry.extras] [tool.poetry.extras]
proxy = [ proxy = [
@ -50,6 +50,7 @@ proxy = [
"fastapi-sso", "fastapi-sso",
"PyJWT", "PyJWT",
"python-multipart", "python-multipart",
"argon2-cffi",
] ]
extra_proxy = [ extra_proxy = [