mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Merge branch 'main' into litellm_banned_keywords_list
This commit is contained in:
commit
b6a05cb787
13 changed files with 259 additions and 60 deletions
|
@ -152,7 +152,14 @@ GENERIC_SCOPE = "openid profile email" # default scope openid is sometimes not e
|
||||||
|
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
#### Step 3. Test flow
|
#### Step 3. Set `PROXY_BASE_URL` in your .env
|
||||||
|
|
||||||
|
Set this in your .env (so the proxy can set the correct redirect url)
|
||||||
|
```shell
|
||||||
|
PROXY_BASE_URL=https://litellm-api.up.railway.app/
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Step 4. Test flow
|
||||||
<Image img={require('../../img/litellm_ui_3.gif')} />
|
<Image img={require('../../img/litellm_ui_3.gif')} />
|
||||||
|
|
||||||
### Set Admin view w/ SSO
|
### Set Admin view w/ SSO
|
||||||
|
|
|
@ -279,9 +279,9 @@ curl 'http://0.0.0.0:8000/key/generate' \
|
||||||
## Set Rate Limits
|
## Set Rate Limits
|
||||||
|
|
||||||
You can set:
|
You can set:
|
||||||
|
- tpm limits (tokens per minute)
|
||||||
|
- rpm limits (requests per minute)
|
||||||
- max parallel requests
|
- max parallel requests
|
||||||
- tpm limits
|
|
||||||
- rpm limits
|
|
||||||
|
|
||||||
<Tabs>
|
<Tabs>
|
||||||
<TabItem value="per-user" label="Per User">
|
<TabItem value="per-user" label="Per User">
|
||||||
|
|
|
@ -559,8 +559,7 @@ def completion(
|
||||||
f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n"
|
f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n"
|
||||||
)
|
)
|
||||||
response = llm_model.predict(
|
response = llm_model.predict(
|
||||||
endpoint=endpoint_path,
|
endpoint=endpoint_path, instances=instances
|
||||||
instances=instances
|
|
||||||
).predictions
|
).predictions
|
||||||
|
|
||||||
completion_response = response[0]
|
completion_response = response[0]
|
||||||
|
@ -585,12 +584,8 @@ def completion(
|
||||||
"request_str": request_str,
|
"request_str": request_str,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
request_str += (
|
request_str += f"llm_model.predict(instances={instances})\n"
|
||||||
f"llm_model.predict(instances={instances})\n"
|
response = llm_model.predict(instances=instances).predictions
|
||||||
)
|
|
||||||
response = llm_model.predict(
|
|
||||||
instances=instances
|
|
||||||
).predictions
|
|
||||||
|
|
||||||
completion_response = response[0]
|
completion_response = response[0]
|
||||||
if (
|
if (
|
||||||
|
@ -614,7 +609,6 @@ def completion(
|
||||||
model_response["choices"][0]["message"]["content"] = str(
|
model_response["choices"][0]["message"]["content"] = str(
|
||||||
completion_response
|
completion_response
|
||||||
)
|
)
|
||||||
model_response["choices"][0]["message"]["content"] = str(completion_response)
|
|
||||||
model_response["created"] = int(time.time())
|
model_response["created"] = int(time.time())
|
||||||
model_response["model"] = model
|
model_response["model"] = model
|
||||||
## CALCULATING USAGE
|
## CALCULATING USAGE
|
||||||
|
@ -766,6 +760,7 @@ async def async_completion(
|
||||||
Vertex AI Model Garden
|
Vertex AI Model Garden
|
||||||
"""
|
"""
|
||||||
from google.cloud import aiplatform
|
from google.cloud import aiplatform
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=prompt,
|
input=prompt,
|
||||||
|
@ -797,11 +792,9 @@ async def async_completion(
|
||||||
and "\nOutput:\n" in completion_response
|
and "\nOutput:\n" in completion_response
|
||||||
):
|
):
|
||||||
completion_response = completion_response.split("\nOutput:\n", 1)[1]
|
completion_response = completion_response.split("\nOutput:\n", 1)[1]
|
||||||
|
|
||||||
elif mode == "private":
|
elif mode == "private":
|
||||||
request_str += (
|
request_str += f"llm_model.predict_async(instances={instances})\n"
|
||||||
f"llm_model.predict_async(instances={instances})\n"
|
|
||||||
)
|
|
||||||
response_obj = await llm_model.predict_async(
|
response_obj = await llm_model.predict_async(
|
||||||
instances=instances,
|
instances=instances,
|
||||||
)
|
)
|
||||||
|
@ -826,7 +819,6 @@ async def async_completion(
|
||||||
model_response["choices"][0]["message"]["content"] = str(
|
model_response["choices"][0]["message"]["content"] = str(
|
||||||
completion_response
|
completion_response
|
||||||
)
|
)
|
||||||
model_response["choices"][0]["message"]["content"] = str(completion_response)
|
|
||||||
model_response["created"] = int(time.time())
|
model_response["created"] = int(time.time())
|
||||||
model_response["model"] = model
|
model_response["model"] = model
|
||||||
## CALCULATING USAGE
|
## CALCULATING USAGE
|
||||||
|
@ -954,6 +946,7 @@ async def async_streaming(
|
||||||
response = llm_model.predict_streaming_async(prompt, **optional_params)
|
response = llm_model.predict_streaming_async(prompt, **optional_params)
|
||||||
elif mode == "custom":
|
elif mode == "custom":
|
||||||
from google.cloud import aiplatform
|
from google.cloud import aiplatform
|
||||||
|
|
||||||
stream = optional_params.pop("stream", None)
|
stream = optional_params.pop("stream", None)
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
@ -972,7 +965,9 @@ async def async_streaming(
|
||||||
endpoint_path = llm_model.endpoint_path(
|
endpoint_path = llm_model.endpoint_path(
|
||||||
project=vertex_project, location=vertex_location, endpoint=model
|
project=vertex_project, location=vertex_location, endpoint=model
|
||||||
)
|
)
|
||||||
request_str += f"client.predict(endpoint={endpoint_path}, instances={instances})\n"
|
request_str += (
|
||||||
|
f"client.predict(endpoint={endpoint_path}, instances={instances})\n"
|
||||||
|
)
|
||||||
response_obj = await llm_model.predict(
|
response_obj = await llm_model.predict(
|
||||||
endpoint=endpoint_path,
|
endpoint=endpoint_path,
|
||||||
instances=instances,
|
instances=instances,
|
||||||
|
|
|
@ -12,7 +12,6 @@ from typing import Any, Literal, Union
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import dotenv, traceback, random, asyncio, time, contextvars
|
import dotenv, traceback, random, asyncio, time, contextvars
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import litellm
|
import litellm
|
||||||
from ._logging import verbose_logger
|
from ._logging import verbose_logger
|
||||||
|
|
|
@ -424,6 +424,10 @@ class LiteLLM_VerificationToken(LiteLLMBase):
|
||||||
model_spend: Dict = {}
|
model_spend: Dict = {}
|
||||||
model_max_budget: Dict = {}
|
model_max_budget: Dict = {}
|
||||||
|
|
||||||
|
# hidden params used for parallel request limiting, not required to create a token
|
||||||
|
user_id_rate_limits: Optional[dict] = None
|
||||||
|
team_id_rate_limits: Optional[dict] = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
protected_namespaces = ()
|
protected_namespaces = ()
|
||||||
|
|
||||||
|
|
|
@ -24,46 +24,21 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def async_pre_call_hook(
|
async def check_key_in_limits(
|
||||||
self,
|
self,
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
cache: DualCache,
|
cache: DualCache,
|
||||||
data: dict,
|
data: dict,
|
||||||
call_type: str,
|
call_type: str,
|
||||||
|
max_parallel_requests: int,
|
||||||
|
tpm_limit: int,
|
||||||
|
rpm_limit: int,
|
||||||
|
request_count_api_key: str,
|
||||||
):
|
):
|
||||||
self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook")
|
|
||||||
api_key = user_api_key_dict.api_key
|
|
||||||
max_parallel_requests = user_api_key_dict.max_parallel_requests or sys.maxsize
|
|
||||||
tpm_limit = user_api_key_dict.tpm_limit or sys.maxsize
|
|
||||||
rpm_limit = user_api_key_dict.rpm_limit or sys.maxsize
|
|
||||||
|
|
||||||
if api_key is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
if (
|
|
||||||
max_parallel_requests == sys.maxsize
|
|
||||||
and tpm_limit == sys.maxsize
|
|
||||||
and rpm_limit == sys.maxsize
|
|
||||||
):
|
|
||||||
return
|
|
||||||
|
|
||||||
self.user_api_key_cache = cache # save the api key cache for updating the value
|
|
||||||
# ------------
|
|
||||||
# Setup values
|
|
||||||
# ------------
|
|
||||||
|
|
||||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
|
||||||
current_hour = datetime.now().strftime("%H")
|
|
||||||
current_minute = datetime.now().strftime("%M")
|
|
||||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
|
||||||
|
|
||||||
request_count_api_key = f"{api_key}::{precise_minute}::request_count"
|
|
||||||
|
|
||||||
# CHECK IF REQUEST ALLOWED
|
|
||||||
current = cache.get_cache(
|
current = cache.get_cache(
|
||||||
key=request_count_api_key
|
key=request_count_api_key
|
||||||
) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
|
) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
|
||||||
self.print_verbose(f"current: {current}")
|
# print(f"current: {current}")
|
||||||
if current is None:
|
if current is None:
|
||||||
new_val = {
|
new_val = {
|
||||||
"current_requests": 1,
|
"current_requests": 1,
|
||||||
|
@ -88,10 +63,107 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
status_code=429, detail="Max parallel request limit reached."
|
status_code=429, detail="Max parallel request limit reached."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def async_pre_call_hook(
|
||||||
|
self,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
cache: DualCache,
|
||||||
|
data: dict,
|
||||||
|
call_type: str,
|
||||||
|
):
|
||||||
|
self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook")
|
||||||
|
api_key = user_api_key_dict.api_key
|
||||||
|
max_parallel_requests = user_api_key_dict.max_parallel_requests or sys.maxsize
|
||||||
|
tpm_limit = user_api_key_dict.tpm_limit or sys.maxsize
|
||||||
|
rpm_limit = user_api_key_dict.rpm_limit or sys.maxsize
|
||||||
|
|
||||||
|
if api_key is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.user_api_key_cache = cache # save the api key cache for updating the value
|
||||||
|
# ------------
|
||||||
|
# Setup values
|
||||||
|
# ------------
|
||||||
|
|
||||||
|
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||||
|
current_hour = datetime.now().strftime("%H")
|
||||||
|
current_minute = datetime.now().strftime("%M")
|
||||||
|
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||||
|
|
||||||
|
request_count_api_key = f"{api_key}::{precise_minute}::request_count"
|
||||||
|
|
||||||
|
# CHECK IF REQUEST ALLOWED for key
|
||||||
|
current = cache.get_cache(
|
||||||
|
key=request_count_api_key
|
||||||
|
) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
|
||||||
|
self.print_verbose(f"current: {current}")
|
||||||
|
if (
|
||||||
|
max_parallel_requests == sys.maxsize
|
||||||
|
and tpm_limit == sys.maxsize
|
||||||
|
and rpm_limit == sys.maxsize
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
elif current is None:
|
||||||
|
new_val = {
|
||||||
|
"current_requests": 1,
|
||||||
|
"current_tpm": 0,
|
||||||
|
"current_rpm": 0,
|
||||||
|
}
|
||||||
|
cache.set_cache(request_count_api_key, new_val)
|
||||||
|
elif (
|
||||||
|
int(current["current_requests"]) < max_parallel_requests
|
||||||
|
and current["current_tpm"] < tpm_limit
|
||||||
|
and current["current_rpm"] < rpm_limit
|
||||||
|
):
|
||||||
|
# Increase count for this token
|
||||||
|
new_val = {
|
||||||
|
"current_requests": current["current_requests"] + 1,
|
||||||
|
"current_tpm": current["current_tpm"],
|
||||||
|
"current_rpm": current["current_rpm"],
|
||||||
|
}
|
||||||
|
cache.set_cache(request_count_api_key, new_val)
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429, detail="Max parallel request limit reached."
|
||||||
|
)
|
||||||
|
|
||||||
|
# check if REQUEST ALLOWED for user_id
|
||||||
|
user_id = user_api_key_dict.user_id
|
||||||
|
_user_id_rate_limits = user_api_key_dict.user_id_rate_limits
|
||||||
|
|
||||||
|
# get user tpm/rpm limits
|
||||||
|
if _user_id_rate_limits is None or _user_id_rate_limits == {}:
|
||||||
|
return
|
||||||
|
user_tpm_limit = _user_id_rate_limits.get("tpm_limit")
|
||||||
|
user_rpm_limit = _user_id_rate_limits.get("rpm_limit")
|
||||||
|
if user_tpm_limit is None:
|
||||||
|
user_tpm_limit = sys.maxsize
|
||||||
|
if user_rpm_limit is None:
|
||||||
|
user_rpm_limit = sys.maxsize
|
||||||
|
|
||||||
|
# now do the same tpm/rpm checks
|
||||||
|
request_count_api_key = f"{user_id}::{precise_minute}::request_count"
|
||||||
|
|
||||||
|
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
|
||||||
|
await self.check_key_in_limits(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
cache=cache,
|
||||||
|
data=data,
|
||||||
|
call_type=call_type,
|
||||||
|
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user
|
||||||
|
request_count_api_key=request_count_api_key,
|
||||||
|
tpm_limit=user_tpm_limit,
|
||||||
|
rpm_limit=user_rpm_limit,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
try:
|
try:
|
||||||
self.print_verbose(f"INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
|
self.print_verbose(f"INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
|
||||||
user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"]
|
user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"]
|
||||||
|
user_api_key_user_id = kwargs["litellm_params"]["metadata"].get(
|
||||||
|
"user_api_key_user_id", None
|
||||||
|
)
|
||||||
|
|
||||||
if user_api_key is None:
|
if user_api_key is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -121,7 +193,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
}
|
}
|
||||||
|
|
||||||
# ------------
|
# ------------
|
||||||
# Update usage
|
# Update usage - API Key
|
||||||
# ------------
|
# ------------
|
||||||
|
|
||||||
new_val = {
|
new_val = {
|
||||||
|
@ -136,6 +208,41 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
self.user_api_key_cache.set_cache(
|
self.user_api_key_cache.set_cache(
|
||||||
request_count_api_key, new_val, ttl=60
|
request_count_api_key, new_val, ttl=60
|
||||||
) # store in cache for 1 min.
|
) # store in cache for 1 min.
|
||||||
|
|
||||||
|
# ------------
|
||||||
|
# Update usage - User
|
||||||
|
# ------------
|
||||||
|
if user_api_key_user_id is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
total_tokens = 0
|
||||||
|
|
||||||
|
if isinstance(response_obj, ModelResponse):
|
||||||
|
total_tokens = response_obj.usage.total_tokens
|
||||||
|
|
||||||
|
request_count_api_key = (
|
||||||
|
f"{user_api_key_user_id}::{precise_minute}::request_count"
|
||||||
|
)
|
||||||
|
|
||||||
|
current = self.user_api_key_cache.get_cache(key=request_count_api_key) or {
|
||||||
|
"current_requests": 1,
|
||||||
|
"current_tpm": total_tokens,
|
||||||
|
"current_rpm": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
new_val = {
|
||||||
|
"current_requests": max(current["current_requests"] - 1, 0),
|
||||||
|
"current_tpm": current["current_tpm"] + total_tokens,
|
||||||
|
"current_rpm": current["current_rpm"] + 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.print_verbose(
|
||||||
|
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
|
||||||
|
)
|
||||||
|
self.user_api_key_cache.set_cache(
|
||||||
|
request_count_api_key, new_val, ttl=60
|
||||||
|
) # store in cache for 1 min.
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.print_verbose(e) # noqa
|
self.print_verbose(e) # noqa
|
||||||
|
|
||||||
|
|
|
@ -4388,7 +4388,20 @@ async def update_team(
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
add new members to the team
|
You can now add / delete users from a team via /team/update
|
||||||
|
|
||||||
|
```
|
||||||
|
curl --location 'http://0.0.0.0:8000/team/update' \
|
||||||
|
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
|
||||||
|
--data-raw '{
|
||||||
|
"team_id": "45e3e396-ee08-4a61-a88e-16b3ce7e0849",
|
||||||
|
"members_with_roles": [{"role": "admin", "user_id": "5c4a0aa3-a1e1-43dc-bd87-3c2da8382a3a"}, {"role": "user", "user_id": "krrish247652@berri.ai"}]
|
||||||
|
}'
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
global prisma_client
|
global prisma_client
|
||||||
|
|
||||||
|
@ -4469,6 +4482,18 @@ async def delete_team(
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
delete team and associated team keys
|
delete team and associated team keys
|
||||||
|
|
||||||
|
```
|
||||||
|
curl --location 'http://0.0.0.0:8000/team/delete' \
|
||||||
|
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
|
||||||
|
--data-raw '{
|
||||||
|
"team_ids": ["45e3e396-ee08-4a61-a88e-16b3ce7e0849"]
|
||||||
|
}'
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
global prisma_client
|
global prisma_client
|
||||||
|
|
||||||
|
|
|
@ -318,7 +318,7 @@ def test_gemini_pro_vision():
|
||||||
# test_gemini_pro_vision()
|
# test_gemini_pro_vision()
|
||||||
|
|
||||||
|
|
||||||
def gemini_pro_function_calling():
|
def test_gemini_pro_function_calling():
|
||||||
load_vertex_ai_credentials()
|
load_vertex_ai_credentials()
|
||||||
tools = [
|
tools = [
|
||||||
{
|
{
|
||||||
|
@ -345,12 +345,15 @@ def gemini_pro_function_calling():
|
||||||
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
|
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
|
||||||
)
|
)
|
||||||
print(f"completion: {completion}")
|
print(f"completion: {completion}")
|
||||||
|
assert completion.choices[0].message.content is None
|
||||||
|
assert len(completion.choices[0].message.tool_calls) == 1
|
||||||
|
|
||||||
|
|
||||||
# gemini_pro_function_calling()
|
# gemini_pro_function_calling()
|
||||||
|
|
||||||
|
|
||||||
async def gemini_pro_async_function_calling():
|
@pytest.mark.asyncio
|
||||||
|
async def test_gemini_pro_async_function_calling():
|
||||||
load_vertex_ai_credentials()
|
load_vertex_ai_credentials()
|
||||||
tools = [
|
tools = [
|
||||||
{
|
{
|
||||||
|
@ -377,6 +380,9 @@ async def gemini_pro_async_function_calling():
|
||||||
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
|
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
|
||||||
)
|
)
|
||||||
print(f"completion: {completion}")
|
print(f"completion: {completion}")
|
||||||
|
assert completion.choices[0].message.content is None
|
||||||
|
assert len(completion.choices[0].message.tool_calls) == 1
|
||||||
|
# raise Exception("it worked!")
|
||||||
|
|
||||||
|
|
||||||
# asyncio.run(gemini_pro_async_function_calling())
|
# asyncio.run(gemini_pro_async_function_calling())
|
||||||
|
|
|
@ -1320,6 +1320,7 @@ def test_completion_together_ai():
|
||||||
max_tokens=256,
|
max_tokens=256,
|
||||||
n=1,
|
n=1,
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
|
timeout=1,
|
||||||
)
|
)
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
print(response)
|
print(response)
|
||||||
|
@ -1330,6 +1331,7 @@ def test_completion_together_ai():
|
||||||
f"${float(cost):.10f}",
|
f"${float(cost):.10f}",
|
||||||
)
|
)
|
||||||
except litellm.Timeout as e:
|
except litellm.Timeout as e:
|
||||||
|
print("got a timeout error")
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
|
@ -139,6 +139,56 @@ async def test_pre_call_hook_tpm_limits():
|
||||||
assert e.status_code == 429
|
assert e.status_code == 429
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pre_call_hook_user_tpm_limits():
|
||||||
|
"""
|
||||||
|
Test if error raised on hitting tpm limits
|
||||||
|
"""
|
||||||
|
# create user with tpm/rpm limits
|
||||||
|
|
||||||
|
_api_key = "sk-12345"
|
||||||
|
user_api_key_dict = UserAPIKeyAuth(
|
||||||
|
api_key=_api_key,
|
||||||
|
user_id="ishaan",
|
||||||
|
user_id_rate_limits={"tpm_limit": 9, "rpm_limit": 10},
|
||||||
|
)
|
||||||
|
res = dict(user_api_key_dict)
|
||||||
|
print("dict user", res)
|
||||||
|
local_cache = DualCache()
|
||||||
|
parallel_request_handler = MaxParallelRequestsHandler()
|
||||||
|
|
||||||
|
await parallel_request_handler.async_pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
||||||
|
)
|
||||||
|
|
||||||
|
kwargs = {
|
||||||
|
"litellm_params": {
|
||||||
|
"metadata": {"user_api_key_user_id": "ishaan", "user_api_key": "gm"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await parallel_request_handler.async_log_success_event(
|
||||||
|
kwargs=kwargs,
|
||||||
|
response_obj=litellm.ModelResponse(usage=litellm.Usage(total_tokens=10)),
|
||||||
|
start_time="",
|
||||||
|
end_time="",
|
||||||
|
)
|
||||||
|
|
||||||
|
## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1}
|
||||||
|
|
||||||
|
try:
|
||||||
|
await parallel_request_handler.async_pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
cache=local_cache,
|
||||||
|
data={},
|
||||||
|
call_type="",
|
||||||
|
)
|
||||||
|
|
||||||
|
pytest.fail(f"Expected call to fail")
|
||||||
|
except Exception as e:
|
||||||
|
assert e.status_code == 429
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_success_call_hook():
|
async def test_success_call_hook():
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -38,7 +38,9 @@ import time
|
||||||
# test_promptlayer_logging()
|
# test_promptlayer_logging()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="ci/cd issues. works locally")
|
@pytest.mark.skip(
|
||||||
|
reason="this works locally but fails on ci/cd since ci/cd is not reading the stdout correctly"
|
||||||
|
)
|
||||||
def test_promptlayer_logging_with_metadata():
|
def test_promptlayer_logging_with_metadata():
|
||||||
try:
|
try:
|
||||||
# Redirect stdout
|
# Redirect stdout
|
||||||
|
@ -67,7 +69,9 @@ def test_promptlayer_logging_with_metadata():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="ci/cd issues. works locally")
|
@pytest.mark.skip(
|
||||||
|
reason="this works locally but fails on ci/cd since ci/cd is not reading the stdout correctly"
|
||||||
|
)
|
||||||
def test_promptlayer_logging_with_metadata_tags():
|
def test_promptlayer_logging_with_metadata_tags():
|
||||||
try:
|
try:
|
||||||
# Redirect stdout
|
# Redirect stdout
|
||||||
|
|
|
@ -4274,8 +4274,8 @@ def get_optional_params(
|
||||||
optional_params["stop_sequences"] = stop
|
optional_params["stop_sequences"] = stop
|
||||||
if max_tokens is not None:
|
if max_tokens is not None:
|
||||||
optional_params["max_output_tokens"] = max_tokens
|
optional_params["max_output_tokens"] = max_tokens
|
||||||
elif custom_llm_provider == "vertex_ai" and model in (
|
elif custom_llm_provider == "vertex_ai" and (
|
||||||
litellm.vertex_chat_models
|
model in litellm.vertex_chat_models
|
||||||
or model in litellm.vertex_code_chat_models
|
or model in litellm.vertex_code_chat_models
|
||||||
or model in litellm.vertex_text_models
|
or model in litellm.vertex_text_models
|
||||||
or model in litellm.vertex_code_text_models
|
or model in litellm.vertex_code_text_models
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "litellm"
|
name = "litellm"
|
||||||
version = "1.26.8"
|
version = "1.26.10"
|
||||||
description = "Library to easily interface with LLM API providers"
|
description = "Library to easily interface with LLM API providers"
|
||||||
authors = ["BerriAI"]
|
authors = ["BerriAI"]
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
|
@ -74,7 +74,7 @@ requires = ["poetry-core", "wheel"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
[tool.commitizen]
|
[tool.commitizen]
|
||||||
version = "1.26.8"
|
version = "1.26.10"
|
||||||
version_files = [
|
version_files = [
|
||||||
"pyproject.toml:^version"
|
"pyproject.toml:^version"
|
||||||
]
|
]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue