diff --git a/docs/my-website/docs/proxy/ui.md b/docs/my-website/docs/proxy/ui.md
index 1f98bc077..ff45f9569 100644
--- a/docs/my-website/docs/proxy/ui.md
+++ b/docs/my-website/docs/proxy/ui.md
@@ -152,7 +152,14 @@ GENERIC_SCOPE = "openid profile email" # default scope openid is sometimes not e
-#### Step 3. Test flow
+#### Step 3. Set `PROXY_BASE_URL` in your .env
+
+Set this in your .env (so the proxy can set the correct redirect url)
+```shell
+PROXY_BASE_URL=https://litellm-api.up.railway.app/
+```
+
+#### Step 4. Test flow
### Set Admin view w/ SSO
diff --git a/docs/my-website/docs/proxy/users.md b/docs/my-website/docs/proxy/users.md
index 3eb0cb808..159b311a9 100644
--- a/docs/my-website/docs/proxy/users.md
+++ b/docs/my-website/docs/proxy/users.md
@@ -279,9 +279,9 @@ curl 'http://0.0.0.0:8000/key/generate' \
## Set Rate Limits
You can set:
+- tpm limits (tokens per minute)
+- rpm limits (requests per minute)
- max parallel requests
-- tpm limits
-- rpm limits
diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py
index 603bd3c22..fdbc1625e 100644
--- a/litellm/llms/vertex_ai.py
+++ b/litellm/llms/vertex_ai.py
@@ -559,8 +559,7 @@ def completion(
f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n"
)
response = llm_model.predict(
- endpoint=endpoint_path,
- instances=instances
+ endpoint=endpoint_path, instances=instances
).predictions
completion_response = response[0]
@@ -585,12 +584,8 @@ def completion(
"request_str": request_str,
},
)
- request_str += (
- f"llm_model.predict(instances={instances})\n"
- )
- response = llm_model.predict(
- instances=instances
- ).predictions
+ request_str += f"llm_model.predict(instances={instances})\n"
+ response = llm_model.predict(instances=instances).predictions
completion_response = response[0]
if (
@@ -614,7 +609,6 @@ def completion(
model_response["choices"][0]["message"]["content"] = str(
completion_response
)
- model_response["choices"][0]["message"]["content"] = str(completion_response)
model_response["created"] = int(time.time())
model_response["model"] = model
## CALCULATING USAGE
@@ -766,6 +760,7 @@ async def async_completion(
Vertex AI Model Garden
"""
from google.cloud import aiplatform
+
## LOGGING
logging_obj.pre_call(
input=prompt,
@@ -797,11 +792,9 @@ async def async_completion(
and "\nOutput:\n" in completion_response
):
completion_response = completion_response.split("\nOutput:\n", 1)[1]
-
+
elif mode == "private":
- request_str += (
- f"llm_model.predict_async(instances={instances})\n"
- )
+ request_str += f"llm_model.predict_async(instances={instances})\n"
response_obj = await llm_model.predict_async(
instances=instances,
)
@@ -826,7 +819,6 @@ async def async_completion(
model_response["choices"][0]["message"]["content"] = str(
completion_response
)
- model_response["choices"][0]["message"]["content"] = str(completion_response)
model_response["created"] = int(time.time())
model_response["model"] = model
## CALCULATING USAGE
@@ -954,6 +946,7 @@ async def async_streaming(
response = llm_model.predict_streaming_async(prompt, **optional_params)
elif mode == "custom":
from google.cloud import aiplatform
+
stream = optional_params.pop("stream", None)
## LOGGING
@@ -972,7 +965,9 @@ async def async_streaming(
endpoint_path = llm_model.endpoint_path(
project=vertex_project, location=vertex_location, endpoint=model
)
- request_str += f"client.predict(endpoint={endpoint_path}, instances={instances})\n"
+ request_str += (
+ f"client.predict(endpoint={endpoint_path}, instances={instances})\n"
+ )
response_obj = await llm_model.predict(
endpoint=endpoint_path,
instances=instances,
diff --git a/litellm/main.py b/litellm/main.py
index 136611066..1ee36504f 100644
--- a/litellm/main.py
+++ b/litellm/main.py
@@ -12,7 +12,6 @@ from typing import Any, Literal, Union
from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy
-
import httpx
import litellm
from ._logging import verbose_logger
diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py
index f0f384094..7f453980f 100644
--- a/litellm/proxy/_types.py
+++ b/litellm/proxy/_types.py
@@ -424,6 +424,10 @@ class LiteLLM_VerificationToken(LiteLLMBase):
model_spend: Dict = {}
model_max_budget: Dict = {}
+ # hidden params used for parallel request limiting, not required to create a token
+ user_id_rate_limits: Optional[dict] = None
+ team_id_rate_limits: Optional[dict] = None
+
class Config:
protected_namespaces = ()
diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py
index 67f8d1ad2..fb61fe3da 100644
--- a/litellm/proxy/hooks/parallel_request_limiter.py
+++ b/litellm/proxy/hooks/parallel_request_limiter.py
@@ -24,46 +24,21 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
except:
pass
- async def async_pre_call_hook(
+ async def check_key_in_limits(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str,
+ max_parallel_requests: int,
+ tpm_limit: int,
+ rpm_limit: int,
+ request_count_api_key: str,
):
- self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook")
- api_key = user_api_key_dict.api_key
- max_parallel_requests = user_api_key_dict.max_parallel_requests or sys.maxsize
- tpm_limit = user_api_key_dict.tpm_limit or sys.maxsize
- rpm_limit = user_api_key_dict.rpm_limit or sys.maxsize
-
- if api_key is None:
- return
-
- if (
- max_parallel_requests == sys.maxsize
- and tpm_limit == sys.maxsize
- and rpm_limit == sys.maxsize
- ):
- return
-
- self.user_api_key_cache = cache # save the api key cache for updating the value
- # ------------
- # Setup values
- # ------------
-
- current_date = datetime.now().strftime("%Y-%m-%d")
- current_hour = datetime.now().strftime("%H")
- current_minute = datetime.now().strftime("%M")
- precise_minute = f"{current_date}-{current_hour}-{current_minute}"
-
- request_count_api_key = f"{api_key}::{precise_minute}::request_count"
-
- # CHECK IF REQUEST ALLOWED
current = cache.get_cache(
key=request_count_api_key
) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
- self.print_verbose(f"current: {current}")
+ # print(f"current: {current}")
if current is None:
new_val = {
"current_requests": 1,
@@ -88,10 +63,107 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
status_code=429, detail="Max parallel request limit reached."
)
+ async def async_pre_call_hook(
+ self,
+ user_api_key_dict: UserAPIKeyAuth,
+ cache: DualCache,
+ data: dict,
+ call_type: str,
+ ):
+ self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook")
+ api_key = user_api_key_dict.api_key
+ max_parallel_requests = user_api_key_dict.max_parallel_requests or sys.maxsize
+ tpm_limit = user_api_key_dict.tpm_limit or sys.maxsize
+ rpm_limit = user_api_key_dict.rpm_limit or sys.maxsize
+
+ if api_key is None:
+ return
+
+ self.user_api_key_cache = cache # save the api key cache for updating the value
+ # ------------
+ # Setup values
+ # ------------
+
+ current_date = datetime.now().strftime("%Y-%m-%d")
+ current_hour = datetime.now().strftime("%H")
+ current_minute = datetime.now().strftime("%M")
+ precise_minute = f"{current_date}-{current_hour}-{current_minute}"
+
+ request_count_api_key = f"{api_key}::{precise_minute}::request_count"
+
+ # CHECK IF REQUEST ALLOWED for key
+ current = cache.get_cache(
+ key=request_count_api_key
+ ) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
+ self.print_verbose(f"current: {current}")
+ if (
+ max_parallel_requests == sys.maxsize
+ and tpm_limit == sys.maxsize
+ and rpm_limit == sys.maxsize
+ ):
+ pass
+ elif current is None:
+ new_val = {
+ "current_requests": 1,
+ "current_tpm": 0,
+ "current_rpm": 0,
+ }
+ cache.set_cache(request_count_api_key, new_val)
+ elif (
+ int(current["current_requests"]) < max_parallel_requests
+ and current["current_tpm"] < tpm_limit
+ and current["current_rpm"] < rpm_limit
+ ):
+ # Increase count for this token
+ new_val = {
+ "current_requests": current["current_requests"] + 1,
+ "current_tpm": current["current_tpm"],
+ "current_rpm": current["current_rpm"],
+ }
+ cache.set_cache(request_count_api_key, new_val)
+ else:
+ raise HTTPException(
+ status_code=429, detail="Max parallel request limit reached."
+ )
+
+ # check if REQUEST ALLOWED for user_id
+ user_id = user_api_key_dict.user_id
+ _user_id_rate_limits = user_api_key_dict.user_id_rate_limits
+
+ # get user tpm/rpm limits
+ if _user_id_rate_limits is None or _user_id_rate_limits == {}:
+ return
+ user_tpm_limit = _user_id_rate_limits.get("tpm_limit")
+ user_rpm_limit = _user_id_rate_limits.get("rpm_limit")
+ if user_tpm_limit is None:
+ user_tpm_limit = sys.maxsize
+ if user_rpm_limit is None:
+ user_rpm_limit = sys.maxsize
+
+ # now do the same tpm/rpm checks
+ request_count_api_key = f"{user_id}::{precise_minute}::request_count"
+
+ # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
+ await self.check_key_in_limits(
+ user_api_key_dict=user_api_key_dict,
+ cache=cache,
+ data=data,
+ call_type=call_type,
+ max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user
+ request_count_api_key=request_count_api_key,
+ tpm_limit=user_tpm_limit,
+ rpm_limit=user_rpm_limit,
+ )
+ return
+
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
self.print_verbose(f"INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"]
+ user_api_key_user_id = kwargs["litellm_params"]["metadata"].get(
+ "user_api_key_user_id", None
+ )
+
if user_api_key is None:
return
@@ -121,7 +193,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
}
# ------------
- # Update usage
+ # Update usage - API Key
# ------------
new_val = {
@@ -136,6 +208,41 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
self.user_api_key_cache.set_cache(
request_count_api_key, new_val, ttl=60
) # store in cache for 1 min.
+
+ # ------------
+ # Update usage - User
+ # ------------
+ if user_api_key_user_id is None:
+ return
+
+ total_tokens = 0
+
+ if isinstance(response_obj, ModelResponse):
+ total_tokens = response_obj.usage.total_tokens
+
+ request_count_api_key = (
+ f"{user_api_key_user_id}::{precise_minute}::request_count"
+ )
+
+ current = self.user_api_key_cache.get_cache(key=request_count_api_key) or {
+ "current_requests": 1,
+ "current_tpm": total_tokens,
+ "current_rpm": 1,
+ }
+
+ new_val = {
+ "current_requests": max(current["current_requests"] - 1, 0),
+ "current_tpm": current["current_tpm"] + total_tokens,
+ "current_rpm": current["current_rpm"] + 1,
+ }
+
+ self.print_verbose(
+ f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
+ )
+ self.user_api_key_cache.set_cache(
+ request_count_api_key, new_val, ttl=60
+ ) # store in cache for 1 min.
+
except Exception as e:
self.print_verbose(e) # noqa
diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py
index 030af777a..d4318f134 100644
--- a/litellm/proxy/proxy_server.py
+++ b/litellm/proxy/proxy_server.py
@@ -4388,7 +4388,20 @@ async def update_team(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
- add new members to the team
+ You can now add / delete users from a team via /team/update
+
+ ```
+ curl --location 'http://0.0.0.0:8000/team/update' \
+
+ --header 'Authorization: Bearer sk-1234' \
+
+ --header 'Content-Type: application/json' \
+
+ --data-raw '{
+ "team_id": "45e3e396-ee08-4a61-a88e-16b3ce7e0849",
+ "members_with_roles": [{"role": "admin", "user_id": "5c4a0aa3-a1e1-43dc-bd87-3c2da8382a3a"}, {"role": "user", "user_id": "krrish247652@berri.ai"}]
+ }'
+ ```
"""
global prisma_client
@@ -4469,6 +4482,18 @@ async def delete_team(
):
"""
delete team and associated team keys
+
+ ```
+ curl --location 'http://0.0.0.0:8000/team/delete' \
+
+ --header 'Authorization: Bearer sk-1234' \
+
+ --header 'Content-Type: application/json' \
+
+ --data-raw '{
+ "team_ids": ["45e3e396-ee08-4a61-a88e-16b3ce7e0849"]
+ }'
+ ```
"""
global prisma_client
diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py
index 9b7473ea2..76ebde7ae 100644
--- a/litellm/tests/test_amazing_vertex_completion.py
+++ b/litellm/tests/test_amazing_vertex_completion.py
@@ -318,7 +318,7 @@ def test_gemini_pro_vision():
# test_gemini_pro_vision()
-def gemini_pro_function_calling():
+def test_gemini_pro_function_calling():
load_vertex_ai_credentials()
tools = [
{
@@ -345,12 +345,15 @@ def gemini_pro_function_calling():
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
)
print(f"completion: {completion}")
+ assert completion.choices[0].message.content is None
+ assert len(completion.choices[0].message.tool_calls) == 1
# gemini_pro_function_calling()
-async def gemini_pro_async_function_calling():
+@pytest.mark.asyncio
+async def test_gemini_pro_async_function_calling():
load_vertex_ai_credentials()
tools = [
{
@@ -377,6 +380,9 @@ async def gemini_pro_async_function_calling():
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
)
print(f"completion: {completion}")
+ assert completion.choices[0].message.content is None
+ assert len(completion.choices[0].message.tool_calls) == 1
+ # raise Exception("it worked!")
# asyncio.run(gemini_pro_async_function_calling())
diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py
index 7816c3918..c9924273d 100644
--- a/litellm/tests/test_completion.py
+++ b/litellm/tests/test_completion.py
@@ -1320,6 +1320,7 @@ def test_completion_together_ai():
max_tokens=256,
n=1,
logger_fn=logger_fn,
+ timeout=1,
)
# Add any assertions here to check the response
print(response)
@@ -1330,6 +1331,7 @@ def test_completion_together_ai():
f"${float(cost):.10f}",
)
except litellm.Timeout as e:
+ print("got a timeout error")
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
diff --git a/litellm/tests/test_parallel_request_limiter.py b/litellm/tests/test_parallel_request_limiter.py
index 17d79c36c..e402b617b 100644
--- a/litellm/tests/test_parallel_request_limiter.py
+++ b/litellm/tests/test_parallel_request_limiter.py
@@ -139,6 +139,56 @@ async def test_pre_call_hook_tpm_limits():
assert e.status_code == 429
+@pytest.mark.asyncio
+async def test_pre_call_hook_user_tpm_limits():
+ """
+ Test if error raised on hitting tpm limits
+ """
+ # create user with tpm/rpm limits
+
+ _api_key = "sk-12345"
+ user_api_key_dict = UserAPIKeyAuth(
+ api_key=_api_key,
+ user_id="ishaan",
+ user_id_rate_limits={"tpm_limit": 9, "rpm_limit": 10},
+ )
+ res = dict(user_api_key_dict)
+ print("dict user", res)
+ local_cache = DualCache()
+ parallel_request_handler = MaxParallelRequestsHandler()
+
+ await parallel_request_handler.async_pre_call_hook(
+ user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
+ )
+
+ kwargs = {
+ "litellm_params": {
+ "metadata": {"user_api_key_user_id": "ishaan", "user_api_key": "gm"}
+ }
+ }
+
+ await parallel_request_handler.async_log_success_event(
+ kwargs=kwargs,
+ response_obj=litellm.ModelResponse(usage=litellm.Usage(total_tokens=10)),
+ start_time="",
+ end_time="",
+ )
+
+ ## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1}
+
+ try:
+ await parallel_request_handler.async_pre_call_hook(
+ user_api_key_dict=user_api_key_dict,
+ cache=local_cache,
+ data={},
+ call_type="",
+ )
+
+ pytest.fail(f"Expected call to fail")
+ except Exception as e:
+ assert e.status_code == 429
+
+
@pytest.mark.asyncio
async def test_success_call_hook():
"""
diff --git a/litellm/tests/test_promptlayer_integration.py b/litellm/tests/test_promptlayer_integration.py
index 518d64bbd..9eff88f0d 100644
--- a/litellm/tests/test_promptlayer_integration.py
+++ b/litellm/tests/test_promptlayer_integration.py
@@ -38,7 +38,9 @@ import time
# test_promptlayer_logging()
-@pytest.mark.skip(reason="ci/cd issues. works locally")
+@pytest.mark.skip(
+ reason="this works locally but fails on ci/cd since ci/cd is not reading the stdout correctly"
+)
def test_promptlayer_logging_with_metadata():
try:
# Redirect stdout
@@ -67,7 +69,9 @@ def test_promptlayer_logging_with_metadata():
pytest.fail(f"Error occurred: {e}")
-@pytest.mark.skip(reason="ci/cd issues. works locally")
+@pytest.mark.skip(
+ reason="this works locally but fails on ci/cd since ci/cd is not reading the stdout correctly"
+)
def test_promptlayer_logging_with_metadata_tags():
try:
# Redirect stdout
diff --git a/litellm/utils.py b/litellm/utils.py
index b2a618660..2a5d40c8f 100644
--- a/litellm/utils.py
+++ b/litellm/utils.py
@@ -4274,8 +4274,8 @@ def get_optional_params(
optional_params["stop_sequences"] = stop
if max_tokens is not None:
optional_params["max_output_tokens"] = max_tokens
- elif custom_llm_provider == "vertex_ai" and model in (
- litellm.vertex_chat_models
+ elif custom_llm_provider == "vertex_ai" and (
+ model in litellm.vertex_chat_models
or model in litellm.vertex_code_chat_models
or model in litellm.vertex_text_models
or model in litellm.vertex_code_text_models
diff --git a/pyproject.toml b/pyproject.toml
index 80381ac1a..4311cd98e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "litellm"
-version = "1.26.8"
+version = "1.26.10"
description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"]
license = "MIT"
@@ -74,7 +74,7 @@ requires = ["poetry-core", "wheel"]
build-backend = "poetry.core.masonry.api"
[tool.commitizen]
-version = "1.26.8"
+version = "1.26.10"
version_files = [
"pyproject.toml:^version"
]