Merge branch 'main' into litellm_fix_azure_function_calling_streaming

This commit is contained in:
Krish Dholakia 2024-02-22 22:36:38 -08:00 committed by GitHub
commit dd4439b6a8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 869 additions and 173 deletions

View file

@ -238,9 +238,11 @@ chat_completion = client.chat.completions.create(
} }
], ],
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
cache={ extra_body = { # OpenAI python accepts extra args in extra_body
"no-cache": True # will not return a cached response cache: {
} "no-cache": True # will not return a cached response
}
}
) )
``` ```
@ -264,9 +266,11 @@ chat_completion = client.chat.completions.create(
} }
], ],
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
cache={ extra_body = { # OpenAI python accepts extra args in extra_body
"ttl": 600 # caches response for 10 minutes cache: {
} "ttl": 600 # caches response for 10 minutes
}
}
) )
``` ```
@ -288,13 +292,15 @@ chat_completion = client.chat.completions.create(
} }
], ],
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
cache={ extra_body = { # OpenAI python accepts extra args in extra_body
"s-maxage": 600 # only get responses cached within last 10 minutes cache: {
} "s-maxage": 600 # only get responses cached within last 10 minutes
}
}
) )
``` ```
## Supported `cache_params` ## Supported `cache_params` on proxy config.yaml
```yaml ```yaml
cache_params: cache_params:

View file

