mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Merge branch 'main' into litellm_fix_azure_function_calling_streaming
This commit is contained in:
commit
dd4439b6a8
23 changed files with 869 additions and 173 deletions
|
@ -238,9 +238,11 @@ chat_completion = client.chat.completions.create(
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
model="gpt-3.5-turbo",
|
model="gpt-3.5-turbo",
|
||||||
cache={
|
extra_body = { # OpenAI python accepts extra args in extra_body
|
||||||
"no-cache": True # will not return a cached response
|
cache: {
|
||||||
}
|
"no-cache": True # will not return a cached response
|
||||||
|
}
|
||||||
|
}
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -264,9 +266,11 @@ chat_completion = client.chat.completions.create(
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
model="gpt-3.5-turbo",
|
model="gpt-3.5-turbo",
|
||||||
cache={
|
extra_body = { # OpenAI python accepts extra args in extra_body
|
||||||
"ttl": 600 # caches response for 10 minutes
|
cache: {
|
||||||
}
|
"ttl": 600 # caches response for 10 minutes
|
||||||
|
}
|
||||||
|
}
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -288,13 +292,15 @@ chat_completion = client.chat.completions.create(
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
model="gpt-3.5-turbo",
|
model="gpt-3.5-turbo",
|
||||||
cache={
|
extra_body = { # OpenAI python accepts extra args in extra_body
|
||||||
"s-maxage": 600 # only get responses cached within last 10 minutes
|
cache: {
|
||||||
}
|
"s-maxage": 600 # only get responses cached within last 10 minutes
|
||||||
|
}
|
||||||
|
}
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Supported `cache_params`
|
## Supported `cache_params` on proxy config.yaml
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
cache_params:
|
cache_params:
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import Tabs from '@theme/Tabs';
|
import Tabs from '@theme/Tabs';
|
||||||
import TabItem from '@theme/TabItem';
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
# ✨ Enterprise Features - Content Moderation
|
# ✨ Enterprise Features - Content Moderation, Blocked Users
|
||||||
|
|
||||||
Features here are behind a commercial license in our `/enterprise` folder. [**See Code**](https://github.com/BerriAI/litellm/tree/main/enterprise)
|
Features here are behind a commercial license in our `/enterprise` folder. [**See Code**](https://github.com/BerriAI/litellm/tree/main/enterprise)
|
||||||
|
|
||||||
|
@ -15,6 +15,7 @@ Features:
|
||||||
- [ ] Content Moderation with LlamaGuard
|
- [ ] Content Moderation with LlamaGuard
|
||||||
- [ ] Content Moderation with Google Text Moderations
|
- [ ] Content Moderation with Google Text Moderations
|
||||||
- [ ] Content Moderation with LLM Guard
|
- [ ] Content Moderation with LLM Guard
|
||||||
|
- [ ] Reject calls from Blocked User list
|
||||||
- [ ] Tracking Spend for Custom Tags
|
- [ ] Tracking Spend for Custom Tags
|
||||||
|
|
||||||
## Content Moderation with LlamaGuard
|
## Content Moderation with LlamaGuard
|
||||||
|
@ -132,6 +133,39 @@ Here are the category specific values:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Enable Blocked User Lists
|
||||||
|
If any call is made to proxy with this user id, it'll be rejected - use this if you want to let users opt-out of ai features
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
litellm_settings:
|
||||||
|
callbacks: ["blocked_user_check"]
|
||||||
|
blocked_user_id_list: ["user_id_1", "user_id_2", ...] # can also be a .txt filepath e.g. `/relative/path/blocked_list.txt`
|
||||||
|
```
|
||||||
|
|
||||||
|
### How to test
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl --location 'http://0.0.0.0:8000/chat/completions' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data ' {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what llm are you"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"user_id": "user_id_1" # this is also an openai supported param
|
||||||
|
}
|
||||||
|
'
|
||||||
|
```
|
||||||
|
|
||||||
|
:::info
|
||||||
|
|
||||||
|
[Suggest a way to improve this](https://github.com/BerriAI/litellm/issues/new/choose)
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
## Tracking Spend for Custom Tags
|
## Tracking Spend for Custom Tags
|
||||||
|
|
||||||
Requirements:
|
Requirements:
|
||||||
|
|
|
@ -133,8 +133,12 @@ The following can be used to customize attribute names when interacting with the
|
||||||
```shell
|
```shell
|
||||||
GENERIC_USER_ID_ATTRIBUTE = "given_name"
|
GENERIC_USER_ID_ATTRIBUTE = "given_name"
|
||||||
GENERIC_USER_EMAIL_ATTRIBUTE = "family_name"
|
GENERIC_USER_EMAIL_ATTRIBUTE = "family_name"
|
||||||
|
GENERIC_USER_DISPLAY_NAME_ATTRIBUTE = "display_name"
|
||||||
|
GENERIC_USER_FIRST_NAME_ATTRIBUTE = "first_name"
|
||||||
|
GENERIC_USER_LAST_NAME_ATTRIBUTE = "last_name"
|
||||||
GENERIC_USER_ROLE_ATTRIBUTE = "given_role"
|
GENERIC_USER_ROLE_ATTRIBUTE = "given_role"
|
||||||
|
GENERIC_CLIENT_STATE = "some-state" # if the provider needs a state parameter
|
||||||
|
GENERIC_INCLUDE_CLIENT_ID = "false" # some providers enforce that the client_id is not in the body
|
||||||
GENERIC_SCOPE = "openid profile email" # default scope openid is sometimes not enough to retrieve basic user info like first_name and last_name located in profile scope
|
GENERIC_SCOPE = "openid profile email" # default scope openid is sometimes not enough to retrieve basic user info like first_name and last_name located in profile scope
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -148,7 +152,14 @@ GENERIC_SCOPE = "openid profile email" # default scope openid is sometimes not e
|
||||||
|
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
#### Step 3. Test flow
|
#### Step 3. Set `PROXY_BASE_URL` in your .env
|
||||||
|
|
||||||
|
Set this in your .env (so the proxy can set the correct redirect url)
|
||||||
|
```shell
|
||||||
|
PROXY_BASE_URL=https://litellm-api.up.railway.app/
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Step 4. Test flow
|
||||||
<Image img={require('../../img/litellm_ui_3.gif')} />
|
<Image img={require('../../img/litellm_ui_3.gif')} />
|
||||||
|
|
||||||
### Set Admin view w/ SSO
|
### Set Admin view w/ SSO
|
||||||
|
@ -183,7 +194,21 @@ We allow you to
|
||||||
- Customize the UI color scheme
|
- Customize the UI color scheme
|
||||||
<Image img={require('../../img/litellm_custom_ai.png')} />
|
<Image img={require('../../img/litellm_custom_ai.png')} />
|
||||||
|
|
||||||
#### Usage
|
#### Set Custom Logo
|
||||||
|
We allow you to pass a local image or a an http/https url of your image
|
||||||
|
|
||||||
|
Set `UI_LOGO_PATH` on your env. We recommend using a hosted image, it's a lot easier to set up and configure / debug
|
||||||
|
|
||||||
|
Exaple setting Hosted image
|
||||||
|
```shell
|
||||||
|
UI_LOGO_PATH="https://litellm-logo-aws-marketplace.s3.us-west-2.amazonaws.com/berriai-logo-github.png"
|
||||||
|
```
|
||||||
|
|
||||||
|
Exaple setting a local image (on your container)
|
||||||
|
```shell
|
||||||
|
UI_LOGO_PATH="ui_images/logo.jpg"
|
||||||
|
```
|
||||||
|
#### Set Custom Color Theme
|
||||||
- Navigate to [/enterprise/enterprise_ui](https://github.com/BerriAI/litellm/blob/main/enterprise/enterprise_ui/_enterprise_colors.json)
|
- Navigate to [/enterprise/enterprise_ui](https://github.com/BerriAI/litellm/blob/main/enterprise/enterprise_ui/_enterprise_colors.json)
|
||||||
- Inside the `enterprise_ui` directory, rename `_enterprise_colors.json` to `enterprise_colors.json`
|
- Inside the `enterprise_ui` directory, rename `_enterprise_colors.json` to `enterprise_colors.json`
|
||||||
- Set your companies custom color scheme in `enterprise_colors.json`
|
- Set your companies custom color scheme in `enterprise_colors.json`
|
||||||
|
@ -202,8 +227,6 @@ Set your colors to any of the following colors: https://www.tremor.so/docs/layou
|
||||||
}
|
}
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
- Set the path to your custom png/jpg logo as `UI_LOGO_PATH` in your .env
|
|
||||||
- Deploy LiteLLM Proxy Server
|
- Deploy LiteLLM Proxy Server
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -279,9 +279,9 @@ curl 'http://0.0.0.0:8000/key/generate' \
|
||||||
## Set Rate Limits
|
## Set Rate Limits
|
||||||
|
|
||||||
You can set:
|
You can set:
|
||||||
|
- tpm limits (tokens per minute)
|
||||||
|
- rpm limits (requests per minute)
|
||||||
- max parallel requests
|
- max parallel requests
|
||||||
- tpm limits
|
|
||||||
- rpm limits
|
|
||||||
|
|
||||||
<Tabs>
|
<Tabs>
|
||||||
<TabItem value="per-user" label="Per User">
|
<TabItem value="per-user" label="Per User">
|
||||||
|
|
|
@ -18,6 +18,62 @@ const sidebars = {
|
||||||
// But you can create a sidebar manually
|
// But you can create a sidebar manually
|
||||||
tutorialSidebar: [
|
tutorialSidebar: [
|
||||||
{ type: "doc", id: "index" }, // NEW
|
{ type: "doc", id: "index" }, // NEW
|
||||||
|
{
|
||||||
|
type: "category",
|
||||||
|
label: "💥 OpenAI Proxy Server",
|
||||||
|
link: {
|
||||||
|
type: 'generated-index',
|
||||||
|
title: '💥 OpenAI Proxy Server',
|
||||||
|
description: `Proxy Server to call 100+ LLMs in a unified interface & track spend, set budgets per virtual key/user`,
|
||||||
|
slug: '/simple_proxy',
|
||||||
|
},
|
||||||
|
items: [
|
||||||
|
"proxy/quick_start",
|
||||||
|
"proxy/configs",
|
||||||
|
{
|
||||||
|
type: 'link',
|
||||||
|
label: '📖 All Endpoints',
|
||||||
|
href: 'https://litellm-api.up.railway.app/',
|
||||||
|
},
|
||||||
|
"proxy/enterprise",
|
||||||
|
"proxy/user_keys",
|
||||||
|
"proxy/virtual_keys",
|
||||||
|
"proxy/users",
|
||||||
|
"proxy/ui",
|
||||||
|
"proxy/model_management",
|
||||||
|
"proxy/health",
|
||||||
|
"proxy/debugging",
|
||||||
|
"proxy/pii_masking",
|
||||||
|
{
|
||||||
|
"type": "category",
|
||||||
|
"label": "🔥 Load Balancing",
|
||||||
|
"items": [
|
||||||
|
"proxy/load_balancing",
|
||||||
|
"proxy/reliability",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"proxy/caching",
|
||||||
|
{
|
||||||
|
"type": "category",
|
||||||
|
"label": "Logging, Alerting",
|
||||||
|
"items": [
|
||||||
|
"proxy/logging",
|
||||||
|
"proxy/alerting",
|
||||||
|
"proxy/streaming_logging",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "category",
|
||||||
|
"label": "Content Moderation",
|
||||||
|
"items": [
|
||||||
|
"proxy/call_hooks",
|
||||||
|
"proxy/rules",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"proxy/deploy",
|
||||||
|
"proxy/cli",
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
type: "category",
|
type: "category",
|
||||||
label: "Completion()",
|
label: "Completion()",
|
||||||
|
@ -92,62 +148,6 @@ const sidebars = {
|
||||||
"providers/petals",
|
"providers/petals",
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
type: "category",
|
|
||||||
label: "💥 OpenAI Proxy Server",
|
|
||||||
link: {
|
|
||||||
type: 'generated-index',
|
|
||||||
title: '💥 OpenAI Proxy Server',
|
|
||||||
description: `Proxy Server to call 100+ LLMs in a unified interface & track spend, set budgets per virtual key/user`,
|
|
||||||
slug: '/simple_proxy',
|
|
||||||
},
|
|
||||||
items: [
|
|
||||||
"proxy/quick_start",
|
|
||||||
"proxy/configs",
|
|
||||||
{
|
|
||||||
type: 'link',
|
|
||||||
label: '📖 All Endpoints',
|
|
||||||
href: 'https://litellm-api.up.railway.app/',
|
|
||||||
},
|
|
||||||
"proxy/enterprise",
|
|
||||||
"proxy/user_keys",
|
|
||||||
"proxy/virtual_keys",
|
|
||||||
"proxy/users",
|
|
||||||
"proxy/ui",
|
|
||||||
"proxy/model_management",
|
|
||||||
"proxy/health",
|
|
||||||
"proxy/debugging",
|
|
||||||
"proxy/pii_masking",
|
|
||||||
{
|
|
||||||
"type": "category",
|
|
||||||
"label": "🔥 Load Balancing",
|
|
||||||
"items": [
|
|
||||||
"proxy/load_balancing",
|
|
||||||
"proxy/reliability",
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"proxy/caching",
|
|
||||||
{
|
|
||||||
"type": "category",
|
|
||||||
"label": "Logging, Alerting",
|
|
||||||
"items": [
|
|
||||||
"proxy/logging",
|
|
||||||
"proxy/alerting",
|
|
||||||
"proxy/streaming_logging",
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "category",
|
|
||||||
"label": "Content Moderation",
|
|
||||||
"items": [
|
|
||||||
"proxy/call_hooks",
|
|
||||||
"proxy/rules",
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"proxy/deploy",
|
|
||||||
"proxy/cli",
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"proxy/custom_pricing",
|
"proxy/custom_pricing",
|
||||||
"routing",
|
"routing",
|
||||||
"rules",
|
"rules",
|
||||||
|
|
103
enterprise/enterprise_hooks/banned_keywords.py
Normal file
103
enterprise/enterprise_hooks/banned_keywords.py
Normal file
|
@ -0,0 +1,103 @@
|
||||||
|
# +------------------------------+
|
||||||
|
#
|
||||||
|
# Banned Keywords
|
||||||
|
#
|
||||||
|
# +------------------------------+
|
||||||
|
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||||
|
## Reject a call / response if it contains certain keywords
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Optional, Literal
|
||||||
|
import litellm
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from fastapi import HTTPException
|
||||||
|
import json, traceback
|
||||||
|
|
||||||
|
|
||||||
|
class _ENTERPRISE_BannedKeywords(CustomLogger):
|
||||||
|
# Class variables or attributes
|
||||||
|
def __init__(self):
|
||||||
|
banned_keywords_list = litellm.banned_keywords_list
|
||||||
|
|
||||||
|
if banned_keywords_list is None:
|
||||||
|
raise Exception(
|
||||||
|
"`banned_keywords_list` can either be a list or filepath. None set."
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(banned_keywords_list, list):
|
||||||
|
self.banned_keywords_list = banned_keywords_list
|
||||||
|
|
||||||
|
if isinstance(banned_keywords_list, str): # assume it's a filepath
|
||||||
|
try:
|
||||||
|
with open(banned_keywords_list, "r") as file:
|
||||||
|
data = file.read()
|
||||||
|
self.banned_keywords_list = data.split("\n")
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise Exception(
|
||||||
|
f"File not found. banned_keywords_list={banned_keywords_list}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(
|
||||||
|
f"An error occurred: {str(e)}, banned_keywords_list={banned_keywords_list}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def print_verbose(self, print_statement, level: Literal["INFO", "DEBUG"] = "DEBUG"):
|
||||||
|
if level == "INFO":
|
||||||
|
verbose_proxy_logger.info(print_statement)
|
||||||
|
elif level == "DEBUG":
|
||||||
|
verbose_proxy_logger.debug(print_statement)
|
||||||
|
|
||||||
|
if litellm.set_verbose is True:
|
||||||
|
print(print_statement) # noqa
|
||||||
|
|
||||||
|
def test_violation(self, test_str: str):
|
||||||
|
for word in self.banned_keywords_list:
|
||||||
|
if word in test_str.lower():
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={"error": f"Keyword banned. Keyword={word}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_pre_call_hook(
|
||||||
|
self,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
cache: DualCache,
|
||||||
|
data: dict,
|
||||||
|
call_type: str, # "completion", "embeddings", "image_generation", "moderation"
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
"""
|
||||||
|
- check if user id part of call
|
||||||
|
- check if user id part of blocked list
|
||||||
|
"""
|
||||||
|
self.print_verbose(f"Inside Banned Keyword List Pre-Call Hook")
|
||||||
|
if call_type == "completion" and "messages" in data:
|
||||||
|
for m in data["messages"]:
|
||||||
|
if "content" in m and isinstance(m["content"], str):
|
||||||
|
self.test_violation(test_str=m["content"])
|
||||||
|
|
||||||
|
except HTTPException as e:
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
async def async_post_call_success_hook(
|
||||||
|
self,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
response,
|
||||||
|
):
|
||||||
|
if isinstance(response, litellm.ModelResponse) and isinstance(
|
||||||
|
response.choices[0], litellm.utils.Choices
|
||||||
|
):
|
||||||
|
for word in self.banned_keywords_list:
|
||||||
|
self.test_violation(test_str=response.choices[0].message.content)
|
||||||
|
|
||||||
|
async def async_post_call_streaming_hook(
|
||||||
|
self,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
response: str,
|
||||||
|
):
|
||||||
|
self.test_violation(test_str=response)
|
80
enterprise/enterprise_hooks/blocked_user_list.py
Normal file
80
enterprise/enterprise_hooks/blocked_user_list.py
Normal file
|
@ -0,0 +1,80 @@
|
||||||
|
# +------------------------------+
|
||||||
|
#
|
||||||
|
# Blocked User List
|
||||||
|
#
|
||||||
|
# +------------------------------+
|
||||||
|
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||||
|
## This accepts a list of user id's for whom calls will be rejected
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Optional, Literal
|
||||||
|
import litellm
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from fastapi import HTTPException
|
||||||
|
import json, traceback
|
||||||
|
|
||||||
|
|
||||||
|
class _ENTERPRISE_BlockedUserList(CustomLogger):
|
||||||
|
# Class variables or attributes
|
||||||
|
def __init__(self):
|
||||||
|
blocked_user_list = litellm.blocked_user_list
|
||||||
|
|
||||||
|
if blocked_user_list is None:
|
||||||
|
raise Exception(
|
||||||
|
"`blocked_user_list` can either be a list or filepath. None set."
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(blocked_user_list, list):
|
||||||
|
self.blocked_user_list = blocked_user_list
|
||||||
|
|
||||||
|
if isinstance(blocked_user_list, str): # assume it's a filepath
|
||||||
|
try:
|
||||||
|
with open(blocked_user_list, "r") as file:
|
||||||
|
data = file.read()
|
||||||
|
self.blocked_user_list = data.split("\n")
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise Exception(
|
||||||
|
f"File not found. blocked_user_list={blocked_user_list}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(
|
||||||
|
f"An error occurred: {str(e)}, blocked_user_list={blocked_user_list}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def print_verbose(self, print_statement, level: Literal["INFO", "DEBUG"] = "DEBUG"):
|
||||||
|
if level == "INFO":
|
||||||
|
verbose_proxy_logger.info(print_statement)
|
||||||
|
elif level == "DEBUG":
|
||||||
|
verbose_proxy_logger.debug(print_statement)
|
||||||
|
|
||||||
|
if litellm.set_verbose is True:
|
||||||
|
print(print_statement) # noqa
|
||||||
|
|
||||||
|
async def async_pre_call_hook(
|
||||||
|
self,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
cache: DualCache,
|
||||||
|
data: dict,
|
||||||
|
call_type: str,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
"""
|
||||||
|
- check if user id part of call
|
||||||
|
- check if user id part of blocked list
|
||||||
|
"""
|
||||||
|
self.print_verbose(f"Inside Blocked User List Pre-Call Hook")
|
||||||
|
if "user_id" in data:
|
||||||
|
if data["user_id"] in self.blocked_user_list:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": f"User blocked from making LLM API Calls. User={data['user_id']}"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except HTTPException as e:
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
|
@ -60,6 +60,8 @@ llamaguard_model_name: Optional[str] = None
|
||||||
presidio_ad_hoc_recognizers: Optional[str] = None
|
presidio_ad_hoc_recognizers: Optional[str] = None
|
||||||
google_moderation_confidence_threshold: Optional[float] = None
|
google_moderation_confidence_threshold: Optional[float] = None
|
||||||
llamaguard_unsafe_content_categories: Optional[str] = None
|
llamaguard_unsafe_content_categories: Optional[str] = None
|
||||||
|
blocked_user_list: Optional[Union[str, List]] = None
|
||||||
|
banned_keywords_list: Optional[Union[str, List]] = None
|
||||||
##################
|
##################
|
||||||
logging: bool = True
|
logging: bool = True
|
||||||
caching: bool = (
|
caching: bool = (
|
||||||
|
|
|
@ -2,12 +2,11 @@
|
||||||
# On success, logs events to Promptlayer
|
# On success, logs events to Promptlayer
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests
|
import requests
|
||||||
import requests
|
from pydantic import BaseModel
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
class PromptLayerLogger:
|
class PromptLayerLogger:
|
||||||
# Class variables or attributes
|
# Class variables or attributes
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -25,16 +24,30 @@ class PromptLayerLogger:
|
||||||
for optional_param in kwargs["optional_params"]:
|
for optional_param in kwargs["optional_params"]:
|
||||||
new_kwargs[optional_param] = kwargs["optional_params"][optional_param]
|
new_kwargs[optional_param] = kwargs["optional_params"][optional_param]
|
||||||
|
|
||||||
|
# Extract PromptLayer tags from metadata, if such exists
|
||||||
|
tags = []
|
||||||
|
metadata = {}
|
||||||
|
if "metadata" in kwargs["litellm_params"]:
|
||||||
|
if "pl_tags" in kwargs["litellm_params"]["metadata"]:
|
||||||
|
tags = kwargs["litellm_params"]["metadata"]["pl_tags"]
|
||||||
|
|
||||||
|
# Remove "pl_tags" from metadata
|
||||||
|
metadata = {k:v for k, v in kwargs["litellm_params"]["metadata"].items() if k != "pl_tags"}
|
||||||
|
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"Prompt Layer Logging - Enters logging function for model kwargs: {new_kwargs}\n, response: {response_obj}"
|
f"Prompt Layer Logging - Enters logging function for model kwargs: {new_kwargs}\n, response: {response_obj}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# python-openai >= 1.0.0 returns Pydantic objects instead of jsons
|
||||||
|
if isinstance(response_obj, BaseModel):
|
||||||
|
response_obj = response_obj.model_dump()
|
||||||
|
|
||||||
request_response = requests.post(
|
request_response = requests.post(
|
||||||
"https://api.promptlayer.com/rest/track-request",
|
"https://api.promptlayer.com/rest/track-request",
|
||||||
json={
|
json={
|
||||||
"function_name": "openai.ChatCompletion.create",
|
"function_name": "openai.ChatCompletion.create",
|
||||||
"kwargs": new_kwargs,
|
"kwargs": new_kwargs,
|
||||||
"tags": ["hello", "world"],
|
"tags": tags,
|
||||||
"request_response": dict(response_obj),
|
"request_response": dict(response_obj),
|
||||||
"request_start_time": int(start_time.timestamp()),
|
"request_start_time": int(start_time.timestamp()),
|
||||||
"request_end_time": int(end_time.timestamp()),
|
"request_end_time": int(end_time.timestamp()),
|
||||||
|
@ -45,22 +58,23 @@ class PromptLayerLogger:
|
||||||
# "prompt_version":1,
|
# "prompt_version":1,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
response_json = request_response.json()
|
||||||
|
if not request_response.json().get("success", False):
|
||||||
|
raise Exception("Promptlayer did not successfully log the response!")
|
||||||
|
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"Prompt Layer Logging: success - final response object: {request_response.text}"
|
f"Prompt Layer Logging: success - final response object: {request_response.text}"
|
||||||
)
|
)
|
||||||
response_json = request_response.json()
|
|
||||||
if "success" not in request_response.json():
|
|
||||||
raise Exception("Promptlayer did not successfully log the response!")
|
|
||||||
|
|
||||||
if "request_id" in response_json:
|
if "request_id" in response_json:
|
||||||
print(kwargs["litellm_params"]["metadata"])
|
if metadata:
|
||||||
if kwargs["litellm_params"]["metadata"] is not None:
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
"https://api.promptlayer.com/rest/track-metadata",
|
"https://api.promptlayer.com/rest/track-metadata",
|
||||||
json={
|
json={
|
||||||
"request_id": response_json["request_id"],
|
"request_id": response_json["request_id"],
|
||||||
"api_key": self.key,
|
"api_key": self.key,
|
||||||
"metadata": kwargs["litellm_params"]["metadata"],
|
"metadata": metadata,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
print_verbose(
|
print_verbose(
|
||||||
|
|
|
@ -559,8 +559,7 @@ def completion(
|
||||||
f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n"
|
f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n"
|
||||||
)
|
)
|
||||||
response = llm_model.predict(
|
response = llm_model.predict(
|
||||||
endpoint=endpoint_path,
|
endpoint=endpoint_path, instances=instances
|
||||||
instances=instances
|
|
||||||
).predictions
|
).predictions
|
||||||
|
|
||||||
completion_response = response[0]
|
completion_response = response[0]
|
||||||
|
@ -585,12 +584,8 @@ def completion(
|
||||||
"request_str": request_str,
|
"request_str": request_str,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
request_str += (
|
request_str += f"llm_model.predict(instances={instances})\n"
|
||||||
f"llm_model.predict(instances={instances})\n"
|
response = llm_model.predict(instances=instances).predictions
|
||||||
)
|
|
||||||
response = llm_model.predict(
|
|
||||||
instances=instances
|
|
||||||
).predictions
|
|
||||||
|
|
||||||
completion_response = response[0]
|
completion_response = response[0]
|
||||||
if (
|
if (
|
||||||
|
@ -614,7 +609,6 @@ def completion(
|
||||||
model_response["choices"][0]["message"]["content"] = str(
|
model_response["choices"][0]["message"]["content"] = str(
|
||||||
completion_response
|
completion_response
|
||||||
)
|
)
|
||||||
model_response["choices"][0]["message"]["content"] = str(completion_response)
|
|
||||||
model_response["created"] = int(time.time())
|
model_response["created"] = int(time.time())
|
||||||
model_response["model"] = model
|
model_response["model"] = model
|
||||||
## CALCULATING USAGE
|
## CALCULATING USAGE
|
||||||
|
@ -766,6 +760,7 @@ async def async_completion(
|
||||||
Vertex AI Model Garden
|
Vertex AI Model Garden
|
||||||
"""
|
"""
|
||||||
from google.cloud import aiplatform
|
from google.cloud import aiplatform
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=prompt,
|
input=prompt,
|
||||||
|
@ -797,11 +792,9 @@ async def async_completion(
|
||||||
and "\nOutput:\n" in completion_response
|
and "\nOutput:\n" in completion_response
|
||||||
):
|
):
|
||||||
completion_response = completion_response.split("\nOutput:\n", 1)[1]
|
completion_response = completion_response.split("\nOutput:\n", 1)[1]
|
||||||
|
|
||||||
elif mode == "private":
|
elif mode == "private":
|
||||||
request_str += (
|
request_str += f"llm_model.predict_async(instances={instances})\n"
|
||||||
f"llm_model.predict_async(instances={instances})\n"
|
|
||||||
)
|
|
||||||
response_obj = await llm_model.predict_async(
|
response_obj = await llm_model.predict_async(
|
||||||
instances=instances,
|
instances=instances,
|
||||||
)
|
)
|
||||||
|
@ -826,7 +819,6 @@ async def async_completion(
|
||||||
model_response["choices"][0]["message"]["content"] = str(
|
model_response["choices"][0]["message"]["content"] = str(
|
||||||
completion_response
|
completion_response
|
||||||
)
|
)
|
||||||
model_response["choices"][0]["message"]["content"] = str(completion_response)
|
|
||||||
model_response["created"] = int(time.time())
|
model_response["created"] = int(time.time())
|
||||||
model_response["model"] = model
|
model_response["model"] = model
|
||||||
## CALCULATING USAGE
|
## CALCULATING USAGE
|
||||||
|
@ -954,6 +946,7 @@ async def async_streaming(
|
||||||
response = llm_model.predict_streaming_async(prompt, **optional_params)
|
response = llm_model.predict_streaming_async(prompt, **optional_params)
|
||||||
elif mode == "custom":
|
elif mode == "custom":
|
||||||
from google.cloud import aiplatform
|
from google.cloud import aiplatform
|
||||||
|
|
||||||
stream = optional_params.pop("stream", None)
|
stream = optional_params.pop("stream", None)
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
@ -972,7 +965,9 @@ async def async_streaming(
|
||||||
endpoint_path = llm_model.endpoint_path(
|
endpoint_path = llm_model.endpoint_path(
|
||||||
project=vertex_project, location=vertex_location, endpoint=model
|
project=vertex_project, location=vertex_location, endpoint=model
|
||||||
)
|
)
|
||||||
request_str += f"client.predict(endpoint={endpoint_path}, instances={instances})\n"
|
request_str += (
|
||||||
|
f"client.predict(endpoint={endpoint_path}, instances={instances})\n"
|
||||||
|
)
|
||||||
response_obj = await llm_model.predict(
|
response_obj = await llm_model.predict(
|
||||||
endpoint=endpoint_path,
|
endpoint=endpoint_path,
|
||||||
instances=instances,
|
instances=instances,
|
||||||
|
|
|
@ -12,7 +12,6 @@ from typing import Any, Literal, Union
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import dotenv, traceback, random, asyncio, time, contextvars
|
import dotenv, traceback, random, asyncio, time, contextvars
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import litellm
|
import litellm
|
||||||
from ._logging import verbose_logger
|
from ._logging import verbose_logger
|
||||||
|
|
|
@ -424,6 +424,10 @@ class LiteLLM_VerificationToken(LiteLLMBase):
|
||||||
model_spend: Dict = {}
|
model_spend: Dict = {}
|
||||||
model_max_budget: Dict = {}
|
model_max_budget: Dict = {}
|
||||||
|
|
||||||
|
# hidden params used for parallel request limiting, not required to create a token
|
||||||
|
user_id_rate_limits: Optional[dict] = None
|
||||||
|
team_id_rate_limits: Optional[dict] = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
protected_namespaces = ()
|
protected_namespaces = ()
|
||||||
|
|
||||||
|
|
BIN
litellm/proxy/cached_logo.jpg
Normal file
BIN
litellm/proxy/cached_logo.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 16 KiB |
|
@ -24,46 +24,21 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def async_pre_call_hook(
|
async def check_key_in_limits(
|
||||||
self,
|
self,
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
cache: DualCache,
|
cache: DualCache,
|
||||||
data: dict,
|
data: dict,
|
||||||
call_type: str,
|
call_type: str,
|
||||||
|
max_parallel_requests: int,
|
||||||
|
tpm_limit: int,
|
||||||
|
rpm_limit: int,
|
||||||
|
request_count_api_key: str,
|
||||||
):
|
):
|
||||||
self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook")
|
|
||||||
api_key = user_api_key_dict.api_key
|
|
||||||
max_parallel_requests = user_api_key_dict.max_parallel_requests or sys.maxsize
|
|
||||||
tpm_limit = user_api_key_dict.tpm_limit or sys.maxsize
|
|
||||||
rpm_limit = user_api_key_dict.rpm_limit or sys.maxsize
|
|
||||||
|
|
||||||
if api_key is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
if (
|
|
||||||
max_parallel_requests == sys.maxsize
|
|
||||||
and tpm_limit == sys.maxsize
|
|
||||||
and rpm_limit == sys.maxsize
|
|
||||||
):
|
|
||||||
return
|
|
||||||
|
|
||||||
self.user_api_key_cache = cache # save the api key cache for updating the value
|
|
||||||
# ------------
|
|
||||||
# Setup values
|
|
||||||
# ------------
|
|
||||||
|
|
||||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
|
||||||
current_hour = datetime.now().strftime("%H")
|
|
||||||
current_minute = datetime.now().strftime("%M")
|
|
||||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
|
||||||
|
|
||||||
request_count_api_key = f"{api_key}::{precise_minute}::request_count"
|
|
||||||
|
|
||||||
# CHECK IF REQUEST ALLOWED
|
|
||||||
current = cache.get_cache(
|
current = cache.get_cache(
|
||||||
key=request_count_api_key
|
key=request_count_api_key
|
||||||
) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
|
) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
|
||||||
self.print_verbose(f"current: {current}")
|
# print(f"current: {current}")
|
||||||
if current is None:
|
if current is None:
|
||||||
new_val = {
|
new_val = {
|
||||||
"current_requests": 1,
|
"current_requests": 1,
|
||||||
|
@ -88,10 +63,107 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
status_code=429, detail="Max parallel request limit reached."
|
status_code=429, detail="Max parallel request limit reached."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def async_pre_call_hook(
|
||||||
|
self,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
cache: DualCache,
|
||||||
|
data: dict,
|
||||||
|
call_type: str,
|
||||||
|
):
|
||||||
|
self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook")
|
||||||
|
api_key = user_api_key_dict.api_key
|
||||||
|
max_parallel_requests = user_api_key_dict.max_parallel_requests or sys.maxsize
|
||||||
|
tpm_limit = user_api_key_dict.tpm_limit or sys.maxsize
|
||||||
|
rpm_limit = user_api_key_dict.rpm_limit or sys.maxsize
|
||||||
|
|
||||||
|
if api_key is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.user_api_key_cache = cache # save the api key cache for updating the value
|
||||||
|
# ------------
|
||||||
|
# Setup values
|
||||||
|
# ------------
|
||||||
|
|
||||||
|
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||||
|
current_hour = datetime.now().strftime("%H")
|
||||||
|
current_minute = datetime.now().strftime("%M")
|
||||||
|
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||||
|
|
||||||
|
request_count_api_key = f"{api_key}::{precise_minute}::request_count"
|
||||||
|
|
||||||
|
# CHECK IF REQUEST ALLOWED for key
|
||||||
|
current = cache.get_cache(
|
||||||
|
key=request_count_api_key
|
||||||
|
) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
|
||||||
|
self.print_verbose(f"current: {current}")
|
||||||
|
if (
|
||||||
|
max_parallel_requests == sys.maxsize
|
||||||
|
and tpm_limit == sys.maxsize
|
||||||
|
and rpm_limit == sys.maxsize
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
elif current is None:
|
||||||
|
new_val = {
|
||||||
|
"current_requests": 1,
|
||||||
|
"current_tpm": 0,
|
||||||
|
"current_rpm": 0,
|
||||||
|
}
|
||||||
|
cache.set_cache(request_count_api_key, new_val)
|
||||||
|
elif (
|
||||||
|
int(current["current_requests"]) < max_parallel_requests
|
||||||
|
and current["current_tpm"] < tpm_limit
|
||||||
|
and current["current_rpm"] < rpm_limit
|
||||||
|
):
|
||||||
|
# Increase count for this token
|
||||||
|
new_val = {
|
||||||
|
"current_requests": current["current_requests"] + 1,
|
||||||
|
"current_tpm": current["current_tpm"],
|
||||||
|
"current_rpm": current["current_rpm"],
|
||||||
|
}
|
||||||
|
cache.set_cache(request_count_api_key, new_val)
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429, detail="Max parallel request limit reached."
|
||||||
|
)
|
||||||
|
|
||||||
|
# check if REQUEST ALLOWED for user_id
|
||||||
|
user_id = user_api_key_dict.user_id
|
||||||
|
_user_id_rate_limits = user_api_key_dict.user_id_rate_limits
|
||||||
|
|
||||||
|
# get user tpm/rpm limits
|
||||||
|
if _user_id_rate_limits is None or _user_id_rate_limits == {}:
|
||||||
|
return
|
||||||
|
user_tpm_limit = _user_id_rate_limits.get("tpm_limit")
|
||||||
|
user_rpm_limit = _user_id_rate_limits.get("rpm_limit")
|
||||||
|
if user_tpm_limit is None:
|
||||||
|
user_tpm_limit = sys.maxsize
|
||||||
|
if user_rpm_limit is None:
|
||||||
|
user_rpm_limit = sys.maxsize
|
||||||
|
|
||||||
|
# now do the same tpm/rpm checks
|
||||||
|
request_count_api_key = f"{user_id}::{precise_minute}::request_count"
|
||||||
|
|
||||||
|
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
|
||||||
|
await self.check_key_in_limits(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
cache=cache,
|
||||||
|
data=data,
|
||||||
|
call_type=call_type,
|
||||||
|
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user
|
||||||
|
request_count_api_key=request_count_api_key,
|
||||||
|
tpm_limit=user_tpm_limit,
|
||||||
|
rpm_limit=user_rpm_limit,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
try:
|
try:
|
||||||
self.print_verbose(f"INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
|
self.print_verbose(f"INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
|
||||||
user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"]
|
user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"]
|
||||||
|
user_api_key_user_id = kwargs["litellm_params"]["metadata"].get(
|
||||||
|
"user_api_key_user_id", None
|
||||||
|
)
|
||||||
|
|
||||||
if user_api_key is None:
|
if user_api_key is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -121,7 +193,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
}
|
}
|
||||||
|
|
||||||
# ------------
|
# ------------
|
||||||
# Update usage
|
# Update usage - API Key
|
||||||
# ------------
|
# ------------
|
||||||
|
|
||||||
new_val = {
|
new_val = {
|
||||||
|
@ -136,6 +208,41 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
self.user_api_key_cache.set_cache(
|
self.user_api_key_cache.set_cache(
|
||||||
request_count_api_key, new_val, ttl=60
|
request_count_api_key, new_val, ttl=60
|
||||||
) # store in cache for 1 min.
|
) # store in cache for 1 min.
|
||||||
|
|
||||||
|
# ------------
|
||||||
|
# Update usage - User
|
||||||
|
# ------------
|
||||||
|
if user_api_key_user_id is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
total_tokens = 0
|
||||||
|
|
||||||
|
if isinstance(response_obj, ModelResponse):
|
||||||
|
total_tokens = response_obj.usage.total_tokens
|
||||||
|
|
||||||
|
request_count_api_key = (
|
||||||
|
f"{user_api_key_user_id}::{precise_minute}::request_count"
|
||||||
|
)
|
||||||
|
|
||||||
|
current = self.user_api_key_cache.get_cache(key=request_count_api_key) or {
|
||||||
|
"current_requests": 1,
|
||||||
|
"current_tpm": total_tokens,
|
||||||
|
"current_rpm": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
new_val = {
|
||||||
|
"current_requests": max(current["current_requests"] - 1, 0),
|
||||||
|
"current_tpm": current["current_tpm"] + total_tokens,
|
||||||
|
"current_rpm": current["current_rpm"] + 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.print_verbose(
|
||||||
|
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
|
||||||
|
)
|
||||||
|
self.user_api_key_cache.set_cache(
|
||||||
|
request_count_api_key, new_val, ttl=60
|
||||||
|
) # store in cache for 1 min.
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.print_verbose(e) # noqa
|
self.print_verbose(e) # noqa
|
||||||
|
|
||||||
|
|
|
@ -1479,6 +1479,26 @@ class ProxyConfig:
|
||||||
|
|
||||||
llm_guard_moderation_obj = _ENTERPRISE_LLMGuard()
|
llm_guard_moderation_obj = _ENTERPRISE_LLMGuard()
|
||||||
imported_list.append(llm_guard_moderation_obj)
|
imported_list.append(llm_guard_moderation_obj)
|
||||||
|
elif (
|
||||||
|
isinstance(callback, str)
|
||||||
|
and callback == "blocked_user_check"
|
||||||
|
):
|
||||||
|
from litellm.proxy.enterprise.enterprise_hooks.blocked_user_list import (
|
||||||
|
_ENTERPRISE_BlockedUserList,
|
||||||
|
)
|
||||||
|
|
||||||
|
blocked_user_list = _ENTERPRISE_BlockedUserList()
|
||||||
|
imported_list.append(blocked_user_list)
|
||||||
|
elif (
|
||||||
|
isinstance(callback, str)
|
||||||
|
and callback == "banned_keywords"
|
||||||
|
):
|
||||||
|
from litellm.proxy.enterprise.enterprise_hooks.banned_keywords import (
|
||||||
|
_ENTERPRISE_BannedKeywords,
|
||||||
|
)
|
||||||
|
|
||||||
|
banned_keywords_obj = _ENTERPRISE_BannedKeywords()
|
||||||
|
imported_list.append(banned_keywords_obj)
|
||||||
else:
|
else:
|
||||||
imported_list.append(
|
imported_list.append(
|
||||||
get_instance_fn(
|
get_instance_fn(
|
||||||
|
@ -4368,7 +4388,20 @@ async def update_team(
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
add new members to the team
|
You can now add / delete users from a team via /team/update
|
||||||
|
|
||||||
|
```
|
||||||
|
curl --location 'http://0.0.0.0:8000/team/update' \
|
||||||
|
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
|
||||||
|
--data-raw '{
|
||||||
|
"team_id": "45e3e396-ee08-4a61-a88e-16b3ce7e0849",
|
||||||
|
"members_with_roles": [{"role": "admin", "user_id": "5c4a0aa3-a1e1-43dc-bd87-3c2da8382a3a"}, {"role": "user", "user_id": "krrish247652@berri.ai"}]
|
||||||
|
}'
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
global prisma_client
|
global prisma_client
|
||||||
|
|
||||||
|
@ -4449,6 +4482,18 @@ async def delete_team(
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
delete team and associated team keys
|
delete team and associated team keys
|
||||||
|
|
||||||
|
```
|
||||||
|
curl --location 'http://0.0.0.0:8000/team/delete' \
|
||||||
|
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
|
||||||
|
--data-raw '{
|
||||||
|
"team_ids": ["45e3e396-ee08-4a61-a88e-16b3ce7e0849"]
|
||||||
|
}'
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
global prisma_client
|
global prisma_client
|
||||||
|
|
||||||
|
@ -5097,7 +5142,15 @@ async def google_login(request: Request):
|
||||||
scope=generic_scope,
|
scope=generic_scope,
|
||||||
)
|
)
|
||||||
with generic_sso:
|
with generic_sso:
|
||||||
return await generic_sso.get_login_redirect()
|
# TODO: state should be a random string and added to the user session with cookie
|
||||||
|
# or a cryptographicly signed state that we can verify stateless
|
||||||
|
# For simplification we are using a static state, this is not perfect but some
|
||||||
|
# SSO providers do not allow stateless verification
|
||||||
|
redirect_params = {}
|
||||||
|
state = os.getenv("GENERIC_CLIENT_STATE", None)
|
||||||
|
if state:
|
||||||
|
redirect_params["state"] = state
|
||||||
|
return await generic_sso.get_login_redirect(**redirect_params) # type: ignore
|
||||||
elif ui_username is not None:
|
elif ui_username is not None:
|
||||||
# No Google, Microsoft SSO
|
# No Google, Microsoft SSO
|
||||||
# Use UI Credentials set in .env
|
# Use UI Credentials set in .env
|
||||||
|
@ -5203,7 +5256,25 @@ def get_image():
|
||||||
|
|
||||||
logo_path = os.getenv("UI_LOGO_PATH", default_logo)
|
logo_path = os.getenv("UI_LOGO_PATH", default_logo)
|
||||||
verbose_proxy_logger.debug(f"Reading logo from {logo_path}")
|
verbose_proxy_logger.debug(f"Reading logo from {logo_path}")
|
||||||
return FileResponse(path=logo_path)
|
|
||||||
|
# Check if the logo path is an HTTP/HTTPS URL
|
||||||
|
if logo_path.startswith(("http://", "https://")):
|
||||||
|
# Download the image and cache it
|
||||||
|
response = requests.get(logo_path)
|
||||||
|
if response.status_code == 200:
|
||||||
|
# Save the image to a local file
|
||||||
|
cache_path = os.path.join(current_dir, "cached_logo.jpg")
|
||||||
|
with open(cache_path, "wb") as f:
|
||||||
|
f.write(response.content)
|
||||||
|
|
||||||
|
# Return the cached image as a FileResponse
|
||||||
|
return FileResponse(cache_path, media_type="image/jpeg")
|
||||||
|
else:
|
||||||
|
# Handle the case when the image cannot be downloaded
|
||||||
|
return FileResponse(default_logo, media_type="image/jpeg")
|
||||||
|
else:
|
||||||
|
# Return the local image file if the logo path is not an HTTP/HTTPS URL
|
||||||
|
return FileResponse(logo_path, media_type="image/jpeg")
|
||||||
|
|
||||||
|
|
||||||
@app.get("/sso/callback", tags=["experimental"])
|
@app.get("/sso/callback", tags=["experimental"])
|
||||||
|
@ -5265,7 +5336,7 @@ async def auth_callback(request: Request):
|
||||||
result = await microsoft_sso.verify_and_process(request)
|
result = await microsoft_sso.verify_and_process(request)
|
||||||
elif generic_client_id is not None:
|
elif generic_client_id is not None:
|
||||||
# make generic sso provider
|
# make generic sso provider
|
||||||
from fastapi_sso.sso.generic import create_provider, DiscoveryDocument
|
from fastapi_sso.sso.generic import create_provider, DiscoveryDocument, OpenID
|
||||||
|
|
||||||
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
||||||
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
|
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
|
||||||
|
@ -5274,6 +5345,9 @@ async def auth_callback(request: Request):
|
||||||
)
|
)
|
||||||
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
|
generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None)
|
||||||
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
|
generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None)
|
||||||
|
generic_include_client_id = (
|
||||||
|
os.getenv("GENERIC_INCLUDE_CLIENT_ID", "false").lower() == "true"
|
||||||
|
)
|
||||||
if generic_client_secret is None:
|
if generic_client_secret is None:
|
||||||
raise ProxyException(
|
raise ProxyException(
|
||||||
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
|
message="GENERIC_CLIENT_SECRET not set. Set it in .env file",
|
||||||
|
@ -5308,12 +5382,50 @@ async def auth_callback(request: Request):
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
|
f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
generic_user_id_attribute_name = os.getenv(
|
||||||
|
"GENERIC_USER_ID_ATTRIBUTE", "preferred_username"
|
||||||
|
)
|
||||||
|
generic_user_display_name_attribute_name = os.getenv(
|
||||||
|
"GENERIC_USER_DISPLAY_NAME_ATTRIBUTE", "sub"
|
||||||
|
)
|
||||||
|
generic_user_email_attribute_name = os.getenv(
|
||||||
|
"GENERIC_USER_EMAIL_ATTRIBUTE", "email"
|
||||||
|
)
|
||||||
|
generic_user_role_attribute_name = os.getenv(
|
||||||
|
"GENERIC_USER_ROLE_ATTRIBUTE", "role"
|
||||||
|
)
|
||||||
|
generic_user_first_name_attribute_name = os.getenv(
|
||||||
|
"GENERIC_USER_FIRST_NAME_ATTRIBUTE", "first_name"
|
||||||
|
)
|
||||||
|
generic_user_last_name_attribute_name = os.getenv(
|
||||||
|
"GENERIC_USER_LAST_NAME_ATTRIBUTE", "last_name"
|
||||||
|
)
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}\n generic_user_role_attribute_name: {generic_user_role_attribute_name}"
|
||||||
|
)
|
||||||
|
|
||||||
discovery = DiscoveryDocument(
|
discovery = DiscoveryDocument(
|
||||||
authorization_endpoint=generic_authorization_endpoint,
|
authorization_endpoint=generic_authorization_endpoint,
|
||||||
token_endpoint=generic_token_endpoint,
|
token_endpoint=generic_token_endpoint,
|
||||||
userinfo_endpoint=generic_userinfo_endpoint,
|
userinfo_endpoint=generic_userinfo_endpoint,
|
||||||
)
|
)
|
||||||
SSOProvider = create_provider(name="oidc", discovery_document=discovery)
|
|
||||||
|
def response_convertor(response, client):
|
||||||
|
return OpenID(
|
||||||
|
id=response.get(generic_user_id_attribute_name),
|
||||||
|
display_name=response.get(generic_user_display_name_attribute_name),
|
||||||
|
email=response.get(generic_user_email_attribute_name),
|
||||||
|
first_name=response.get(generic_user_first_name_attribute_name),
|
||||||
|
last_name=response.get(generic_user_last_name_attribute_name),
|
||||||
|
)
|
||||||
|
|
||||||
|
SSOProvider = create_provider(
|
||||||
|
name="oidc",
|
||||||
|
discovery_document=discovery,
|
||||||
|
response_convertor=response_convertor,
|
||||||
|
)
|
||||||
generic_sso = SSOProvider(
|
generic_sso = SSOProvider(
|
||||||
client_id=generic_client_id,
|
client_id=generic_client_id,
|
||||||
client_secret=generic_client_secret,
|
client_secret=generic_client_secret,
|
||||||
|
@ -5322,43 +5434,36 @@ async def auth_callback(request: Request):
|
||||||
scope=generic_scope,
|
scope=generic_scope,
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.debug(f"calling generic_sso.verify_and_process")
|
verbose_proxy_logger.debug(f"calling generic_sso.verify_and_process")
|
||||||
request_body = await request.body()
|
result = await generic_sso.verify_and_process(
|
||||||
request_query_params = request.query_params
|
request, params={"include_client_id": generic_include_client_id}
|
||||||
# get "code" from query params
|
)
|
||||||
code = request_query_params.get("code")
|
|
||||||
result = await generic_sso.verify_and_process(request)
|
|
||||||
verbose_proxy_logger.debug(f"generic result: {result}")
|
verbose_proxy_logger.debug(f"generic result: {result}")
|
||||||
|
|
||||||
# User is Authe'd in - generate key for the UI to access Proxy
|
# User is Authe'd in - generate key for the UI to access Proxy
|
||||||
user_email = getattr(result, "email", None)
|
user_email = getattr(result, "email", None)
|
||||||
user_id = getattr(result, "id", None)
|
user_id = getattr(result, "id", None)
|
||||||
|
|
||||||
# generic client id
|
# generic client id
|
||||||
if generic_client_id is not None:
|
if generic_client_id is not None:
|
||||||
generic_user_id_attribute_name = os.getenv("GENERIC_USER_ID_ATTRIBUTE", "email")
|
user_id = getattr(result, "id", None)
|
||||||
generic_user_email_attribute_name = os.getenv(
|
user_email = getattr(result, "email", None)
|
||||||
"GENERIC_USER_EMAIL_ATTRIBUTE", "email"
|
|
||||||
)
|
|
||||||
generic_user_role_attribute_name = os.getenv(
|
|
||||||
"GENERIC_USER_ROLE_ATTRIBUTE", "role"
|
|
||||||
)
|
|
||||||
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}\n generic_user_role_attribute_name: {generic_user_role_attribute_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
user_id = getattr(result, generic_user_id_attribute_name, None)
|
|
||||||
user_email = getattr(result, generic_user_email_attribute_name, None)
|
|
||||||
user_role = getattr(result, generic_user_role_attribute_name, None)
|
user_role = getattr(result, generic_user_role_attribute_name, None)
|
||||||
|
|
||||||
if user_id is None:
|
if user_id is None:
|
||||||
user_id = getattr(result, "first_name", "") + getattr(result, "last_name", "")
|
user_id = getattr(result, "first_name", "") + getattr(result, "last_name", "")
|
||||||
# get user_info from litellm DB
|
|
||||||
user_info = None
|
user_info = None
|
||||||
if prisma_client is not None:
|
|
||||||
user_info = await prisma_client.get_data(user_id=user_id, table_name="user")
|
|
||||||
user_id_models: List = []
|
user_id_models: List = []
|
||||||
if user_info is not None:
|
|
||||||
user_id_models = getattr(user_info, "models", [])
|
# User might not be already created on first generation of key
|
||||||
|
# But if it is, we want its models preferences
|
||||||
|
try:
|
||||||
|
if prisma_client is not None:
|
||||||
|
user_info = await prisma_client.get_data(user_id=user_id, table_name="user")
|
||||||
|
if user_info is not None:
|
||||||
|
user_id_models = getattr(user_info, "models", [])
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
response = await generate_key_helper_fn(
|
response = await generate_key_helper_fn(
|
||||||
**{
|
**{
|
||||||
|
|
|
@ -318,7 +318,7 @@ def test_gemini_pro_vision():
|
||||||
# test_gemini_pro_vision()
|
# test_gemini_pro_vision()
|
||||||
|
|
||||||
|
|
||||||
def gemini_pro_function_calling():
|
def test_gemini_pro_function_calling():
|
||||||
load_vertex_ai_credentials()
|
load_vertex_ai_credentials()
|
||||||
tools = [
|
tools = [
|
||||||
{
|
{
|
||||||
|
@ -345,12 +345,15 @@ def gemini_pro_function_calling():
|
||||||
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
|
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
|
||||||
)
|
)
|
||||||
print(f"completion: {completion}")
|
print(f"completion: {completion}")
|
||||||
|
assert completion.choices[0].message.content is None
|
||||||
|
assert len(completion.choices[0].message.tool_calls) == 1
|
||||||
|
|
||||||
|
|
||||||
# gemini_pro_function_calling()
|
# gemini_pro_function_calling()
|
||||||
|
|
||||||
|
|
||||||
async def gemini_pro_async_function_calling():
|
@pytest.mark.asyncio
|
||||||
|
async def test_gemini_pro_async_function_calling():
|
||||||
load_vertex_ai_credentials()
|
load_vertex_ai_credentials()
|
||||||
tools = [
|
tools = [
|
||||||
{
|
{
|
||||||
|
@ -377,6 +380,9 @@ async def gemini_pro_async_function_calling():
|
||||||
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
|
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
|
||||||
)
|
)
|
||||||
print(f"completion: {completion}")
|
print(f"completion: {completion}")
|
||||||
|
assert completion.choices[0].message.content is None
|
||||||
|
assert len(completion.choices[0].message.tool_calls) == 1
|
||||||
|
# raise Exception("it worked!")
|
||||||
|
|
||||||
|
|
||||||
# asyncio.run(gemini_pro_async_function_calling())
|
# asyncio.run(gemini_pro_async_function_calling())
|
||||||
|
|
63
litellm/tests/test_banned_keyword_list.py
Normal file
63
litellm/tests/test_banned_keyword_list.py
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
# What is this?
|
||||||
|
## This tests the blocked user pre call hook for the proxy server
|
||||||
|
|
||||||
|
|
||||||
|
import sys, os, asyncio, time, random
|
||||||
|
from datetime import datetime
|
||||||
|
import traceback
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import pytest
|
||||||
|
import litellm
|
||||||
|
from litellm.proxy.enterprise.enterprise_hooks.banned_keywords import (
|
||||||
|
_ENTERPRISE_BannedKeywords,
|
||||||
|
)
|
||||||
|
from litellm import Router, mock_completion
|
||||||
|
from litellm.proxy.utils import ProxyLogging
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_banned_keywords_check():
|
||||||
|
"""
|
||||||
|
- Set some banned keywords as a litellm module value
|
||||||
|
- Test to see if a call with banned keywords is made, an error is raised
|
||||||
|
- Test to see if a call without banned keywords is made it passes
|
||||||
|
"""
|
||||||
|
litellm.banned_keywords_list = ["hello"]
|
||||||
|
|
||||||
|
banned_keywords_obj = _ENTERPRISE_BannedKeywords()
|
||||||
|
|
||||||
|
_api_key = "sk-12345"
|
||||||
|
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
|
||||||
|
local_cache = DualCache()
|
||||||
|
|
||||||
|
## Case 1: blocked user id passed
|
||||||
|
try:
|
||||||
|
await banned_keywords_obj.async_pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
cache=local_cache,
|
||||||
|
call_type="completion",
|
||||||
|
data={"messages": [{"role": "user", "content": "Hello world"}]},
|
||||||
|
)
|
||||||
|
pytest.fail(f"Expected call to fail")
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
## Case 2: normal user id passed
|
||||||
|
try:
|
||||||
|
await banned_keywords_obj.async_pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
cache=local_cache,
|
||||||
|
call_type="completion",
|
||||||
|
data={"messages": [{"role": "user", "content": "Hey, how's it going?"}]},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An error occurred - {str(e)}")
|
63
litellm/tests/test_blocked_user_list.py
Normal file
63
litellm/tests/test_blocked_user_list.py
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
# What is this?
|
||||||
|
## This tests the blocked user pre call hook for the proxy server
|
||||||
|
|
||||||
|
|
||||||
|
import sys, os, asyncio, time, random
|
||||||
|
from datetime import datetime
|
||||||
|
import traceback
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import pytest
|
||||||
|
import litellm
|
||||||
|
from litellm.proxy.enterprise.enterprise_hooks.blocked_user_list import (
|
||||||
|
_ENTERPRISE_BlockedUserList,
|
||||||
|
)
|
||||||
|
from litellm import Router, mock_completion
|
||||||
|
from litellm.proxy.utils import ProxyLogging
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_block_user_check():
|
||||||
|
"""
|
||||||
|
- Set a blocked user as a litellm module value
|
||||||
|
- Test to see if a call with that user id is made, an error is raised
|
||||||
|
- Test to see if a call without that user is passes
|
||||||
|
"""
|
||||||
|
litellm.blocked_user_list = ["user_id_1"]
|
||||||
|
|
||||||
|
blocked_user_obj = _ENTERPRISE_BlockedUserList()
|
||||||
|
|
||||||
|
_api_key = "sk-12345"
|
||||||
|
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key)
|
||||||
|
local_cache = DualCache()
|
||||||
|
|
||||||
|
## Case 1: blocked user id passed
|
||||||
|
try:
|
||||||
|
await blocked_user_obj.async_pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
cache=local_cache,
|
||||||
|
call_type="completion",
|
||||||
|
data={"user_id": "user_id_1"},
|
||||||
|
)
|
||||||
|
pytest.fail(f"Expected call to fail")
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
## Case 2: normal user id passed
|
||||||
|
try:
|
||||||
|
await blocked_user_obj.async_pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
cache=local_cache,
|
||||||
|
call_type="completion",
|
||||||
|
data={"user_id": "user_id_2"},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An error occurred - {str(e)}")
|
|
@ -139,6 +139,56 @@ async def test_pre_call_hook_tpm_limits():
|
||||||
assert e.status_code == 429
|
assert e.status_code == 429
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pre_call_hook_user_tpm_limits():
|
||||||
|
"""
|
||||||
|
Test if error raised on hitting tpm limits
|
||||||
|
"""
|
||||||
|
# create user with tpm/rpm limits
|
||||||
|
|
||||||
|
_api_key = "sk-12345"
|
||||||
|
user_api_key_dict = UserAPIKeyAuth(
|
||||||
|
api_key=_api_key,
|
||||||
|
user_id="ishaan",
|
||||||
|
user_id_rate_limits={"tpm_limit": 9, "rpm_limit": 10},
|
||||||
|
)
|
||||||
|
res = dict(user_api_key_dict)
|
||||||
|
print("dict user", res)
|
||||||
|
local_cache = DualCache()
|
||||||
|
parallel_request_handler = MaxParallelRequestsHandler()
|
||||||
|
|
||||||
|
await parallel_request_handler.async_pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
||||||
|
)
|
||||||
|
|
||||||
|
kwargs = {
|
||||||
|
"litellm_params": {
|
||||||
|
"metadata": {"user_api_key_user_id": "ishaan", "user_api_key": "gm"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await parallel_request_handler.async_log_success_event(
|
||||||
|
kwargs=kwargs,
|
||||||
|
response_obj=litellm.ModelResponse(usage=litellm.Usage(total_tokens=10)),
|
||||||
|
start_time="",
|
||||||
|
end_time="",
|
||||||
|
)
|
||||||
|
|
||||||
|
## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1}
|
||||||
|
|
||||||
|
try:
|
||||||
|
await parallel_request_handler.async_pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
cache=local_cache,
|
||||||
|
data={},
|
||||||
|
call_type="",
|
||||||
|
)
|
||||||
|
|
||||||
|
pytest.fail(f"Expected call to fail")
|
||||||
|
except Exception as e:
|
||||||
|
assert e.status_code == 429
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_success_call_hook():
|
async def test_success_call_hook():
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -7,10 +7,9 @@ sys.path.insert(0, os.path.abspath("../.."))
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
litellm.success_callback = ["promptlayer"]
|
import pytest
|
||||||
litellm.set_verbose = True
|
|
||||||
import time
|
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
# def test_promptlayer_logging():
|
# def test_promptlayer_logging():
|
||||||
# try:
|
# try:
|
||||||
|
@ -39,11 +38,16 @@ import time
|
||||||
# test_promptlayer_logging()
|
# test_promptlayer_logging()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(
|
||||||
|
reason="this works locally but fails on ci/cd since ci/cd is not reading the stdout correctly"
|
||||||
|
)
|
||||||
def test_promptlayer_logging_with_metadata():
|
def test_promptlayer_logging_with_metadata():
|
||||||
try:
|
try:
|
||||||
# Redirect stdout
|
# Redirect stdout
|
||||||
old_stdout = sys.stdout
|
old_stdout = sys.stdout
|
||||||
sys.stdout = new_stdout = io.StringIO()
|
sys.stdout = new_stdout = io.StringIO()
|
||||||
|
litellm.set_verbose = True
|
||||||
|
litellm.success_callback = ["promptlayer"]
|
||||||
|
|
||||||
response = completion(
|
response = completion(
|
||||||
model="gpt-3.5-turbo",
|
model="gpt-3.5-turbo",
|
||||||
|
@ -58,15 +62,43 @@ def test_promptlayer_logging_with_metadata():
|
||||||
sys.stdout = old_stdout
|
sys.stdout = old_stdout
|
||||||
output = new_stdout.getvalue().strip()
|
output = new_stdout.getvalue().strip()
|
||||||
print(output)
|
print(output)
|
||||||
if "LiteLLM: Prompt Layer Logging: success" not in output:
|
|
||||||
raise Exception("Required log message not found!")
|
assert "Prompt Layer Logging: success" in output
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
# test_promptlayer_logging_with_metadata()
|
@pytest.mark.skip(
|
||||||
|
reason="this works locally but fails on ci/cd since ci/cd is not reading the stdout correctly"
|
||||||
|
)
|
||||||
|
def test_promptlayer_logging_with_metadata_tags():
|
||||||
|
try:
|
||||||
|
# Redirect stdout
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
litellm.success_callback = ["promptlayer"]
|
||||||
|
old_stdout = sys.stdout
|
||||||
|
sys.stdout = new_stdout = io.StringIO()
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hi 👋 - i'm ai21"}],
|
||||||
|
temperature=0.2,
|
||||||
|
max_tokens=20,
|
||||||
|
metadata={"model": "ai21", "pl_tags": ["env:dev"]},
|
||||||
|
mock_response="this is a mock response",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Restore stdout
|
||||||
|
time.sleep(1)
|
||||||
|
sys.stdout = old_stdout
|
||||||
|
output = new_stdout.getvalue().strip()
|
||||||
|
print(output)
|
||||||
|
|
||||||
|
assert "Prompt Layer Logging: success" in output
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
# def test_chat_openai():
|
# def test_chat_openai():
|
||||||
# try:
|
# try:
|
||||||
|
|
|
@ -393,6 +393,8 @@ def test_completion_palm_stream():
|
||||||
if complete_response.strip() == "":
|
if complete_response.strip() == "":
|
||||||
raise Exception("Empty response received")
|
raise Exception("Empty response received")
|
||||||
print(f"completion_response: {complete_response}")
|
print(f"completion_response: {complete_response}")
|
||||||
|
except litellm.Timeout as e:
|
||||||
|
pass
|
||||||
except litellm.APIError as e:
|
except litellm.APIError as e:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -4277,8 +4277,8 @@ def get_optional_params(
|
||||||
optional_params["stop_sequences"] = stop
|
optional_params["stop_sequences"] = stop
|
||||||
if max_tokens is not None:
|
if max_tokens is not None:
|
||||||
optional_params["max_output_tokens"] = max_tokens
|
optional_params["max_output_tokens"] = max_tokens
|
||||||
elif custom_llm_provider == "vertex_ai" and model in (
|
elif custom_llm_provider == "vertex_ai" and (
|
||||||
litellm.vertex_chat_models
|
model in litellm.vertex_chat_models
|
||||||
or model in litellm.vertex_code_chat_models
|
or model in litellm.vertex_code_chat_models
|
||||||
or model in litellm.vertex_text_models
|
or model in litellm.vertex_text_models
|
||||||
or model in litellm.vertex_code_text_models
|
or model in litellm.vertex_code_text_models
|
||||||
|
@ -6827,6 +6827,14 @@ def exception_type(
|
||||||
llm_provider="palm",
|
llm_provider="palm",
|
||||||
response=original_exception.response,
|
response=original_exception.response,
|
||||||
)
|
)
|
||||||
|
if "504 Deadline expired before operation could complete." in error_str:
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise Timeout(
|
||||||
|
message=f"PalmException - {original_exception.message}",
|
||||||
|
model=model,
|
||||||
|
llm_provider="palm",
|
||||||
|
request=original_exception.request,
|
||||||
|
)
|
||||||
if "400 Request payload size exceeds" in error_str:
|
if "400 Request payload size exceeds" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise ContextWindowExceededError(
|
raise ContextWindowExceededError(
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "litellm"
|
name = "litellm"
|
||||||
version = "1.26.8"
|
version = "1.26.10"
|
||||||
description = "Library to easily interface with LLM API providers"
|
description = "Library to easily interface with LLM API providers"
|
||||||
authors = ["BerriAI"]
|
authors = ["BerriAI"]
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
|
@ -74,7 +74,7 @@ requires = ["poetry-core", "wheel"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
[tool.commitizen]
|
[tool.commitizen]
|
||||||
version = "1.26.8"
|
version = "1.26.10"
|
||||||
version_files = [
|
version_files = [
|
||||||
"pyproject.toml:^version"
|
"pyproject.toml:^version"
|
||||||
]
|
]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue