Merge branch 'main' into litellm_banned_keywords_list

This commit is contained in:
Krish Dholakia 2024-02-22 22:20:59 -08:00 committed by GitHub
commit b6a05cb787
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 259 additions and 60 deletions

View file

@ -152,7 +152,14 @@ GENERIC_SCOPE = "openid profile email" # default scope openid is sometimes not e
</Tabs> </Tabs>
#### Step 3. Test flow #### Step 3. Set `PROXY_BASE_URL` in your .env
Set this in your .env (so the proxy can set the correct redirect url)
```shell
PROXY_BASE_URL=https://litellm-api.up.railway.app/
```
#### Step 4. Test flow
<Image img={require('../../img/litellm_ui_3.gif')} /> <Image img={require('../../img/litellm_ui_3.gif')} />
### Set Admin view w/ SSO ### Set Admin view w/ SSO

View file

@ -279,9 +279,9 @@ curl 'http://0.0.0.0:8000/key/generate' \
## Set Rate Limits ## Set Rate Limits
You can set: You can set:
- tpm limits (tokens per minute)
- rpm limits (requests per minute)
- max parallel requests - max parallel requests
- tpm limits
- rpm limits
<Tabs> <Tabs>
<TabItem value="per-user" label="Per User"> <TabItem value="per-user" label="Per User">

View file

@ -559,8 +559,7 @@ def completion(
f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n" f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n"
) )
response = llm_model.predict( response = llm_model.predict(
endpoint=endpoint_path, endpoint=endpoint_path, instances=instances
instances=instances
).predictions ).predictions
completion_response = response[0] completion_response = response[0]
@ -585,12 +584,8 @@ def completion(
"request_str": request_str, "request_str": request_str,
}, },
) )
request_str += ( request_str += f"llm_model.predict(instances={instances})\n"
f"llm_model.predict(instances={instances})\n" response = llm_model.predict(instances=instances).predictions
)
response = llm_model.predict(
instances=instances
).predictions
completion_response = response[0] completion_response = response[0]
if ( if (
@ -614,7 +609,6 @@ def completion(
model_response["choices"][0]["message"]["content"] = str( model_response["choices"][0]["message"]["content"] = str(
completion_response completion_response
) )
model_response["choices"][0]["message"]["content"] = str(completion_response)
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = model model_response["model"] = model
## CALCULATING USAGE ## CALCULATING USAGE
@ -766,6 +760,7 @@ async def async_completion(
Vertex AI Model Garden Vertex AI Model Garden
""" """
from google.cloud import aiplatform from google.cloud import aiplatform
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
@ -797,11 +792,9 @@ async def async_completion(
and "\nOutput:\n" in completion_response and "\nOutput:\n" in completion_response
): ):
completion_response = completion_response.split("\nOutput:\n", 1)[1] completion_response = completion_response.split("\nOutput:\n", 1)[1]
elif mode == "private": elif mode == "private":
request_str += ( request_str += f"llm_model.predict_async(instances={instances})\n"
f"llm_model.predict_async(instances={instances})\n"
)
response_obj = await llm_model.predict_async( response_obj = await llm_model.predict_async(
instances=instances, instances=instances,
) )
@ -826,7 +819,6 @@ async def async_completion(
model_response["choices"][0]["message"]["content"] = str( model_response["choices"][0]["message"]["content"] = str(
completion_response completion_response
) )
model_response["choices"][0]["message"]["content"] = str(completion_response)
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = model model_response["model"] = model
## CALCULATING USAGE ## CALCULATING USAGE
@ -954,6 +946,7 @@ async def async_streaming(
response = llm_model.predict_streaming_async(prompt, **optional_params) response = llm_model.predict_streaming_async(prompt, **optional_params)
elif mode == "custom": elif mode == "custom":
from google.cloud import aiplatform from google.cloud import aiplatform
stream = optional_params.pop("stream", None) stream = optional_params.pop("stream", None)
## LOGGING ## LOGGING
@ -972,7 +965,9 @@ async def async_streaming(
endpoint_path = llm_model.endpoint_path( endpoint_path = llm_model.endpoint_path(
project=vertex_project, location=vertex_location, endpoint=model project=vertex_project, location=vertex_location, endpoint=model
) )
request_str += f"client.predict(endpoint={endpoint_path}, instances={instances})\n" request_str += (
f"client.predict(endpoint={endpoint_path}, instances={instances})\n"
)
response_obj = await llm_model.predict( response_obj = await llm_model.predict(
endpoint=endpoint_path, endpoint=endpoint_path,
instances=instances, instances=instances,

View file

@ -12,7 +12,6 @@ from typing import Any, Literal, Union
from functools import partial from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy from copy import deepcopy
import httpx import httpx
import litellm import litellm
from ._logging import verbose_logger from ._logging import verbose_logger

View file

@ -424,6 +424,10 @@ class LiteLLM_VerificationToken(LiteLLMBase):
model_spend: Dict = {} model_spend: Dict = {}
model_max_budget: Dict = {} model_max_budget: Dict = {}
# hidden params used for parallel request limiting, not required to create a token
user_id_rate_limits: Optional[dict] = None
team_id_rate_limits: Optional[dict] = None
class Config: class Config:
protected_namespaces = () protected_namespaces = ()

View file

@ -24,46 +24,21 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
except: except:
pass pass
async def async_pre_call_hook( async def check_key_in_limits(
self, self,
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
cache: DualCache, cache: DualCache,
data: dict, data: dict,
call_type: str, call_type: str,
max_parallel_requests: int,
tpm_limit: int,
rpm_limit: int,
request_count_api_key: str,
): ):
self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook")
api_key = user_api_key_dict.api_key
max_parallel_requests = user_api_key_dict.max_parallel_requests or sys.maxsize
tpm_limit = user_api_key_dict.tpm_limit or sys.maxsize
rpm_limit = user_api_key_dict.rpm_limit or sys.maxsize
if api_key is None:
return
if (
max_parallel_requests == sys.maxsize
and tpm_limit == sys.maxsize
and rpm_limit == sys.maxsize
):
return
self.user_api_key_cache = cache # save the api key cache for updating the value
# ------------
# Setup values
# ------------
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{api_key}::{precise_minute}::request_count"
# CHECK IF REQUEST ALLOWED
current = cache.get_cache( current = cache.get_cache(
key=request_count_api_key key=request_count_api_key
) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10} ) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
self.print_verbose(f"current: {current}") # print(f"current: {current}")
if current is None: if current is None:
new_val = { new_val = {
"current_requests": 1, "current_requests": 1,
@ -88,10 +63,107 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
status_code=429, detail="Max parallel request limit reached." status_code=429, detail="Max parallel request limit reached."
) )
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str,
):
self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook")
api_key = user_api_key_dict.api_key
max_parallel_requests = user_api_key_dict.max_parallel_requests or sys.maxsize
tpm_limit = user_api_key_dict.tpm_limit or sys.maxsize
rpm_limit = user_api_key_dict.rpm_limit or sys.maxsize
if api_key is None:
return
self.user_api_key_cache = cache # save the api key cache for updating the value
# ------------
# Setup values
# ------------
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{api_key}::{precise_minute}::request_count"
# CHECK IF REQUEST ALLOWED for key
current = cache.get_cache(
key=request_count_api_key
) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
self.print_verbose(f"current: {current}")
if (
max_parallel_requests == sys.maxsize
and tpm_limit == sys.maxsize
and rpm_limit == sys.maxsize
):
pass
elif current is None:
new_val = {
"current_requests": 1,
"current_tpm": 0,
"current_rpm": 0,
}
cache.set_cache(request_count_api_key, new_val)
elif (
int(current["current_requests"]) < max_parallel_requests
and current["current_tpm"] < tpm_limit
and current["current_rpm"] < rpm_limit
):
# Increase count for this token
new_val = {
"current_requests": current["current_requests"] + 1,
"current_tpm": current["current_tpm"],
"current_rpm": current["current_rpm"],
}
cache.set_cache(request_count_api_key, new_val)
else:
raise HTTPException(
status_code=429, detail="Max parallel request limit reached."
)
# check if REQUEST ALLOWED for user_id
user_id = user_api_key_dict.user_id
_user_id_rate_limits = user_api_key_dict.user_id_rate_limits
# get user tpm/rpm limits
if _user_id_rate_limits is None or _user_id_rate_limits == {}:
return
user_tpm_limit = _user_id_rate_limits.get("tpm_limit")
user_rpm_limit = _user_id_rate_limits.get("rpm_limit")
if user_tpm_limit is None:
user_tpm_limit = sys.maxsize
if user_rpm_limit is None:
user_rpm_limit = sys.maxsize
# now do the same tpm/rpm checks
request_count_api_key = f"{user_id}::{precise_minute}::request_count"
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
await self.check_key_in_limits(
user_api_key_dict=user_api_key_dict,
cache=cache,
data=data,
call_type=call_type,
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user
request_count_api_key=request_count_api_key,
tpm_limit=user_tpm_limit,
rpm_limit=user_rpm_limit,
)
return
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try: try:
self.print_verbose(f"INSIDE parallel request limiter ASYNC SUCCESS LOGGING") self.print_verbose(f"INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"] user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"]
user_api_key_user_id = kwargs["litellm_params"]["metadata"].get(
"user_api_key_user_id", None
)
if user_api_key is None: if user_api_key is None:
return return
@ -121,7 +193,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
} }
# ------------ # ------------
# Update usage # Update usage - API Key
# ------------ # ------------
new_val = { new_val = {
@ -136,6 +208,41 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
self.user_api_key_cache.set_cache( self.user_api_key_cache.set_cache(
request_count_api_key, new_val, ttl=60 request_count_api_key, new_val, ttl=60
) # store in cache for 1 min. ) # store in cache for 1 min.
# ------------
# Update usage - User
# ------------
if user_api_key_user_id is None:
return
total_tokens = 0
if isinstance(response_obj, ModelResponse):
total_tokens = response_obj.usage.total_tokens
request_count_api_key = (
f"{user_api_key_user_id}::{precise_minute}::request_count"
)
current = self.user_api_key_cache.get_cache(key=request_count_api_key) or {
"current_requests": 1,
"current_tpm": total_tokens,
"current_rpm": 1,
}
new_val = {
"current_requests": max(current["current_requests"] - 1, 0),
"current_tpm": current["current_tpm"] + total_tokens,
"current_rpm": current["current_rpm"] + 1,
}
self.print_verbose(
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
)
self.user_api_key_cache.set_cache(
request_count_api_key, new_val, ttl=60
) # store in cache for 1 min.
except Exception as e: except Exception as e:
self.print_verbose(e) # noqa self.print_verbose(e) # noqa

View file

@ -4388,7 +4388,20 @@ async def update_team(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
): ):
""" """
add new members to the team You can now add / delete users from a team via /team/update
```
curl --location 'http://0.0.0.0:8000/team/update' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data-raw '{
"team_id": "45e3e396-ee08-4a61-a88e-16b3ce7e0849",
"members_with_roles": [{"role": "admin", "user_id": "5c4a0aa3-a1e1-43dc-bd87-3c2da8382a3a"}, {"role": "user", "user_id": "krrish247652@berri.ai"}]
}'
```
""" """
global prisma_client global prisma_client
@ -4469,6 +4482,18 @@ async def delete_team(
): ):
""" """
delete team and associated team keys delete team and associated team keys
```
curl --location 'http://0.0.0.0:8000/team/delete' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data-raw '{
"team_ids": ["45e3e396-ee08-4a61-a88e-16b3ce7e0849"]
}'
```
""" """
global prisma_client global prisma_client

View file

@ -318,7 +318,7 @@ def test_gemini_pro_vision():
# test_gemini_pro_vision() # test_gemini_pro_vision()
def gemini_pro_function_calling(): def test_gemini_pro_function_calling():
load_vertex_ai_credentials() load_vertex_ai_credentials()
tools = [ tools = [
{ {
@ -345,12 +345,15 @@ def gemini_pro_function_calling():
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto" model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
) )
print(f"completion: {completion}") print(f"completion: {completion}")
assert completion.choices[0].message.content is None
assert len(completion.choices[0].message.tool_calls) == 1
# gemini_pro_function_calling() # gemini_pro_function_calling()
async def gemini_pro_async_function_calling(): @pytest.mark.asyncio
async def test_gemini_pro_async_function_calling():
load_vertex_ai_credentials() load_vertex_ai_credentials()
tools = [ tools = [
{ {
@ -377,6 +380,9 @@ async def gemini_pro_async_function_calling():
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto" model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
) )
print(f"completion: {completion}") print(f"completion: {completion}")
assert completion.choices[0].message.content is None
assert len(completion.choices[0].message.tool_calls) == 1
# raise Exception("it worked!")
# asyncio.run(gemini_pro_async_function_calling()) # asyncio.run(gemini_pro_async_function_calling())

View file

@ -1320,6 +1320,7 @@ def test_completion_together_ai():
max_tokens=256, max_tokens=256,
n=1, n=1,
logger_fn=logger_fn, logger_fn=logger_fn,
timeout=1,
) )
# Add any assertions here to check the response # Add any assertions here to check the response
print(response) print(response)
@ -1330,6 +1331,7 @@ def test_completion_together_ai():
f"${float(cost):.10f}", f"${float(cost):.10f}",
) )
except litellm.Timeout as e: except litellm.Timeout as e:
print("got a timeout error")
pass pass
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")

View file

@ -139,6 +139,56 @@ async def test_pre_call_hook_tpm_limits():
assert e.status_code == 429 assert e.status_code == 429
@pytest.mark.asyncio
async def test_pre_call_hook_user_tpm_limits():
"""
Test if error raised on hitting tpm limits
"""
# create user with tpm/rpm limits
_api_key = "sk-12345"
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key,
user_id="ishaan",
user_id_rate_limits={"tpm_limit": 9, "rpm_limit": 10},
)
res = dict(user_api_key_dict)
print("dict user", res)
local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler()
await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
)
kwargs = {
"litellm_params": {
"metadata": {"user_api_key_user_id": "ishaan", "user_api_key": "gm"}
}
}
await parallel_request_handler.async_log_success_event(
kwargs=kwargs,
response_obj=litellm.ModelResponse(usage=litellm.Usage(total_tokens=10)),
start_time="",
end_time="",
)
## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1}
try:
await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=local_cache,
data={},
call_type="",
)
pytest.fail(f"Expected call to fail")
except Exception as e:
assert e.status_code == 429
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_success_call_hook(): async def test_success_call_hook():
""" """

View file

@ -38,7 +38,9 @@ import time
# test_promptlayer_logging() # test_promptlayer_logging()
@pytest.mark.skip(reason="ci/cd issues. works locally") @pytest.mark.skip(
reason="this works locally but fails on ci/cd since ci/cd is not reading the stdout correctly"
)
def test_promptlayer_logging_with_metadata(): def test_promptlayer_logging_with_metadata():
try: try:
# Redirect stdout # Redirect stdout
@ -67,7 +69,9 @@ def test_promptlayer_logging_with_metadata():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@pytest.mark.skip(reason="ci/cd issues. works locally") @pytest.mark.skip(
reason="this works locally but fails on ci/cd since ci/cd is not reading the stdout correctly"
)
def test_promptlayer_logging_with_metadata_tags(): def test_promptlayer_logging_with_metadata_tags():
try: try:
# Redirect stdout # Redirect stdout

View file

@ -4274,8 +4274,8 @@ def get_optional_params(
optional_params["stop_sequences"] = stop optional_params["stop_sequences"] = stop
if max_tokens is not None: if max_tokens is not None:
optional_params["max_output_tokens"] = max_tokens optional_params["max_output_tokens"] = max_tokens
elif custom_llm_provider == "vertex_ai" and model in ( elif custom_llm_provider == "vertex_ai" and (
litellm.vertex_chat_models model in litellm.vertex_chat_models
or model in litellm.vertex_code_chat_models or model in litellm.vertex_code_chat_models
or model in litellm.vertex_text_models or model in litellm.vertex_text_models
or model in litellm.vertex_code_text_models or model in litellm.vertex_code_text_models

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "1.26.8" version = "1.26.10"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT" license = "MIT"
@ -74,7 +74,7 @@ requires = ["poetry-core", "wheel"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.commitizen] [tool.commitizen]
version = "1.26.8" version = "1.26.10"
version_files = [ version_files = [
"pyproject.toml:^version" "pyproject.toml:^version"
] ]