@ -1,7 +1,7 @@
import Tabs from '@theme/Tabs'; import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem'; import TabItem from '@theme/TabItem';
# ✨ Enterprise Features - Content Moderation # ✨ Enterprise Features - Content Moderation, Blocked Users
Features here are behind a commercial license in our `/enterprise` folder. [**See Code**](https://github.com/BerriAI/litellm/tree/main/enterprise) Features here are behind a commercial license in our `/enterprise` folder. [**See Code**](https://github.com/BerriAI/litellm/tree/main/enterprise)
@ -15,6 +15,7 @@ Features:
- [ ] Content Moderation with LlamaGuard - [ ] Content Moderation with LlamaGuard
- [ ] Content Moderation with Google Text Moderations - [ ] Content Moderation with Google Text Moderations
- [ ] Content Moderation with LLM Guard - [ ] Content Moderation with LLM Guard
- [ ] Reject calls from Blocked User list
- [ ] Tracking Spend for Custom Tags - [ ] Tracking Spend for Custom Tags
## Content Moderation with LlamaGuard ## Content Moderation with LlamaGuard
@ -132,6 +133,39 @@ Here are the category specific values:
## Enable Blocked User Lists
If any call is made to proxy with this user id, it'll be rejected - use this if you want to let users opt-out of ai features
```yaml
litellm_settings:
callbacks: ["blocked_user_check"]
blocked_user_id_list: ["user_id_1", "user_id_2", ...] # can also be a .txt filepath e.g. `/relative/path/blocked_list.txt`
```
### How to test
```bash
curl --location 'http://0.0.0.0:8000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
"user_id": "user_id_1" # this is also an openai supported param
}
'
```
:::info
[Suggest a way to improve this](https://github.com/BerriAI/litellm/issues/new/choose)
:::
## Tracking Spend for Custom Tags ## Tracking Spend for Custom Tags
Requirements: Requirements:

View file

@ -133,8 +133,12 @@ The following can be used to customize attribute names when interacting with the
```shell ```shell
GENERIC_USER_ID_ATTRIBUTE = "given_name" GENERIC_USER_ID_ATTRIBUTE = "given_name"
GENERIC_USER_EMAIL_ATTRIBUTE = "family_name" GENERIC_USER_EMAIL_ATTRIBUTE = "family_name"
GENERIC_USER_DISPLAY_NAME_ATTRIBUTE = "display_name"
GENERIC_USER_FIRST_NAME_ATTRIBUTE = "first_name"
GENERIC_USER_LAST_NAME_ATTRIBUTE = "last_name"
GENERIC_USER_ROLE_ATTRIBUTE = "given_role" GENERIC_USER_ROLE_ATTRIBUTE = "given_role"
GENERIC_CLIENT_STATE = "some-state" # if the provider needs a state parameter
GENERIC_INCLUDE_CLIENT_ID = "false" # some providers enforce that the client_id is not in the body
GENERIC_SCOPE = "openid profile email" # default scope openid is sometimes not enough to retrieve basic user info like first_name and last_name located in profile scope GENERIC_SCOPE = "openid profile email" # default scope openid is sometimes not enough to retrieve basic user info like first_name and last_name located in profile scope
``` ```
@ -148,7 +152,14 @@ GENERIC_SCOPE = "openid profile email" # default scope openid is sometimes not e
</Tabs> </Tabs>
#### Step 3. Test flow #### Step 3. Set `PROXY_BASE_URL` in your .env
Set this in your .env (so the proxy can set the correct redirect url)
```shell
PROXY_BASE_URL=https://litellm-api.up.railway.app/
```
#### Step 4. Test flow
<Image img={require('../../img/litellm_ui_3.gif')} /> <Image img={require('../../img/litellm_ui_3.gif')} />
### Set Admin view w/ SSO ### Set Admin view w/ SSO
@ -183,7 +194,21 @@ We allow you to
- Customize the UI color scheme - Customize the UI color scheme
<Image img={require('../../img/litellm_custom_ai.png')} /> <Image img={require('../../img/litellm_custom_ai.png')} />
#### Usage #### Set Custom Logo
We allow you to pass a local image or a an http/https url of your image
Set `UI_LOGO_PATH` on your env. We recommend using a hosted image, it's a lot easier to set up and configure / debug
Exaple setting Hosted image
```shell
UI_LOGO_PATH="https://litellm-logo-aws-marketplace.s3.us-west-2.amazonaws.com/berriai-logo-github.png"
```
Exaple setting a local image (on your container)
```shell
UI_LOGO_PATH="ui_images/logo.jpg"
```
#### Set Custom Color Theme
- Navigate to [/enterprise/enterprise_ui](https://github.com/BerriAI/litellm/blob/main/enterprise/enterprise_ui/_enterprise_colors.json) - Navigate to [/enterprise/enterprise_ui](https://github.com/BerriAI/litellm/blob/main/enterprise/enterprise_ui/_enterprise_colors.json)
- Inside the `enterprise_ui` directory, rename `_enterprise_colors.json` to `enterprise_colors.json` - Inside the `enterprise_ui` directory, rename `_enterprise_colors.json` to `enterprise_colors.json`
- Set your companies custom color scheme in `enterprise_colors.json` - Set your companies custom color scheme in `enterprise_colors.json`
@ -202,8 +227,6 @@ Set your colors to any of the following colors: https://www.tremor.so/docs/layou
} }
``` ```
- Set the path to your custom png/jpg logo as `UI_LOGO_PATH` in your .env
- Deploy LiteLLM Proxy Server - Deploy LiteLLM Proxy Server

View file

@ -279,9 +279,9 @@ curl 'http://0.0.0.0:8000/key/generate' \
## Set Rate Limits ## Set Rate Limits
You can set: You can set:
- tpm limits (tokens per minute)
- rpm limits (requests per minute)
- max parallel requests - max parallel requests
- tpm limits
- rpm limits
<Tabs> <Tabs>
<TabItem value="per-user" label="Per User"> <TabItem value="per-user" label="Per User">

View file

@ -18,6 +18,62 @@ const sidebars = {
// But you can create a sidebar manually // But you can create a sidebar manually
tutorialSidebar: [ tutorialSidebar: [
{ type: "doc", id: "index" }, // NEW { type: "doc", id: "index" }, // NEW
{
type: "category",
label: "💥 OpenAI Proxy Server",
link: {
type: 'generated-index',
title: '💥 OpenAI Proxy Server',
description: `Proxy Server to call 100+ LLMs in a unified interface & track spend, set budgets per virtual key/user`,
slug: '/simple_proxy',
},
items: [
"proxy/quick_start",
"proxy/configs",
{
type: 'link',
label: '📖 All Endpoints',
href: 'https://litellm-api.up.railway.app/',
},
"proxy/enterprise",
"proxy/user_keys",
"proxy/virtual_keys",
"proxy/users",
"proxy/ui",
"proxy/model_management",
"proxy/health",
"proxy/debugging",
"proxy/pii_masking",
{
"type": "category",
"label": "🔥 Load Balancing",
"items": [
"proxy/load_balancing",
"proxy/reliability",
]
},
"proxy/caching",
{
"type": "category",
"label": "Logging, Alerting",
"items": [
"proxy/logging",
"proxy/alerting",
"proxy/streaming_logging",
]
},
{
"type": "category",
"label": "Content Moderation",
"items": [
"proxy/call_hooks",
"proxy/rules",
]
},
"proxy/deploy",
"proxy/cli",
]
},
{ {
type: "category", type: "category",
label: "Completion()", label: "Completion()",
@ -92,62 +148,6 @@ const sidebars = {
"providers/petals", "providers/petals",
] ]
}, },
{
type: "category",
label: "💥 OpenAI Proxy Server",
link: {
type: 'generated-index',
title: '💥 OpenAI Proxy Server',
description: `Proxy Server to call 100+ LLMs in a unified interface & track spend, set budgets per virtual key/user`,
slug: '/simple_proxy',
},
items: [
"proxy/quick_start",
"proxy/configs",
{
type: 'link',
label: '📖 All Endpoints',
href: 'https://litellm-api.up.railway.app/',
},
"proxy/enterprise",
"proxy/user_keys",
"proxy/virtual_keys",
"proxy/users",
"proxy/ui",
"proxy/model_management",
"proxy/health",
"proxy/debugging",
"proxy/pii_masking",
{
"type": "category",
"label": "🔥 Load Balancing",
"items": [
"proxy/load_balancing",
"proxy/reliability",
]
},
"proxy/caching",
{
"type": "category",
"label": "Logging, Alerting",
"items": [
"proxy/logging",
"proxy/alerting",
"proxy/streaming_logging",
]
},
{
"type": "category",
"label": "Content Moderation",
"items": [
"proxy/call_hooks",
"proxy/rules",
]
},
"proxy/deploy",
"proxy/cli",
]
},
"proxy/custom_pricing", "proxy/custom_pricing",
"routing", "routing",
"rules", "rules",

View file

@ -0,0 +1,103 @@
# +------------------------------+
#
# Banned Keywords
#
# +------------------------------+
# Thank you users! We ❤️ you! - Krrish & Ishaan
## Reject a call / response if it contains certain keywords
from typing import Optional, Literal
import litellm
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
from litellm._logging import verbose_proxy_logger
from fastapi import HTTPException
import json, traceback
class _ENTERPRISE_BannedKeywords(CustomLogger):
# Class variables or attributes
def __init__(self):
banned_keywords_list = litellm.banned_keywords_list
if banned_keywords_list is None:
raise Exception(
"`banned_keywords_list` can either be a list or filepath. None set."
)
if isinstance(banned_keywords_list, list):
self.banned_keywords_list = banned_keywords_list
if isinstance(banned_keywords_list, str): # assume it's a filepath
try:
with open(banned_keywords_list, "r") as file:
data = file.read()
self.banned_keywords_list = data.split("\n")
except FileNotFoundError:
raise Exception(
f"File not found. banned_keywords_list={banned_keywords_list}"
)
except Exception as e:
raise Exception(
f"An error occurred: {str(e)}, banned_keywords_list={banned_keywords_list}"
)
def print_verbose(self, print_statement, level: Literal["INFO", "DEBUG"] = "DEBUG"):
if level == "INFO":
verbose_proxy_logger.info(print_statement)
elif level == "DEBUG":
verbose_proxy_logger.debug(print_statement)
if litellm.set_verbose is True:
print(print_statement) # noqa
def test_violation(self, test_str: str):
for word in self.banned_keywords_list:
if word in test_str.lower():
raise HTTPException(
status_code=400,
detail={"error": f"Keyword banned. Keyword={word}"},
)
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str, # "completion", "embeddings", "image_generation", "moderation"
):
try:
"""
- check if user id part of call
- check if user id part of blocked list
"""
self.print_verbose(f"Inside Banned Keyword List Pre-Call Hook")
if call_type == "completion" and "messages" in data:
for m in data["messages"]:
if "content" in m and isinstance(m["content"], str):
self.test_violation(test_str=m["content"])
except HTTPException as e:
raise e
except Exception as e:
traceback.print_exc()
async def async_post_call_success_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
response,
):
if isinstance(response, litellm.ModelResponse) and isinstance(
response.choices[0], litellm.utils.Choices
):
for word in self.banned_keywords_list:
self.test_violation(test_str=response.choices[0].message.content)
async def async_post_call_streaming_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
response: str,
):
self.test_violation(test_str=response)

View file

@ -0,0 +1,80 @@
# +------------------------------+
#
# Blocked User List
#
# +------------------------------+
# Thank you users! We ❤️ you! - Krrish & Ishaan
## This accepts a list of user id's for whom calls will be rejected
from typing import Optional, Literal
import litellm
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
from litellm._logging import verbose_proxy_logger
from fastapi import HTTPException
import json, traceback
class _ENTERPRISE_BlockedUserList(CustomLogger):
# Class variables or attributes
def __init__(self):
blocked_user_list = litellm.blocked_user_list
if blocked_user_list is None:
raise Exception(
"`blocked_user_list` can either be a list or filepath. None set."
)
if isinstance(blocked_user_list, list):
self.blocked_user_list = blocked_user_list
if isinstance(blocked_user_list, str): # assume it's a filepath
try:
with open(blocked_user_list, "r") as file:
data = file.read()
self.blocked_user_list = data.split("\n")
except FileNotFoundError:
raise Exception(
f"File not found. blocked_user_list={blocked_user_list}"
)
except Exception as e:
raise Exception(
f"An error occurred: {str(e)}, blocked_user_list={blocked_user_list}"
)
def print_verbose(self, print_statement, level: Literal["INFO", "DEBUG"] = "DEBUG"):
if level == "INFO":
verbose_proxy_logger.info(print_statement)
elif level == "DEBUG":
verbose_proxy_logger.debug(print_statement)
if litellm.set_verbose is True:
print(print_statement) # noqa
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str,
):
try:
"""
- check if user id part of call
- check if user id part of blocked list
"""
self.print_verbose(f"Inside Blocked User List Pre-Call Hook")
if "user_id" in data:
if data["user_id"] in self.blocked_user_list:
raise HTTPException(
status_code=400,
detail={
"error": f"User blocked from making LLM API Calls. User={data['user_id']}"
},
)
except HTTPException as e:
raise e
except Exception as e:
traceback.print_exc()

View file

@ -60,6 +60,8 @@ llamaguard_model_name: Optional[str] = None
presidio_ad_hoc_recognizers: Optional[str] = None presidio_ad_hoc_recognizers: Optional[str] = None
google_moderation_confidence_threshold: Optional[float] = None google_moderation_confidence_threshold: Optional[float] = None
llamaguard_unsafe_content_categories: Optional[str] = None llamaguard_unsafe_content_categories: Optional[str] = None
blocked_user_list: Optional[Union[str, List]] = None
banned_keywords_list: Optional[Union[str, List]] = None
################## ##################
logging: bool = True logging: bool = True
caching: bool = ( caching: bool = (

View file

@ -2,12 +2,11 @@
# On success, logs events to Promptlayer # On success, logs events to Promptlayer
import dotenv, os import dotenv, os
import requests import requests
import requests from pydantic import BaseModel
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
class PromptLayerLogger: class PromptLayerLogger:
# Class variables or attributes # Class variables or attributes
def __init__(self): def __init__(self):
@ -25,16 +24,30 @@ class PromptLayerLogger:
for optional_param in kwargs["optional_params"]: for optional_param in kwargs["optional_params"]:
new_kwargs[optional_param] = kwargs["optional_params"][optional_param] new_kwargs[optional_param] = kwargs["optional_params"][optional_param]
# Extract PromptLayer tags from metadata, if such exists
tags = []
metadata = {}
if "metadata" in kwargs["litellm_params"]:
if "pl_tags" in kwargs["litellm_params"]["metadata"]:
tags = kwargs["litellm_params"]["metadata"]["pl_tags"]
# Remove "pl_tags" from metadata
metadata = {k:v for k, v in kwargs["litellm_params"]["metadata"].items() if k != "pl_tags"}
print_verbose( print_verbose(
f"Prompt Layer Logging - Enters logging function for model kwargs: {new_kwargs}\n, response: {response_obj}" f"Prompt Layer Logging - Enters logging function for model kwargs: {new_kwargs}\n, response: {response_obj}"
) )
# python-openai >= 1.0.0 returns Pydantic objects instead of jsons
if isinstance(response_obj, BaseModel):
response_obj = response_obj.model_dump()
request_response = requests.post( request_response = requests.post(
"https://api.promptlayer.com/rest/track-request", "https://api.promptlayer.com/rest/track-request",
json={ json={
"function_name": "openai.ChatCompletion.create", "function_name": "openai.ChatCompletion.create",
"kwargs": new_kwargs, "kwargs": new_kwargs,
"tags": ["hello", "world"], "tags": tags,
"request_response": dict(response_obj), "request_response": dict(response_obj),
"request_start_time": int(start_time.timestamp()), "request_start_time": int(start_time.timestamp()),
"request_end_time": int(end_time.timestamp()), "request_end_time": int(end_time.timestamp()),
@ -45,22 +58,23 @@ class PromptLayerLogger:
# "prompt_version":1, # "prompt_version":1,
}, },
) )
response_json = request_response.json()
if not request_response.json().get("success", False):
raise Exception("Promptlayer did not successfully log the response!")
print_verbose( print_verbose(
f"Prompt Layer Logging: success - final response object: {request_response.text}" f"Prompt Layer Logging: success - final response object: {request_response.text}"
) )
response_json = request_response.json()
if "success" not in request_response.json():
raise Exception("Promptlayer did not successfully log the response!")
if "request_id" in response_json: if "request_id" in response_json:
print(kwargs["litellm_params"]["metadata"]) if metadata:
if kwargs["litellm_params"]["metadata"] is not None:
response = requests.post( response = requests.post(
"https://api.promptlayer.com/rest/track-metadata", "https://api.promptlayer.com/rest/track-metadata",
json={ json={
"request_id": response_json["request_id"], "request_id": response_json["request_id"],
"api_key": self.key, "api_key": self.key,
"metadata": kwargs["litellm_params"]["metadata"], "metadata": metadata,
}, },
) )
print_verbose( print_verbose(

View file

@ -559,8 +559,7 @@ def completion(
f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n" f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n"
) )
response = llm_model.predict( response = llm_model.predict(
endpoint=endpoint_path, endpoint=endpoint_path, instances=instances
instances=instances
).predictions ).predictions
completion_response = response[0] completion_response = response[0]
@ -585,12 +584,8 @@ def completion(
"request_str": request_str, "request_str": request_str,
}, },
) )
request_str += ( request_str += f"llm_model.predict(instances={instances})\n"
f"llm_model.predict(instances={instances})\n" response = llm_model.predict(instances=instances).predictions
)
response = llm_model.predict(
instances=instances
).predictions
completion_response = response[0] completion_response = response[0]
if ( if (
@ -614,7 +609,6 @@ def completion(
model_response["choices"][0]["message"]["content"] = str( model_response["choices"][0]["message"]["content"] = str(
completion_response completion_response
) )
model_response["choices"][0]["message"]["content"] = str(completion_response)
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = model model_response["model"] = model
## CALCULATING USAGE ## CALCULATING USAGE
@ -766,6 +760,7 @@ async def async_completion(
Vertex AI Model Garden Vertex AI Model Garden
""" """
from google.cloud import aiplatform from google.cloud import aiplatform
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
@ -797,11 +792,9 @@ async def async_completion(
and "\nOutput:\n" in completion_response and "\nOutput:\n" in completion_response
): ):
completion_response = completion_response.split("\nOutput:\n", 1)[1] completion_response = completion_response.split("\nOutput:\n", 1)[1]
elif mode == "private": elif mode == "private":
request_str += ( request_str += f"llm_model.predict_async(instances={instances})\n"
f"llm_model.predict_async(instances={instances})\n"
)
response_obj = await llm_model.predict_async( response_obj = await llm_model.predict_async(
instances=instances, instances=instances,
) )
@ -826,7 +819,6 @@ async def async_completion(
model_response["choices"][0]["message"]["content"] = str( model_response["choices"][0]["message"]["content"] = str(
completion_response completion_response
) )
model_response["choices"][0]["message"]["content"] = str(completion_response)
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = model model_response["model"] = model
## CALCULATING USAGE ## CALCULATING USAGE
@ -954,6 +946,7 @@ async def async_streaming(
response = llm_model.predict_streaming_async(prompt, **optional_params) response = llm_model.predict_streaming_async(prompt, **optional_params)
elif mode == "custom": elif mode == "custom":
from google.cloud import aiplatform from google.cloud import aiplatform
stream = optional_params.pop("stream", None) stream = optional_params.pop("stream", None)
## LOGGING ## LOGGING
@ -972,7 +965,9 @@ async def async_streaming(
endpoint_path = llm_model.endpoint_path( endpoint_path = llm_model.endpoint_path(
project=vertex_project, location=vertex_location, endpoint=model project=vertex_project, location=vertex_location, endpoint=model
) )
request_str += f"client.predict(endpoint={endpoint_path}, instances={instances})\n" request_str += (
f"client.predict(endpoint={endpoint_path}, instances={instances})\n"
)
response_obj = await llm_model.predict( response_obj = await llm_model.predict(
endpoint=endpoint_path, endpoint=endpoint_path,
instances=instances, instances=instances,

View file

@ -12,7 +12,6 @@ from typing import Any, Literal, Union
from functools import partial from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy from copy import deepcopy
import httpx import httpx
import litellm import litellm
from ._logging import verbose_logger from ._logging import verbose_logger

View file

@ -424,6 +424,10 @@ class LiteLLM_VerificationToken(LiteLLMBase):
model_spend: Dict = {} model_spend: Dict = {}
model_max_budget: Dict = {} model_max_budget: Dict = {}
# hidden params used for parallel request limiting, not required to create a token
user_id_rate_limits: Optional[dict] = None
team_id_rate_limits: Optional[dict] = None
class Config: class Config:
protected_namespaces = () protected_namespaces = ()

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

View file

@ -24,46 +24,21 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
except: except:
pass pass
async def async_pre_call_hook( async def check_key_in_limits(
self, self,
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
cache: DualCache, cache: DualCache,
data: dict, data: dict,
call_type: str, call_type: str,
max_parallel_requests: int,
tpm_limit: int,
rpm_limit: int,
request_count_api_key: str,
): ):
self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook")
api_key = user_api_key_dict.api_key
max_parallel_requests = user_api_key_dict.max_parallel_requests or sys.maxsize
tpm_limit = user_api_key_dict.tpm_limit or sys.maxsize
rpm_limit = user_api_key_dict.rpm_limit or sys.maxsize
if api_key is None:
return
if (
max_parallel_requests == sys.maxsize
and tpm_limit == sys.maxsize
and rpm_limit == sys.maxsize
):
return
self.user_api_key_cache = cache # save the api key cache for updating the value
# ------------
# Setup values
# ------------
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{api_key}::{precise_minute}::request_count"
# CHECK IF REQUEST ALLOWED
current = cache.get_cache( current = cache.get_cache(
key=request_count_api_key key=request_count_api_key
) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10} ) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
self.print_verbose(f"current: {current}") # print(f"current: {current}")
if current is None: if current is None:
new_val = { new_val = {
"current_requests": 1, "current_requests": 1,
@ -88,10 +63,107 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
status_code=429, detail="Max parallel request limit reached." status_code=429, detail="Max parallel request limit reached."
) )
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str,
):
self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook")
api_key = user_api_key_dict.api_key
max_parallel_requests = user_api_key_dict.max_parallel_requests or sys.maxsize
tpm_limit = user_api_key_dict.tpm_limit or sys.maxsize
rpm_limit = user_api_key_dict.rpm_limit or sys.maxsize
if api_key is None:
return
self.user_api_key_cache = cache # save the api key cache for updating the value
# ------------
# Setup values
# ------------
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{api_key}::{precise_minute}::request_count"
# CHECK IF REQUEST ALLOWED for key
current = cache.get_cache(
key=request_count_api_key
) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
self.print_verbose(f"current: {current}")
if (
max_parallel_requests == sys.maxsize
and tpm_limit == sys.maxsize
and rpm_limit == sys.maxsize
):
pass
elif current is None:
new_val = {
"current_requests": 1,
"current_tpm": 0,
"current_rpm": 0,
}
cache.set_cache(request_count_api_key, new_val)
elif (
int(current["current_requests"]) < max_parallel_requests
and current["current_tpm"] < tpm_limit
and current["current_rpm"] < rpm_limit
):
# Increase count for this token
new_val = {
"current_requests": current["current_requests"] + 1,
"current_tpm": current["current_tpm"],
"current_rpm": current["current_rpm"],
}
cache.set_cache(request_count_api_key, new_val)
else:
raise HTTPException(
status_code=429, detail="Max parallel request limit reached."
)
# check if REQUEST ALLOWED for user_id
user_id = user_api_key_dict.user_id
_user_id_rate_limits = user_api_key_dict.user_id_rate_limits
# get user tpm/rpm limits
if _user_id_rate_limits is None or _user_id_rate_limits == {}:
return
user_tpm_limit = _user_id_rate_limits.get("tpm_limit")
user_rpm_limit = _user_id_rate_limits.get("rpm_limit")
if user_tpm_limit is None:
user_tpm_limit = sys.maxsize
if user_rpm_limit is None:
user_rpm_limit = sys.maxsize
# now do the same tpm/rpm checks
request_count_api_key = f"{user_id}::{precise_minute}::request_count"
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
await self.check_key_in_limits(
user_api_key_dict=user_api_key_dict,
cache=cache,
data=data,
call_type=call_type,
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user
request_count_api_key=request_count_api_key,
tpm_limit=user_tpm_limit,
rpm_limit=user_rpm_limit,
)
return
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try: try:
self.print_verbose(f"INSIDE parallel request limiter ASYNC SUCCESS LOGGING") self.print_verbose(f"INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"] user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"]
user_api_key_user_id = kwargs["litellm_params"]["metadata"].get(
"user_api_key_user_id", None
)
if user_api_key is None: if user_api_key is None:
return return
@ -121,7 +193,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
} }
# ------------ # ------------
# Update usage # Update usage - API Key
# ------------ # ------------
new_val = { new_val = {
@ -136,6 +208,41 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
self.user_api_key_cache.set_cache( self.user_api_key_cache.set_cache(
request_count_api_key, new_val, ttl=60 request_count_api_key, new_val, ttl=60
) # store in cache for 1 min. ) # store in cache for 1 min.
# ------------
# Update usage - User
# ------------
if user_api_key_user_id is None:
return
total_tokens = 0
if isinstance(response_obj, ModelResponse):
total_tokens = response_obj.usage.total_tokens
request_count_api_key = (
f"{user_api_key_user_id}::{precise_minute}::request_count"
)
current = self.user_api_key_cache.get_cache(key=request_count_api_key) or {
"current_requests": 1,
"current_tpm": total_tokens,
"current_rpm": 1,
}
new_val = {
"current_requests": max(current["current_requests"] - 1, 0),
"current_tpm": current["current_tpm"] + total_tokens,
"current_rpm": current["current_rpm"] + 1,
}
self.print_verbose(
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
)
self.user_api_key_cache.set_cache(
request_count_api_key, new_val, ttl=60
) # store in cache for 1 min.
except Exception as e: except Exception as e:
self.print_verbose(e) # noqa self.print_verbose(e) # noqa

View file

@ -1479,6 +1479,26 @@ class ProxyConfig:
llm_guard_moderation_obj = _ENTERPRISE_LLMGuard() llm_guard_moderation_obj = _ENTERPRISE_LLMGuard()
imported_list.append(llm_guard_moderation_obj) imported_list.append(llm_guard_moderation_obj)
elif (
isinstance(callback, str)
and callback == "blocked_user_check"
):
from litellm.proxy.enterprise.enterprise_hooks.blocked_user_list import (
_ENTERPRISE_BlockedUserList,
)
blocked_user_list = _ENTERPRISE_BlockedUserList()
imported_list.append(blocked_user_list)
elif (
isinstance(callback, str)
and callback == "banned_keywords"
):
from litellm.proxy.enterprise.enterprise_hooks.banned_keywords import (
_ENTERPRISE_BannedKeywords,
)
banned_keywords_obj = _ENTERPRISE_BannedKeywords()
imported_list.append(banned_keywords_obj)
else: else:
imported_list.append( imported_list.append(
get_instance_fn( get_instance_fn(
@ -4368,7 +4388,20 @@ async def update_team(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
): ):
""" """
add new members to the team You can now add / delete users from a team via /team/update
```
curl --location 'http://0.0.0.0:8000/team/update' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data-raw '{
"team_id": "45e3e396-ee08-4a61-a88e-16b3ce7e0849",
"members_with_roles": [{"role": "admin", "user_id": "5c4a0aa3-a1e1-43dc-bd87-3c2da8382a3a"}, {"role": "user", "user_id": "krrish247652@berri.ai"}]
}'
```
""" """
global prisma_client global prisma_client
@ -4449,6 +4482,18 @@ async def delete_team(
): ):
""" """
delete team and associated team keys delete team and associated team keys
```
curl --location 'http://0.0.0.0:8000/team/delete' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data-raw '{
"team_ids": ["45e3e396-ee08-4a61-a88e-16b3ce7e0849"]
}'
```
""" """
global prisma_client global prisma_client
@ -5097,7 +5142,15 @@ async def google_login(request: Request):
scope=generic_scope, scope=generic_scope,
) )
with generic_sso: with generic_sso:
return await generic_sso.get_login_redirect() # TODO: state should be a random string and added to the user session with cookie
# or a cryptographicly signed state that we can verify stateless
# For simplification we are using a static state, this is not perfect but some
# SSO providers do not allow stateless verification
redirect_params = {}
state = os.getenv("GENERIC_CLIENT_STATE", None)
if state:
redirect_params["state"] = state
return await generic_sso.get_login_redirect(**redirect_params) # type: ignore
elif ui_username is not None: elif ui_username is not None:
# No Google, Microsoft SSO # No Google, Microsoft SSO
# Use UI Credentials set in .env # Use UI Credentials set in .env
@ -5203,7 +5256,25 @@ def get_image():
logo_path = os.getenv("UI_LOGO_PATH", default_logo) logo_path = os.getenv("UI_LOGO_PATH", default_logo)
verbose_proxy_logger.debug(f"Reading logo from {logo_path}") verbose_proxy_logger.debug(f"Reading logo from {logo_path}")
return FileResponse(path=logo_path)
# Check if the logo path is an HTTP/HTTPS URL
if logo_path.startswith(("http://", "https://")):
# Download the image and cache it
response = requests.get(logo_path)
if response.status_code == 200:
# Save the image to a local file
cache_path = os.path.join(current_dir, "cached_logo.jpg")
with open(cache_path, "wb") as f:
f.write(response.content)
# Return the cached image as a FileResponse
return FileResponse(cache_path, media_type="image/jpeg")
else:
# Handle the case when the image cannot be downloaded
return FileResponse(default_logo, media_type="image/jpeg")
else:
# Return the local image file if the logo path is not an HTTP/HTTPS URL
return FileResponse(logo_path, media_type="image/jpeg")
@app.get("/sso/callback", tags=["experimental"]) @app.get("/sso/callback", tags=["experimental"])
@ -5265,7 +5336,7 @@ async def auth_callback(request: Request):
result = await microsoft_sso.verify_and_process(request) result = await microsoft_sso.verify_and_process(request)
elif generic_client_id is not None: elif generic_client_id is not None:
# make generic sso provider # make generic sso provider
from fastapi_sso.sso.generic import create_provider, DiscoveryDocument from fastapi_sso.sso.generic import create_provider, DiscoveryDocument, OpenID
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None) generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ") generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
@ -5274,6 +5345,9 @@ async def auth_callback(request: Request):
) )
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None) generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None) generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
generic_include_client_id = (
os.getenv("GENERIC_INCLUDE_CLIENT_ID", "false").lower() == "true"
)
if generic_client_secret is None: if generic_client_secret is None:
raise ProxyException( raise ProxyException(
message="GENERIC_CLIENT_SECRET not set. Set it in .env file", message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
@ -5308,12 +5382,50 @@ async def auth_callback(request: Request):
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n" f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
) )
generic_user_id_attribute_name = os.getenv(
"GENERIC_USER_ID_ATTRIBUTE", "preferred_username"
)
generic_user_display_name_attribute_name = os.getenv(
"GENERIC_USER_DISPLAY_NAME_ATTRIBUTE", "sub"
)
generic_user_email_attribute_name = os.getenv(
"GENERIC_USER_EMAIL_ATTRIBUTE", "email"
)
generic_user_role_attribute_name = os.getenv(
"GENERIC_USER_ROLE_ATTRIBUTE", "role"
)
generic_user_first_name_attribute_name = os.getenv(
"GENERIC_USER_FIRST_NAME_ATTRIBUTE", "first_name"
)
generic_user_last_name_attribute_name = os.getenv(
"GENERIC_USER_LAST_NAME_ATTRIBUTE", "last_name"
)
verbose_proxy_logger.debug(
f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}\n generic_user_role_attribute_name: {generic_user_role_attribute_name}"
)
discovery = DiscoveryDocument( discovery = DiscoveryDocument(
authorization_endpoint=generic_authorization_endpoint, authorization_endpoint=generic_authorization_endpoint,
token_endpoint=generic_token_endpoint, token_endpoint=generic_token_endpoint,
userinfo_endpoint=generic_userinfo_endpoint, userinfo_endpoint=generic_userinfo_endpoint,
) )
SSOProvider = create_provider(name="oidc", discovery_document=discovery)
def response_convertor(response, client):
return OpenID(
id=response.get(generic_user_id_attribute_name),
display_name=response.get(generic_user_display_name_attribute_name),
email=response.get(generic_user_email_attribute_name),
first_name=response.get(generic_user_first_name_attribute_name),
last_name=response.get(generic_user_last_name_attribute_name),
)
SSOProvider = create_provider(
name="oidc",
discovery_document=discovery,
response_convertor=response_convertor,
)
generic_sso = SSOProvider( generic_sso = SSOProvider(
client_id=generic_client_id, client_id=generic_client_id,
client_secret=generic_client_secret, client_secret=generic_client_secret,
@ -5322,43 +5434,36 @@ async def auth_callback(request: Request):
scope=generic_scope, scope=generic_scope,
) )
verbose_proxy_logger.debug(f"calling generic_sso.verify_and_process") verbose_proxy_logger.debug(f"calling generic_sso.verify_and_process")
request_body = await request.body() result = await generic_sso.verify_and_process(
request_query_params = request.query_params request, params={"include_client_id": generic_include_client_id}
# get "code" from query params )
code = request_query_params.get("code")
result = await generic_sso.verify_and_process(request)
verbose_proxy_logger.debug(f"generic result: {result}") verbose_proxy_logger.debug(f"generic result: {result}")
# User is Authe'd in - generate key for the UI to access Proxy # User is Authe'd in - generate key for the UI to access Proxy
user_email = getattr(result, "email", None) user_email = getattr(result, "email", None)
user_id = getattr(result, "id", None) user_id = getattr(result, "id", None)
# generic client id # generic client id
if generic_client_id is not None: if generic_client_id is not None:
generic_user_id_attribute_name = os.getenv("GENERIC_USER_ID_ATTRIBUTE", "email") user_id = getattr(result, "id", None)
generic_user_email_attribute_name = os.getenv( user_email = getattr(result, "email", None)
"GENERIC_USER_EMAIL_ATTRIBUTE", "email"
)
generic_user_role_attribute_name = os.getenv(
"GENERIC_USER_ROLE_ATTRIBUTE", "role"
)
verbose_proxy_logger.debug(
f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}\n generic_user_role_attribute_name: {generic_user_role_attribute_name}"
)
user_id = getattr(result, generic_user_id_attribute_name, None)
user_email = getattr(result, generic_user_email_attribute_name, None)
user_role = getattr(result, generic_user_role_attribute_name, None) user_role = getattr(result, generic_user_role_attribute_name, None)
if user_id is None: if user_id is None:
user_id = getattr(result, "first_name", "") + getattr(result, "last_name", "") user_id = getattr(result, "first_name", "") + getattr(result, "last_name", "")
# get user_info from litellm DB
user_info = None user_info = None
if prisma_client is not None:
user_info = await prisma_client.get_data(user_id=user_id, table_name="user")
user_id_models: List = [] user_id_models: List = []
if user_info is not None:
user_id_models = getattr(user_info, "models", []) # User might not be already created on first generation of key
# But if it is, we want its models preferences
try:
if prisma_client is not None:
user_info = await prisma_client.get_data(user_id=user_id, table_name="user")
if user_info is not None:
user_id_models = getattr(user_info, "models", [])
except Exception as e:
pass
response = await generate_key_helper_fn( response = await generate_key_helper_fn(
**{ **{

View file

@ -318,7 +318,7 @@ def test_gemini_pro_vision():
# test_gemini_pro_vision() # test_gemini_pro_vision()
def gemini_pro_function_calling(): def test_gemini_pro_function_calling():
load_vertex_ai_credentials() load_vertex_ai_credentials()
tools = [ tools = [
{ {
@ -345,12 +345,15 @@ def gemini_pro_function_calling():
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto" model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
) )
print(f"completion: {completion}") print(f"completion: {completion}")
assert completion.choices[0].message.content is None
assert len(completion.choices[0].message.tool_calls) == 1
# gemini_pro_function_calling() # gemini_pro_function_calling()
async def gemini_pro_async_function_calling(): @pytest.mark.asyncio
async def test_gemini_pro_async_function_calling():
load_vertex_ai_credentials() load_vertex_ai_credentials()
tools = [ tools = [
{ {
@ -377,6 +380,9 @@ async def gemini_pro_async_function_calling():
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto" model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
) )
print(f"completion: {completion}") print(f"completion: {completion}")
assert completion.choices[0].message.content is None
assert len(completion.choices[0].message.tool_calls) == 1
# raise Exception("it worked!")
# asyncio.run(gemini_pro_async_function_calling()) # asyncio.run(gemini_pro_async_function_calling())

View file

@ -0,0 +1,63 @@
# What is this?
## This tests the blocked user pre call hook for the proxy server
import sys, os, asyncio, time, random
from datetime import datetime
import traceback
from dotenv import load_dotenv
load_dotenv()
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
from litellm.proxy.enterprise.enterprise_hooks.banned_keywords import (
_ENTERPRISE_BannedKeywords,
)
from litellm import Router, mock_completion
from litellm.proxy.utils import ProxyLogging
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
@pytest.mark.asyncio
async def test_banned_keywords_check():
"""
- Set some banned keywords as a litellm module value
- Test to see if a call with banned keywords is made, an error is raised
- Test to see if a call without banned keywords is made it passes
"""
litellm.banned_keywords_list = ["hello"]
banned_keywords_obj = _ENTERPRISE_BannedKeywords()
_api_key = "sk-12345"
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
local_cache = DualCache()
## Case 1: blocked user id passed
try:
await banned_keywords_obj.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=local_cache,
call_type="completion",
data={"messages": [{"role": "user", "content": "Hello world"}]},
)
pytest.fail(f"Expected call to fail")
except Exception as e:
pass
## Case 2: normal user id passed
try:
await banned_keywords_obj.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=local_cache,
call_type="completion",
data={"messages": [{"role": "user", "content": "Hey, how's it going?"}]},
)
except Exception as e:
pytest.fail(f"An error occurred - {str(e)}")

View file

@ -0,0 +1,63 @@
# What is this?
## This tests the blocked user pre call hook for the proxy server
import sys, os, asyncio, time, random
from datetime import datetime
import traceback
from dotenv import load_dotenv
load_dotenv()
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
from litellm.proxy.enterprise.enterprise_hooks.blocked_user_list import (
_ENTERPRISE_BlockedUserList,
)
from litellm import Router, mock_completion
from litellm.proxy.utils import ProxyLogging
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
@pytest.mark.asyncio
async def test_block_user_check():
"""
- Set a blocked user as a litellm module value
- Test to see if a call with that user id is made, an error is raised
- Test to see if a call without that user is passes
"""
litellm.blocked_user_list = ["user_id_1"]
blocked_user_obj = _ENTERPRISE_BlockedUserList()
_api_key = "sk-12345"
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
local_cache = DualCache()
## Case 1: blocked user id passed
try:
await blocked_user_obj.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=local_cache,
call_type="completion",
data={"user_id": "user_id_1"},
)
pytest.fail(f"Expected call to fail")
except Exception as e:
pass
## Case 2: normal user id passed
try:
await blocked_user_obj.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=local_cache,
call_type="completion",
data={"user_id": "user_id_2"},
)
except Exception as e:
pytest.fail(f"An error occurred - {str(e)}")

View file

@ -139,6 +139,56 @@ async def test_pre_call_hook_tpm_limits():
assert e.status_code == 429 assert e.status_code == 429
@pytest.mark.asyncio
async def test_pre_call_hook_user_tpm_limits():
"""
Test if error raised on hitting tpm limits
"""
# create user with tpm/rpm limits
_api_key = "sk-12345"
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key,
user_id="ishaan",
user_id_rate_limits={"tpm_limit": 9, "rpm_limit": 10},
)
res = dict(user_api_key_dict)
print("dict user", res)
local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler()
await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
)
kwargs = {
"litellm_params": {
"metadata": {"user_api_key_user_id": "ishaan", "user_api_key": "gm"}
}
}
await parallel_request_handler.async_log_success_event(
kwargs=kwargs,
response_obj=litellm.ModelResponse(usage=litellm.Usage(total_tokens=10)),
start_time="",
end_time="",
)
## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1}
try:
await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=local_cache,
data={},
call_type="",
)
pytest.fail(f"Expected call to fail")
except Exception as e:
assert e.status_code == 429
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_success_call_hook(): async def test_success_call_hook():
""" """

View file

@ -7,10 +7,9 @@ sys.path.insert(0, os.path.abspath("../.."))
from litellm import completion from litellm import completion
import litellm import litellm
litellm.success_callback = ["promptlayer"] import pytest
litellm.set_verbose = True
import time
import time
# def test_promptlayer_logging(): # def test_promptlayer_logging():
# try: # try:
@ -39,11 +38,16 @@ import time
# test_promptlayer_logging() # test_promptlayer_logging()
@pytest.mark.skip(
reason="this works locally but fails on ci/cd since ci/cd is not reading the stdout correctly"
)
def test_promptlayer_logging_with_metadata(): def test_promptlayer_logging_with_metadata():
try: try:
# Redirect stdout # Redirect stdout
old_stdout = sys.stdout old_stdout = sys.stdout
sys.stdout = new_stdout = io.StringIO() sys.stdout = new_stdout = io.StringIO()
litellm.set_verbose = True
litellm.success_callback = ["promptlayer"]
response = completion( response = completion(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
@ -58,15 +62,43 @@ def test_promptlayer_logging_with_metadata():
sys.stdout = old_stdout sys.stdout = old_stdout
output = new_stdout.getvalue().strip() output = new_stdout.getvalue().strip()
print(output) print(output)
if "LiteLLM: Prompt Layer Logging: success" not in output:
raise Exception("Required log message not found!") assert "Prompt Layer Logging: success" in output
except Exception as e: except Exception as e:
print(e) pytest.fail(f"Error occurred: {e}")
# test_promptlayer_logging_with_metadata() @pytest.mark.skip(
reason="this works locally but fails on ci/cd since ci/cd is not reading the stdout correctly"
)
def test_promptlayer_logging_with_metadata_tags():
try:
# Redirect stdout
litellm.set_verbose = True
litellm.success_callback = ["promptlayer"]
old_stdout = sys.stdout
sys.stdout = new_stdout = io.StringIO()
response = completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi 👋 - i'm ai21"}],
temperature=0.2,
max_tokens=20,
metadata={"model": "ai21", "pl_tags": ["env:dev"]},
mock_response="this is a mock response",
)
# Restore stdout
time.sleep(1)
sys.stdout = old_stdout
output = new_stdout.getvalue().strip()
print(output)
assert "Prompt Layer Logging: success" in output
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# def test_chat_openai(): # def test_chat_openai():
# try: # try:

View file

@ -393,6 +393,8 @@ def test_completion_palm_stream():
if complete_response.strip() == "": if complete_response.strip() == "":
raise Exception("Empty response received") raise Exception("Empty response received")
print(f"completion_response: {complete_response}") print(f"completion_response: {complete_response}")
except litellm.Timeout as e:
pass
except litellm.APIError as e: except litellm.APIError as e:
pass pass
except Exception as e: except Exception as e:

View file

@ -4277,8 +4277,8 @@ def get_optional_params(
optional_params["stop_sequences"] = stop optional_params["stop_sequences"] = stop
if max_tokens is not None: if max_tokens is not None:
optional_params["max_output_tokens"] = max_tokens optional_params["max_output_tokens"] = max_tokens
elif custom_llm_provider == "vertex_ai" and model in ( elif custom_llm_provider == "vertex_ai" and (
litellm.vertex_chat_models model in litellm.vertex_chat_models
or model in litellm.vertex_code_chat_models or model in litellm.vertex_code_chat_models
or model in litellm.vertex_text_models or model in litellm.vertex_text_models
or model in litellm.vertex_code_text_models or model in litellm.vertex_code_text_models
@ -6827,6 +6827,14 @@ def exception_type(
llm_provider="palm", llm_provider="palm",
response=original_exception.response, response=original_exception.response,
) )
if "504 Deadline expired before operation could complete." in error_str:
exception_mapping_worked = True
raise Timeout(
message=f"PalmException - {original_exception.message}",
model=model,
llm_provider="palm",
request=original_exception.request,
)
if "400 Request payload size exceeds" in error_str: if "400 Request payload size exceeds" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
raise ContextWindowExceededError( raise ContextWindowExceededError(

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "1.26.8" version = "1.26.10"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT" license = "MIT"
@ -74,7 +74,7 @@ requires = ["poetry-core", "wheel"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.commitizen] [tool.commitizen]
version = "1.26.8" version = "1.26.10"
version_files = [ version_files = [
"pyproject.toml:^version" "pyproject.toml:^version"
] ]