Merge branch 'BerriAI:main' into main

This commit is contained in:
Hannes Burrichter 2024-05-21 13:51:55 +02:00 committed by GitHub
commit b89b3d8c44
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
102 changed files with 8852 additions and 6557 deletions

View file

@ -42,7 +42,7 @@ jobs:
pip install lunary==0.2.5
pip install "langfuse==2.27.1"
pip install numpydoc
pip install traceloop-sdk==0.0.69
pip install traceloop-sdk==0.18.2
pip install openai
pip install prisma
pip install "httpx==0.24.1"

View file

@ -34,7 +34,7 @@ LiteLLM manages:
[**Jump to OpenAI Proxy Docs**](https://github.com/BerriAI/litellm?tab=readme-ov-file#openai-proxy---docs) <br>
[**Jump to Supported LLM Providers**](https://github.com/BerriAI/litellm?tab=readme-ov-file#supported-providers-docs)
🚨 **Stable Release:** Use docker images with: `main-stable` tag. These run through 12 hr load tests (1k req./min).
🚨 **Stable Release:** Use docker images with the `-stable` tag. These have undergone 12 hour load tests, before being published.
Support for more providers. Missing a provider or LLM Platform, raise a [feature request](https://github.com/BerriAI/litellm/issues/new?assignees=&labels=enhancement&projects=&template=feature_request.yml&title=%5BFeature%5D%3A+).

View file

@ -151,3 +151,19 @@ response = image_generation(
)
print(f"response: {response}")
```
## VertexAI - Image Generation Models
### Usage
Use this for image generation models on VertexAI
```python
response = litellm.image_generation(
prompt="An olympic size swimming pool",
model="vertex_ai/imagegeneration@006",
vertex_ai_project="adroit-crow-413218",
vertex_ai_location="us-central1",
)
print(f"response: {response}")
```

View file

@ -0,0 +1,173 @@
import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Lago - Usage Based Billing
[Lago](https://www.getlago.com/) offers a self-hosted and cloud, metering and usage-based billing solution.
<Image img={require('../../img/lago.jpeg')} />
## Quick Start
Use just 1 lines of code, to instantly log your responses **across all providers** with Lago
Get your Lago [API Key](https://docs.getlago.com/guide/self-hosted/docker#find-your-api-key)
```python
litellm.callbacks = ["lago"] # logs cost + usage of successful calls to lago
```
<Tabs>
<TabItem value="sdk" label="SDK">
```python
# pip install lago
import litellm
import os
os.environ["LAGO_API_BASE"] = "" # http://0.0.0.0:3000
os.environ["LAGO_API_KEY"] = ""
os.environ["LAGO_API_EVENT_CODE"] = "" # The billable metric's code - https://docs.getlago.com/guide/events/ingesting-usage#define-a-billable-metric
# LLM API Keys
os.environ['OPENAI_API_KEY']=""
# set lago as a callback, litellm will send the data to lago
litellm.success_callback = ["lago"]
# openai call
response = litellm.completion(
model="gpt-3.5-turbo",
messages=[
{"role": "user", "content": "Hi 👋 - i'm openai"}
],
user="your_customer_id" # 👈 SET YOUR CUSTOMER ID HERE
)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add to Config.yaml
```yaml
model_list:
- litellm_params:
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
api_key: my-fake-key
model: openai/my-fake-model
model_name: fake-openai-endpoint
litellm_settings:
callbacks: ["lago"] # 👈 KEY CHANGE
```
2. Start Proxy
```
litellm --config /path/to/config.yaml
```
3. Test it!
<Tabs>
<TabItem value="curl" label="Curl">
```bash
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "fake-openai-endpoint",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
"user": "your-customer-id" # 👈 SET YOUR CUSTOMER ID
}
'
```
</TabItem>
<TabItem value="openai_python" label="OpenAI Python SDK">
```python
import openai
client = openai.OpenAI(
api_key="anything",
base_url="http://0.0.0.0:4000"
)
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
], user="my_customer_id") # 👈 whatever your customer id is
print(response)
```
</TabItem>
<TabItem value="langchain" label="Langchain">
```python
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema import HumanMessage, SystemMessage
import os
os.environ["OPENAI_API_KEY"] = "anything"
chat = ChatOpenAI(
openai_api_base="http://0.0.0.0:4000",
model = "gpt-3.5-turbo",
temperature=0.1,
extra_body={
"user": "my_customer_id" # 👈 whatever your customer id is
}
)
messages = [
SystemMessage(
content="You are a helpful assistant that im using to make a test request to."
),
HumanMessage(
content="test from litellm. tell me why it's amazing in 1 sentence"
),
]
response = chat(messages)
print(response)
```
</TabItem>
</Tabs>
</TabItem>
</Tabs>
<Image img={require('../../img/lago_2.png')} />
## Advanced - Lagos Logging object
This is what LiteLLM will log to Lagos
```
{
"event": {
"transaction_id": "<generated_unique_id>",
"external_customer_id": <litellm_end_user_id>, # passed via `user` param in /chat/completion call - https://platform.openai.com/docs/api-reference/chat/create
"code": os.getenv("LAGO_API_EVENT_CODE"),
"properties": {
"input_tokens": <number>,
"output_tokens": <number>,
"model": <string>,
"response_cost": <number>, # 👈 LITELLM CALCULATED RESPONSE COST - https://github.com/BerriAI/litellm/blob/d43f75150a65f91f60dc2c0c9462ce3ffc713c1f/litellm/utils.py#L1473
}
}
}
```

View file

@ -71,6 +71,23 @@ response = litellm.completion(
)
print(response)
```
### Make LiteLLM Proxy use Custom `LANGSMITH_BASE_URL`
If you're using a custom LangSmith instance, you can set the
`LANGSMITH_BASE_URL` environment variable to point to your instance.
For example, you can make LiteLLM Proxy log to a local LangSmith instance with
this config:
```yaml
litellm_settings:
success_callback: ["langsmith"]
environment_variables:
LANGSMITH_BASE_URL: "http://localhost:1984"
LANGSMITH_PROJECT: "litellm-proxy"
```
## Support & Talk to Founders
- [Schedule Demo 👋](https://calendly.com/d/4mp-gd3-k5k/berriai-1-1-onboarding-litellm-hosted-version)

View file

@ -20,7 +20,7 @@ Use just 2 lines of code, to instantly log your responses **across all providers
Get your OpenMeter API Key from https://openmeter.cloud/meters
```python
litellm.success_callback = ["openmeter"] # logs cost + usage of successful calls to openmeter
litellm.callbacks = ["openmeter"] # logs cost + usage of successful calls to openmeter
```
@ -28,7 +28,7 @@ litellm.success_callback = ["openmeter"] # logs cost + usage of successful calls
<TabItem value="sdk" label="SDK">
```python
# pip install langfuse
# pip install openmeter
import litellm
import os
@ -39,8 +39,8 @@ os.environ["OPENMETER_API_KEY"] = ""
# LLM API Keys
os.environ['OPENAI_API_KEY']=""
# set langfuse as a callback, litellm will send the data to langfuse
litellm.success_callback = ["openmeter"]
# set openmeter as a callback, litellm will send the data to openmeter
litellm.callbacks = ["openmeter"]
# openai call
response = litellm.completion(
@ -64,7 +64,7 @@ model_list:
model_name: fake-openai-endpoint
litellm_settings:
success_callback: ["openmeter"] # 👈 KEY CHANGE
callbacks: ["openmeter"] # 👈 KEY CHANGE
```
2. Start Proxy

View file

@ -223,6 +223,32 @@ assert isinstance(
```
### Setting `anthropic-beta` Header in Requests
Pass the the `extra_headers` param to litellm, All headers will be forwarded to Anthropic API
```python
response = completion(
model="anthropic/claude-3-opus-20240229",
messages=messages,
tools=tools,
)
```
### Forcing Anthropic Tool Use
If you want Claude to use a specific tool to answer the users question
You can do this by specifying the tool in the `tool_choice` field like so:
```python
response = completion(
model="anthropic/claude-3-opus-20240229",
messages=messages,
tools=tools,
tool_choice={"type": "tool", "name": "get_weather"},
)
```
### Parallel Function Calling

View file

@ -102,12 +102,18 @@ Ollama supported models: https://github.com/ollama/ollama
| Model Name | Function Call |
|----------------------|-----------------------------------------------------------------------------------
| Mistral | `completion(model='ollama/mistral', messages, api_base="http://localhost:11434", stream=True)` |
| Mistral-7B-Instruct-v0.1 | `completion(model='ollama/mistral-7B-Instruct-v0.1', messages, api_base="http://localhost:11434", stream=False)` |
| Mistral-7B-Instruct-v0.2 | `completion(model='ollama/mistral-7B-Instruct-v0.2', messages, api_base="http://localhost:11434", stream=False)` |
| Mixtral-8x7B-Instruct-v0.1 | `completion(model='ollama/mistral-8x7B-Instruct-v0.1', messages, api_base="http://localhost:11434", stream=False)` |
| Mixtral-8x22B-Instruct-v0.1 | `completion(model='ollama/mixtral-8x22B-Instruct-v0.1', messages, api_base="http://localhost:11434", stream=False)` |
| Llama2 7B | `completion(model='ollama/llama2', messages, api_base="http://localhost:11434", stream=True)` |
| Llama2 13B | `completion(model='ollama/llama2:13b', messages, api_base="http://localhost:11434", stream=True)` |
| Llama2 70B | `completion(model='ollama/llama2:70b', messages, api_base="http://localhost:11434", stream=True)` |
| Llama2 Uncensored | `completion(model='ollama/llama2-uncensored', messages, api_base="http://localhost:11434", stream=True)` |
| Code Llama | `completion(model='ollama/codellama', messages, api_base="http://localhost:11434", stream=True)` |
| Llama2 Uncensored | `completion(model='ollama/llama2-uncensored', messages, api_base="http://localhost:11434", stream=True)` |
|Meta LLaMa3 8B | `completion(model='ollama/llama3', messages, api_base="http://localhost:11434", stream=False)` |
| Meta LLaMa3 70B | `completion(model='ollama/llama3:70b', messages, api_base="http://localhost:11434", stream=False)` |
| Orca Mini | `completion(model='ollama/orca-mini', messages, api_base="http://localhost:11434", stream=True)` |
| Vicuna | `completion(model='ollama/vicuna', messages, api_base="http://localhost:11434", stream=True)` |
| Nous-Hermes | `completion(model='ollama/nous-hermes', messages, api_base="http://localhost:11434", stream=True)` |

View file

@ -188,6 +188,7 @@ These also support the `OPENAI_API_BASE` environment variable, which can be used
## OpenAI Vision Models
| Model Name | Function Call |
|-----------------------|-----------------------------------------------------------------|
| gpt-4o | `response = completion(model="gpt-4o", messages=messages)` |
| gpt-4-turbo | `response = completion(model="gpt-4-turbo", messages=messages)` |
| gpt-4-vision-preview | `response = completion(model="gpt-4-vision-preview", messages=messages)` |

View file

@ -508,6 +508,31 @@ All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a02
| text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` |
| text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` |
## Image Generation Models
Usage
```python
response = await litellm.aimage_generation(
prompt="An olympic size swimming pool",
model="vertex_ai/imagegeneration@006",
vertex_ai_project="adroit-crow-413218",
vertex_ai_location="us-central1",
)
```
**Generating multiple images**
Use the `n` parameter to pass how many images you want generated
```python
response = await litellm.aimage_generation(
prompt="An olympic size swimming pool",
model="vertex_ai/imagegeneration@006",
vertex_ai_project="adroit-crow-413218",
vertex_ai_location="us-central1",
n=1,
)
```
## Extra

View file

@ -1,4 +1,4 @@
# 🚨 Alerting
# 🚨 Alerting / Webhooks
Get alerts for:
@ -61,10 +61,38 @@ curl -X GET 'http://localhost:4000/health/services?service=slack' \
-H 'Authorization: Bearer sk-1234'
```
## Advanced - Opting into specific alert types
## Extras
Set `alert_types` if you want to Opt into only specific alert types
### Using Discord Webhooks
```shell
general_settings:
alerting: ["slack"]
alert_types: ["spend_reports"]
```
All Possible Alert Types
```python
alert_types:
Optional[
List[
Literal[
"llm_exceptions",
"llm_too_slow",
"llm_requests_hanging",
"budget_alerts",
"db_exceptions",
"daily_reports",
"spend_reports",
"cooldown_deployment",
"new_model_added",
]
]
```
## Advanced - Using Discord Webhooks
Discord provides a slack compatible webhook url that you can use for alerting
@ -96,3 +124,80 @@ environment_variables:
```
That's it ! You're ready to go !
## Advanced - [BETA] Webhooks for Budget Alerts
**Note**: This is a beta feature, so the spec might change.
Set a webhook to get notified for budget alerts.
1. Setup config.yaml
Add url to your environment, for testing you can use a link from [here](https://webhook.site/)
```bash
export WEBHOOK_URL="https://webhook.site/6ab090e8-c55f-4a23-b075-3209f5c57906"
```
Add 'webhook' to config.yaml
```yaml
general_settings:
alerting: ["webhook"] # 👈 KEY CHANGE
```
2. Start proxy
```bash
litellm --config /path/to/config.yaml
# RUNNING on http://0.0.0.0:4000
```
3. Test it!
```bash
curl -X GET --location 'http://0.0.0.0:4000/health/services?service=webhook' \
--header 'Authorization: Bearer sk-1234'
```
**Expected Response**
```bash
{
"spend": 1, # the spend for the 'event_group'
"max_budget": 0, # the 'max_budget' set for the 'event_group'
"token": "88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b",
"user_id": "default_user_id",
"team_id": null,
"user_email": null,
"key_alias": null,
"projected_exceeded_data": null,
"projected_spend": null,
"event": "budget_crossed", # Literal["budget_crossed", "threshold_crossed", "projected_limit_exceeded"]
"event_group": "user",
"event_message": "User Budget: Budget Crossed"
}
```
**API Spec for Webhook Event**
- `spend` *float*: The current spend amount for the 'event_group'.
- `max_budget` *float*: The maximum allowed budget for the 'event_group'.
- `token` *str*: A hashed value of the key, used for authentication or identification purposes.
- `user_id` *str or null*: The ID of the user associated with the event (optional).
- `team_id` *str or null*: The ID of the team associated with the event (optional).
- `user_email` *str or null*: The email of the user associated with the event (optional).
- `key_alias` *str or null*: An alias for the key associated with the event (optional).
- `projected_exceeded_date` *str or null*: The date when the budget is projected to be exceeded, returned when 'soft_budget' is set for key (optional).
- `projected_spend` *float or null*: The projected spend amount, returned when 'soft_budget' is set for key (optional).
- `event` *Literal["budget_crossed", "threshold_crossed", "projected_limit_exceeded"]*: The type of event that triggered the webhook. Possible values are:
* "budget_crossed": Indicates that the spend has exceeded the max budget.
* "threshold_crossed": Indicates that spend has crossed a threshold (currently sent when 85% and 95% of budget is reached).
* "projected_limit_exceeded": For "key" only - Indicates that the projected spend is expected to exceed the soft budget threshold.
- `event_group` *Literal["user", "key", "team", "proxy"]*: The group associated with the event. Possible values are:
* "user": The event is related to a specific user.
* "key": The event is related to a specific key.
* "team": The event is related to a team.
* "proxy": The event is related to a proxy.
- `event_message` *str*: A human-readable description of the event.

View file

@ -0,0 +1,319 @@
import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# 💵 Billing
Bill internal teams, external customers for their usage
**🚨 Requirements**
- [Setup Lago](https://docs.getlago.com/guide/self-hosted/docker#run-the-app), for usage-based billing. We recommend following [their Stripe tutorial](https://docs.getlago.com/templates/per-transaction/stripe#step-1-create-billable-metrics-for-transaction)
Steps:
- Connect the proxy to Lago
- Set the id you want to bill for (customers, internal users, teams)
- Start!
## Quick Start
Bill internal teams for their usage
### 1. Connect proxy to Lago
Set 'lago' as a callback on your proxy config.yaml
```yaml
model_name:
- model_name: fake-openai-endpoint
litellm_params:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
litellm_settings:
callbacks: ["lago"] # 👈 KEY CHANGE
general_settings:
master_key: sk-1234
```
Add your Lago keys to the environment
```bash
export LAGO_API_BASE="http://localhost:3000" # self-host - https://docs.getlago.com/guide/self-hosted/docker#run-the-app
export LAGO_API_KEY="3e29d607-de54-49aa-a019-ecf585729070" # Get key - https://docs.getlago.com/guide/self-hosted/docker#find-your-api-key
export LAGO_API_EVENT_CODE="openai_tokens" # name of lago billing code
export LAGO_API_CHARGE_BY="team_id" # 👈 Charges 'team_id' attached to proxy key
```
Start proxy
```bash
litellm --config /path/to/config.yaml
```
### 2. Create Key for Internal Team
```bash
curl 'http://0.0.0.0:4000/key/generate' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data-raw '{"team_id": "my-unique-id"}' # 👈 Internal Team's ID
```
Response Object:
```bash
{
"key": "sk-tXL0wt5-lOOVK9sfY2UacA",
}
```
### 3. Start billing!
<Tabs>
<TabItem value="curl" label="Curl">
```bash
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer sk-tXL0wt5-lOOVK9sfY2UacA' \ # 👈 Team's Key
--data ' {
"model": "fake-openai-endpoint",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
}
'
```
</TabItem>
<TabItem value="openai_python" label="OpenAI Python SDK">
```python
import openai
client = openai.OpenAI(
api_key="sk-tXL0wt5-lOOVK9sfY2UacA", # 👈 Team's Key
base_url="http://0.0.0.0:4000"
)
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
])
print(response)
```
</TabItem>
<TabItem value="langchain" label="Langchain">
```python
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema import HumanMessage, SystemMessage
import os
os.environ["OPENAI_API_KEY"] = "sk-tXL0wt5-lOOVK9sfY2UacA" # 👈 Team's Key
chat = ChatOpenAI(
openai_api_base="http://0.0.0.0:4000",
model = "gpt-3.5-turbo",
temperature=0.1,
)
messages = [
SystemMessage(
content="You are a helpful assistant that im using to make a test request to."
),
HumanMessage(
content="test from litellm. tell me why it's amazing in 1 sentence"
),
]
response = chat(messages)
print(response)
```
</TabItem>
</Tabs>
**See Results on Lago**
<Image img={require('../../img/lago_2.png')} style={{ width: '500px', height: 'auto' }} />
## Advanced - Lago Logging object
This is what LiteLLM will log to Lagos
```
{
"event": {
"transaction_id": "<generated_unique_id>",
"external_customer_id": <selected_id>, # either 'end_user_id', 'user_id', or 'team_id'. Default 'end_user_id'.
"code": os.getenv("LAGO_API_EVENT_CODE"),
"properties": {
"input_tokens": <number>,
"output_tokens": <number>,
"model": <string>,
"response_cost": <number>, # 👈 LITELLM CALCULATED RESPONSE COST - https://github.com/BerriAI/litellm/blob/d43f75150a65f91f60dc2c0c9462ce3ffc713c1f/litellm/utils.py#L1473
}
}
}
```
## Advanced - Bill Customers, Internal Users
For:
- Customers (id passed via 'user' param in /chat/completion call) = 'end_user_id'
- Internal Users (id set when [creating keys](https://docs.litellm.ai/docs/proxy/virtual_keys#advanced---spend-tracking)) = 'user_id'
- Teams (id set when [creating keys](https://docs.litellm.ai/docs/proxy/virtual_keys#advanced---spend-tracking)) = 'team_id'
<Tabs>
<TabItem value="customers" label="Customer Billing">
1. Set 'LAGO_API_CHARGE_BY' to 'end_user_id'
```bash
export LAGO_API_CHARGE_BY="end_user_id"
```
2. Test it!
<Tabs>
<TabItem value="curl" label="Curl">
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
"user": "my_customer_id" # 👈 whatever your customer id is
}
'
```
</TabItem>
<TabItem value="openai_sdk" label="OpenAI Python SDK">
```python
import openai
client = openai.OpenAI(
api_key="anything",
base_url="http://0.0.0.0:4000"
)
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
], user="my_customer_id") # 👈 whatever your customer id is
print(response)
```
</TabItem>
<TabItem value="langchain" label="Langchain">
```python
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema import HumanMessage, SystemMessage
import os
os.environ["OPENAI_API_KEY"] = "anything"
chat = ChatOpenAI(
openai_api_base="http://0.0.0.0:4000",
model = "gpt-3.5-turbo",
temperature=0.1,
extra_body={
"user": "my_customer_id" # 👈 whatever your customer id is
}
)
messages = [
SystemMessage(
content="You are a helpful assistant that im using to make a test request to."
),
HumanMessage(
content="test from litellm. tell me why it's amazing in 1 sentence"
),
]
response = chat(messages)
print(response)
```
</TabItem>
</Tabs>
</TabItem>
<TabItem value="users" label="Internal User Billing">
1. Set 'LAGO_API_CHARGE_BY' to 'user_id'
```bash
export LAGO_API_CHARGE_BY="user_id"
```
2. Create a key for that user
```bash
curl 'http://0.0.0.0:4000/key/generate' \
--header 'Authorization: Bearer <your-master-key>' \
--header 'Content-Type: application/json' \
--data-raw '{"user_id": "my-unique-id"}' # 👈 Internal User's id
```
Response Object:
```bash
{
"key": "sk-tXL0wt5-lOOVK9sfY2UacA",
}
```
3. Make API Calls with that Key
```python
import openai
client = openai.OpenAI(
api_key="sk-tXL0wt5-lOOVK9sfY2UacA", # 👈 Generated key
base_url="http://0.0.0.0:4000"
)
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
])
print(response)
```
</TabItem>
</Tabs>

View file

@ -25,26 +25,45 @@ class MyCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observabilit
def __init__(self):
pass
#### ASYNC ####
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
pass
async def async_log_pre_api_call(self, model, messages, kwargs):
pass
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
pass
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
pass
#### CALL HOOKS - proxy only ####
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: Literal["completion", "embeddings"]):
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
]) -> Optional[dict, str, Exception]:
data["model"] = "my-new-model"
return data
async def async_post_call_failure_hook(
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
):
pass
async def async_post_call_success_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
response,
):
pass
async def async_moderation_hook( # call made in parallel to llm api call
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
):
pass
async def async_post_call_streaming_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
response: str,
):
pass
proxy_handler_instance = MyCustomHandler()
```
@ -191,3 +210,99 @@ general_settings:
**Result**
<Image img={require('../../img/end_user_enforcement.png')}/>
## Advanced - Return rejected message as response
For chat completions and text completion calls, you can return a rejected message as a user response.
Do this by returning a string. LiteLLM takes care of returning the response in the correct format depending on the endpoint and if it's streaming/non-streaming.
For non-chat/text completion endpoints, this response is returned as a 400 status code exception.
### 1. Create Custom Handler
```python
from litellm.integrations.custom_logger import CustomLogger
import litellm
from litellm.utils import get_formatted_prompt
# This file includes the custom callbacks for LiteLLM Proxy
# Once defined, these can be passed in proxy_config.yaml
class MyCustomHandler(CustomLogger):
def __init__(self):
pass
#### CALL HOOKS - proxy only ####
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
]) -> Optional[dict, str, Exception]:
formatted_prompt = get_formatted_prompt(data=data, call_type=call_type)
if "Hello world" in formatted_prompt:
return "This is an invalid response"
return data
proxy_handler_instance = MyCustomHandler()
```
### 2. Update config.yaml
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo
litellm_settings:
callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
```
### 3. Test it!
```shell
$ litellm /path/to/config.yaml
```
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--data ' {
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "Hello world"
}
],
}'
```
**Expected Response**
```
{
"id": "chatcmpl-d00bbede-2d90-4618-bf7b-11a1c23cf360",
"choices": [
{
"finish_reason": "stop",
"index": 0,
"message": {
"content": "This is an invalid response.", # 👈 REJECTED RESPONSE
"role": "assistant"
}
}
],
"created": 1716234198,
"model": null,
"object": "chat.completion",
"system_fingerprint": null,
"usage": {}
}
```

View file

@ -5,6 +5,8 @@
- debug (prints info logs)
- detailed debug (prints debug logs)
The proxy also supports json logs. [See here](#json-logs)
## `debug`
**via cli**
@ -32,3 +34,19 @@ $ litellm --detailed_debug
```python
os.environ["LITELLM_LOG"] = "DEBUG"
```
## JSON LOGS
Set `JSON_LOGS="True"` in your env:
```bash
export JSON_LOGS="True"
```
Start proxy
```bash
$ litellm
```
The proxy will now all logs in json format.

View file

@ -1,7 +1,8 @@
import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# ✨ Enterprise Features - Content Mod, SSO
# ✨ Enterprise Features - Content Mod, SSO, Custom Swagger
Features here are behind a commercial license in our `/enterprise` folder. [**See Code**](https://github.com/BerriAI/litellm/tree/main/enterprise)
@ -20,6 +21,7 @@ Features:
- ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors)
- ✅ Don't log/store specific requests to Langfuse, Sentry, etc. (eg confidential LLM requests)
- ✅ Tracking Spend for Custom Tags
- ✅ Custom Branding + Routes on Swagger Docs
@ -527,3 +529,38 @@ curl -X GET "http://0.0.0.0:4000/spend/tags" \
<!-- ## Tracking Spend per Key
## Tracking Spend per User -->
## Swagger Docs - Custom Routes + Branding
:::info
Requires a LiteLLM Enterprise key to use. Request one [here](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
:::
Set LiteLLM Key in your environment
```bash
LITELLM_LICENSE=""
```
### Customize Title + Description
In your environment, set:
```bash
DOCS_TITLE="TotalGPT"
DOCS_DESCRIPTION="Sample Company Description"
```
### Customize Routes
Hide admin routes from users.
In your environment, set:
```bash
DOCS_FILTERED="True" # only shows openai routes to user
```
<Image img={require('../../img/custom_swagger.png')} style={{ width: '900px', height: 'auto' }} />

Binary file not shown.

After

Width:  |  Height:  |  Size: 223 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 344 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 219 KiB

View file

@ -41,6 +41,7 @@ const sidebars = {
"proxy/reliability",
"proxy/cost_tracking",
"proxy/users",
"proxy/billing",
"proxy/user_keys",
"proxy/enterprise",
"proxy/virtual_keys",
@ -175,6 +176,7 @@ const sidebars = {
"observability/custom_callback",
"observability/langfuse_integration",
"observability/sentry",
"observability/lago",
"observability/openmeter",
"observability/promptlayer_integration",
"observability/wandb_integration",

View file

@ -27,8 +27,8 @@ input_callback: List[Union[str, Callable]] = []
success_callback: List[Union[str, Callable]] = []
failure_callback: List[Union[str, Callable]] = []
service_callback: List[Union[str, Callable]] = []
callbacks: List[Callable] = []
_custom_logger_compatible_callbacks: list = ["openmeter"]
_custom_logger_compatible_callbacks_literal = Literal["lago", "openmeter"]
callbacks: List[Union[Callable, _custom_logger_compatible_callbacks_literal]] = []
_langfuse_default_tags: Optional[
List[
Literal[
@ -724,6 +724,9 @@ from .utils import (
get_supported_openai_params,
get_api_base,
get_first_chars_messages,
ModelResponse,
ImageResponse,
ImageObject,
)
from .llms.huggingface_restapi import HuggingfaceConfig
from .llms.anthropic import AnthropicConfig

View file

@ -1,18 +1,32 @@
import logging
import logging, os, json
from logging import Formatter
set_verbose = False
json_logs = False
json_logs = bool(os.getenv("JSON_LOGS", False))
# Create a handler for the logger (you may need to adapt this based on your needs)
handler = logging.StreamHandler()
handler.setLevel(logging.DEBUG)
class JsonFormatter(Formatter):
def __init__(self):
super(JsonFormatter, self).__init__()
def format(self, record):
json_record = {}
json_record["message"] = record.getMessage()
return json.dumps(json_record)
# Create a formatter and set it for the handler
if json_logs:
handler.setFormatter(JsonFormatter())
else:
formatter = logging.Formatter(
"\033[92m%(asctime)s - %(name)s:%(levelname)s\033[0m: %(filename)s:%(lineno)s - %(message)s",
datefmt="%H:%M:%S",
)
handler.setFormatter(formatter)
verbose_proxy_logger = logging.getLogger("LiteLLM Proxy")

View file

@ -15,11 +15,19 @@ from typing import Optional
class AuthenticationError(openai.AuthenticationError): # type: ignore
def __init__(self, message, llm_provider, model, response: httpx.Response):
def __init__(
self,
message,
llm_provider,
model,
response: httpx.Response,
litellm_debug_info: Optional[str] = None,
):
self.status_code = 401
self.message = message
self.llm_provider = llm_provider
self.model = model
self.litellm_debug_info = litellm_debug_info
super().__init__(
self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs
@ -27,11 +35,19 @@ class AuthenticationError(openai.AuthenticationError): # type: ignore
# raise when invalid models passed, example gpt-8
class NotFoundError(openai.NotFoundError): # type: ignore
def __init__(self, message, model, llm_provider, response: httpx.Response):
def __init__(
self,
message,
model,
llm_provider,
response: httpx.Response,
litellm_debug_info: Optional[str] = None,
):
self.status_code = 404
self.message = message
self.model = model
self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info
super().__init__(
self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs
@ -39,12 +55,18 @@ class NotFoundError(openai.NotFoundError): # type: ignore
class BadRequestError(openai.BadRequestError): # type: ignore
def __init__(
self, message, model, llm_provider, response: Optional[httpx.Response] = None
self,
message,
model,
llm_provider,
response: Optional[httpx.Response] = None,
litellm_debug_info: Optional[str] = None,
):
self.status_code = 400
self.message = message
self.model = model
self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info
response = response or httpx.Response(
status_code=self.status_code,
request=httpx.Request(
@ -57,18 +79,28 @@ class BadRequestError(openai.BadRequestError): # type: ignore
class UnprocessableEntityError(openai.UnprocessableEntityError): # type: ignore
def __init__(self, message, model, llm_provider, response: httpx.Response):
def __init__(
self,
message,
model,
llm_provider,
response: httpx.Response,
litellm_debug_info: Optional[str] = None,
):
self.status_code = 422
self.message = message
self.model = model
self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info
super().__init__(
self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs
class Timeout(openai.APITimeoutError): # type: ignore
def __init__(self, message, model, llm_provider):
def __init__(
self, message, model, llm_provider, litellm_debug_info: Optional[str] = None
):
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
super().__init__(
request=request
@ -77,6 +109,7 @@ class Timeout(openai.APITimeoutError): # type: ignore
self.message = message
self.model = model
self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info
# custom function to convert to str
def __str__(self):
@ -84,22 +117,38 @@ class Timeout(openai.APITimeoutError): # type: ignore
class PermissionDeniedError(openai.PermissionDeniedError): # type:ignore
def __init__(self, message, llm_provider, model, response: httpx.Response):
def __init__(
self,
message,
llm_provider,
model,
response: httpx.Response,
litellm_debug_info: Optional[str] = None,
):
self.status_code = 403
self.message = message
self.llm_provider = llm_provider
self.model = model
self.litellm_debug_info = litellm_debug_info
super().__init__(
self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs
class RateLimitError(openai.RateLimitError): # type: ignore
def __init__(self, message, llm_provider, model, response: httpx.Response):
def __init__(
self,
message,
llm_provider,
model,
response: httpx.Response,
litellm_debug_info: Optional[str] = None,
):
self.status_code = 429
self.message = message
self.llm_provider = llm_provider
self.modle = model
self.litellm_debug_info = litellm_debug_info
super().__init__(
self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs
@ -107,11 +156,45 @@ class RateLimitError(openai.RateLimitError): # type: ignore
# sub class of rate limit error - meant to give more granularity for error handling context window exceeded errors
class ContextWindowExceededError(BadRequestError): # type: ignore
def __init__(self, message, model, llm_provider, response: httpx.Response):
def __init__(
self,
message,
model,
llm_provider,
response: httpx.Response,
litellm_debug_info: Optional[str] = None,
):
self.status_code = 400
self.message = message
self.model = model
self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info
super().__init__(
message=self.message,
model=self.model, # type: ignore
llm_provider=self.llm_provider, # type: ignore
response=response,
) # Call the base class constructor with the parameters it needs
# sub class of bad request error - meant to help us catch guardrails-related errors on proxy.
class RejectedRequestError(BadRequestError): # type: ignore
def __init__(
self,
message,
model,
llm_provider,
request_data: dict,
litellm_debug_info: Optional[str] = None,
):
self.status_code = 400
self.message = message
self.model = model
self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info
self.request_data = request_data
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
response = httpx.Response(status_code=500, request=request)
super().__init__(
message=self.message,
model=self.model, # type: ignore
@ -122,11 +205,19 @@ class ContextWindowExceededError(BadRequestError): # type: ignore
class ContentPolicyViolationError(BadRequestError): # type: ignore
# Error code: 400 - {'error': {'code': 'content_policy_violation', 'message': 'Your request was rejected as a result of our safety system. Image descriptions generated from your prompt may contain text that is not allowed by our safety system. If you believe this was done in error, your request may succeed if retried, or by adjusting your prompt.', 'param': None, 'type': 'invalid_request_error'}}
def __init__(self, message, model, llm_provider, response: httpx.Response):
def __init__(
self,
message,
model,
llm_provider,
response: httpx.Response,
litellm_debug_info: Optional[str] = None,
):
self.status_code = 400
self.message = message
self.model = model
self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info
super().__init__(
message=self.message,
model=self.model, # type: ignore
@ -136,11 +227,19 @@ class ContentPolicyViolationError(BadRequestError): # type: ignore
class ServiceUnavailableError(openai.APIStatusError): # type: ignore
def __init__(self, message, llm_provider, model, response: httpx.Response):
def __init__(
self,
message,
llm_provider,
model,
response: httpx.Response,
litellm_debug_info: Optional[str] = None,
):
self.status_code = 503
self.message = message
self.llm_provider = llm_provider
self.model = model
self.litellm_debug_info = litellm_debug_info
super().__init__(
self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs
@ -149,33 +248,51 @@ class ServiceUnavailableError(openai.APIStatusError): # type: ignore
# raise this when the API returns an invalid response object - https://github.com/openai/openai-python/blob/1be14ee34a0f8e42d3f9aa5451aa4cb161f1781f/openai/api_requestor.py#L401
class APIError(openai.APIError): # type: ignore
def __init__(
self, status_code, message, llm_provider, model, request: httpx.Request
self,
status_code,
message,
llm_provider,
model,
request: httpx.Request,
litellm_debug_info: Optional[str] = None,
):
self.status_code = status_code
self.message = message
self.llm_provider = llm_provider
self.model = model
self.litellm_debug_info = litellm_debug_info
super().__init__(self.message, request=request, body=None) # type: ignore
# raised if an invalid request (not get, delete, put, post) is made
class APIConnectionError(openai.APIConnectionError): # type: ignore
def __init__(self, message, llm_provider, model, request: httpx.Request):
def __init__(
self,
message,
llm_provider,
model,
request: httpx.Request,
litellm_debug_info: Optional[str] = None,
):
self.message = message
self.llm_provider = llm_provider
self.model = model
self.status_code = 500
self.litellm_debug_info = litellm_debug_info
super().__init__(message=self.message, request=request)
# raised if an invalid request (not get, delete, put, post) is made
class APIResponseValidationError(openai.APIResponseValidationError): # type: ignore
def __init__(self, message, llm_provider, model):
def __init__(
self, message, llm_provider, model, litellm_debug_info: Optional[str] = None
):
self.message = message
self.llm_provider = llm_provider
self.model = model
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
response = httpx.Response(status_code=500, request=request)
self.litellm_debug_info = litellm_debug_info
super().__init__(response=response, body=None, message=message)

View file

@ -4,7 +4,6 @@ import dotenv, os
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from typing import Literal, Union, Optional
import traceback
@ -64,8 +63,17 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: Literal["completion", "embeddings", "image_generation"],
):
call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
],
) -> Optional[
Union[Exception, str, dict]
]: # raise exception if invalid, return a str for the user to receive - if rejected, or return a modified dictionary for passing into litellm
pass
async def async_post_call_failure_hook(

View file

@ -0,0 +1,179 @@
# What is this?
## On Success events log cost to Lago - https://github.com/BerriAI/litellm/issues/3639
import dotenv, os, json
import litellm
import traceback, httpx
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
import uuid
from typing import Optional, Literal
def get_utc_datetime():
import datetime as dt
from datetime import datetime
if hasattr(dt, "UTC"):
return datetime.now(dt.UTC) # type: ignore
else:
return datetime.utcnow() # type: ignore
class LagoLogger(CustomLogger):
def __init__(self) -> None:
super().__init__()
self.validate_environment()
self.async_http_handler = AsyncHTTPHandler()
self.sync_http_handler = HTTPHandler()
def validate_environment(self):
"""
Expects
LAGO_API_BASE,
LAGO_API_KEY,
LAGO_API_EVENT_CODE,
Optional:
LAGO_API_CHARGE_BY
in the environment
"""
missing_keys = []
if os.getenv("LAGO_API_KEY", None) is None:
missing_keys.append("LAGO_API_KEY")
if os.getenv("LAGO_API_BASE", None) is None:
missing_keys.append("LAGO_API_BASE")
if os.getenv("LAGO_API_EVENT_CODE", None) is None:
missing_keys.append("LAGO_API_EVENT_CODE")
if len(missing_keys) > 0:
raise Exception("Missing keys={} in environment.".format(missing_keys))
def _common_logic(self, kwargs: dict, response_obj) -> dict:
call_id = response_obj.get("id", kwargs.get("litellm_call_id"))
dt = get_utc_datetime().isoformat()
cost = kwargs.get("response_cost", None)
model = kwargs.get("model")
usage = {}
if (
isinstance(response_obj, litellm.ModelResponse)
or isinstance(response_obj, litellm.EmbeddingResponse)
) and hasattr(response_obj, "usage"):
usage = {
"prompt_tokens": response_obj["usage"].get("prompt_tokens", 0),
"completion_tokens": response_obj["usage"].get("completion_tokens", 0),
"total_tokens": response_obj["usage"].get("total_tokens"),
}
litellm_params = kwargs.get("litellm_params", {}) or {}
proxy_server_request = litellm_params.get("proxy_server_request") or {}
end_user_id = proxy_server_request.get("body", {}).get("user", None)
user_id = litellm_params["metadata"].get("user_api_key_user_id", None)
team_id = litellm_params["metadata"].get("user_api_key_team_id", None)
org_id = litellm_params["metadata"].get("user_api_key_org_id", None)
charge_by: Literal["end_user_id", "team_id", "user_id"] = "end_user_id"
external_customer_id: Optional[str] = None
if os.getenv("LAGO_API_CHARGE_BY", None) is not None and isinstance(
os.environ["LAGO_API_CHARGE_BY"], str
):
if os.environ["LAGO_API_CHARGE_BY"] in [
"end_user_id",
"user_id",
"team_id",
]:
charge_by = os.environ["LAGO_API_CHARGE_BY"] # type: ignore
else:
raise Exception("invalid LAGO_API_CHARGE_BY set")
if charge_by == "end_user_id":
external_customer_id = end_user_id
elif charge_by == "team_id":
external_customer_id = team_id
elif charge_by == "user_id":
external_customer_id = user_id
if external_customer_id is None:
raise Exception("External Customer ID is not set")
return {
"event": {
"transaction_id": str(uuid.uuid4()),
"external_customer_id": external_customer_id,
"code": os.getenv("LAGO_API_EVENT_CODE"),
"properties": {"model": model, "response_cost": cost, **usage},
}
}
def log_success_event(self, kwargs, response_obj, start_time, end_time):
_url = os.getenv("LAGO_API_BASE")
assert _url is not None and isinstance(
_url, str
), "LAGO_API_BASE missing or not set correctly. LAGO_API_BASE={}".format(_url)
if _url.endswith("/"):
_url += "api/v1/events"
else:
_url += "/api/v1/events"
api_key = os.getenv("LAGO_API_KEY")
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
_headers = {
"Content-Type": "application/json",
"Authorization": "Bearer {}".format(api_key),
}
try:
response = self.sync_http_handler.post(
url=_url,
data=json.dumps(_data),
headers=_headers,
)
response.raise_for_status()
except Exception as e:
if hasattr(response, "text"):
litellm.print_verbose(f"\nError Message: {response.text}")
raise e
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
_url = os.getenv("LAGO_API_BASE")
assert _url is not None and isinstance(
_url, str
), "LAGO_API_BASE missing or not set correctly. LAGO_API_BASE={}".format(
_url
)
if _url.endswith("/"):
_url += "api/v1/events"
else:
_url += "/api/v1/events"
api_key = os.getenv("LAGO_API_KEY")
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
_headers = {
"Content-Type": "application/json",
"Authorization": "Bearer {}".format(api_key),
}
except Exception as e:
raise e
response: Optional[httpx.Response] = None
try:
response = await self.async_http_handler.post(
url=_url,
data=json.dumps(_data),
headers=_headers,
)
response.raise_for_status()
except Exception as e:
if response is not None and hasattr(response, "text"):
litellm.print_verbose(f"\nError Message: {response.text}")
raise e

View file

@ -93,6 +93,7 @@ class LangFuseLogger:
)
litellm_params = kwargs.get("litellm_params", {})
litellm_call_id = kwargs.get("litellm_call_id", None)
metadata = (
litellm_params.get("metadata", {}) or {}
) # if litellm_params['metadata'] == None
@ -161,6 +162,7 @@ class LangFuseLogger:
response_obj,
level,
print_verbose,
litellm_call_id,
)
elif response_obj is not None:
self._log_langfuse_v1(
@ -255,6 +257,7 @@ class LangFuseLogger:
response_obj,
level,
print_verbose,
litellm_call_id,
) -> tuple:
import langfuse
@ -318,7 +321,7 @@ class LangFuseLogger:
session_id = clean_metadata.pop("session_id", None)
trace_name = clean_metadata.pop("trace_name", None)
trace_id = clean_metadata.pop("trace_id", None)
trace_id = clean_metadata.pop("trace_id", litellm_call_id)
existing_trace_id = clean_metadata.pop("existing_trace_id", None)
update_trace_keys = clean_metadata.pop("update_trace_keys", [])
debug = clean_metadata.pop("debug_langfuse", None)
@ -351,9 +354,13 @@ class LangFuseLogger:
# Special keys that are found in the function arguments and not the metadata
if "input" in update_trace_keys:
trace_params["input"] = input if not mask_input else "redacted-by-litellm"
trace_params["input"] = (
input if not mask_input else "redacted-by-litellm"
)
if "output" in update_trace_keys:
trace_params["output"] = output if not mask_output else "redacted-by-litellm"
trace_params["output"] = (
output if not mask_output else "redacted-by-litellm"
)
else: # don't overwrite an existing trace
trace_params = {
"id": trace_id,
@ -375,7 +382,9 @@ class LangFuseLogger:
if level == "ERROR":
trace_params["status_message"] = output
else:
trace_params["output"] = output if not mask_output else "redacted-by-litellm"
trace_params["output"] = (
output if not mask_output else "redacted-by-litellm"
)
if debug == True or (isinstance(debug, str) and debug.lower() == "true"):
if "metadata" in trace_params:

View file

@ -44,6 +44,8 @@ class LangsmithLogger:
print_verbose(
f"Langsmith Logging - project_name: {project_name}, run_name {run_name}"
)
langsmith_base_url = os.getenv("LANGSMITH_BASE_URL", "https://api.smith.langchain.com")
try:
print_verbose(
f"Langsmith Logging - Enters logging function for model {kwargs}"
@ -86,8 +88,12 @@ class LangsmithLogger:
"end_time": end_time,
}
url = f"{langsmith_base_url}/runs"
print_verbose(
f"Langsmith Logging - About to send data to {url} ..."
)
response = requests.post(
"https://api.smith.langchain.com/runs",
url=url,
json=data,
headers={"x-api-key": self.langsmith_api_key},
)

View file

@ -1,7 +1,7 @@
#### What this does ####
# Class for sending Slack Alerts #
import dotenv, os
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy._types import UserAPIKeyAuth, CallInfo
from litellm._logging import verbose_logger, verbose_proxy_logger
import litellm, threading
from typing import List, Literal, Any, Union, Optional, Dict
@ -36,6 +36,13 @@ class SlackAlertingArgs(LiteLLMBase):
os.getenv("SLACK_DAILY_REPORT_FREQUENCY", default_daily_report_frequency)
)
report_check_interval: int = 5 * 60 # 5 minutes
budget_alert_ttl: int = 24 * 60 * 60 # 24 hours
class WebhookEvent(CallInfo):
event: Literal["budget_crossed", "threshold_crossed", "projected_limit_exceeded"]
event_group: Literal["user", "key", "team", "proxy"]
event_message: str # human-readable description of event
class DeploymentMetrics(LiteLLMBase):
@ -87,6 +94,9 @@ class SlackAlerting(CustomLogger):
"budget_alerts",
"db_exceptions",
"daily_reports",
"spend_reports",
"cooldown_deployment",
"new_model_added",
]
] = [
"llm_exceptions",
@ -95,6 +105,9 @@ class SlackAlerting(CustomLogger):
"budget_alerts",
"db_exceptions",
"daily_reports",
"spend_reports",
"cooldown_deployment",
"new_model_added",
],
alert_to_webhook_url: Optional[
Dict
@ -158,13 +171,28 @@ class SlackAlerting(CustomLogger):
) -> Optional[str]:
"""
Returns langfuse trace url
- check:
-> existing_trace_id
-> trace_id
-> litellm_call_id
"""
# do nothing for now
if request_data is not None:
trace_id = None
if (
request_data is not None
and request_data.get("metadata", {}).get("trace_id", None) is not None
request_data.get("metadata", {}).get("existing_trace_id", None)
is not None
):
trace_id = request_data["metadata"]["existing_trace_id"]
elif request_data.get("metadata", {}).get("trace_id", None) is not None:
trace_id = request_data["metadata"]["trace_id"]
elif request_data.get("litellm_logging_obj", None) is not None and hasattr(
request_data["litellm_logging_obj"], "model_call_details"
):
trace_id = request_data["litellm_logging_obj"].model_call_details[
"litellm_call_id"
]
if litellm.utils.langFuseLogger is not None:
base_url = litellm.utils.langFuseLogger.Langfuse.base_url
return f"{base_url}/trace/{trace_id}"
@ -549,126 +577,130 @@ class SlackAlerting(CustomLogger):
alert_type="llm_requests_hanging",
)
async def failed_tracking_alert(self, error_message: str):
"""Raise alert when tracking failed for specific model"""
_cache: DualCache = self.internal_usage_cache
message = "Failed Tracking Cost for" + error_message
_cache_key = "budget_alerts:failed_tracking:{}".format(message)
result = await _cache.async_get_cache(key=_cache_key)
if result is None:
await self.send_alert(
message=message, level="High", alert_type="budget_alerts"
)
await _cache.async_set_cache(
key=_cache_key,
value="SENT",
ttl=self.alerting_args.budget_alert_ttl,
)
async def budget_alerts(
self,
type: Literal[
"token_budget",
"user_budget",
"user_and_proxy_budget",
"failed_budgets",
"failed_tracking",
"team_budget",
"proxy_budget",
"projected_limit_exceeded",
],
user_max_budget: float,
user_current_spend: float,
user_info=None,
error_message="",
user_info: CallInfo,
):
## PREVENTITIVE ALERTING ## - https://github.com/BerriAI/litellm/issues/2727
# - Alert once within 24hr period
# - Cache this information
# - Don't re-alert, if alert already sent
_cache: DualCache = self.internal_usage_cache
if self.alerting is None or self.alert_types is None:
# do nothing if alerting is not switched on
return
if "budget_alerts" not in self.alert_types:
return
_id: str = "default_id" # used for caching
if type == "user_and_proxy_budget":
user_info = dict(user_info)
user_id = user_info["user_id"]
_id = user_id
max_budget = user_info["max_budget"]
spend = user_info["spend"]
user_email = user_info["user_email"]
user_info = f"""\nUser ID: {user_id}\nMax Budget: ${max_budget}\nSpend: ${spend}\nUser Email: {user_email}"""
user_info_json = user_info.model_dump(exclude_none=True)
for k, v in user_info_json.items():
user_info_str = "\n{}: {}\n".format(k, v)
event: Optional[
Literal["budget_crossed", "threshold_crossed", "projected_limit_exceeded"]
] = None
event_group: Optional[Literal["user", "team", "key", "proxy"]] = None
event_message: str = ""
webhook_event: Optional[WebhookEvent] = None
if type == "proxy_budget":
event_group = "proxy"
event_message += "Proxy Budget: "
elif type == "user_budget":
event_group = "user"
event_message += "User Budget: "
_id = user_info.user_id or _id
elif type == "team_budget":
event_group = "team"
event_message += "Team Budget: "
_id = user_info.team_id or _id
elif type == "token_budget":
token_info = dict(user_info)
token = token_info["token"]
_id = token
spend = token_info["spend"]
max_budget = token_info["max_budget"]
user_id = token_info["user_id"]
user_info = f"""\nToken: {token}\nSpend: ${spend}\nMax Budget: ${max_budget}\nUser ID: {user_id}"""
elif type == "failed_tracking":
user_id = str(user_info)
_id = user_id
user_info = f"\nUser ID: {user_id}\n Error {error_message}"
message = "Failed Tracking Cost for" + user_info
await self.send_alert(
message=message, level="High", alert_type="budget_alerts"
)
return
elif type == "projected_limit_exceeded" and user_info is not None:
"""
Input variables:
user_info = {
"key_alias": key_alias,
"projected_spend": projected_spend,
"projected_exceeded_date": projected_exceeded_date,
}
user_max_budget=soft_limit,
user_current_spend=new_spend
"""
message = f"""\n🚨 `ProjectedLimitExceededError` 💸\n\n`Key Alias:` {user_info["key_alias"]} \n`Expected Day of Error`: {user_info["projected_exceeded_date"]} \n`Current Spend`: {user_current_spend} \n`Projected Spend at end of month`: {user_info["projected_spend"]} \n`Soft Limit`: {user_max_budget}"""
await self.send_alert(
message=message, level="High", alert_type="budget_alerts"
)
return
else:
user_info = str(user_info)
event_group = "key"
event_message += "Key Budget: "
_id = user_info.token
elif type == "projected_limit_exceeded":
event_group = "key"
event_message += "Key Budget: Projected Limit Exceeded"
event = "projected_limit_exceeded"
_id = user_info.token
# percent of max_budget left to spend
if user_max_budget > 0:
percent_left = (user_max_budget - user_current_spend) / user_max_budget
if user_info.max_budget > 0:
percent_left = (
user_info.max_budget - user_info.spend
) / user_info.max_budget
else:
percent_left = 0
verbose_proxy_logger.debug(
f"Budget Alerts: Percent left: {percent_left} for {user_info}"
)
## PREVENTITIVE ALERTING ## - https://github.com/BerriAI/litellm/issues/2727
# - Alert once within 28d period
# - Cache this information
# - Don't re-alert, if alert already sent
_cache: DualCache = self.internal_usage_cache
# check if crossed budget
if user_current_spend >= user_max_budget:
verbose_proxy_logger.debug("Budget Crossed for %s", user_info)
message = "Budget Crossed for" + user_info
result = await _cache.async_get_cache(key=message)
if result is None:
await self.send_alert(
message=message, level="High", alert_type="budget_alerts"
)
await _cache.async_set_cache(key=message, value="SENT", ttl=2419200)
return
if user_info.spend >= user_info.max_budget:
event = "budget_crossed"
event_message += "Budget Crossed"
elif percent_left <= 0.05:
event = "threshold_crossed"
event_message += "5% Threshold Crossed"
elif percent_left <= 0.15:
event = "threshold_crossed"
event_message += "15% Threshold Crossed"
# check if 5% of max budget is left
if percent_left <= 0.05:
message = "5% budget left for" + user_info
cache_key = "alerting:{}".format(_id)
result = await _cache.async_get_cache(key=cache_key)
if event is not None and event_group is not None:
_cache_key = "budget_alerts:{}:{}".format(event, _id)
result = await _cache.async_get_cache(key=_cache_key)
if result is None:
webhook_event = WebhookEvent(
event=event,
event_group=event_group,
event_message=event_message,
**user_info_json,
)
await self.send_alert(
message=message, level="Medium", alert_type="budget_alerts"
message=event_message + "\n\n" + user_info_str,
level="High",
alert_type="budget_alerts",
user_info=webhook_event,
)
await _cache.async_set_cache(
key=_cache_key,
value="SENT",
ttl=self.alerting_args.budget_alert_ttl,
)
await _cache.async_set_cache(key=cache_key, value="SENT", ttl=2419200)
return
return
# check if 15% of max budget is left
if percent_left <= 0.15:
message = "15% budget left for" + user_info
result = await _cache.async_get_cache(key=message)
if result is None:
await self.send_alert(
message=message, level="Low", alert_type="budget_alerts"
)
await _cache.async_set_cache(key=message, value="SENT", ttl=2419200)
return
return
async def model_added_alert(self, model_name: str, litellm_model_name: str):
async def model_added_alert(
self, model_name: str, litellm_model_name: str, passed_model_info: Any
):
base_model_from_user = getattr(passed_model_info, "base_model", None)
model_info = {}
base_model = ""
if base_model_from_user is not None:
model_info = litellm.model_cost.get(base_model_from_user, {})
base_model = f"Base Model: `{base_model_from_user}`\n"
else:
model_info = litellm.model_cost.get(litellm_model_name, {})
model_info_str = ""
for k, v in model_info.items():
@ -681,6 +713,7 @@ class SlackAlerting(CustomLogger):
message = f"""
*🚅 New Model Added*
Model Name: `{model_name}`
{base_model}
Usage OpenAI Python SDK:
```
@ -715,6 +748,34 @@ Model Info:
async def model_removed_alert(self, model_name: str):
pass
async def send_webhook_alert(self, webhook_event: WebhookEvent) -> bool:
"""
Sends structured alert to webhook, if set.
Currently only implemented for budget alerts
Returns -> True if sent, False if not.
"""
webhook_url = os.getenv("WEBHOOK_URL", None)
if webhook_url is None:
raise Exception("Missing webhook_url from environment")
payload = webhook_event.model_dump_json()
headers = {"Content-type": "application/json"}
response = await self.async_http_handler.post(
url=webhook_url,
headers=headers,
data=payload,
)
if response.status_code == 200:
return True
else:
print("Error sending webhook alert. Error=", response.text) # noqa
return False
async def send_alert(
self,
message: str,
@ -726,9 +787,11 @@ Model Info:
"budget_alerts",
"db_exceptions",
"daily_reports",
"spend_reports",
"new_model_added",
"cooldown_deployment",
],
user_info: Optional[WebhookEvent] = None,
**kwargs,
):
"""
@ -748,6 +811,19 @@ Model Info:
if self.alerting is None:
return
if (
"webhook" in self.alerting
and alert_type == "budget_alerts"
and user_info is not None
):
await self.send_webhook_alert(webhook_event=user_info)
if "slack" not in self.alerting:
return
if alert_type not in self.alert_types:
return
from datetime import datetime
import json
@ -795,6 +871,7 @@ Model Info:
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
"""Log deployment latency"""
try:
if "daily_reports" in self.alert_types:
model_id = (
kwargs.get("litellm_params", {}).get("model_info", {}).get("id", "")
@ -806,7 +883,10 @@ Model Info:
if isinstance(response_obj, litellm.ModelResponse):
completion_tokens = response_obj.usage.completion_tokens
final_value = float(response_s.total_seconds() / completion_tokens)
if completion_tokens is not None and completion_tokens > 0:
final_value = float(
response_s.total_seconds() / completion_tokens
)
await self.async_update_daily_reports(
DeploymentMetrics(
@ -816,6 +896,12 @@ Model Info:
updated_at=litellm.utils.get_utc_datetime(),
)
)
except Exception as e:
verbose_proxy_logger.error(
"[Non-Blocking Error] Slack Alerting: Got error in logging LLM deployment latency: ",
e,
)
pass
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
"""Log failure + deployment latency"""
@ -942,7 +1028,7 @@ Model Info:
await self.send_alert(
message=_weekly_spend_message,
level="Low",
alert_type="daily_reports",
alert_type="spend_reports",
)
except Exception as e:
verbose_proxy_logger.error("Error sending weekly spend report", e)
@ -993,7 +1079,7 @@ Model Info:
await self.send_alert(
message=_spend_message,
level="Low",
alert_type="daily_reports",
alert_type="spend_reports",
)
except Exception as e:
verbose_proxy_logger.error("Error sending weekly spend report", e)

View file

@ -93,6 +93,7 @@ class AnthropicConfig:
"max_tokens",
"tools",
"tool_choice",
"extra_headers",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
@ -504,7 +505,9 @@ class AnthropicChatCompletion(BaseLLM):
## Handle Tool Calling
if "tools" in optional_params:
_is_function_call = True
headers["anthropic-beta"] = "tools-2024-04-04"
if "anthropic-beta" not in headers:
# default to v1 of "anthropic-beta"
headers["anthropic-beta"] = "tools-2024-05-16"
anthropic_tools = []
for tool in optional_params["tools"]:

View file

@ -21,7 +21,7 @@ class BaseLLM:
messages: list,
print_verbose,
encoding,
) -> litellm.utils.ModelResponse:
) -> Union[litellm.utils.ModelResponse, litellm.utils.CustomStreamWrapper]:
"""
Helper function to process the response across sync + async completion calls
"""

View file

@ -1,6 +1,6 @@
# What is this?
## Initial implementation of calling bedrock via httpx client (allows for async calls).
## V0 - just covers cohere command-r support
## V1 - covers cohere + anthropic claude-3 support
import os, types
import json
@ -29,12 +29,20 @@ from litellm.utils import (
get_secret,
Logging,
)
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt, cohere_message_pt
import litellm, uuid
from .prompt_templates.factory import (
prompt_factory,
custom_prompt,
cohere_message_pt,
construct_tool_use_system_prompt,
extract_between_tags,
parse_xml_params,
contains_tag,
)
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from .base import BaseLLM
import httpx # type: ignore
from .bedrock import BedrockError, convert_messages_to_prompt
from .bedrock import BedrockError, convert_messages_to_prompt, ModelResponseIterator
from litellm.types.llms.bedrock import *
@ -280,7 +288,8 @@ class BedrockLLM(BaseLLM):
messages: List,
print_verbose,
encoding,
) -> ModelResponse:
) -> Union[ModelResponse, CustomStreamWrapper]:
provider = model.split(".")[0]
## LOGGING
logging_obj.post_call(
input=messages,
@ -297,26 +306,210 @@ class BedrockLLM(BaseLLM):
raise BedrockError(message=response.text, status_code=422)
try:
model_response.choices[0].message.content = completion_response["text"] # type: ignore
if provider == "cohere":
if "text" in completion_response:
outputText = completion_response["text"] # type: ignore
elif "generations" in completion_response:
outputText = completion_response["generations"][0]["text"]
model_response["finish_reason"] = map_finish_reason(
completion_response["generations"][0]["finish_reason"]
)
elif provider == "anthropic":
if model.startswith("anthropic.claude-3"):
json_schemas: dict = {}
_is_function_call = False
## Handle Tool Calling
if "tools" in optional_params:
_is_function_call = True
for tool in optional_params["tools"]:
json_schemas[tool["function"]["name"]] = tool[
"function"
].get("parameters", None)
outputText = completion_response.get("content")[0].get("text", None)
if outputText is not None and contains_tag(
"invoke", outputText
): # OUTPUT PARSE FUNCTION CALL
function_name = extract_between_tags("tool_name", outputText)[0]
function_arguments_str = extract_between_tags(
"invoke", outputText
)[0].strip()
function_arguments_str = (
f"<invoke>{function_arguments_str}</invoke>"
)
function_arguments = parse_xml_params(
function_arguments_str,
json_schema=json_schemas.get(
function_name, None
), # check if we have a json schema for this function name)
)
_message = litellm.Message(
tool_calls=[
{
"id": f"call_{uuid.uuid4()}",
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(function_arguments),
},
}
],
content=None,
)
model_response.choices[0].message = _message # type: ignore
model_response._hidden_params["original_response"] = (
outputText # allow user to access raw anthropic tool calling response
)
if (
_is_function_call == True
and stream is not None
and stream == True
):
print_verbose(
f"INSIDE BEDROCK STREAMING TOOL CALLING CONDITION BLOCK"
)
# return an iterator
streaming_model_response = ModelResponse(stream=True)
streaming_model_response.choices[0].finish_reason = getattr(
model_response.choices[0], "finish_reason", "stop"
)
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
streaming_choice = litellm.utils.StreamingChoices()
streaming_choice.index = model_response.choices[0].index
_tool_calls = []
print_verbose(
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
)
print_verbose(
f"type of streaming_choice: {type(streaming_choice)}"
)
if isinstance(model_response.choices[0], litellm.Choices):
if getattr(
model_response.choices[0].message, "tool_calls", None
) is not None and isinstance(
model_response.choices[0].message.tool_calls, list
):
for tool_call in model_response.choices[
0
].message.tool_calls:
_tool_call = {**tool_call.dict(), "index": 0}
_tool_calls.append(_tool_call)
delta_obj = litellm.utils.Delta(
content=getattr(
model_response.choices[0].message, "content", None
),
role=model_response.choices[0].message.role,
tool_calls=_tool_calls,
)
streaming_choice.delta = delta_obj
streaming_model_response.choices = [streaming_choice]
completion_stream = ModelResponseIterator(
model_response=streaming_model_response
)
print_verbose(
f"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
)
return litellm.CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
model_response["finish_reason"] = map_finish_reason(
completion_response.get("stop_reason", "")
)
_usage = litellm.Usage(
prompt_tokens=completion_response["usage"]["input_tokens"],
completion_tokens=completion_response["usage"]["output_tokens"],
total_tokens=completion_response["usage"]["input_tokens"]
+ completion_response["usage"]["output_tokens"],
)
setattr(model_response, "usage", _usage)
else:
outputText = completion_response["completion"]
model_response["finish_reason"] = completion_response["stop_reason"]
elif provider == "ai21":
outputText = (
completion_response.get("completions")[0].get("data").get("text")
)
elif provider == "meta":
outputText = completion_response["generation"]
elif provider == "mistral":
outputText = completion_response["outputs"][0]["text"]
model_response["finish_reason"] = completion_response["outputs"][0][
"stop_reason"
]
else: # amazon titan
outputText = completion_response.get("results")[0].get("outputText")
except Exception as e:
raise BedrockError(message=response.text, status_code=422)
raise BedrockError(
message="Error processing={}, Received error={}".format(
response.text, str(e)
),
status_code=422,
)
try:
if (
len(outputText) > 0
and hasattr(model_response.choices[0], "message")
and getattr(model_response.choices[0].message, "tool_calls", None)
is None
):
model_response["choices"][0]["message"]["content"] = outputText
elif (
hasattr(model_response.choices[0], "message")
and getattr(model_response.choices[0].message, "tool_calls", None)
is not None
):
pass
else:
raise Exception()
except:
raise BedrockError(
message=json.dumps(outputText), status_code=response.status_code
)
if stream and provider == "ai21":
streaming_model_response = ModelResponse(stream=True)
streaming_model_response.choices[0].finish_reason = model_response.choices[ # type: ignore
0
].finish_reason
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
streaming_choice = litellm.utils.StreamingChoices()
streaming_choice.index = model_response.choices[0].index
delta_obj = litellm.utils.Delta(
content=getattr(model_response.choices[0].message, "content", None),
role=model_response.choices[0].message.role,
)
streaming_choice.delta = delta_obj
streaming_model_response.choices = [streaming_choice]
mri = ModelResponseIterator(model_response=streaming_model_response)
return CustomStreamWrapper(
completion_stream=mri,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
## CALCULATING USAGE - bedrock returns usage in the headers
bedrock_input_tokens = response.headers.get(
"x-amzn-bedrock-input-token-count", None
)
bedrock_output_tokens = response.headers.get(
"x-amzn-bedrock-output-token-count", None
)
prompt_tokens = int(
response.headers.get(
"x-amzn-bedrock-input-token-count",
len(encoding.encode("".join(m.get("content", "") for m in messages))),
)
bedrock_input_tokens or litellm.token_counter(messages=messages)
)
completion_tokens = int(
response.headers.get(
"x-amzn-bedrock-output-token-count",
len(
encoding.encode(
model_response.choices[0].message.content, # type: ignore
disallowed_special=(),
)
),
bedrock_output_tokens
or litellm.token_counter(
text=model_response.choices[0].message.content, # type: ignore
count_response_tokens=True,
)
)
@ -359,6 +552,7 @@ class BedrockLLM(BaseLLM):
## SETUP ##
stream = optional_params.pop("stream", None)
provider = model.split(".")[0]
## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
@ -414,19 +608,18 @@ class BedrockLLM(BaseLLM):
else:
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
if stream is not None and stream == True:
if (stream is not None and stream == True) and provider != "ai21":
endpoint_url = f"{endpoint_url}/model/{model}/invoke-with-response-stream"
else:
endpoint_url = f"{endpoint_url}/model/{model}/invoke"
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
provider = model.split(".")[0]
prompt, chat_history = self.convert_messages_to_prompt(
model, messages, provider, custom_prompt_dict
)
inference_params = copy.deepcopy(optional_params)
json_schemas: dict = {}
if provider == "cohere":
if model.startswith("cohere.command-r"):
## LOAD CONFIG
@ -453,8 +646,114 @@ class BedrockLLM(BaseLLM):
True # cohere requires stream = True in inference params
)
data = json.dumps({"prompt": prompt, **inference_params})
elif provider == "anthropic":
if model.startswith("anthropic.claude-3"):
# Separate system prompt from rest of message
system_prompt_idx: list[int] = []
system_messages: list[str] = []
for idx, message in enumerate(messages):
if message["role"] == "system":
system_messages.append(message["content"])
system_prompt_idx.append(idx)
if len(system_prompt_idx) > 0:
inference_params["system"] = "\n".join(system_messages)
messages = [
i for j, i in enumerate(messages) if j not in system_prompt_idx
]
# Format rest of message according to anthropic guidelines
messages = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic_xml"
) # type: ignore
## LOAD CONFIG
config = litellm.AmazonAnthropicClaude3Config.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
## Handle Tool Calling
if "tools" in inference_params:
_is_function_call = True
for tool in inference_params["tools"]:
json_schemas[tool["function"]["name"]] = tool["function"].get(
"parameters", None
)
tool_calling_system_prompt = construct_tool_use_system_prompt(
tools=inference_params["tools"]
)
inference_params["system"] = (
inference_params.get("system", "\n")
+ tool_calling_system_prompt
) # add the anthropic tool calling prompt to the system prompt
inference_params.pop("tools")
data = json.dumps({"messages": messages, **inference_params})
else:
raise Exception("UNSUPPORTED PROVIDER")
## LOAD CONFIG
config = litellm.AmazonAnthropicConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = json.dumps({"prompt": prompt, **inference_params})
elif provider == "ai21":
## LOAD CONFIG
config = litellm.AmazonAI21Config.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = json.dumps({"prompt": prompt, **inference_params})
elif provider == "mistral":
## LOAD CONFIG
config = litellm.AmazonMistralConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = json.dumps({"prompt": prompt, **inference_params})
elif provider == "amazon": # amazon titan
## LOAD CONFIG
config = litellm.AmazonTitanConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = json.dumps(
{
"inputText": prompt,
"textGenerationConfig": inference_params,
}
)
elif provider == "meta":
## LOAD CONFIG
config = litellm.AmazonLlamaConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = json.dumps({"prompt": prompt, **inference_params})
else:
## LOGGING
logging_obj.pre_call(
input=messages,
api_key="",
additional_args={
"complete_input_dict": inference_params,
},
)
raise Exception(
"Bedrock HTTPX: Unsupported provider={}, model={}".format(
provider, model
)
)
## COMPLETION CALL
@ -482,7 +781,7 @@ class BedrockLLM(BaseLLM):
if acompletion:
if isinstance(client, HTTPHandler):
client = None
if stream:
if stream == True and provider != "ai21":
return self.async_streaming(
model=model,
messages=messages,
@ -511,7 +810,7 @@ class BedrockLLM(BaseLLM):
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=False,
stream=stream, # type: ignore
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=prepped.headers,
@ -528,7 +827,7 @@ class BedrockLLM(BaseLLM):
self.client = HTTPHandler(**_params) # type: ignore
else:
self.client = client
if stream is not None and stream == True:
if (stream is not None and stream == True) and provider != "ai21":
response = self.client.post(
url=prepped.url,
headers=prepped.headers, # type: ignore
@ -541,7 +840,7 @@ class BedrockLLM(BaseLLM):
status_code=response.status_code, message=response.text
)
decoder = AWSEventStreamDecoder()
decoder = AWSEventStreamDecoder(model=model)
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
streaming_response = CustomStreamWrapper(
@ -550,15 +849,24 @@ class BedrockLLM(BaseLLM):
custom_llm_provider="bedrock",
logging_obj=logging_obj,
)
## LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=streaming_response,
additional_args={"complete_input_dict": data},
)
return streaming_response
response = self.client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
try:
response = self.client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=response.text)
except httpx.TimeoutException as e:
raise BedrockError(status_code=408, message="Timeout error occurred.")
return self.process_response(
model=model,
@ -591,7 +899,7 @@ class BedrockLLM(BaseLLM):
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> ModelResponse:
) -> Union[ModelResponse, CustomStreamWrapper]:
if client is None:
_params = {}
if timeout is not None:
@ -602,12 +910,20 @@ class BedrockLLM(BaseLLM):
else:
self.client = client # type: ignore
try:
response = await self.client.post(api_base, headers=headers, data=data) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=response.text)
except httpx.TimeoutException as e:
raise BedrockError(status_code=408, message="Timeout error occurred.")
return self.process_response(
model=model,
response=response,
model_response=model_response,
stream=stream,
stream=stream if isinstance(stream, bool) else False,
logging_obj=logging_obj,
api_key="",
data=data,
@ -650,7 +966,7 @@ class BedrockLLM(BaseLLM):
if response.status_code != 200:
raise BedrockError(status_code=response.status_code, message=response.text)
decoder = AWSEventStreamDecoder()
decoder = AWSEventStreamDecoder(model=model)
completion_stream = decoder.aiter_bytes(response.aiter_bytes(chunk_size=1024))
streaming_response = CustomStreamWrapper(
@ -659,6 +975,15 @@ class BedrockLLM(BaseLLM):
custom_llm_provider="bedrock",
logging_obj=logging_obj,
)
## LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=streaming_response,
additional_args={"complete_input_dict": data},
)
return streaming_response
def embedding(self, *args, **kwargs):
@ -676,11 +1001,70 @@ def get_response_stream_shape():
class AWSEventStreamDecoder:
def __init__(self) -> None:
def __init__(self, model: str) -> None:
from botocore.parsers import EventStreamJSONParser
self.model = model
self.parser = EventStreamJSONParser()
def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
text = ""
is_finished = False
finish_reason = ""
if "outputText" in chunk_data:
text = chunk_data["outputText"]
# ai21 mapping
if "ai21" in self.model: # fake ai21 streaming
text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore
is_finished = True
finish_reason = "stop"
######## bedrock.anthropic mappings ###############
elif "completion" in chunk_data: # not claude-3
text = chunk_data["completion"] # bedrock.anthropic
stop_reason = chunk_data.get("stop_reason", None)
if stop_reason != None:
is_finished = True
finish_reason = stop_reason
elif "delta" in chunk_data:
if chunk_data["delta"].get("text", None) is not None:
text = chunk_data["delta"]["text"]
stop_reason = chunk_data["delta"].get("stop_reason", None)
if stop_reason != None:
is_finished = True
finish_reason = stop_reason
######## bedrock.mistral mappings ###############
elif "outputs" in chunk_data:
if (
len(chunk_data["outputs"]) == 1
and chunk_data["outputs"][0].get("text", None) is not None
):
text = chunk_data["outputs"][0]["text"]
stop_reason = chunk_data.get("stop_reason", None)
if stop_reason != None:
is_finished = True
finish_reason = stop_reason
######## bedrock.cohere mappings ###############
# meta mapping
elif "generation" in chunk_data:
text = chunk_data["generation"] # bedrock.meta
# cohere mapping
elif "text" in chunk_data:
text = chunk_data["text"] # bedrock.cohere
# cohere mapping for finish reason
elif "finish_reason" in chunk_data:
finish_reason = chunk_data["finish_reason"]
is_finished = True
elif chunk_data.get("completionReason", None):
is_finished = True
finish_reason = chunk_data["completionReason"]
return GenericStreamingChunk(
**{
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
)
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]:
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
from botocore.eventstream import EventStreamBuffer
@ -693,12 +1077,7 @@ class AWSEventStreamDecoder:
if message:
# sse_event = ServerSentEvent(data=message, event="completion")
_data = json.loads(message)
streaming_chunk: GenericStreamingChunk = GenericStreamingChunk(
text=_data.get("text", ""),
is_finished=_data.get("is_finished", False),
finish_reason=_data.get("finish_reason", ""),
)
yield streaming_chunk
yield self._chunk_parser(chunk_data=_data)
async def aiter_bytes(
self, iterator: AsyncIterator[bytes]
@ -713,12 +1092,7 @@ class AWSEventStreamDecoder:
message = self._parse_message_from_event(event)
if message:
_data = json.loads(message)
streaming_chunk: GenericStreamingChunk = GenericStreamingChunk(
text=_data.get("text", ""),
is_finished=_data.get("is_finished", False),
finish_reason=_data.get("finish_reason", ""),
)
yield streaming_chunk
yield self._chunk_parser(chunk_data=_data)
def _parse_message_from_event(self, event) -> Optional[str]:
response_dict = event.to_response_dict()

View file

@ -260,7 +260,7 @@ def completion(
message_obj = Message(content=item.content.parts[0].text)
else:
message_obj = Message(content=None)
choice_obj = Choices(index=idx + 1, message=message_obj)
choice_obj = Choices(index=idx, message=message_obj)
choices_list.append(choice_obj)
model_response["choices"] = choices_list
except Exception as e:
@ -352,7 +352,7 @@ async def async_completion(
message_obj = Message(content=item.content.parts[0].text)
else:
message_obj = Message(content=None)
choice_obj = Choices(index=idx + 1, message=message_obj)
choice_obj = Choices(index=idx, message=message_obj)
choices_list.append(choice_obj)
model_response["choices"] = choices_list
except Exception as e:

View file

@ -96,7 +96,7 @@ class MistralConfig:
safe_prompt: Optional[bool] = None,
response_format: Optional[dict] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@ -211,7 +211,7 @@ class OpenAIConfig:
temperature: Optional[int] = None,
top_p: Optional[int] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@ -234,6 +234,47 @@ class OpenAIConfig:
and v is not None
}
def get_supported_openai_params(self, model: str) -> list:
base_params = [
"frequency_penalty",
"logit_bias",
"logprobs",
"top_logprobs",
"max_tokens",
"n",
"presence_penalty",
"seed",
"stop",
"stream",
"stream_options",
"temperature",
"top_p",
"tools",
"tool_choice",
"user",
"function_call",
"functions",
"max_retries",
"extra_headers",
] # works across all models
model_specific_params = []
if (
model != "gpt-3.5-turbo-16k" and model != "gpt-4"
): # gpt-4 does not support 'response_format'
model_specific_params.append("response_format")
return base_params + model_specific_params
def map_openai_params(
self, non_default_params: dict, optional_params: dict, model: str
) -> dict:
supported_openai_params = self.get_supported_openai_params(model)
for param, value in non_default_params.items():
if param in supported_openai_params:
optional_params[param] = value
return optional_params
class OpenAITextCompletionConfig:
"""
@ -294,7 +335,7 @@ class OpenAITextCompletionConfig:
temperature: Optional[float] = None,
top_p: Optional[float] = None,
) -> None:
locals_ = locals()
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

View file

@ -12,6 +12,7 @@ from typing import (
Sequence,
)
import litellm
import litellm.types
from litellm.types.completion import (
ChatCompletionUserMessageParam,
ChatCompletionSystemMessageParam,
@ -20,9 +21,12 @@ from litellm.types.completion import (
ChatCompletionMessageToolCallParam,
ChatCompletionToolMessageParam,
)
import litellm.types.llms
from litellm.types.llms.anthropic import *
import uuid
import litellm.types.llms.vertex_ai
def default_pt(messages):
return " ".join(message["content"] for message in messages)
@ -841,6 +845,175 @@ def anthropic_messages_pt_xml(messages: list):
# ------------------------------------------------------------------------------
def infer_protocol_value(
value: Any,
) -> Literal[
"string_value",
"number_value",
"bool_value",
"struct_value",
"list_value",
"null_value",
"unknown",
]:
if value is None:
return "null_value"
if isinstance(value, int) or isinstance(value, float):
return "number_value"
if isinstance(value, str):
return "string_value"
if isinstance(value, bool):
return "bool_value"
if isinstance(value, dict):
return "struct_value"
if isinstance(value, list):
return "list_value"
return "unknown"
def convert_to_gemini_tool_call_invoke(
tool_calls: list,
) -> List[litellm.types.llms.vertex_ai.PartType]:
"""
OpenAI tool invokes:
{
"role": "assistant",
"content": null,
"tool_calls": [
{
"id": "call_abc123",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": "{\n\"location\": \"Boston, MA\"\n}"
}
}
]
},
"""
"""
Gemini tool call invokes: - https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/function-calling#submit-api-output
content {
role: "model"
parts [
{
function_call {
name: "get_current_weather"
args {
fields {
key: "unit"
value {
string_value: "fahrenheit"
}
}
fields {
key: "predicted_temperature"
value {
number_value: 45
}
}
fields {
key: "location"
value {
string_value: "Boston, MA"
}
}
}
},
{
function_call {
name: "get_current_weather"
args {
fields {
key: "location"
value {
string_value: "San Francisco"
}
}
}
}
}
]
}
"""
"""
- json.load the arguments
- iterate through arguments -> create a FunctionCallArgs for each field
"""
try:
_parts_list: List[litellm.types.llms.vertex_ai.PartType] = []
for tool in tool_calls:
if "function" in tool:
name = tool["function"].get("name", "")
arguments = tool["function"].get("arguments", "")
arguments_dict = json.loads(arguments)
for k, v in arguments_dict.items():
inferred_protocol_value = infer_protocol_value(value=v)
_field = litellm.types.llms.vertex_ai.Field(
key=k, value={inferred_protocol_value: v}
)
_fields = litellm.types.llms.vertex_ai.FunctionCallArgs(
fields=_field
)
function_call = litellm.types.llms.vertex_ai.FunctionCall(
name=name,
args=_fields,
)
_parts_list.append(
litellm.types.llms.vertex_ai.PartType(function_call=function_call)
)
return _parts_list
except Exception as e:
raise Exception(
"Unable to convert openai tool calls={} to gemini tool calls. Received error={}".format(
tool_calls, str(e)
)
)
def convert_to_gemini_tool_call_result(
message: dict,
) -> litellm.types.llms.vertex_ai.PartType:
"""
OpenAI message with a tool result looks like:
{
"tool_call_id": "tool_1",
"role": "tool",
"name": "get_current_weather",
"content": "function result goes here",
},
OpenAI message with a function call result looks like:
{
"role": "function",
"name": "get_current_weather",
"content": "function result goes here",
}
"""
content = message.get("content", "")
name = message.get("name", "")
# We can't determine from openai message format whether it's a successful or
# error call result so default to the successful result template
inferred_content_value = infer_protocol_value(value=content)
_field = litellm.types.llms.vertex_ai.Field(
key="content", value={inferred_content_value: content}
)
_function_call_args = litellm.types.llms.vertex_ai.FunctionCallArgs(fields=_field)
_function_response = litellm.types.llms.vertex_ai.FunctionResponse(
name=name, response=_function_call_args
)
_part = litellm.types.llms.vertex_ai.PartType(function_response=_function_response)
return _part
def convert_to_anthropic_tool_result(message: dict) -> dict:
"""
OpenAI message with a tool result looks like:
@ -1328,6 +1501,7 @@ def _gemini_vision_convert_messages(messages: list):
# Case 1: Image from URL
image = _load_image_from_url(img)
processed_images.append(image)
else:
try:
from PIL import Image
@ -1335,7 +1509,21 @@ def _gemini_vision_convert_messages(messages: list):
raise Exception(
"gemini image conversion failed please run `pip install Pillow`"
)
# Case 2: Image filepath (e.g. temp.jpeg) given
if "base64" in img:
# Case 2: Base64 image data
import base64
import io
# Extract the base64 image data
base64_data = img.split("base64,")[1]
# Decode the base64 image data
image_data = base64.b64decode(base64_data)
# Load the image from the decoded data
image = Image.open(io.BytesIO(image_data))
else:
# Case 3: Image filepath (e.g. temp.jpeg) given
image = Image.open(img)
processed_images.append(image)
content = [prompt] + processed_images

View file

@ -2,11 +2,12 @@ import os, types
import json
import requests # type: ignore
import time
from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage
import litellm
from typing import Callable, Optional, Union, Tuple, Any
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
import litellm, asyncio
import httpx # type: ignore
from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
class ReplicateError(Exception):
@ -145,6 +146,65 @@ def start_prediction(
)
async def async_start_prediction(
version_id,
input_data,
api_token,
api_base,
logging_obj,
print_verbose,
http_handler: AsyncHTTPHandler,
) -> str:
base_url = api_base
if "deployments" in version_id:
print_verbose("\nLiteLLM: Request to custom replicate deployment")
version_id = version_id.replace("deployments/", "")
base_url = f"https://api.replicate.com/v1/deployments/{version_id}"
print_verbose(f"Deployment base URL: {base_url}\n")
else: # assume it's a model
base_url = f"https://api.replicate.com/v1/models/{version_id}"
headers = {
"Authorization": f"Token {api_token}",
"Content-Type": "application/json",
}
initial_prediction_data = {
"input": input_data,
}
if ":" in version_id and len(version_id) > 64:
model_parts = version_id.split(":")
if (
len(model_parts) > 1 and len(model_parts[1]) == 64
): ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"
initial_prediction_data["version"] = model_parts[1]
## LOGGING
logging_obj.pre_call(
input=input_data["prompt"],
api_key="",
additional_args={
"complete_input_dict": initial_prediction_data,
"headers": headers,
"api_base": base_url,
},
)
response = await http_handler.post(
url="{}/predictions".format(base_url),
data=json.dumps(initial_prediction_data),
headers=headers,
)
if response.status_code == 201:
response_data = response.json()
return response_data.get("urls", {}).get("get")
else:
raise ReplicateError(
response.status_code, f"Failed to start prediction {response.text}"
)
# Function to handle prediction response (non-streaming)
def handle_prediction_response(prediction_url, api_token, print_verbose):
output_string = ""
@ -178,6 +238,40 @@ def handle_prediction_response(prediction_url, api_token, print_verbose):
return output_string, logs
async def async_handle_prediction_response(
prediction_url, api_token, print_verbose, http_handler: AsyncHTTPHandler
) -> Tuple[str, Any]:
output_string = ""
headers = {
"Authorization": f"Token {api_token}",
"Content-Type": "application/json",
}
status = ""
logs = ""
while True and (status not in ["succeeded", "failed", "canceled"]):
print_verbose(f"replicate: polling endpoint: {prediction_url}")
await asyncio.sleep(0.5)
response = await http_handler.get(prediction_url, headers=headers)
if response.status_code == 200:
response_data = response.json()
if "output" in response_data:
output_string = "".join(response_data["output"])
print_verbose(f"Non-streamed output:{output_string}")
status = response_data.get("status", None)
logs = response_data.get("logs", "")
if status == "failed":
replicate_error = response_data.get("error", "")
raise ReplicateError(
status_code=400,
message=f"Error: {replicate_error}, \nReplicate logs:{logs}",
)
else:
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
print_verbose("Replicate: Failed to fetch prediction status and output.")
return output_string, logs
# Function to handle prediction response (streaming)
def handle_prediction_response_streaming(prediction_url, api_token, print_verbose):
previous_output = ""
@ -214,6 +308,45 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos
)
# Function to handle prediction response (streaming)
async def async_handle_prediction_response_streaming(
prediction_url, api_token, print_verbose
):
http_handler = AsyncHTTPHandler(concurrent_limit=1)
previous_output = ""
output_string = ""
headers = {
"Authorization": f"Token {api_token}",
"Content-Type": "application/json",
}
status = ""
while True and (status not in ["succeeded", "failed", "canceled"]):
await asyncio.sleep(0.5) # prevent being rate limited by replicate
print_verbose(f"replicate: polling endpoint: {prediction_url}")
response = await http_handler.get(prediction_url, headers=headers)
if response.status_code == 200:
response_data = response.json()
status = response_data["status"]
if "output" in response_data:
output_string = "".join(response_data["output"])
new_output = output_string[len(previous_output) :]
print_verbose(f"New chunk: {new_output}")
yield {"output": new_output, "status": status}
previous_output = output_string
status = response_data["status"]
if status == "failed":
replicate_error = response_data.get("error", "")
raise ReplicateError(
status_code=400, message=f"Error: {replicate_error}"
)
else:
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
print_verbose(
f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}"
)
# Function to extract version ID from model string
def model_to_version_id(model):
if ":" in model:
@ -222,6 +355,39 @@ def model_to_version_id(model):
return model
def process_response(
model_response: ModelResponse,
result: str,
model: str,
encoding: Any,
prompt: str,
) -> ModelResponse:
if len(result) == 0: # edge case, where result from replicate is empty
result = " "
## Building RESPONSE OBJECT
if len(result) > 1:
model_response["choices"][0]["message"]["content"] = result
# Calculate usage
prompt_tokens = len(encoding.encode(prompt, disallowed_special=()))
completion_tokens = len(
encoding.encode(
model_response["choices"][0]["message"].get("content", ""),
disallowed_special=(),
)
)
model_response["model"] = "replicate/" + model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
# Main function for prediction completion
def completion(
model: str,
@ -229,14 +395,15 @@ def completion(
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
optional_params: dict,
logging_obj,
api_key,
encoding,
custom_prompt_dict={},
optional_params=None,
litellm_params=None,
logger_fn=None,
):
acompletion=None,
) -> Union[ModelResponse, CustomStreamWrapper]:
# Start a prediction and get the prediction URL
version_id = model_to_version_id(model)
## Load Config
@ -274,6 +441,12 @@ def completion(
else:
prompt = prompt_factory(model=model, messages=messages)
if prompt is None or not isinstance(prompt, str):
raise ReplicateError(
status_code=400,
message="LiteLLM Error - prompt is not a string - {}".format(prompt),
)
# If system prompt is supported, and a system prompt is provided, use it
if system_prompt is not None:
input_data = {
@ -285,6 +458,20 @@ def completion(
else:
input_data = {"prompt": prompt, **optional_params}
if acompletion is not None and acompletion == True:
return async_completion(
model_response=model_response,
model=model,
prompt=prompt,
encoding=encoding,
optional_params=optional_params,
version_id=version_id,
input_data=input_data,
api_key=api_key,
api_base=api_base,
logging_obj=logging_obj,
print_verbose=print_verbose,
) # type: ignore
## COMPLETION CALL
## Replicate Compeltion calls have 2 steps
## Step1: Start Prediction: gets a prediction url
@ -293,6 +480,7 @@ def completion(
model_response["created"] = int(
time.time()
) # for pricing this must remain right before calling api
prediction_url = start_prediction(
version_id,
input_data,
@ -306,9 +494,10 @@ def completion(
# Handle the prediction response (streaming or non-streaming)
if "stream" in optional_params and optional_params["stream"] == True:
print_verbose("streaming request")
return handle_prediction_response_streaming(
_response = handle_prediction_response_streaming(
prediction_url, api_key, print_verbose
)
return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore
else:
result, logs = handle_prediction_response(
prediction_url, api_key, print_verbose
@ -328,29 +517,56 @@ def completion(
print_verbose(f"raw model_response: {result}")
if len(result) == 0: # edge case, where result from replicate is empty
result = " "
return process_response(
model_response=model_response,
result=result,
model=model,
encoding=encoding,
prompt=prompt,
)
## Building RESPONSE OBJECT
if len(result) > 1:
model_response["choices"][0]["message"]["content"] = result
# Calculate usage
prompt_tokens = len(encoding.encode(prompt, disallowed_special=()))
completion_tokens = len(
encoding.encode(
model_response["choices"][0]["message"].get("content", ""),
disallowed_special=(),
async def async_completion(
model_response: ModelResponse,
model: str,
prompt: str,
encoding,
optional_params: dict,
version_id,
input_data,
api_key,
api_base,
logging_obj,
print_verbose,
) -> Union[ModelResponse, CustomStreamWrapper]:
http_handler = AsyncHTTPHandler(concurrent_limit=1)
prediction_url = await async_start_prediction(
version_id,
input_data,
api_key,
api_base,
logging_obj=logging_obj,
print_verbose=print_verbose,
http_handler=http_handler,
)
if "stream" in optional_params and optional_params["stream"] == True:
_response = async_handle_prediction_response_streaming(
prediction_url, api_key, print_verbose
)
model_response["model"] = "replicate/" + model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore
result, logs = await async_handle_prediction_response(
prediction_url, api_key, print_verbose, http_handler=http_handler
)
return process_response(
model_response=model_response,
result=result,
model=model,
encoding=encoding,
prompt=prompt,
)
setattr(model_response, "usage", usage)
return model_response
# # Example usage:

View file

@ -3,10 +3,15 @@ import json
from enum import Enum
import requests # type: ignore
import time
from typing import Callable, Optional, Union, List
from typing import Callable, Optional, Union, List, Literal
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
import litellm, uuid
import httpx, inspect # type: ignore
from litellm.types.llms.vertex_ai import *
from litellm.llms.prompt_templates.factory import (
convert_to_gemini_tool_call_result,
convert_to_gemini_tool_call_invoke,
)
class VertexAIError(Exception):
@ -283,6 +288,125 @@ def _load_image_from_url(image_url: str):
return Image.from_bytes(data=image_bytes)
def _convert_gemini_role(role: str) -> Literal["user", "model"]:
if role == "user":
return "user"
else:
return "model"
def _process_gemini_image(image_url: str) -> PartType:
try:
if "gs://" in image_url:
# Case 1: Images with Cloud Storage URIs
# The supported MIME types for images include image/png and image/jpeg.
part_mime = "image/png" if "png" in image_url else "image/jpeg"
_file_data = FileDataType(mime_type=part_mime, file_uri=image_url)
return PartType(file_data=_file_data)
elif "https:/" in image_url:
# Case 2: Images with direct links
image = _load_image_from_url(image_url)
_blob = BlobType(data=image.data, mime_type=image._mime_type)
return PartType(inline_data=_blob)
elif ".mp4" in image_url and "gs://" in image_url:
# Case 3: Videos with Cloud Storage URIs
part_mime = "video/mp4"
_file_data = FileDataType(mime_type=part_mime, file_uri=image_url)
return PartType(file_data=_file_data)
elif "base64" in image_url:
# Case 4: Images with base64 encoding
import base64, re
# base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image>
image_metadata, img_without_base_64 = image_url.split(",")
# read mime_type from img_without_base_64=data:image/jpeg;base64
# Extract MIME type using regular expression
mime_type_match = re.match(r"data:(.*?);base64", image_metadata)
if mime_type_match:
mime_type = mime_type_match.group(1)
else:
mime_type = "image/jpeg"
decoded_img = base64.b64decode(img_without_base_64)
_blob = BlobType(data=decoded_img, mime_type=mime_type)
return PartType(inline_data=_blob)
raise Exception("Invalid image received - {}".format(image_url))
except Exception as e:
raise e
def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
"""
Converts given messages from OpenAI format to Gemini format
- Parts must be iterable
- Roles must alternate b/w 'user' and 'model' (same as anthropic -> merge consecutive roles)
- Please ensure that function response turn comes immediately after a function call turn
"""
user_message_types = {"user", "system"}
contents: List[ContentType] = []
msg_i = 0
while msg_i < len(messages):
user_content: List[PartType] = []
init_msg_i = msg_i
## MERGE CONSECUTIVE USER CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types:
if isinstance(messages[msg_i]["content"], list):
_parts: List[PartType] = []
for element in messages[msg_i]["content"]:
if isinstance(element, dict):
if element["type"] == "text":
_part = PartType(text=element["text"])
_parts.append(_part)
elif element["type"] == "image_url":
image_url = element["image_url"]["url"]
_part = _process_gemini_image(image_url=image_url)
_parts.append(_part) # type: ignore
user_content.extend(_parts)
else:
_part = PartType(text=messages[msg_i]["content"])
user_content.append(_part)
msg_i += 1
if user_content:
contents.append(ContentType(role="user", parts=user_content))
assistant_content = []
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
assistant_text = (
messages[msg_i].get("content") or ""
) # either string or none
if assistant_text:
assistant_content.append(PartType(text=assistant_text))
if messages[msg_i].get(
"tool_calls", []
): # support assistant tool invoke convertion
assistant_content.extend(
convert_to_gemini_tool_call_invoke(messages[msg_i]["tool_calls"])
)
msg_i += 1
if assistant_content:
contents.append(ContentType(role="model", parts=assistant_content))
## APPEND TOOL CALL MESSAGES ##
if msg_i < len(messages) and messages[msg_i]["role"] == "tool":
_part = convert_to_gemini_tool_call_result(messages[msg_i])
contents.append(ContentType(parts=[_part])) # type: ignore
msg_i += 1
if msg_i == init_msg_i: # prevent infinite loops
raise Exception(
"Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format(
messages[msg_i]
)
)
return contents
def _gemini_vision_convert_messages(messages: list):
"""
Converts given messages for GPT-4 Vision to Gemini format.
@ -396,10 +520,10 @@ def completion(
print_verbose: Callable,
encoding,
logging_obj,
optional_params: dict,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
optional_params=None,
litellm_params=None,
logger_fn=None,
acompletion: bool = False,
@ -556,6 +680,7 @@ def completion(
"model_response": model_response,
"encoding": encoding,
"messages": messages,
"request_str": request_str,
"print_verbose": print_verbose,
"client_options": client_options,
"instances": instances,
@ -574,11 +699,9 @@ def completion(
print_verbose("\nMaking VertexAI Gemini Pro / Pro Vision Call")
print_verbose(f"\nProcessing input messages = {messages}")
tools = optional_params.pop("tools", None)
prompt, images = _gemini_vision_convert_messages(messages=messages)
content = [prompt] + images
content = _gemini_convert_messages_with_history(messages=messages)
stream = optional_params.pop("stream", False)
if stream == True:
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
logging_obj.pre_call(
input=prompt,
@ -589,7 +712,7 @@ def completion(
},
)
model_response = llm_model.generate_content(
_model_response = llm_model.generate_content(
contents=content,
generation_config=optional_params,
safety_settings=safety_settings,
@ -597,7 +720,7 @@ def completion(
tools=tools,
)
return model_response
return _model_response
request_str += f"response = llm_model.generate_content({content})\n"
## LOGGING
@ -850,12 +973,12 @@ async def async_completion(
mode: str,
prompt: str,
model: str,
messages: list,
model_response: ModelResponse,
logging_obj=None,
request_str=None,
request_str: str,
print_verbose: Callable,
logging_obj,
encoding=None,
messages=None,
print_verbose=None,
client_options=None,
instances=None,
vertex_project=None,
@ -875,8 +998,7 @@ async def async_completion(
tools = optional_params.pop("tools", None)
stream = optional_params.pop("stream", False)
prompt, images = _gemini_vision_convert_messages(messages=messages)
content = [prompt] + images
content = _gemini_convert_messages_with_history(messages=messages)
request_str += f"response = llm_model.generate_content({content})\n"
## LOGGING
@ -1076,11 +1198,11 @@ async def async_streaming(
prompt: str,
model: str,
model_response: ModelResponse,
logging_obj=None,
request_str=None,
messages: list,
print_verbose: Callable,
logging_obj,
request_str: str,
encoding=None,
messages=None,
print_verbose=None,
client_options=None,
instances=None,
vertex_project=None,
@ -1097,8 +1219,8 @@ async def async_streaming(
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
print_verbose(f"\nProcessing input messages = {messages}")
prompt, images = _gemini_vision_convert_messages(messages=messages)
content = [prompt] + images
content = _gemini_convert_messages_with_history(messages=messages)
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
logging_obj.pre_call(
input=prompt,

View file

@ -0,0 +1,224 @@
import os, types
import json
from enum import Enum
import requests # type: ignore
import time
from typing import Callable, Optional, Union, List, Any, Tuple
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
import litellm, uuid
import httpx, inspect # type: ignore
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from .base import BaseLLM
class VertexAIError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url=" https://cloud.google.com/vertex-ai/"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class VertexLLM(BaseLLM):
def __init__(self) -> None:
super().__init__()
self.access_token: Optional[str] = None
self.refresh_token: Optional[str] = None
self._credentials: Optional[Any] = None
self.project_id: Optional[str] = None
self.async_handler: Optional[AsyncHTTPHandler] = None
def load_auth(self) -> Tuple[Any, str]:
from google.auth.transport.requests import Request # type: ignore[import-untyped]
from google.auth.credentials import Credentials # type: ignore[import-untyped]
import google.auth as google_auth
credentials, project_id = google_auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
credentials.refresh(Request())
if not project_id:
raise ValueError("Could not resolve project_id")
if not isinstance(project_id, str):
raise TypeError(
f"Expected project_id to be a str but got {type(project_id)}"
)
return credentials, project_id
def refresh_auth(self, credentials: Any) -> None:
from google.auth.transport.requests import Request # type: ignore[import-untyped]
credentials.refresh(Request())
def _prepare_request(self, request: httpx.Request) -> None:
access_token = self._ensure_access_token()
if request.headers.get("Authorization"):
# already authenticated, nothing for us to do
return
request.headers["Authorization"] = f"Bearer {access_token}"
def _ensure_access_token(self) -> str:
if self.access_token is not None:
return self.access_token
if not self._credentials:
self._credentials, project_id = self.load_auth()
if not self.project_id:
self.project_id = project_id
else:
self.refresh_auth(self._credentials)
if not self._credentials.token:
raise RuntimeError("Could not resolve API token from the environment")
assert isinstance(self._credentials.token, str)
return self._credentials.token
def image_generation(
self,
prompt: str,
vertex_project: str,
vertex_location: str,
model: Optional[
str
] = "imagegeneration", # vertex ai uses imagegeneration as the default model
client: Optional[AsyncHTTPHandler] = None,
optional_params: Optional[dict] = None,
timeout: Optional[int] = None,
logging_obj=None,
model_response=None,
aimg_generation=False,
):
if aimg_generation == True:
response = self.aimage_generation(
prompt=prompt,
vertex_project=vertex_project,
vertex_location=vertex_location,
model=model,
client=client,
optional_params=optional_params,
timeout=timeout,
logging_obj=logging_obj,
model_response=model_response,
)
return response
async def aimage_generation(
self,
prompt: str,
vertex_project: str,
vertex_location: str,
model_response: litellm.ImageResponse,
model: Optional[
str
] = "imagegeneration", # vertex ai uses imagegeneration as the default model
client: Optional[AsyncHTTPHandler] = None,
optional_params: Optional[dict] = None,
timeout: Optional[int] = None,
logging_obj=None,
):
response = None
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
_httpx_timeout = httpx.Timeout(timeout)
_params["timeout"] = _httpx_timeout
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
self.async_handler = AsyncHTTPHandler(**_params) # type: ignore
else:
self.async_handler = client # type: ignore
# make POST request to
# https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
"""
Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
curl -X POST \
-H "Authorization: Bearer $(gcloud auth print-access-token)" \
-H "Content-Type: application/json; charset=utf-8" \
-d {
"instances": [
{
"prompt": "a cat"
}
],
"parameters": {
"sampleCount": 1
}
} \
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
"""
auth_header = self._ensure_access_token()
optional_params = optional_params or {
"sampleCount": 1
} # default optional params
request_data = {
"instances": [{"prompt": prompt}],
"parameters": optional_params,
}
request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
response = await self.async_handler.post(
url=url,
headers={
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {auth_header}",
},
data=json.dumps(request_data),
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
"""
Vertex AI Image generation response example:
{
"predictions": [
{
"bytesBase64Encoded": "BASE64_IMG_BYTES",
"mimeType": "image/png"
},
{
"mimeType": "image/png",
"bytesBase64Encoded": "BASE64_IMG_BYTES"
}
]
}
"""
_json_response = response.json()
_predictions = _json_response["predictions"]
_response_data: List[litellm.ImageObject] = []
for _prediction in _predictions:
_bytes_base64_encoded = _prediction["bytesBase64Encoded"]
image_object = litellm.ImageObject(b64_json=_bytes_base64_encoded)
_response_data.append(image_object)
model_response.data = _response_data
return model_response

View file

@ -79,6 +79,7 @@ from .llms.anthropic_text import AnthropicTextCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.predibase import PredibaseChatCompletion
from .llms.bedrock_httpx import BedrockLLM
from .llms.vertex_httpx import VertexLLM
from .llms.triton import TritonChatCompletion
from .llms.prompt_templates.factory import (
prompt_factory,
@ -118,6 +119,7 @@ huggingface = Huggingface()
predibase_chat_completions = PredibaseChatCompletion()
triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM()
vertex_chat_completion = VertexLLM()
####### COMPLETION ENDPOINTS ################
@ -320,12 +322,13 @@ async def acompletion(
or custom_llm_provider == "huggingface"
or custom_llm_provider == "ollama"
or custom_llm_provider == "ollama_chat"
or custom_llm_provider == "replicate"
or custom_llm_provider == "vertex_ai"
or custom_llm_provider == "gemini"
or custom_llm_provider == "sagemaker"
or custom_llm_provider == "anthropic"
or custom_llm_provider == "predibase"
or (custom_llm_provider == "bedrock" and "cohere" in model)
or custom_llm_provider == "bedrock"
or custom_llm_provider in litellm.openai_compatible_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
init_response = await loop.run_in_executor(None, func_with_context)
@ -367,6 +370,8 @@ async def acompletion(
async def _async_streaming(response, model, custom_llm_provider, args):
try:
print_verbose(f"received response in _async_streaming: {response}")
if asyncio.iscoroutine(response):
response = await response
async for line in response:
print_verbose(f"line in async streaming: {line}")
yield line
@ -552,7 +557,7 @@ def completion(
model_info = kwargs.get("model_info", None)
proxy_server_request = kwargs.get("proxy_server_request", None)
fallbacks = kwargs.get("fallbacks", None)
headers = kwargs.get("headers", None)
headers = kwargs.get("headers", None) or extra_headers
num_retries = kwargs.get("num_retries", None) ## deprecated
max_retries = kwargs.get("max_retries", None)
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
@ -674,20 +679,6 @@ def completion(
k: v for k, v in kwargs.items() if k not in default_params
} # model-specific params - pass them straight to the model/provider
### TIMEOUT LOGIC ###
timeout = timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
try:
if base_url is not None:
api_base = base_url
@ -727,6 +718,16 @@ def completion(
"aws_region_name", None
) # support region-based pricing for bedrock
### TIMEOUT LOGIC ###
timeout = timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if isinstance(timeout, httpx.Timeout) and not supports_httpx_timeout(
custom_llm_provider
):
timeout = timeout.read or 600 # default 10 min timeout
elif not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ###
if input_cost_per_token is not None and output_cost_per_token is not None:
litellm.register_model(
@ -1192,7 +1193,7 @@ def completion(
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
model_response = replicate.completion(
model_response = replicate.completion( # type: ignore
model=model,
messages=messages,
api_base=api_base,
@ -1205,12 +1206,10 @@ def completion(
api_key=replicate_key,
logging_obj=logging,
custom_prompt_dict=custom_prompt_dict,
acompletion=acompletion,
)
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
model_response = CustomStreamWrapper(model_response, model, logging_obj=logging, custom_llm_provider="replicate") # type: ignore
if optional_params.get("stream", False) or acompletion == True:
if optional_params.get("stream", False) == True:
## LOGGING
logging.post_call(
input=messages,
@ -1984,23 +1983,9 @@ def completion(
# boto3 reads keys from .env
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
if "cohere" in model:
response = bedrock_chat_completion.completion(
model=model,
messages=messages,
custom_prompt_dict=litellm.custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
extra_headers=extra_headers,
timeout=timeout,
acompletion=acompletion,
)
else:
if (
"aws_bedrock_client" in optional_params
): # use old bedrock flow for aws_bedrock_client users.
response = bedrock.completion(
model=model,
messages=messages,
@ -2036,7 +2021,22 @@ def completion(
custom_llm_provider="bedrock",
logging_obj=logging,
)
else:
response = bedrock_chat_completion.completion(
model=model,
messages=messages,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
extra_headers=extra_headers,
timeout=timeout,
acompletion=acompletion,
)
if optional_params.get("stream", False):
## LOGGING
logging.post_call(
@ -3856,6 +3856,36 @@ def image_generation(
model_response=model_response,
aimg_generation=aimg_generation,
)
elif custom_llm_provider == "vertex_ai":
vertex_ai_project = (
optional_params.pop("vertex_project", None)
or optional_params.pop("vertex_ai_project", None)
or litellm.vertex_project
or get_secret("VERTEXAI_PROJECT")
)
vertex_ai_location = (
optional_params.pop("vertex_location", None)
or optional_params.pop("vertex_ai_location", None)
or litellm.vertex_location
or get_secret("VERTEXAI_LOCATION")
)
vertex_credentials = (
optional_params.pop("vertex_credentials", None)
or optional_params.pop("vertex_ai_credentials", None)
or get_secret("VERTEXAI_CREDENTIALS")
)
model_response = vertex_chat_completion.image_generation(
model=model,
prompt=prompt,
timeout=timeout,
logging_obj=litellm_logging_obj,
optional_params=optional_params,
model_response=model_response,
vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location,
aimg_generation=aimg_generation,
)
return model_response
except Exception as e:
## Map to OpenAI Exception

View file

@ -234,6 +234,24 @@
"litellm_provider": "openai",
"mode": "chat"
},
"ft:davinci-002": {
"max_tokens": 16384,
"max_input_tokens": 16384,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000002,
"output_cost_per_token": 0.000002,
"litellm_provider": "text-completion-openai",
"mode": "completion"
},
"ft:babbage-002": {
"max_tokens": 16384,
"max_input_tokens": 16384,
"max_output_tokens": 4096,
"input_cost_per_token": 0.0000004,
"output_cost_per_token": 0.0000004,
"litellm_provider": "text-completion-openai",
"mode": "completion"
},
"text-embedding-3-large": {
"max_tokens": 8191,
"max_input_tokens": 8191,
@ -1385,6 +1403,24 @@
"mode": "completion",
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
"gemini/gemini-1.5-flash-latest": {
"max_tokens": 8192,
"max_input_tokens": 1000000,
"max_output_tokens": 8192,
"max_images_per_prompt": 3000,
"max_videos_per_prompt": 10,
"max_video_length": 1,
"max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1,
"max_pdf_size_mb": 30,
"input_cost_per_token": 0,
"output_cost_per_token": 0,
"litellm_provider": "vertex_ai-language-models",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
"gemini/gemini-pro": {
"max_tokens": 8192,
"max_input_tokens": 32760,
@ -1744,6 +1780,30 @@
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/openai/gpt-4o": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000005,
"output_cost_per_token": 0.000015,
"litellm_provider": "openrouter",
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_vision": true
},
"openrouter/openai/gpt-4o-2024-05-13": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000005,
"output_cost_per_token": 0.000015,
"litellm_provider": "openrouter",
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_vision": true
},
"openrouter/openai/gpt-4-vision-preview": {
"max_tokens": 130000,
"input_cost_per_token": 0.00001,
@ -2943,6 +3003,24 @@
"litellm_provider": "ollama",
"mode": "completion"
},
"ollama/llama3": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"max_output_tokens": 8192,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "ollama",
"mode": "chat"
},
"ollama/llama3:70b": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"max_output_tokens": 8192,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "ollama",
"mode": "chat"
},
"ollama/mistral": {
"max_tokens": 8192,
"max_input_tokens": 8192,
@ -2952,6 +3030,42 @@
"litellm_provider": "ollama",
"mode": "completion"
},
"ollama/mistral-7B-Instruct-v0.1": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"max_output_tokens": 8192,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "ollama",
"mode": "chat"
},
"ollama/mistral-7B-Instruct-v0.2": {
"max_tokens": 32768,
"max_input_tokens": 32768,
"max_output_tokens": 32768,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "ollama",
"mode": "chat"
},
"ollama/mixtral-8x7B-Instruct-v0.1": {
"max_tokens": 32768,
"max_input_tokens": 32768,
"max_output_tokens": 32768,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "ollama",
"mode": "chat"
},
"ollama/mixtral-8x22B-Instruct-v0.1": {
"max_tokens": 65536,
"max_input_tokens": 65536,
"max_output_tokens": 65536,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "ollama",
"mode": "chat"
},
"ollama/codellama": {
"max_tokens": 4096,
"max_input_tokens": 4096,

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -1 +1 @@
<!DOCTYPE html><html id="__next_error__"><head><meta charSet="utf-8"/><meta name="viewport" content="width=device-width, initial-scale=1"/><link rel="preload" as="script" fetchPriority="low" href="/ui/_next/static/chunks/webpack-de9c0fadf6a94b3b.js" crossorigin=""/><script src="/ui/_next/static/chunks/fd9d1056-f960ab1e6d32b002.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/69-04708d7d4a17c1ee.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/main-app-9b4fb13a7db53edf.js" async="" crossorigin=""></script><title>LiteLLM Dashboard</title><meta name="description" content="LiteLLM Proxy Admin UI"/><link rel="icon" href="/ui/favicon.ico" type="image/x-icon" sizes="16x16"/><meta name="next-size-adjust"/><script src="/ui/_next/static/chunks/polyfills-c67a75d1b6f99dc8.js" crossorigin="" noModule=""></script></head><body><script src="/ui/_next/static/chunks/webpack-de9c0fadf6a94b3b.js" crossorigin="" async=""></script><script>(self.__next_f=self.__next_f||[]).push([0]);self.__next_f.push([2,null])</script><script>self.__next_f.push([1,"1:HL[\"/ui/_next/static/media/c9a5bc6a7c948fb0-s.p.woff2\",\"font\",{\"crossOrigin\":\"\",\"type\":\"font/woff2\"}]\n2:HL[\"/ui/_next/static/css/f04e46b02318b660.css\",\"style\",{\"crossOrigin\":\"\"}]\n0:\"$L3\"\n"])</script><script>self.__next_f.push([1,"4:I[47690,[],\"\"]\n6:I[77831,[],\"\"]\n7:I[7926,[\"936\",\"static/chunks/2f6dbc85-052c4579f80d66ae.js\",\"884\",\"static/chunks/884-7576ee407a2ecbe6.js\",\"931\",\"static/chunks/app/page-6a39771cacf75ea6.js\"],\"\"]\n8:I[5613,[],\"\"]\n9:I[31778,[],\"\"]\nb:I[48955,[],\"\"]\nc:[]\n"])</script><script>self.__next_f.push([1,"3:[[[\"$\",\"link\",\"0\",{\"rel\":\"stylesheet\",\"href\":\"/ui/_next/static/css/f04e46b02318b660.css\",\"precedence\":\"next\",\"crossOrigin\":\"\"}]],[\"$\",\"$L4\",null,{\"buildId\":\"obp5wqVSVDMiDTC414cR8\",\"assetPrefix\":\"/ui\",\"initialCanonicalUrl\":\"/\",\"initialTree\":[\"\",{\"children\":[\"__PAGE__\",{}]},\"$undefined\",\"$undefined\",true],\"initialSeedData\":[\"\",{\"children\":[\"__PAGE__\",{},[\"$L5\",[\"$\",\"$L6\",null,{\"propsForComponent\":{\"params\":{}},\"Component\":\"$7\",\"isStaticGeneration\":true}],null]]},[null,[\"$\",\"html\",null,{\"lang\":\"en\",\"children\":[\"$\",\"body\",null,{\"className\":\"__className_c23dc8\",\"children\":[\"$\",\"$L8\",null,{\"parallelRouterKey\":\"children\",\"segmentPath\":[\"children\"],\"loading\":\"$undefined\",\"loadingStyles\":\"$undefined\",\"loadingScripts\":\"$undefined\",\"hasLoading\":false,\"error\":\"$undefined\",\"errorStyles\":\"$undefined\",\"errorScripts\":\"$undefined\",\"template\":[\"$\",\"$L9\",null,{}],\"templateStyles\":\"$undefined\",\"templateScripts\":\"$undefined\",\"notFound\":[[\"$\",\"title\",null,{\"children\":\"404: This page could not be found.\"}],[\"$\",\"div\",null,{\"style\":{\"fontFamily\":\"system-ui,\\\"Segoe UI\\\",Roboto,Helvetica,Arial,sans-serif,\\\"Apple Color Emoji\\\",\\\"Segoe UI Emoji\\\"\",\"height\":\"100vh\",\"textAlign\":\"center\",\"display\":\"flex\",\"flexDirection\":\"column\",\"alignItems\":\"center\",\"justifyContent\":\"center\"},\"children\":[\"$\",\"div\",null,{\"children\":[[\"$\",\"style\",null,{\"dangerouslySetInnerHTML\":{\"__html\":\"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}\"}}],[\"$\",\"h1\",null,{\"className\":\"next-error-h1\",\"style\":{\"display\":\"inline-block\",\"margin\":\"0 20px 0 0\",\"padding\":\"0 23px 0 0\",\"fontSize\":24,\"fontWeight\":500,\"verticalAlign\":\"top\",\"lineHeight\":\"49px\"},\"children\":\"404\"}],[\"$\",\"div\",null,{\"style\":{\"display\":\"inline-block\"},\"children\":[\"$\",\"h2\",null,{\"style\":{\"fontSize\":14,\"fontWeight\":400,\"lineHeight\":\"49px\",\"margin\":0},\"children\":\"This page could not be found.\"}]}]]}]}]],\"notFoundStyles\":[],\"styles\":null}]}]}],null]],\"initialHead\":[false,\"$La\"],\"globalErrorComponent\":\"$b\",\"missingSlots\":\"$Wc\"}]]\n"])</script><script>self.__next_f.push([1,"a:[[\"$\",\"meta\",\"0\",{\"name\":\"viewport\",\"content\":\"width=device-width, initial-scale=1\"}],[\"$\",\"meta\",\"1\",{\"charSet\":\"utf-8\"}],[\"$\",\"title\",\"2\",{\"children\":\"LiteLLM Dashboard\"}],[\"$\",\"meta\",\"3\",{\"name\":\"description\",\"content\":\"LiteLLM Proxy Admin UI\"}],[\"$\",\"link\",\"4\",{\"rel\":\"icon\",\"href\":\"/ui/favicon.ico\",\"type\":\"image/x-icon\",\"sizes\":\"16x16\"}],[\"$\",\"meta\",\"5\",{\"name\":\"next-size-adjust\"}]]\n5:null\n"])</script><script>self.__next_f.push([1,""])</script></body></html>
<!DOCTYPE html><html id="__next_error__"><head><meta charSet="utf-8"/><meta name="viewport" content="width=device-width, initial-scale=1"/><link rel="preload" as="script" fetchPriority="low" href="/ui/_next/static/chunks/webpack-de9c0fadf6a94b3b.js" crossorigin=""/><script src="/ui/_next/static/chunks/fd9d1056-f960ab1e6d32b002.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/69-04708d7d4a17c1ee.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/main-app-9b4fb13a7db53edf.js" async="" crossorigin=""></script><title>LiteLLM Dashboard</title><meta name="description" content="LiteLLM Proxy Admin UI"/><link rel="icon" href="/ui/favicon.ico" type="image/x-icon" sizes="16x16"/><meta name="next-size-adjust"/><script src="/ui/_next/static/chunks/polyfills-c67a75d1b6f99dc8.js" crossorigin="" noModule=""></script></head><body><script src="/ui/_next/static/chunks/webpack-de9c0fadf6a94b3b.js" crossorigin="" async=""></script><script>(self.__next_f=self.__next_f||[]).push([0]);self.__next_f.push([2,null])</script><script>self.__next_f.push([1,"1:HL[\"/ui/_next/static/media/c9a5bc6a7c948fb0-s.p.woff2\",\"font\",{\"crossOrigin\":\"\",\"type\":\"font/woff2\"}]\n2:HL[\"/ui/_next/static/css/f04e46b02318b660.css\",\"style\",{\"crossOrigin\":\"\"}]\n0:\"$L3\"\n"])</script><script>self.__next_f.push([1,"4:I[47690,[],\"\"]\n6:I[77831,[],\"\"]\n7:I[4858,[\"936\",\"static/chunks/2f6dbc85-052c4579f80d66ae.js\",\"884\",\"static/chunks/884-7576ee407a2ecbe6.js\",\"931\",\"static/chunks/app/page-f20fdea77aed85ba.js\"],\"\"]\n8:I[5613,[],\"\"]\n9:I[31778,[],\"\"]\nb:I[48955,[],\"\"]\nc:[]\n"])</script><script>self.__next_f.push([1,"3:[[[\"$\",\"link\",\"0\",{\"rel\":\"stylesheet\",\"href\":\"/ui/_next/static/css/f04e46b02318b660.css\",\"precedence\":\"next\",\"crossOrigin\":\"\"}]],[\"$\",\"$L4\",null,{\"buildId\":\"l-0LDfSCdaUCAbcLIx_QC\",\"assetPrefix\":\"/ui\",\"initialCanonicalUrl\":\"/\",\"initialTree\":[\"\",{\"children\":[\"__PAGE__\",{}]},\"$undefined\",\"$undefined\",true],\"initialSeedData\":[\"\",{\"children\":[\"__PAGE__\",{},[\"$L5\",[\"$\",\"$L6\",null,{\"propsForComponent\":{\"params\":{}},\"Component\":\"$7\",\"isStaticGeneration\":true}],null]]},[null,[\"$\",\"html\",null,{\"lang\":\"en\",\"children\":[\"$\",\"body\",null,{\"className\":\"__className_c23dc8\",\"children\":[\"$\",\"$L8\",null,{\"parallelRouterKey\":\"children\",\"segmentPath\":[\"children\"],\"loading\":\"$undefined\",\"loadingStyles\":\"$undefined\",\"loadingScripts\":\"$undefined\",\"hasLoading\":false,\"error\":\"$undefined\",\"errorStyles\":\"$undefined\",\"errorScripts\":\"$undefined\",\"template\":[\"$\",\"$L9\",null,{}],\"templateStyles\":\"$undefined\",\"templateScripts\":\"$undefined\",\"notFound\":[[\"$\",\"title\",null,{\"children\":\"404: This page could not be found.\"}],[\"$\",\"div\",null,{\"style\":{\"fontFamily\":\"system-ui,\\\"Segoe UI\\\",Roboto,Helvetica,Arial,sans-serif,\\\"Apple Color Emoji\\\",\\\"Segoe UI Emoji\\\"\",\"height\":\"100vh\",\"textAlign\":\"center\",\"display\":\"flex\",\"flexDirection\":\"column\",\"alignItems\":\"center\",\"justifyContent\":\"center\"},\"children\":[\"$\",\"div\",null,{\"children\":[[\"$\",\"style\",null,{\"dangerouslySetInnerHTML\":{\"__html\":\"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}\"}}],[\"$\",\"h1\",null,{\"className\":\"next-error-h1\",\"style\":{\"display\":\"inline-block\",\"margin\":\"0 20px 0 0\",\"padding\":\"0 23px 0 0\",\"fontSize\":24,\"fontWeight\":500,\"verticalAlign\":\"top\",\"lineHeight\":\"49px\"},\"children\":\"404\"}],[\"$\",\"div\",null,{\"style\":{\"display\":\"inline-block\"},\"children\":[\"$\",\"h2\",null,{\"style\":{\"fontSize\":14,\"fontWeight\":400,\"lineHeight\":\"49px\",\"margin\":0},\"children\":\"This page could not be found.\"}]}]]}]}]],\"notFoundStyles\":[],\"styles\":null}]}]}],null]],\"initialHead\":[false,\"$La\"],\"globalErrorComponent\":\"$b\",\"missingSlots\":\"$Wc\"}]]\n"])</script><script>self.__next_f.push([1,"a:[[\"$\",\"meta\",\"0\",{\"name\":\"viewport\",\"content\":\"width=device-width, initial-scale=1\"}],[\"$\",\"meta\",\"1\",{\"charSet\":\"utf-8\"}],[\"$\",\"title\",\"2\",{\"children\":\"LiteLLM Dashboard\"}],[\"$\",\"meta\",\"3\",{\"name\":\"description\",\"content\":\"LiteLLM Proxy Admin UI\"}],[\"$\",\"link\",\"4\",{\"rel\":\"icon\",\"href\":\"/ui/favicon.ico\",\"type\":\"image/x-icon\",\"sizes\":\"16x16\"}],[\"$\",\"meta\",\"5\",{\"name\":\"next-size-adjust\"}]]\n5:null\n"])</script><script>self.__next_f.push([1,""])</script></body></html>

View file

@ -1,7 +1,7 @@
2:I[77831,[],""]
3:I[7926,["936","static/chunks/2f6dbc85-052c4579f80d66ae.js","884","static/chunks/884-7576ee407a2ecbe6.js","931","static/chunks/app/page-6a39771cacf75ea6.js"],""]
3:I[4858,["936","static/chunks/2f6dbc85-052c4579f80d66ae.js","884","static/chunks/884-7576ee407a2ecbe6.js","931","static/chunks/app/page-f20fdea77aed85ba.js"],""]
4:I[5613,[],""]
5:I[31778,[],""]
0:["obp5wqVSVDMiDTC414cR8",[[["",{"children":["__PAGE__",{}]},"$undefined","$undefined",true],["",{"children":["__PAGE__",{},["$L1",["$","$L2",null,{"propsForComponent":{"params":{}},"Component":"$3","isStaticGeneration":true}],null]]},[null,["$","html",null,{"lang":"en","children":["$","body",null,{"className":"__className_c23dc8","children":["$","$L4",null,{"parallelRouterKey":"children","segmentPath":["children"],"loading":"$undefined","loadingStyles":"$undefined","loadingScripts":"$undefined","hasLoading":false,"error":"$undefined","errorStyles":"$undefined","errorScripts":"$undefined","template":["$","$L5",null,{}],"templateStyles":"$undefined","templateScripts":"$undefined","notFound":[["$","title",null,{"children":"404: This page could not be found."}],["$","div",null,{"style":{"fontFamily":"system-ui,\"Segoe UI\",Roboto,Helvetica,Arial,sans-serif,\"Apple Color Emoji\",\"Segoe UI Emoji\"","height":"100vh","textAlign":"center","display":"flex","flexDirection":"column","alignItems":"center","justifyContent":"center"},"children":["$","div",null,{"children":[["$","style",null,{"dangerouslySetInnerHTML":{"__html":"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}"}}],["$","h1",null,{"className":"next-error-h1","style":{"display":"inline-block","margin":"0 20px 0 0","padding":"0 23px 0 0","fontSize":24,"fontWeight":500,"verticalAlign":"top","lineHeight":"49px"},"children":"404"}],["$","div",null,{"style":{"display":"inline-block"},"children":["$","h2",null,{"style":{"fontSize":14,"fontWeight":400,"lineHeight":"49px","margin":0},"children":"This page could not be found."}]}]]}]}]],"notFoundStyles":[],"styles":null}]}]}],null]],[[["$","link","0",{"rel":"stylesheet","href":"/ui/_next/static/css/f04e46b02318b660.css","precedence":"next","crossOrigin":""}]],"$L6"]]]]
0:["l-0LDfSCdaUCAbcLIx_QC",[[["",{"children":["__PAGE__",{}]},"$undefined","$undefined",true],["",{"children":["__PAGE__",{},["$L1",["$","$L2",null,{"propsForComponent":{"params":{}},"Component":"$3","isStaticGeneration":true}],null]]},[null,["$","html",null,{"lang":"en","children":["$","body",null,{"className":"__className_c23dc8","children":["$","$L4",null,{"parallelRouterKey":"children","segmentPath":["children"],"loading":"$undefined","loadingStyles":"$undefined","loadingScripts":"$undefined","hasLoading":false,"error":"$undefined","errorStyles":"$undefined","errorScripts":"$undefined","template":["$","$L5",null,{}],"templateStyles":"$undefined","templateScripts":"$undefined","notFound":[["$","title",null,{"children":"404: This page could not be found."}],["$","div",null,{"style":{"fontFamily":"system-ui,\"Segoe UI\",Roboto,Helvetica,Arial,sans-serif,\"Apple Color Emoji\",\"Segoe UI Emoji\"","height":"100vh","textAlign":"center","display":"flex","flexDirection":"column","alignItems":"center","justifyContent":"center"},"children":["$","div",null,{"children":[["$","style",null,{"dangerouslySetInnerHTML":{"__html":"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}"}}],["$","h1",null,{"className":"next-error-h1","style":{"display":"inline-block","margin":"0 20px 0 0","padding":"0 23px 0 0","fontSize":24,"fontWeight":500,"verticalAlign":"top","lineHeight":"49px"},"children":"404"}],["$","div",null,{"style":{"display":"inline-block"},"children":["$","h2",null,{"style":{"fontSize":14,"fontWeight":400,"lineHeight":"49px","margin":0},"children":"This page could not be found."}]}]]}]}]],"notFoundStyles":[],"styles":null}]}]}],null]],[[["$","link","0",{"rel":"stylesheet","href":"/ui/_next/static/css/f04e46b02318b660.css","precedence":"next","crossOrigin":""}]],"$L6"]]]]
6:[["$","meta","0",{"name":"viewport","content":"width=device-width, initial-scale=1"}],["$","meta","1",{"charSet":"utf-8"}],["$","title","2",{"children":"LiteLLM Dashboard"}],["$","meta","3",{"name":"description","content":"LiteLLM Proxy Admin UI"}],["$","link","4",{"rel":"icon","href":"/ui/favicon.ico","type":"image/x-icon","sizes":"16x16"}],["$","meta","5",{"name":"next-size-adjust"}]]
1:null

20
litellm/proxy/_logging.py Normal file
View file

@ -0,0 +1,20 @@
import json
import logging
from logging import Formatter
class JsonFormatter(Formatter):
def __init__(self):
super(JsonFormatter, self).__init__()
def format(self, record):
json_record = {}
json_record["message"] = record.getMessage()
return json.dumps(json_record)
logger = logging.root
handler = logging.StreamHandler()
handler.setFormatter(JsonFormatter())
logger.handlers = [handler]
logger.setLevel(logging.DEBUG)

View file

@ -1,45 +1,20 @@
model_list:
- litellm_params:
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
api_version: 2023-07-01-preview
model: azure/azure-embedding-model
model_info:
base_model: text-embedding-ada-002
mode: embedding
model_name: text-embedding-ada-002
- model_name: gpt-3.5-turbo-012
- model_name: gpt-3.5-turbo-fake-model
litellm_params:
model: gpt-3.5-turbo
model: openai/my-fake-model
api_base: http://0.0.0.0:8080
api_key: ""
- model_name: gpt-3.5-turbo-0125-preview
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/gpt-35-turbo
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
api_key: os.environ/AZURE_EUROPE_API_KEY
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/chatgpt-v-2
api_key: os.environ/AZURE_API_KEY
api_base: os.environ/AZURE_API_BASE
input_cost_per_token: 0.0
output_cost_per_token: 0.0
- model_name: bert-classifier
litellm_params:
model: huggingface/text-classification/shahrukhx01/question-vs-statement-classifier
api_key: os.environ/HUGGINGFACE_API_KEY
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
api_version: "2023-05-15"
api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault
router_settings:
redis_host: redis
# redis_password: <your redis password>
redis_port: 6379
enable_pre_call_checks: true
litellm_settings:
set_verbose: True
fallbacks: [{"gpt-3.5-turbo-012": ["gpt-3.5-turbo-0125-preview"]}]
# service_callback: ["prometheus_system"]
# success_callback: ["prometheus"]
# failure_callback: ["prometheus"]
general_settings:
enable_jwt_auth: True
disable_reset_budget: True
proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds)
routing_strategy: simple-shuffle # Literal["simple-shuffle", "least-busy", "usage-based-routing","latency-based-routing"], default="simple-shuffle"
alerting: ["slack"]

View file

@ -1,37 +1,11 @@
from pydantic import ConfigDict, BaseModel, Field, root_validator, Json, VERSION
from pydantic import BaseModel, Extra, Field, root_validator, Json, validator
from dataclasses import fields
import enum
from typing import Optional, List, Union, Dict, Literal, Any
from datetime import datetime
import uuid
import json
import uuid, json, sys, os
from litellm.types.router import UpdateRouterConfig
try:
from pydantic import model_validator # type: ignore
except ImportError:
from pydantic import root_validator # pydantic v1
def model_validator(mode): # type: ignore
pre = mode == "before"
return root_validator(pre=pre)
# Function to get Pydantic version
def is_pydantic_v2() -> int:
return int(VERSION.split(".")[0])
def get_model_config(arbitrary_types_allowed: bool = False) -> ConfigDict:
# Version-specific configuration
if is_pydantic_v2() >= 2:
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=arbitrary_types_allowed, protected_namespaces=()) # type: ignore
else:
from pydantic import Extra
model_config = ConfigDict(extra=Extra.allow, arbitrary_types_allowed=arbitrary_types_allowed) # type: ignore
return model_config
def hash_token(token: str):
import hashlib
@ -61,7 +35,8 @@ class LiteLLMBase(BaseModel):
# if using pydantic v1
return self.__fields_set__
model_config = get_model_config()
class Config:
protected_namespaces = ()
class LiteLLM_UpperboundKeyGenerateParams(LiteLLMBase):
@ -77,8 +52,18 @@ class LiteLLM_UpperboundKeyGenerateParams(LiteLLMBase):
class LiteLLMRoutes(enum.Enum):
openai_route_names: List = [
"chat_completion",
"completion",
"embeddings",
"image_generation",
"audio_transcriptions",
"moderations",
"model_list", # OpenAI /v1/models route
]
openai_routes: List = [
# chat completions
"/engines/{model}/chat/completions",
"/openai/deployments/{model}/chat/completions",
"/chat/completions",
"/v1/chat/completions",
@ -102,11 +87,8 @@ class LiteLLMRoutes(enum.Enum):
# models
"/models",
"/v1/models",
]
# NOTE: ROUTES ONLY FOR MASTER KEY - only the Master Key should be able to Reset Spend
master_key_only_routes: List = [
"/global/spend/reset",
# token counter
"/utils/token_counter",
]
info_routes: List = [
@ -119,6 +101,11 @@ class LiteLLMRoutes(enum.Enum):
"/v2/key/info",
]
# NOTE: ROUTES ONLY FOR MASTER KEY - only the Master Key should be able to Reset Spend
master_key_only_routes: List = [
"/global/spend/reset",
]
sso_only_routes: List = [
"/key/generate",
"/key/update",
@ -227,13 +214,19 @@ class LiteLLM_JWTAuth(LiteLLMBase):
"global_spend_tracking_routes",
"info_routes",
]
team_jwt_scope: str = "litellm_team"
team_id_jwt_field: str = "client_id"
team_id_jwt_field: Optional[str] = None
team_allowed_routes: List[
Literal["openai_routes", "info_routes", "management_routes"]
] = ["openai_routes", "info_routes"]
team_id_default: Optional[str] = Field(
default=None,
description="If no team_id given, default permissions/spend-tracking to this team.s",
)
org_id_jwt_field: Optional[str] = None
user_id_jwt_field: Optional[str] = None
user_id_upsert: bool = Field(
default=False, description="If user doesn't exist, upsert them into the db."
)
end_user_id_jwt_field: Optional[str] = None
public_key_ttl: float = 600
@ -258,8 +251,12 @@ class LiteLLMPromptInjectionParams(LiteLLMBase):
llm_api_name: Optional[str] = None
llm_api_system_prompt: Optional[str] = None
llm_api_fail_call_string: Optional[str] = None
reject_as_response: Optional[bool] = Field(
default=False,
description="Return rejected request error message as a string to the user. Default behaviour is to raise an exception.",
)
@model_validator(mode="before")
@root_validator(pre=True)
def check_llm_api_params(cls, values):
llm_api_check = values.get("llm_api_check")
if llm_api_check is True:
@ -317,7 +314,8 @@ class ProxyChatCompletionRequest(LiteLLMBase):
deployment_id: Optional[str] = None
request_timeout: Optional[int] = None
model_config = get_model_config()
class Config:
extra = "allow" # allow params not defined here, these fall in litellm.completion(**kwargs)
class ModelInfoDelete(LiteLLMBase):
@ -344,9 +342,11 @@ class ModelInfo(LiteLLMBase):
]
]
model_config = get_model_config()
class Config:
extra = Extra.allow # Allow extra fields
protected_namespaces = ()
@model_validator(mode="before")
@root_validator(pre=True)
def set_model_info(cls, values):
if values.get("id") is None:
values.update({"id": str(uuid.uuid4())})
@ -372,9 +372,10 @@ class ModelParams(LiteLLMBase):
litellm_params: dict
model_info: ModelInfo
model_config = get_model_config()
class Config:
protected_namespaces = ()
@model_validator(mode="before")
@root_validator(pre=True)
def set_model_info(cls, values):
if values.get("model_info") is None:
values.update({"model_info": ModelInfo()})
@ -410,7 +411,8 @@ class GenerateKeyRequest(GenerateRequestBase):
{}
) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {}
model_config = get_model_config()
class Config:
protected_namespaces = ()
class GenerateKeyResponse(GenerateKeyRequest):
@ -420,7 +422,7 @@ class GenerateKeyResponse(GenerateKeyRequest):
user_id: Optional[str] = None
token_id: Optional[str] = None
@model_validator(mode="before")
@root_validator(pre=True)
def set_model_info(cls, values):
if values.get("token") is not None:
values.update({"key": values.get("token")})
@ -460,7 +462,8 @@ class LiteLLM_ModelTable(LiteLLMBase):
created_by: str
updated_by: str
model_config = get_model_config()
class Config:
protected_namespaces = ()
class NewUserRequest(GenerateKeyRequest):
@ -488,7 +491,7 @@ class UpdateUserRequest(GenerateRequestBase):
user_role: Optional[str] = None
max_budget: Optional[float] = None
@model_validator(mode="before")
@root_validator(pre=True)
def check_user_info(cls, values):
if values.get("user_id") is None and values.get("user_email") is None:
raise ValueError("Either user id or user email must be provided")
@ -508,7 +511,7 @@ class NewEndUserRequest(LiteLLMBase):
None # if no equivalent model in allowed region - default all requests to this model
)
@model_validator(mode="before")
@root_validator(pre=True)
def check_user_info(cls, values):
if values.get("max_budget") is not None and values.get("budget_id") is not None:
raise ValueError("Set either 'max_budget' or 'budget_id', not both.")
@ -521,7 +524,7 @@ class Member(LiteLLMBase):
user_id: Optional[str] = None
user_email: Optional[str] = None
@model_validator(mode="before")
@root_validator(pre=True)
def check_user_info(cls, values):
if values.get("user_id") is None and values.get("user_email") is None:
raise ValueError("Either user id or user email must be provided")
@ -546,7 +549,8 @@ class TeamBase(LiteLLMBase):
class NewTeamRequest(TeamBase):
model_aliases: Optional[dict] = None
model_config = get_model_config()
class Config:
protected_namespaces = ()
class GlobalEndUsersSpend(LiteLLMBase):
@ -565,7 +569,7 @@ class TeamMemberDeleteRequest(LiteLLMBase):
user_id: Optional[str] = None
user_email: Optional[str] = None
@model_validator(mode="before")
@root_validator(pre=True)
def check_user_info(cls, values):
if values.get("user_id") is None and values.get("user_email") is None:
raise ValueError("Either user id or user email must be provided")
@ -599,9 +603,10 @@ class LiteLLM_TeamTable(TeamBase):
budget_reset_at: Optional[datetime] = None
model_id: Optional[int] = None
model_config = get_model_config()
class Config:
protected_namespaces = ()
@model_validator(mode="before")
@root_validator(pre=True)
def set_model_info(cls, values):
dict_fields = [
"metadata",
@ -637,7 +642,8 @@ class LiteLLM_BudgetTable(LiteLLMBase):
model_max_budget: Optional[dict] = None
budget_duration: Optional[str] = None
model_config = get_model_config()
class Config:
protected_namespaces = ()
class NewOrganizationRequest(LiteLLM_BudgetTable):
@ -687,7 +693,8 @@ class KeyManagementSettings(LiteLLMBase):
class TeamDefaultSettings(LiteLLMBase):
team_id: str
model_config = get_model_config()
class Config:
extra = "allow" # allow params not defined here, these fall in litellm.completion(**kwargs)
class DynamoDBArgs(LiteLLMBase):
@ -711,6 +718,25 @@ class DynamoDBArgs(LiteLLMBase):
assume_role_aws_session_name: Optional[str] = None
class ConfigFieldUpdate(LiteLLMBase):
field_name: str
field_value: Any
config_type: Literal["general_settings"]
class ConfigFieldDelete(LiteLLMBase):
config_type: Literal["general_settings"]
field_name: str
class ConfigList(LiteLLMBase):
field_name: str
field_type: str
field_description: str
field_value: Any
stored_in_db: Optional[bool]
class ConfigGeneralSettings(LiteLLMBase):
"""
Documents all the fields supported by `general_settings` in config.yaml
@ -758,7 +784,11 @@ class ConfigGeneralSettings(LiteLLMBase):
description="override user_api_key_auth with your own auth script - https://docs.litellm.ai/docs/proxy/virtual_keys#custom-auth",
)
max_parallel_requests: Optional[int] = Field(
None, description="maximum parallel requests for each api key"
None,
description="maximum parallel requests for each api key",
)
global_max_parallel_requests: Optional[int] = Field(
None, description="global max parallel requests to allow for a proxy instance."
)
infer_model_from_keys: Optional[bool] = Field(
None,
@ -828,7 +858,8 @@ class ConfigYAML(LiteLLMBase):
description="litellm router object settings. See router.py __init__ for all, example router.num_retries=5, router.timeout=5, router.max_retries=5, router.retry_after=5",
)
model_config = get_model_config()
class Config:
protected_namespaces = ()
class LiteLLM_VerificationToken(LiteLLMBase):
@ -862,7 +893,8 @@ class LiteLLM_VerificationToken(LiteLLMBase):
user_id_rate_limits: Optional[dict] = None
team_id_rate_limits: Optional[dict] = None
model_config = get_model_config()
class Config:
protected_namespaces = ()
class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
@ -892,7 +924,7 @@ class UserAPIKeyAuth(
user_role: Optional[Literal["proxy_admin", "app_owner", "app_user"]] = None
allowed_model_region: Optional[Literal["eu"]] = None
@model_validator(mode="before")
@root_validator(pre=True)
def check_api_key(cls, values):
if values.get("api_key") is not None:
values.update({"token": hash_token(values.get("api_key"))})
@ -919,7 +951,7 @@ class LiteLLM_UserTable(LiteLLMBase):
tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None
@model_validator(mode="before")
@root_validator(pre=True)
def set_model_info(cls, values):
if values.get("spend") is None:
values.update({"spend": 0.0})
@ -927,7 +959,8 @@ class LiteLLM_UserTable(LiteLLMBase):
values.update({"models": []})
return values
model_config = get_model_config()
class Config:
protected_namespaces = ()
class LiteLLM_EndUserTable(LiteLLMBase):
@ -939,13 +972,14 @@ class LiteLLM_EndUserTable(LiteLLMBase):
default_model: Optional[str] = None
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
@model_validator(mode="before")
@root_validator(pre=True)
def set_model_info(cls, values):
if values.get("spend") is None:
values.update({"spend": 0.0})
return values
model_config = get_model_config()
class Config:
protected_namespaces = ()
class LiteLLM_SpendLogs(LiteLLMBase):
@ -983,3 +1017,30 @@ class LiteLLM_ErrorLogs(LiteLLMBase):
class LiteLLM_SpendLogs_ResponseObject(LiteLLMBase):
response: Optional[List[Union[LiteLLM_SpendLogs, Any]]] = None
class TokenCountRequest(LiteLLMBase):
model: str
prompt: Optional[str] = None
messages: Optional[List[dict]] = None
class TokenCountResponse(LiteLLMBase):
total_tokens: int
request_model: str
model_used: str
tokenizer_type: str
class CallInfo(LiteLLMBase):
"""Used for slack budget alerting"""
spend: float
max_budget: float
token: str = Field(description="Hashed value of that key")
user_id: Optional[str] = None
team_id: Optional[str] = None
user_email: Optional[str] = None
key_alias: Optional[str] = None
projected_exceeded_date: Optional[str] = None
projected_spend: Optional[float] = None

View file

@ -26,7 +26,7 @@ all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes
def common_checks(
request_body: dict,
team_object: LiteLLM_TeamTable,
team_object: Optional[LiteLLM_TeamTable],
user_object: Optional[LiteLLM_UserTable],
end_user_object: Optional[LiteLLM_EndUserTable],
global_proxy_spend: Optional[float],
@ -45,13 +45,14 @@ def common_checks(
6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
"""
_model = request_body.get("model", None)
if team_object.blocked == True:
if team_object is not None and team_object.blocked == True:
raise Exception(
f"Team={team_object.team_id} is blocked. Update via `/team/unblock` if your admin."
)
# 2. If user can call model
if (
_model is not None
and team_object is not None
and len(team_object.models) > 0
and _model not in team_object.models
):
@ -65,7 +66,8 @@ def common_checks(
)
# 3. If team is in budget
if (
team_object.max_budget is not None
team_object is not None
and team_object.max_budget is not None
and team_object.spend is not None
and team_object.spend > team_object.max_budget
):
@ -239,6 +241,7 @@ async def get_user_object(
user_id: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
user_id_upsert: bool,
) -> Optional[LiteLLM_UserTable]:
"""
- Check if user id in proxy User Table
@ -252,7 +255,7 @@ async def get_user_object(
return None
# check if in cache
cached_user_obj = user_api_key_cache.async_get_cache(key=user_id)
cached_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
if cached_user_obj is not None:
if isinstance(cached_user_obj, dict):
return LiteLLM_UserTable(**cached_user_obj)
@ -260,16 +263,27 @@ async def get_user_object(
return cached_user_obj
# else, check db
try:
response = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_id}
)
if response is None:
if user_id_upsert:
response = await prisma_client.db.litellm_usertable.create(
data={"user_id": user_id}
)
else:
raise Exception
return LiteLLM_UserTable(**response.dict())
except Exception as e: # if end-user not in db
raise Exception(
_response = LiteLLM_UserTable(**dict(response))
# save the user object to cache
await user_api_key_cache.async_set_cache(key=user_id, value=_response)
return _response
except Exception as e: # if user not in db
raise ValueError(
f"User doesn't exist in db. 'user_id'={user_id}. Create user via `/user/new` call."
)
@ -290,7 +304,7 @@ async def get_team_object(
)
# check if in cache
cached_team_obj = user_api_key_cache.async_get_cache(key=team_id)
cached_team_obj = await user_api_key_cache.async_get_cache(key=team_id)
if cached_team_obj is not None:
if isinstance(cached_team_obj, dict):
return LiteLLM_TeamTable(**cached_team_obj)
@ -305,7 +319,11 @@ async def get_team_object(
if response is None:
raise Exception
return LiteLLM_TeamTable(**response.dict())
_response = LiteLLM_TeamTable(**response.dict())
# save the team object to cache
await user_api_key_cache.async_set_cache(key=response.team_id, value=_response)
return _response
except Exception as e:
raise Exception(
f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."

View file

@ -55,12 +55,9 @@ class JWTHandler:
return True
return False
def is_team(self, scopes: list) -> bool:
if self.litellm_jwtauth.team_jwt_scope in scopes:
return True
return False
def get_end_user_id(self, token: dict, default_value: Optional[str]) -> str:
def get_end_user_id(
self, token: dict, default_value: Optional[str]
) -> Optional[str]:
try:
if self.litellm_jwtauth.end_user_id_jwt_field is not None:
user_id = token[self.litellm_jwtauth.end_user_id_jwt_field]
@ -70,13 +67,36 @@ class JWTHandler:
user_id = default_value
return user_id
def is_required_team_id(self) -> bool:
"""
Returns:
- True: if 'team_id_jwt_field' is set
- False: if not
"""
if self.litellm_jwtauth.team_id_jwt_field is None:
return False
return True
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
try:
if self.litellm_jwtauth.team_id_jwt_field is not None:
team_id = token[self.litellm_jwtauth.team_id_jwt_field]
elif self.litellm_jwtauth.team_id_default is not None:
team_id = self.litellm_jwtauth.team_id_default
else:
team_id = None
except KeyError:
team_id = default_value
return team_id
def is_upsert_user_id(self) -> bool:
"""
Returns:
- True: if 'user_id_upsert' is set
- False: if not
"""
return self.litellm_jwtauth.user_id_upsert
def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
try:
if self.litellm_jwtauth.user_id_jwt_field is not None:
@ -207,12 +227,14 @@ class JWTHandler:
raise Exception(f"Validation fails: {str(e)}")
elif public_key is not None and isinstance(public_key, str):
try:
cert = x509.load_pem_x509_certificate(public_key.encode(), default_backend())
cert = x509.load_pem_x509_certificate(
public_key.encode(), default_backend()
)
# Extract public key
key = cert.public_key().public_bytes(
serialization.Encoding.PEM,
serialization.PublicFormat.SubjectPublicKeyInfo
serialization.PublicFormat.SubjectPublicKeyInfo,
)
# decode the token using the public key
@ -221,7 +243,7 @@ class JWTHandler:
key,
algorithms=algorithms,
audience=audience,
options=decode_options
options=decode_options,
)
return payload

View file

@ -0,0 +1,42 @@
# What is this?
## If litellm license in env, checks if it's valid
import os
from litellm.llms.custom_httpx.http_handler import HTTPHandler
class LicenseCheck:
"""
- Check if license in env
- Returns if license is valid
"""
base_url = "https://license.litellm.ai"
def __init__(self) -> None:
self.license_str = os.getenv("LITELLM_LICENSE", None)
self.http_handler = HTTPHandler()
def _verify(self, license_str: str) -> bool:
url = "{}/verify_license/{}".format(self.base_url, license_str)
try: # don't impact user, if call fails
response = self.http_handler.get(url=url)
response.raise_for_status()
response_json = response.json()
premium = response_json["verify"]
assert isinstance(premium, bool)
return premium
except Exception as e:
return False
def is_premium(self) -> bool:
if self.license_str is None:
return False
elif self._verify(license_str=self.license_str):
return True
return False

View file

@ -79,6 +79,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
max_parallel_requests = user_api_key_dict.max_parallel_requests
if max_parallel_requests is None:
max_parallel_requests = sys.maxsize
global_max_parallel_requests = data.get("metadata", {}).get(
"global_max_parallel_requests", None
)
tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize)
if tpm_limit is None:
tpm_limit = sys.maxsize
@ -91,6 +94,24 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
# Setup values
# ------------
if global_max_parallel_requests is not None:
# get value from cache
_key = "global_max_parallel_requests"
current_global_requests = await cache.async_get_cache(
key=_key, local_only=True
)
# check if below limit
if current_global_requests is None:
current_global_requests = 1
# if above -> raise error
if current_global_requests >= global_max_parallel_requests:
raise HTTPException(
status_code=429, detail="Max parallel request limit reached."
)
# if below -> increment
else:
await cache.async_increment_cache(key=_key, value=1, local_only=True)
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
@ -207,6 +228,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
self.print_verbose(f"INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get(
"global_max_parallel_requests", None
)
user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"]
user_api_key_user_id = kwargs["litellm_params"]["metadata"].get(
"user_api_key_user_id", None
@ -222,6 +246,14 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
# Setup values
# ------------
if global_max_parallel_requests is not None:
# get value from cache
_key = "global_max_parallel_requests"
# decrement
await self.user_api_key_cache.async_increment_cache(
key=_key, value=-1, local_only=True
)
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
@ -336,6 +368,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
self.print_verbose(f"Inside Max Parallel Request Failure Hook")
global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get(
"global_max_parallel_requests", None
)
user_api_key = (
kwargs["litellm_params"].get("metadata", {}).get("user_api_key", None)
)
@ -347,17 +382,26 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
return
## decrement call count if call failed
if (
hasattr(kwargs["exception"], "status_code")
and kwargs["exception"].status_code == 429
and "Max parallel request limit reached" in str(kwargs["exception"])
):
if "Max parallel request limit reached" in str(kwargs["exception"]):
pass # ignore failed calls due to max limit being reached
else:
# ------------
# Setup values
# ------------
if global_max_parallel_requests is not None:
# get value from cache
_key = "global_max_parallel_requests"
current_global_requests = (
await self.user_api_key_cache.async_get_cache(
key=_key, local_only=True
)
)
# decrement
await self.user_api_key_cache.async_increment_cache(
key=_key, value=-1, local_only=True
)
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")

View file

@ -146,6 +146,7 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
try:
assert call_type in [
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
@ -192,6 +193,15 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
return data
except HTTPException as e:
if (
e.status_code == 400
and isinstance(e.detail, dict)
and "error" in e.detail
and self.prompt_injection_params is not None
and self.prompt_injection_params.reject_as_response
):
return e.detail["error"]
raise e
except Exception as e:
traceback.print_exc()

View file

@ -17,6 +17,7 @@ if litellm_mode == "DEV":
from importlib import resources
import shutil
telemetry = None
@ -505,6 +506,7 @@ def run_server(
port = random.randint(1024, 49152)
from litellm.proxy.proxy_server import app
import litellm
if run_gunicorn == False:
if ssl_certfile_path is not None and ssl_keyfile_path is not None:
@ -518,6 +520,14 @@ def run_server(
ssl_keyfile=ssl_keyfile_path,
ssl_certfile=ssl_certfile_path,
) # run uvicorn
else:
print(f"litellm.json_logs: {litellm.json_logs}")
if litellm.json_logs:
from litellm.proxy._logging import logger
uvicorn.run(
app, host=host, port=port, log_config=None
) # run uvicorn w/ json
else:
uvicorn.run(app, host=host, port=port) # run uvicorn
elif run_gunicorn == True:

File diff suppressed because it is too large Load diff

View file

@ -11,6 +11,7 @@ from litellm.proxy._types import (
LiteLLM_EndUserTable,
LiteLLM_TeamTable,
Member,
CallInfo,
)
from litellm.caching import DualCache, RedisCache
from litellm.router import Deployment, ModelInfo, LiteLLM_Params
@ -18,8 +19,18 @@ from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.proxy.hooks.parallel_request_limiter import (
_PROXY_MaxParallelRequestsHandler,
)
from litellm.exceptions import RejectedRequestError
from litellm._service_logger import ServiceLogging, ServiceTypes
from litellm import ModelResponse, EmbeddingResponse, ImageResponse
from litellm import (
ModelResponse,
EmbeddingResponse,
ImageResponse,
TranscriptionResponse,
TextCompletionResponse,
CustomStreamWrapper,
TextCompletionStreamWrapper,
)
from litellm.utils import ModelResponseIterator
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
@ -32,6 +43,7 @@ from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from datetime import datetime, timedelta
from litellm.integrations.slack_alerting import SlackAlerting
from typing_extensions import overload
def print_verbose(print_statement):
@ -74,6 +86,9 @@ class ProxyLogging:
"budget_alerts",
"db_exceptions",
"daily_reports",
"spend_reports",
"cooldown_deployment",
"new_model_added",
]
] = [
"llm_exceptions",
@ -82,6 +97,9 @@ class ProxyLogging:
"budget_alerts",
"db_exceptions",
"daily_reports",
"spend_reports",
"cooldown_deployment",
"new_model_added",
]
self.slack_alerting_instance = SlackAlerting(
alerting_threshold=self.alerting_threshold,
@ -104,6 +122,9 @@ class ProxyLogging:
"budget_alerts",
"db_exceptions",
"daily_reports",
"spend_reports",
"cooldown_deployment",
"new_model_added",
]
]
] = None,
@ -122,7 +143,13 @@ class ProxyLogging:
alerting_args=alerting_args,
)
if "daily_reports" in self.alert_types:
if (
self.alerting is not None
and "slack" in self.alerting
and "daily_reports" in self.alert_types
):
# NOTE: ENSURE we only add callbacks when alerting is on
# We should NOT add callbacks when alerting is off
litellm.callbacks.append(self.slack_alerting_instance) # type: ignore
if redis_cache is not None:
@ -140,6 +167,8 @@ class ProxyLogging:
self.slack_alerting_instance.response_taking_too_long_callback
)
for callback in litellm.callbacks:
if isinstance(callback, str):
callback = litellm.utils._init_custom_logger_compatible_class(callback)
if callback not in litellm.input_callback:
litellm.input_callback.append(callback)
if callback not in litellm.success_callback:
@ -165,18 +194,20 @@ class ProxyLogging:
)
litellm.utils.set_callbacks(callback_list=callback_list)
# The actual implementation of the function
async def pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
data: dict,
call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
],
):
) -> dict:
"""
Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body.
@ -203,8 +234,25 @@ class ProxyLogging:
call_type=call_type,
)
if response is not None:
if isinstance(response, Exception):
raise response
elif isinstance(response, dict):
data = response
elif isinstance(response, str):
if (
call_type == "completion"
or call_type == "text_completion"
):
raise RejectedRequestError(
message=response,
model=data.get("model", ""),
llm_provider="",
request_data=data,
)
else:
raise HTTPException(
status_code=400, detail={"error": response}
)
print_verbose(f"final data being sent to {call_type} call: {data}")
return data
except Exception as e:
@ -252,8 +300,8 @@ class ProxyLogging:
"""
Runs the CustomLogger's async_moderation_hook()
"""
for callback in litellm.callbacks:
new_data = copy.deepcopy(data)
for callback in litellm.callbacks:
try:
if isinstance(callback, CustomLogger):
await callback.async_moderation_hook(
@ -265,30 +313,30 @@ class ProxyLogging:
raise e
return data
async def failed_tracking_alert(self, error_message: str):
if self.alerting is None:
return
await self.slack_alerting_instance.failed_tracking_alert(
error_message=error_message
)
async def budget_alerts(
self,
type: Literal[
"token_budget",
"user_budget",
"user_and_proxy_budget",
"failed_budgets",
"failed_tracking",
"team_budget",
"proxy_budget",
"projected_limit_exceeded",
],
user_max_budget: float,
user_current_spend: float,
user_info=None,
error_message="",
user_info: CallInfo,
):
if self.alerting is None:
# do nothing if alerting is not switched on
return
await self.slack_alerting_instance.budget_alerts(
type=type,
user_max_budget=user_max_budget,
user_current_spend=user_current_spend,
user_info=user_info,
error_message=error_message,
)
async def alerting_handler(
@ -344,7 +392,11 @@ class ProxyLogging:
for client in self.alerting:
if client == "slack":
await self.slack_alerting_instance.send_alert(
message=message, level=level, alert_type=alert_type, **extra_kwargs
message=message,
level=level,
alert_type=alert_type,
user_info=None,
**extra_kwargs,
)
elif client == "sentry":
if litellm.utils.sentry_sdk_instance is not None:
@ -418,9 +470,14 @@ class ProxyLogging:
Related issue - https://github.com/BerriAI/litellm/issues/3395
"""
litellm_debug_info = getattr(original_exception, "litellm_debug_info", None)
exception_str = str(original_exception)
if litellm_debug_info is not None:
exception_str += litellm_debug_info
asyncio.create_task(
self.alerting_handler(
message=f"LLM API call failed: {str(original_exception)}",
message=f"LLM API call failed: `{exception_str}`",
level="High",
alert_type="llm_exceptions",
request_data=request_data,
@ -1787,7 +1844,9 @@ def hash_token(token: str):
return hashed_token
def get_logging_payload(kwargs, response_obj, start_time, end_time):
def get_logging_payload(
kwargs, response_obj, start_time, end_time, end_user_id: Optional[str]
):
from litellm.proxy._types import LiteLLM_SpendLogs
from pydantic import Json
import uuid
@ -1865,7 +1924,7 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time):
"prompt_tokens": usage.get("prompt_tokens", 0),
"completion_tokens": usage.get("completion_tokens", 0),
"request_tags": metadata.get("tags", []),
"end_user": kwargs.get("user", ""),
"end_user": end_user_id or "",
"api_base": litellm_params.get("api_base", ""),
}
@ -2028,6 +2087,11 @@ async def update_spend(
raise e
### UPDATE END-USER TABLE ###
verbose_proxy_logger.debug(
"End-User Spend transactions: {}".format(
len(prisma_client.end_user_list_transactons.keys())
)
)
if len(prisma_client.end_user_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
start_time = time.time()
@ -2043,13 +2107,18 @@ async def update_spend(
max_end_user_budget = None
if litellm.max_end_user_budget is not None:
max_end_user_budget = litellm.max_end_user_budget
new_user_obj = LiteLLM_EndUserTable(
user_id=end_user_id, spend=response_cost, blocked=False
)
batcher.litellm_endusertable.update_many(
batcher.litellm_endusertable.upsert(
where={"user_id": end_user_id},
data={"spend": {"increment": response_cost}},
data={
"create": {
"user_id": end_user_id,
"spend": response_cost,
"blocked": False,
},
"update": {"spend": {"increment": response_cost}},
},
)
prisma_client.end_user_list_transactons = (
{}
) # Clear the remaining transactions after processing all batches in the loop.

View file

@ -262,13 +262,22 @@ class Router:
self.retry_after = retry_after
self.routing_strategy = routing_strategy
self.fallbacks = fallbacks or litellm.fallbacks
## SETTING FALLBACKS ##
### validate if it's set + in correct format
_fallbacks = fallbacks or litellm.fallbacks
self.validate_fallbacks(fallback_param=_fallbacks)
### set fallbacks
self.fallbacks = _fallbacks
if default_fallbacks is not None or litellm.default_fallbacks is not None:
_fallbacks = default_fallbacks or litellm.default_fallbacks
if self.fallbacks is not None:
self.fallbacks.append({"*": _fallbacks})
else:
self.fallbacks = [{"*": _fallbacks}]
self.context_window_fallbacks = (
context_window_fallbacks or litellm.context_window_fallbacks
)
@ -336,6 +345,21 @@ class Router:
if self.alerting_config is not None:
self._initialize_alerting()
def validate_fallbacks(self, fallback_param: Optional[List]):
if fallback_param is None:
return
if len(fallback_param) > 0: # if set
## for dictionary in list, check if only 1 key in dict
for _dict in fallback_param:
assert isinstance(_dict, dict), "Item={}, not a dictionary".format(
_dict
)
assert (
len(_dict.keys()) == 1
), "Only 1 key allows in dictionary. You set={} for dict={}".format(
len(_dict.keys()), _dict
)
def routing_strategy_init(self, routing_strategy: str, routing_strategy_args: dict):
if routing_strategy == "least-busy":
self.leastbusy_logger = LeastBusyLoggingHandler(
@ -638,6 +662,10 @@ class Router:
async def abatch_completion(
self, models: List[str], messages: List[Dict[str, str]], **kwargs
):
"""
Async Batch Completion - Batch Process 1 request to multiple model_group on litellm.Router
Use this for sending the same request to N models
"""
async def _async_completion_no_exceptions(
model: str, messages: List[Dict[str, str]], **kwargs
@ -662,6 +690,51 @@ class Router:
response = await asyncio.gather(*_tasks)
return response
async def abatch_completion_one_model_multiple_requests(
self, model: str, messages: List[List[Dict[str, str]]], **kwargs
):
"""
Async Batch Completion - Batch Process multiple Messages to one model_group on litellm.Router
Use this for sending multiple requests to 1 model
Args:
model (List[str]): model group
messages (List[List[Dict[str, str]]]): list of messages. Each element in the list is one request
**kwargs: additional kwargs
Usage:
response = await self.abatch_completion_one_model_multiple_requests(
model="gpt-3.5-turbo",
messages=[
[{"role": "user", "content": "hello"}, {"role": "user", "content": "tell me something funny"}],
[{"role": "user", "content": "hello good mornign"}],
]
)
"""
async def _async_completion_no_exceptions(
model: str, messages: List[Dict[str, str]], **kwargs
):
"""
Wrapper around self.async_completion that catches exceptions and returns them as a result
"""
try:
return await self.acompletion(model=model, messages=messages, **kwargs)
except Exception as e:
return e
_tasks = []
for message_request in messages:
# add each task but if the task fails
_tasks.append(
_async_completion_no_exceptions(
model=model, messages=message_request, **kwargs
)
)
response = await asyncio.gather(*_tasks)
return response
def image_generation(self, prompt: str, model: str, **kwargs):
try:
kwargs["model"] = model
@ -1899,10 +1972,28 @@ class Router:
metadata = kwargs.get("litellm_params", {}).get("metadata", None)
_model_info = kwargs.get("litellm_params", {}).get("model_info", {})
exception_response = getattr(exception, "response", {})
exception_headers = getattr(exception_response, "headers", None)
_time_to_cooldown = self.cooldown_time
if exception_headers is not None:
_time_to_cooldown = (
litellm.utils._get_retry_after_from_exception_header(
response_headers=exception_headers
)
)
if _time_to_cooldown < 0:
# if the response headers did not read it -> set to default cooldown time
_time_to_cooldown = self.cooldown_time
if isinstance(_model_info, dict):
deployment_id = _model_info.get("id", None)
self._set_cooldown_deployments(
exception_status=exception_status, deployment=deployment_id
exception_status=exception_status,
deployment=deployment_id,
time_to_cooldown=_time_to_cooldown,
) # setting deployment_id in cooldown deployments
if custom_llm_provider:
model_name = f"{custom_llm_provider}/{model_name}"
@ -1962,8 +2053,50 @@ class Router:
key=rpm_key, value=request_count, local_only=True
) # don't change existing ttl
def _is_cooldown_required(self, exception_status: Union[str, int]):
"""
A function to determine if a cooldown is required based on the exception status.
Parameters:
exception_status (Union[str, int]): The status of the exception.
Returns:
bool: True if a cooldown is required, False otherwise.
"""
try:
if isinstance(exception_status, str):
exception_status = int(exception_status)
if exception_status >= 400 and exception_status < 500:
if exception_status == 429:
# Cool down 429 Rate Limit Errors
return True
elif exception_status == 401:
# Cool down 401 Auth Errors
return True
elif exception_status == 408:
return True
else:
# Do NOT cool down all other 4XX Errors
return False
else:
# should cool down for all other errors
return True
except:
# Catch all - if any exceptions default to cooling down
return True
def _set_cooldown_deployments(
self, exception_status: Union[str, int], deployment: Optional[str] = None
self,
exception_status: Union[str, int],
deployment: Optional[str] = None,
time_to_cooldown: Optional[float] = None,
):
"""
Add a model to the list of models being cooled down for that minute, if it exceeds the allowed fails / minute
@ -1975,6 +2108,9 @@ class Router:
if deployment is None:
return
if self._is_cooldown_required(exception_status=exception_status) == False:
return
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
# get current fails for deployment
@ -1987,6 +2123,8 @@ class Router:
f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {self.allowed_fails}"
)
cooldown_time = self.cooldown_time or 1
if time_to_cooldown is not None:
cooldown_time = time_to_cooldown
if isinstance(exception_status, str):
try:
@ -2024,7 +2162,9 @@ class Router:
)
self.send_deployment_cooldown_alert(
deployment_id=deployment, exception_status=exception_status
deployment_id=deployment,
exception_status=exception_status,
cooldown_time=cooldown_time,
)
else:
self.failed_calls.set_cache(
@ -2309,7 +2449,7 @@ class Router:
organization = litellm.get_secret(organization_env_name)
litellm_params["organization"] = organization
if "azure" in model_name and isinstance(api_key, str):
if "azure" in model_name:
if api_base is None or not isinstance(api_base, str):
raise ValueError(
f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}"
@ -3185,7 +3325,7 @@ class Router:
if _rate_limit_error == True: # allow generic fallback logic to take place
raise ValueError(
f"{RouterErrors.no_deployments_available.value}, passed model={model}"
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. Try again in {self.cooldown_time} seconds."
)
elif _context_window_error == True:
raise litellm.ContextWindowExceededError(
@ -3257,7 +3397,9 @@ class Router:
litellm.print_verbose(f"initial list of deployments: {healthy_deployments}")
if len(healthy_deployments) == 0:
raise ValueError(f"No healthy deployment available, passed model={model}. ")
raise ValueError(
f"No healthy deployment available, passed model={model}. Try again in {self.cooldown_time} seconds"
)
if litellm.model_alias_map and model in litellm.model_alias_map:
model = litellm.model_alias_map[
model
@ -3347,7 +3489,7 @@ class Router:
if _allowed_model_region is None:
_allowed_model_region = "n/a"
raise ValueError(
f"{RouterErrors.no_deployments_available.value}, passed model={model}. Enable pre-call-checks={self.enable_pre_call_checks}, allowed_model_region={_allowed_model_region}"
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. Enable pre-call-checks={self.enable_pre_call_checks}, allowed_model_region={_allowed_model_region}"
)
if (
@ -3415,7 +3557,7 @@ class Router:
f"get_available_deployment for model: {model}, No deployment available"
)
raise ValueError(
f"{RouterErrors.no_deployments_available.value}, passed model={model}"
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}"
)
verbose_router_logger.info(
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
@ -3545,7 +3687,7 @@ class Router:
f"get_available_deployment for model: {model}, No deployment available"
)
raise ValueError(
f"{RouterErrors.no_deployments_available.value}, passed model={model}"
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}"
)
verbose_router_logger.info(
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
@ -3683,7 +3825,10 @@ class Router:
print("\033[94m\nInitialized Alerting for litellm.Router\033[0m\n") # noqa
def send_deployment_cooldown_alert(
self, deployment_id: str, exception_status: Union[str, int]
self,
deployment_id: str,
exception_status: Union[str, int],
cooldown_time: float,
):
try:
from litellm.proxy.proxy_server import proxy_logging_obj
@ -3707,7 +3852,7 @@ class Router:
)
asyncio.create_task(
proxy_logging_obj.slack_alerting_instance.send_alert(
message=f"Router: Cooling down deployment: {_api_base}, for {self.cooldown_time} seconds. Got exception: {str(exception_status)}. Change 'cooldown_time' + 'allowed_fails' under 'Router Settings' on proxy UI, or via config - https://docs.litellm.ai/docs/proxy/reliability#fallbacks--retries--timeouts--cooldowns",
message=f"Router: Cooling down Deployment:\nModel Name: `{_model_name}`\nAPI Base: `{_api_base}`\nCooldown Time: `{cooldown_time} seconds`\nException Status Code: `{str(exception_status)}`\n\nChange 'cooldown_time' + 'allowed_fails' under 'Router Settings' on proxy UI, or via config - https://docs.litellm.ai/docs/proxy/reliability#fallbacks--retries--timeouts--cooldowns",
alert_type="cooldown_deployment",
level="Low",
)

View file

@ -27,7 +27,7 @@ class LiteLLMBase(BaseModel):
class RoutingArgs(LiteLLMBase):
ttl: int = 1 * 60 * 60 # 1 hour
ttl: float = 1 * 60 * 60 # 1 hour
lowest_latency_buffer: float = 0
max_latency_list_size: int = 10

File diff suppressed because it is too large Load diff

View file

@ -242,12 +242,24 @@ async def test_langfuse_masked_input_output(langfuse_client):
response = await create_async_task(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "This is a test"}],
metadata={"trace_id": _unique_trace_name, "mask_input": mask_value, "mask_output": mask_value},
mock_response="This is a test response"
metadata={
"trace_id": _unique_trace_name,
"mask_input": mask_value,
"mask_output": mask_value,
},
mock_response="This is a test response",
)
print(response)
expected_input = "redacted-by-litellm" if mask_value else {'messages': [{'content': 'This is a test', 'role': 'user'}]}
expected_output = "redacted-by-litellm" if mask_value else {'content': 'This is a test response', 'role': 'assistant'}
expected_input = (
"redacted-by-litellm"
if mask_value
else {"messages": [{"content": "This is a test", "role": "user"}]}
)
expected_output = (
"redacted-by-litellm"
if mask_value
else {"content": "This is a test response", "role": "assistant"}
)
langfuse_client.flush()
await asyncio.sleep(2)
@ -262,6 +274,7 @@ async def test_langfuse_masked_input_output(langfuse_client):
assert generations[0].input == expected_input
assert generations[0].output == expected_output
@pytest.mark.asyncio
async def test_langfuse_logging_metadata(langfuse_client):
"""
@ -523,7 +536,7 @@ def test_langfuse_logging_function_calling():
# test_langfuse_logging_function_calling()
def test_langfuse_existing_trace_id():
def test_aaalangfuse_existing_trace_id():
"""
When existing trace id is passed, don't set trace params -> prevents overwriting the trace
@ -577,7 +590,7 @@ def test_langfuse_existing_trace_id():
"verbose": False,
"custom_llm_provider": "openai",
"api_base": "https://api.openai.com/v1/",
"litellm_call_id": "508113a1-c6f1-48ce-a3e1-01c6cce9330e",
"litellm_call_id": None,
"model_alias_map": {},
"completion_call_id": None,
"metadata": None,
@ -593,7 +606,7 @@ def test_langfuse_existing_trace_id():
"stream": False,
"user": None,
"call_type": "completion",
"litellm_call_id": "508113a1-c6f1-48ce-a3e1-01c6cce9330e",
"litellm_call_id": None,
"completion_start_time": "2024-05-01 07:31:29.903685",
"temperature": 0.1,
"extra_body": {},
@ -633,6 +646,8 @@ def test_langfuse_existing_trace_id():
trace_id = langfuse_response_object["trace_id"]
assert trace_id is not None
langfuse_client.flush()
time.sleep(2)

View file

@ -1,7 +1,7 @@
# What is this?
## Tests slack alerting on proxy logging object
import sys, json
import sys, json, uuid, random
import os
import io, asyncio
from datetime import datetime, timedelta
@ -22,6 +22,7 @@ import unittest.mock
from unittest.mock import AsyncMock
import pytest
from litellm.router import AlertingConfig, Router
from litellm.proxy._types import CallInfo
@pytest.mark.parametrize(
@ -123,7 +124,9 @@ from datetime import datetime, timedelta
@pytest.fixture
def slack_alerting():
return SlackAlerting(alerting_threshold=1, internal_usage_cache=DualCache())
return SlackAlerting(
alerting_threshold=1, internal_usage_cache=DualCache(), alerting=["slack"]
)
# Test for hanging LLM responses
@ -161,7 +164,10 @@ async def test_budget_alerts_crossed(slack_alerting):
user_current_spend = 101
with patch.object(slack_alerting, "send_alert", new=AsyncMock()) as mock_send_alert:
await slack_alerting.budget_alerts(
"user_budget", user_max_budget, user_current_spend
"user_budget",
user_info=CallInfo(
token="", spend=user_current_spend, max_budget=user_max_budget
),
)
mock_send_alert.assert_awaited_once()
@ -173,12 +179,18 @@ async def test_budget_alerts_crossed_again(slack_alerting):
user_current_spend = 101
with patch.object(slack_alerting, "send_alert", new=AsyncMock()) as mock_send_alert:
await slack_alerting.budget_alerts(
"user_budget", user_max_budget, user_current_spend
"user_budget",
user_info=CallInfo(
token="", spend=user_current_spend, max_budget=user_max_budget
),
)
mock_send_alert.assert_awaited_once()
mock_send_alert.reset_mock()
await slack_alerting.budget_alerts(
"user_budget", user_max_budget, user_current_spend
"user_budget",
user_info=CallInfo(
token="", spend=user_current_spend, max_budget=user_max_budget
),
)
mock_send_alert.assert_not_awaited()
@ -365,21 +377,23 @@ async def test_send_llm_exception_to_slack():
@pytest.mark.asyncio
async def test_send_daily_reports_ignores_zero_values():
router = MagicMock()
router.get_model_ids.return_value = ['model1', 'model2', 'model3']
router.get_model_ids.return_value = ["model1", "model2", "model3"]
slack_alerting = SlackAlerting(internal_usage_cache=MagicMock())
# model1:failed=None, model2:failed=0, model3:failed=10, model1:latency=0; model2:latency=0; model3:latency=None
slack_alerting.internal_usage_cache.async_batch_get_cache = AsyncMock(return_value=[None, 0, 10, 0, 0, None])
slack_alerting.internal_usage_cache.async_batch_get_cache = AsyncMock(
return_value=[None, 0, 10, 0, 0, None]
)
slack_alerting.internal_usage_cache.async_batch_set_cache = AsyncMock()
router.get_model_info.side_effect = lambda x: {"litellm_params": {"model": x}}
with patch.object(slack_alerting, 'send_alert', new=AsyncMock()) as mock_send_alert:
with patch.object(slack_alerting, "send_alert", new=AsyncMock()) as mock_send_alert:
result = await slack_alerting.send_daily_reports(router)
# Check that the send_alert method was called
mock_send_alert.assert_called_once()
message = mock_send_alert.call_args[1]['message']
message = mock_send_alert.call_args[1]["message"]
# Ensure the message includes only the non-zero, non-None metrics
assert "model3" in message
@ -393,15 +407,91 @@ async def test_send_daily_reports_ignores_zero_values():
@pytest.mark.asyncio
async def test_send_daily_reports_all_zero_or_none():
router = MagicMock()
router.get_model_ids.return_value = ['model1', 'model2', 'model3']
router.get_model_ids.return_value = ["model1", "model2", "model3"]
slack_alerting = SlackAlerting(internal_usage_cache=MagicMock())
slack_alerting.internal_usage_cache.async_batch_get_cache = AsyncMock(return_value=[None, 0, None, 0, None, 0])
slack_alerting.internal_usage_cache.async_batch_get_cache = AsyncMock(
return_value=[None, 0, None, 0, None, 0]
)
with patch.object(slack_alerting, 'send_alert', new=AsyncMock()) as mock_send_alert:
with patch.object(slack_alerting, "send_alert", new=AsyncMock()) as mock_send_alert:
result = await slack_alerting.send_daily_reports(router)
# Check that the send_alert method was not called
mock_send_alert.assert_not_called()
assert result == False
# test user budget crossed alert sent only once, even if user makes multiple calls
@pytest.mark.parametrize(
"alerting_type",
[
"token_budget",
"user_budget",
"team_budget",
"proxy_budget",
"projected_limit_exceeded",
],
)
@pytest.mark.asyncio
async def test_send_token_budget_crossed_alerts(alerting_type):
slack_alerting = SlackAlerting()
with patch.object(slack_alerting, "send_alert", new=AsyncMock()) as mock_send_alert:
user_info = {
"token": "50e55ca5bfbd0759697538e8d23c0cd5031f52d9e19e176d7233b20c7c4d3403",
"spend": 86,
"max_budget": 100,
"user_id": "ishaan@berri.ai",
"user_email": "ishaan@berri.ai",
"key_alias": "my-test-key",
"projected_exceeded_date": "10/20/2024",
"projected_spend": 200,
}
user_info = CallInfo(**user_info)
for _ in range(50):
await slack_alerting.budget_alerts(
type=alerting_type,
user_info=user_info,
)
mock_send_alert.assert_awaited_once()
@pytest.mark.parametrize(
"alerting_type",
[
"token_budget",
"user_budget",
"team_budget",
"proxy_budget",
"projected_limit_exceeded",
],
)
@pytest.mark.asyncio
async def test_webhook_alerting(alerting_type):
slack_alerting = SlackAlerting(alerting=["webhook"])
with patch.object(
slack_alerting, "send_webhook_alert", new=AsyncMock()
) as mock_send_alert:
user_info = {
"token": "50e55ca5bfbd0759697538e8d23c0cd5031f52d9e19e176d7233b20c7c4d3403",
"spend": 1,
"max_budget": 0,
"user_id": "ishaan@berri.ai",
"user_email": "ishaan@berri.ai",
"key_alias": "my-test-key",
"projected_exceeded_date": "10/20/2024",
"projected_spend": 200,
}
user_info = CallInfo(**user_info)
for _ in range(50):
await slack_alerting.budget_alerts(
type=alerting_type,
user_info=user_info,
)
mock_send_alert.assert_awaited_once()

View file

@ -16,6 +16,7 @@ from litellm.tests.test_streaming import streaming_format_tests
import json
import os
import tempfile
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history
litellm.num_retries = 3
litellm.cache = None
@ -98,7 +99,7 @@ def load_vertex_ai_credentials():
@pytest.mark.asyncio
async def get_response():
async def test_get_response():
load_vertex_ai_credentials()
prompt = '\ndef count_nums(arr):\n """\n Write a function count_nums which takes an array of integers and returns\n the number of elements which has a sum of digits > 0.\n If a number is negative, then its first signed digit will be negative:\n e.g. -123 has signed digits -1, 2, and 3.\n >>> count_nums([]) == 0\n >>> count_nums([-1, 11, -11]) == 1\n >>> count_nums([1, 1, 2]) == 3\n """\n'
try:
@ -376,9 +377,8 @@ def test_vertex_ai_stream():
print("making request", model)
response = completion(
model=model,
messages=[
{"role": "user", "content": "write 10 line code code for saying hi"}
],
messages=[{"role": "user", "content": "hello tell me a short story"}],
max_tokens=15,
stream=True,
)
completed_str = ""
@ -389,7 +389,7 @@ def test_vertex_ai_stream():
completed_str += content
assert type(content) == str
# pass
assert len(completed_str) > 4
assert len(completed_str) > 1
except litellm.RateLimitError as e:
pass
except Exception as e:
@ -595,30 +595,68 @@ def test_gemini_pro_vision_base64():
async def test_gemini_pro_function_calling(sync_mode):
try:
load_vertex_ai_credentials()
data = {
"model": "vertex_ai/gemini-pro",
"messages": [
litellm.set_verbose = True
messages = [
{
"role": "system",
"content": "Your name is Litellm Bot, you are a helpful assistant",
},
# User asks for their name and weather in San Francisco
{
"role": "user",
"content": "Call the submit_cities function with San Francisco and New York",
"content": "Hello, what is your name and can you tell me the weather?",
},
# Assistant replies with a tool call
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_123",
"type": "function",
"index": 0,
"function": {
"name": "get_weather",
"arguments": '{"location":"San Francisco, CA"}',
},
}
],
"tools": [
},
# The result of the tool call is added to the history
{
"role": "tool",
"tool_call_id": "call_123",
"name": "get_weather",
"content": "27 degrees celsius and clear in San Francisco, CA",
},
# Now the assistant can reply with the result of the tool call.
]
tools = [
{
"type": "function",
"function": {
"name": "submit_cities",
"description": "Submits a list of cities",
"name": "get_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"cities": {"type": "array", "items": {"type": "string"}}
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
}
},
"required": ["cities"],
"required": ["location"],
},
},
}
],
]
data = {
"model": "vertex_ai/gemini-1.5-pro-preview-0514",
"messages": messages,
"tools": tools,
}
if sync_mode:
response = litellm.completion(**data)
@ -638,7 +676,7 @@ async def test_gemini_pro_function_calling(sync_mode):
# gemini_pro_function_calling()
@pytest.mark.parametrize("sync_mode", [False, True])
@pytest.mark.parametrize("sync_mode", [True])
@pytest.mark.asyncio
async def test_gemini_pro_function_calling_streaming(sync_mode):
load_vertex_ai_credentials()
@ -713,7 +751,7 @@ async def test_gemini_pro_async_function_calling():
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"description": "Get the current weather in a given location.",
"parameters": {
"type": "object",
"properties": {
@ -743,8 +781,9 @@ async def test_gemini_pro_async_function_calling():
print(f"completion: {completion}")
assert completion.choices[0].message.content is None
assert len(completion.choices[0].message.tool_calls) == 1
except litellm.APIError as e:
pass
# except litellm.APIError as e:
# pass
except litellm.RateLimitError as e:
pass
except Exception as e:
@ -894,3 +933,45 @@ async def test_vertexai_aembedding():
# traceback.print_exc()
# raise e
# test_gemini_pro_vision_async()
def test_prompt_factory():
messages = [
{
"role": "system",
"content": "Your name is Litellm Bot, you are a helpful assistant",
},
# User asks for their name and weather in San Francisco
{
"role": "user",
"content": "Hello, what is your name and can you tell me the weather?",
},
# Assistant replies with a tool call
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_123",
"type": "function",
"index": 0,
"function": {
"name": "get_weather",
"arguments": '{"location":"San Francisco, CA"}',
},
}
],
},
# The result of the tool call is added to the history
{
"role": "tool",
"tool_call_id": "call_123",
"name": "get_weather",
"content": "27 degrees celsius and clear in San Francisco, CA",
},
# Now the assistant can reply with the result of the tool call.
]
translated_messages = _gemini_convert_messages_with_history(messages=messages)
print(f"\n\ntranslated_messages: {translated_messages}\ntranslated_messages")

View file

@ -206,6 +206,7 @@ def test_completion_bedrock_claude_sts_client_auth():
# test_completion_bedrock_claude_sts_client_auth()
@pytest.mark.skip(reason="We don't have Circle CI OIDC credentials as yet")
def test_completion_bedrock_claude_sts_oidc_auth():
print("\ncalling bedrock claude with oidc auth")
@ -244,7 +245,7 @@ def test_bedrock_extra_headers():
messages=messages,
max_tokens=10,
temperature=0.78,
extra_headers={"x-key": "x_key_value"}
extra_headers={"x-key": "x_key_value"},
)
# Add any assertions here to check the response
assert len(response.choices) > 0
@ -259,7 +260,7 @@ def test_bedrock_claude_3():
try:
litellm.set_verbose = True
data = {
"max_tokens": 2000,
"max_tokens": 100,
"stream": False,
"temperature": 0.3,
"messages": [
@ -282,6 +283,7 @@ def test_bedrock_claude_3():
}
response: ModelResponse = completion(
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
num_retries=3,
# messages=messages,
# max_tokens=10,
# temperature=0.78,

View file

@ -1,4 +1,4 @@
import sys, os
import sys, os, json
import traceback
from dotenv import load_dotenv
@ -7,7 +7,7 @@ import os, io
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the, system path
) # Adds the parent directory to the system path
import pytest
import litellm
from litellm import embedding, completion, completion_cost, Timeout
@ -38,7 +38,7 @@ def reset_callbacks():
@pytest.mark.skip(reason="Local test")
def test_response_model_none():
"""
Addresses - https://github.com/BerriAI/litellm/issues/2972
Addresses:https://github.com/BerriAI/litellm/issues/2972
"""
x = completion(
model="mymodel",
@ -278,7 +278,8 @@ def test_completion_claude_3_function_call():
model="anthropic/claude-3-opus-20240229",
messages=messages,
tools=tools,
tool_choice="auto",
tool_choice={"type": "tool", "name": "get_weather"},
extra_headers={"anthropic-beta": "tools-2024-05-16"},
)
# Add any assertions, here to check response args
print(response)
@ -1053,6 +1054,25 @@ def test_completion_azure_gpt4_vision():
# test_completion_azure_gpt4_vision()
@pytest.mark.parametrize("model", ["gpt-3.5-turbo", "gpt-4", "gpt-4o"])
def test_completion_openai_params(model):
litellm.drop_params = True
messages = [
{
"role": "user",
"content": """Generate JSON about Bill Gates: { "full_name": "", "title": "" }""",
}
]
response = completion(
model=model,
messages=messages,
response_format={"type": "json_object"},
)
print(f"response: {response}")
def test_completion_fireworks_ai():
try:
litellm.set_verbose = True
@ -1161,28 +1181,28 @@ HF Tests we should pass
# Test util to sort models to TGI, conv, None
def test_get_hf_task_for_model():
model = "glaiveai/glaive-coder-7b"
model_type = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
model_type, _ = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
print(f"model:{model}, model type: {model_type}")
assert model_type == "text-generation-inference"
model = "meta-llama/Llama-2-7b-hf"
model_type = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
model_type, _ = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
print(f"model:{model}, model type: {model_type}")
assert model_type == "text-generation-inference"
model = "facebook/blenderbot-400M-distill"
model_type = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
model_type, _ = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
print(f"model:{model}, model type: {model_type}")
assert model_type == "conversational"
model = "facebook/blenderbot-3B"
model_type = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
model_type, _ = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
print(f"model:{model}, model type: {model_type}")
assert model_type == "conversational"
# neither Conv or None
model = "roneneldan/TinyStories-3M"
model_type = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
model_type, _ = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
print(f"model:{model}, model type: {model_type}")
assert model_type == "text-generation"
@ -2300,36 +2320,30 @@ def test_completion_azure_deployment_id():
# test_completion_azure_deployment_id()
# Only works for local endpoint
# def test_completion_anthropic_openai_proxy():
# try:
# response = completion(
# model="custom_openai/claude-2",
# messages=messages,
# api_base="http://0.0.0.0:8000"
# )
# # Add any assertions here to check the response
# print(response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# test_completion_anthropic_openai_proxy()
import asyncio
def test_completion_replicate_llama3():
@pytest.mark.parametrize("sync_mode", [False, True])
@pytest.mark.asyncio
async def test_completion_replicate_llama3(sync_mode):
litellm.set_verbose = True
model_name = "replicate/meta/meta-llama-3-8b-instruct"
try:
if sync_mode:
response = completion(
model=model_name,
messages=messages,
)
else:
response = await litellm.acompletion(
model=model_name,
messages=messages,
)
print(f"ASYNC REPLICATE RESPONSE - {response}")
print(response)
# Add any assertions here to check the response
response_str = response["choices"][0]["message"]["content"]
print("RESPONSE STRING\n", response_str)
if type(response_str) != str:
pytest.fail(f"Error occurred: {e}")
assert isinstance(response, litellm.ModelResponse)
response_format_tests(response=response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@ -2670,14 +2684,29 @@ def response_format_tests(response: litellm.ModelResponse):
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.parametrize(
"model",
[
"bedrock/cohere.command-r-plus-v1:0",
"anthropic.claude-3-sonnet-20240229-v1:0",
"anthropic.claude-instant-v1",
"bedrock/ai21.j2-mid",
"mistral.mistral-7b-instruct-v0:2",
"bedrock/amazon.titan-tg1-large",
"meta.llama3-8b-instruct-v1:0",
"cohere.command-text-v14",
],
)
@pytest.mark.asyncio
async def test_completion_bedrock_command_r(sync_mode):
async def test_completion_bedrock_httpx_models(sync_mode, model):
litellm.set_verbose = True
if sync_mode:
response = completion(
model="bedrock/cohere.command-r-plus-v1:0",
model=model,
messages=[{"role": "user", "content": "Hey! how's it going?"}],
temperature=0.2,
max_tokens=200,
)
assert isinstance(response, litellm.ModelResponse)
@ -2685,8 +2714,10 @@ async def test_completion_bedrock_command_r(sync_mode):
response_format_tests(response=response)
else:
response = await litellm.acompletion(
model="bedrock/cohere.command-r-plus-v1:0",
model=model,
messages=[{"role": "user", "content": "Hey! how's it going?"}],
temperature=0.2,
max_tokens=100,
)
assert isinstance(response, litellm.ModelResponse)
@ -2722,69 +2753,12 @@ def test_completion_bedrock_titan_null_response():
pytest.fail(f"An error occurred - {str(e)}")
def test_completion_bedrock_titan():
try:
response = completion(
model="bedrock/amazon.titan-tg1-large",
messages=messages,
temperature=0.2,
max_tokens=200,
top_p=0.8,
logger_fn=logger_fn,
)
# Add any assertions here to check the response
print(response)
except RateLimitError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_bedrock_titan()
def test_completion_bedrock_claude():
print("calling claude")
try:
response = completion(
model="anthropic.claude-instant-v1",
messages=messages,
max_tokens=10,
temperature=0.1,
logger_fn=logger_fn,
)
# Add any assertions here to check the response
print(response)
except RateLimitError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_bedrock_claude()
def test_completion_bedrock_cohere():
print("calling bedrock cohere")
litellm.set_verbose = True
try:
response = completion(
model="bedrock/cohere.command-text-v14",
messages=[{"role": "user", "content": "hi"}],
temperature=0.1,
max_tokens=10,
stream=True,
)
# Add any assertions here to check the response
print(response)
for chunk in response:
print(chunk)
except RateLimitError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_bedrock_cohere()
@ -2807,23 +2781,6 @@ def test_completion_bedrock_cohere():
# pytest.fail(f"Error occurred: {e}")
# test_completion_bedrock_claude_stream()
# def test_completion_bedrock_ai21():
# try:
# litellm.set_verbose = False
# response = completion(
# model="bedrock/ai21.j2-mid",
# messages=messages,
# temperature=0.2,
# top_p=0.2,
# max_tokens=20
# )
# # Add any assertions here to check the response
# print(response)
# except RateLimitError:
# pass
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
######## Test VLLM ########
# def test_completion_vllm():
@ -3096,7 +3053,6 @@ def test_mistral_anyscale_stream():
print(chunk["choices"][0]["delta"].get("content", ""), end="")
# test_mistral_anyscale_stream()
# test_completion_anyscale_2()
# def test_completion_with_fallbacks_multiple_keys():
# print(f"backup key 1: {os.getenv('BACKUP_OPENAI_API_KEY_1')}")
@ -3246,6 +3202,7 @@ def test_completion_gemini():
response = completion(model=model_name, messages=messages)
# Add any assertions,here to check the response
print(response)
assert response.choices[0]["index"] == 0
except litellm.APIError as e:
pass
except Exception as e:

View file

@ -65,6 +65,42 @@ async def test_custom_pricing(sync_mode):
assert new_handler.response_cost == 0
def test_custom_pricing_as_completion_cost_param():
from litellm import ModelResponse, Choices, Message
from litellm.utils import Usage
resp = ModelResponse(
id="chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac",
choices=[
Choices(
finish_reason=None,
index=0,
message=Message(
content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a",
role="assistant",
),
)
],
created=1700775391,
model="ft:gpt-3.5-turbo:my-org:custom_suffix:id",
object="chat.completion",
system_fingerprint=None,
usage=Usage(prompt_tokens=21, completion_tokens=17, total_tokens=38),
)
cost = litellm.completion_cost(
completion_response=resp,
custom_cost_per_token={
"input_cost_per_token": 1000,
"output_cost_per_token": 20,
},
)
expected_cost = 1000 * 21 + 17 * 20
assert round(cost, 5) == round(expected_cost, 5)
def test_get_gpt3_tokens():
max_tokens = get_max_tokens("gpt-3.5-turbo")
print(max_tokens)

View file

@ -5,7 +5,6 @@
import sys, os
import traceback
from dotenv import load_dotenv
from pydantic import ConfigDict
load_dotenv()
import os, io
@ -14,36 +13,21 @@ sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the, system path
import pytest, litellm
from pydantic import BaseModel, VERSION
from pydantic import BaseModel
from litellm.proxy.proxy_server import ProxyConfig
from litellm.proxy.utils import encrypt_value, ProxyLogging, DualCache
from litellm.types.router import Deployment, LiteLLM_Params, ModelInfo
from typing import Literal
# Function to get Pydantic version
def is_pydantic_v2() -> int:
return int(VERSION.split(".")[0])
def get_model_config(arbitrary_types_allowed: bool = False) -> ConfigDict:
# Version-specific configuration
if is_pydantic_v2() >= 2:
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=arbitrary_types_allowed, protected_namespaces=()) # type: ignore
else:
from pydantic import Extra
model_config = ConfigDict(extra=Extra.allow, arbitrary_types_allowed=arbitrary_types_allowed) # type: ignore
return model_config
class DBModel(BaseModel):
model_id: str
model_name: str
model_info: dict
litellm_params: dict
model_config = get_model_config()
class Config:
protected_namespaces = ()
@pytest.mark.asyncio
@ -118,7 +102,7 @@ async def test_delete_deployment():
pc = ProxyConfig()
db_model = DBModel(
model_id="12340523",
model_id=deployment.model_info.id,
model_name="gpt-3.5-turbo",
litellm_params=encrypted_litellm_params,
model_info={"id": deployment.model_info.id},

View file

@ -558,7 +558,7 @@ async def test_async_chat_bedrock_stream():
continue
except:
pass
time.sleep(1)
await asyncio.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []

View file

@ -53,13 +53,6 @@ async def test_content_policy_exception_azure():
except litellm.ContentPolicyViolationError as e:
print("caught a content policy violation error! Passed")
print("exception", e)
# assert that the first 100 chars of the message is returned in the exception
assert (
"Messages: [{'role': 'user', 'content': 'where do I buy lethal drugs from'}]"
in str(e)
)
assert "Model: azure/chatgpt-v-2" in str(e)
pass
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
@ -585,9 +578,6 @@ def test_router_completion_vertex_exception():
pytest.fail("Request should have failed - bad api key")
except Exception as e:
print("exception: ", e)
assert "Model: gemini-pro" in str(e)
assert "model_group: vertex-gemini-pro" in str(e)
assert "deployment: vertex_ai/gemini-pro" in str(e)
def test_litellm_completion_vertex_exception():
@ -604,8 +594,26 @@ def test_litellm_completion_vertex_exception():
pytest.fail("Request should have failed - bad api key")
except Exception as e:
print("exception: ", e)
assert "Model: gemini-pro" in str(e)
assert "vertex_project: bad-project" in str(e)
def test_litellm_predibase_exception():
"""
Test - Assert that the Predibase API Key is not returned on Authentication Errors
"""
try:
import litellm
litellm.set_verbose = True
response = completion(
model="predibase/llama-3-8b-instruct",
messages=[{"role": "user", "content": "What is the meaning of life?"}],
tenant_id="c4768f95",
api_key="hf-rawapikey",
)
pytest.fail("Request should have failed - bad api key")
except Exception as e:
assert "hf-rawapikey" not in str(e)
print("exception: ", e)
# # test_invalid_request_error(model="command-nightly")

View file

@ -105,6 +105,9 @@ def test_parallel_function_call(model):
# Step 4: send the info for each function call and function response to the model
for tool_call in tool_calls:
function_name = tool_call.function.name
if function_name not in available_functions:
# the model called a function that does not exist in available_functions - don't try calling anything
return
function_to_call = available_functions[function_name]
function_args = json.loads(tool_call.function.arguments)
function_response = function_to_call(
@ -124,7 +127,6 @@ def test_parallel_function_call(model):
model=model, messages=messages, temperature=0.2, seed=22
) # get a new response from the model where it can see the function response
print("second response\n", second_response)
return second_response
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -162,6 +162,39 @@ async def test_aimage_generation_bedrock_with_optional_params():
print(f"response: {response}")
except litellm.RateLimitError as e:
pass
except litellm.ContentPolicyViolationError:
pass # Azure randomly raises these errors skip when they occur
except Exception as e:
if "Your task failed as a result of our safety system." in str(e):
pass
else:
pytest.fail(f"An exception occurred - {str(e)}")
@pytest.mark.asyncio
async def test_aimage_generation_vertex_ai():
from test_amazing_vertex_completion import load_vertex_ai_credentials
litellm.set_verbose = True
load_vertex_ai_credentials()
try:
response = await litellm.aimage_generation(
prompt="An olympic size swimming pool",
model="vertex_ai/imagegeneration@006",
vertex_ai_project="adroit-crow-413218",
vertex_ai_location="us-central1",
n=1,
)
assert response.data is not None
assert len(response.data) > 0
for d in response.data:
assert isinstance(d, litellm.ImageObject)
print("data in response.data", d)
assert d.b64_json is not None
except litellm.RateLimitError as e:
pass
except litellm.ContentPolicyViolationError:
pass # Azure randomly raises these errors - skip when they occur
except Exception as e:

View file

@ -1,7 +1,7 @@
#### What this tests ####
# Unit tests for JWT-Auth
import sys, os, asyncio, time, random
import sys, os, asyncio, time, random, uuid
import traceback
from dotenv import load_dotenv
@ -24,6 +24,7 @@ public_key = {
"alg": "RS256",
}
def test_load_config_with_custom_role_names():
config = {
"general_settings": {
@ -77,7 +78,8 @@ async def test_token_single_public_key():
== "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ"
)
@pytest.mark.parametrize('audience', [None, "litellm-proxy"])
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
@pytest.mark.asyncio
async def test_valid_invalid_token(audience):
"""
@ -90,7 +92,7 @@ async def test_valid_invalid_token(audience):
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.backends import default_backend
os.environ.pop('JWT_AUDIENCE', None)
os.environ.pop("JWT_AUDIENCE", None)
if audience:
os.environ["JWT_AUDIENCE"] = audience
@ -138,7 +140,7 @@ async def test_valid_invalid_token(audience):
"sub": "user123",
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm-proxy-admin",
"aud": audience
"aud": audience,
}
# Generate the JWT token
@ -166,7 +168,7 @@ async def test_valid_invalid_token(audience):
"sub": "user123",
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm-NO-SCOPE",
"aud": audience
"aud": audience,
}
# Generate the JWT token
@ -183,6 +185,7 @@ async def test_valid_invalid_token(audience):
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
@pytest.fixture
def prisma_client():
import litellm
@ -205,7 +208,7 @@ def prisma_client():
return prisma_client
@pytest.mark.parametrize('audience', [None, "litellm-proxy"])
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
@pytest.mark.asyncio
async def test_team_token_output(prisma_client, audience):
import jwt, json
@ -222,7 +225,7 @@ async def test_team_token_output(prisma_client, audience):
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
await litellm.proxy.proxy_server.prisma_client.connect()
os.environ.pop('JWT_AUDIENCE', None)
os.environ.pop("JWT_AUDIENCE", None)
if audience:
os.environ["JWT_AUDIENCE"] = audience
@ -261,7 +264,7 @@ async def test_team_token_output(prisma_client, audience):
jwt_handler.user_api_key_cache = cache
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_id_jwt_field="client_id")
# VALID TOKEN
## GENERATE A TOKEN
@ -274,7 +277,7 @@ async def test_team_token_output(prisma_client, audience):
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_team",
"client_id": team_id,
"aud": audience
"aud": audience,
}
# Generate the JWT token
@ -289,7 +292,7 @@ async def test_team_token_output(prisma_client, audience):
"sub": "user123",
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_proxy_admin",
"aud": audience
"aud": audience,
}
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")
@ -315,7 +318,13 @@ async def test_team_token_output(prisma_client, audience):
## 1. INITIAL TEAM CALL - should fail
# use generated key to auth in
setattr(litellm.proxy.proxy_server, "general_settings", {"enable_jwt_auth": True})
setattr(
litellm.proxy.proxy_server,
"general_settings",
{
"enable_jwt_auth": True,
},
)
setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler)
try:
result = await user_api_key_auth(request=request, api_key=bearer_token)
@ -358,9 +367,22 @@ async def test_team_token_output(prisma_client, audience):
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
@pytest.mark.parametrize('audience', [None, "litellm-proxy"])
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
@pytest.mark.parametrize(
"team_id_set, default_team_id",
[(True, False), (False, True)],
)
@pytest.mark.parametrize("user_id_upsert", [True, False])
@pytest.mark.asyncio
async def test_user_token_output(prisma_client, audience):
async def test_user_token_output(
prisma_client, audience, team_id_set, default_team_id, user_id_upsert
):
import uuid
args = locals()
print(f"received args - {args}")
if default_team_id:
default_team_id = "team_id_12344_{}".format(uuid.uuid4())
"""
- If user required, check if it exists
- fail initial request (when user doesn't exist)
@ -373,7 +395,12 @@ async def test_user_token_output(prisma_client, audience):
from cryptography.hazmat.backends import default_backend
from fastapi import Request
from starlette.datastructures import URL
from litellm.proxy.proxy_server import user_api_key_auth, new_team, new_user
from litellm.proxy.proxy_server import (
user_api_key_auth,
new_team,
new_user,
user_info,
)
from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth, NewUserRequest
import litellm
import uuid
@ -381,7 +408,7 @@ async def test_user_token_output(prisma_client, audience):
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
await litellm.proxy.proxy_server.prisma_client.connect()
os.environ.pop('JWT_AUDIENCE', None)
os.environ.pop("JWT_AUDIENCE", None)
if audience:
os.environ["JWT_AUDIENCE"] = audience
@ -423,6 +450,11 @@ async def test_user_token_output(prisma_client, audience):
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth()
jwt_handler.litellm_jwtauth.user_id_jwt_field = "sub"
jwt_handler.litellm_jwtauth.team_id_default = default_team_id
jwt_handler.litellm_jwtauth.user_id_upsert = user_id_upsert
if team_id_set:
jwt_handler.litellm_jwtauth.team_id_jwt_field = "client_id"
# VALID TOKEN
## GENERATE A TOKEN
@ -436,7 +468,7 @@ async def test_user_token_output(prisma_client, audience):
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_team",
"client_id": team_id,
"aud": audience
"aud": audience,
}
# Generate the JWT token
@ -451,7 +483,7 @@ async def test_user_token_output(prisma_client, audience):
"sub": user_id,
"exp": expiration_time, # set the token to expire in 10 minutes
"scope": "litellm_proxy_admin",
"aud": audience
"aud": audience,
}
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")
@ -503,6 +535,16 @@ async def test_user_token_output(prisma_client, audience):
),
user_api_key_dict=result,
)
if default_team_id:
await new_team(
data=NewTeamRequest(
team_id=default_team_id,
tpm_limit=100,
rpm_limit=99,
models=["gpt-3.5-turbo", "gpt-4"],
),
user_api_key_dict=result,
)
except Exception as e:
pytest.fail(f"This should not fail - {str(e)}")
@ -513,11 +555,23 @@ async def test_user_token_output(prisma_client, audience):
team_result: UserAPIKeyAuth = await user_api_key_auth(
request=request, api_key=bearer_token
)
if user_id_upsert == False:
pytest.fail(f"User doesn't exist. this should fail")
except Exception as e:
pass
## 4. Create user
if user_id_upsert:
## check if user already exists
try:
bearer_token = "Bearer " + admin_token
request._url = URL(url="/team/new")
result = await user_api_key_auth(request=request, api_key=bearer_token)
await user_info(user_id=user_id)
except Exception as e:
pytest.fail(f"This should not fail - {str(e)}")
else:
try:
bearer_token = "Bearer " + admin_token
@ -543,6 +597,7 @@ async def test_user_token_output(prisma_client, audience):
## 6. ASSERT USER_API_KEY_AUTH format (used for tpm/rpm limiting in parallel_request_limiter.py AND cost tracking)
if team_id_set or default_team_id is not None:
assert team_result.team_tpm_limit == 100
assert team_result.team_rpm_limit == 99
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]

View file

@ -23,6 +23,7 @@ import sys, os
import traceback
from dotenv import load_dotenv
from fastapi import Request
from fastapi.routing import APIRoute
from datetime import datetime
load_dotenv()
@ -51,6 +52,13 @@ from litellm.proxy.proxy_server import (
user_info,
info_key_fn,
new_team,
chat_completion,
completion,
embeddings,
image_generation,
audio_transcriptions,
moderations,
model_list,
)
from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend
from litellm._logging import verbose_proxy_logger
@ -146,7 +154,38 @@ async def test_new_user_response(prisma_client):
pytest.fail(f"Got exception {e}")
def test_generate_and_call_with_valid_key(prisma_client):
@pytest.mark.parametrize(
"api_route", [
# chat_completion
APIRoute(path="/engines/{model}/chat/completions", endpoint=chat_completion),
APIRoute(path="/openai/deployments/{model}/chat/completions", endpoint=chat_completion),
APIRoute(path="/chat/completions", endpoint=chat_completion),
APIRoute(path="/v1/chat/completions", endpoint=chat_completion),
# completion
APIRoute(path="/completions", endpoint=completion),
APIRoute(path="/v1/completions", endpoint=completion),
APIRoute(path="/engines/{model}/completions", endpoint=completion),
APIRoute(path="/openai/deployments/{model}/completions", endpoint=completion),
# embeddings
APIRoute(path="/v1/embeddings", endpoint=embeddings),
APIRoute(path="/embeddings", endpoint=embeddings),
APIRoute(path="/openai/deployments/{model}/embeddings", endpoint=embeddings),
# image generation
APIRoute(path="/v1/images/generations", endpoint=image_generation),
APIRoute(path="/images/generations", endpoint=image_generation),
# audio transcriptions
APIRoute(path="/v1/audio/transcriptions", endpoint=audio_transcriptions),
APIRoute(path="/audio/transcriptions", endpoint=audio_transcriptions),
# moderations
APIRoute(path="/v1/moderations", endpoint=moderations),
APIRoute(path="/moderations", endpoint=moderations),
# model_list
APIRoute(path= "/v1/models", endpoint=model_list),
APIRoute(path= "/models", endpoint=model_list),
],
ids=lambda route: str(dict(route=route.endpoint.__name__, path=route.path)),
)
def test_generate_and_call_with_valid_key(prisma_client, api_route):
# 1. Generate a Key, and use it to make a call
print("prisma client=", prisma_client)
@ -181,8 +220,12 @@ def test_generate_and_call_with_valid_key(prisma_client):
)
print("token from prisma", value_from_prisma)
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
request = Request({
"type": "http",
"route": api_route,
"path": api_route.path,
"headers": [("Authorization", bearer_token)]
})
# use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token)

View file

@ -705,7 +705,7 @@ async def test_lowest_latency_routing_first_pick():
) # type: ignore
deployments = {}
for _ in range(5):
for _ in range(10):
response = await router.acompletion(
model="azure-model", messages=[{"role": "user", "content": "hello"}]
)

View file

@ -28,6 +28,37 @@ from datetime import datetime
## On Request failure
@pytest.mark.asyncio
async def test_global_max_parallel_requests():
"""
Test if ParallelRequestHandler respects 'global_max_parallel_requests'
data["metadata"]["global_max_parallel_requests"]
"""
global_max_parallel_requests = 0
_api_key = "sk-12345"
_api_key = hash_token("sk-12345")
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=100)
local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler()
for _ in range(3):
try:
await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=local_cache,
data={
"metadata": {
"global_max_parallel_requests": global_max_parallel_requests
}
},
call_type="",
)
pytest.fail("Expected call to fail")
except Exception as e:
pass
@pytest.mark.asyncio
async def test_pre_call_hook():
"""

View file

@ -0,0 +1,138 @@
# Test the following scenarios:
# 1. Generate a Key, and use it to make a call
import sys, os
import traceback
from dotenv import load_dotenv
from fastapi import Request
from datetime import datetime
load_dotenv()
import os, io, time
# this file is to test litellm/proxy
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest, logging, asyncio
import litellm, asyncio
from litellm.proxy.proxy_server import token_counter
from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend
from litellm._logging import verbose_proxy_logger
verbose_proxy_logger.setLevel(level=logging.DEBUG)
from litellm.proxy._types import TokenCountRequest, TokenCountResponse
from litellm import Router
@pytest.mark.asyncio
async def test_vLLM_token_counting():
"""
Test Token counter for vLLM models
- User passes model="special-alias"
- token_counter should infer that special_alias -> maps to wolfram/miquliz-120b-v2.0
-> token counter should use hugging face tokenizer
"""
llm_router = Router(
model_list=[
{
"model_name": "special-alias",
"litellm_params": {
"model": "openai/wolfram/miquliz-120b-v2.0",
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
},
}
]
)
setattr(litellm.proxy.proxy_server, "llm_router", llm_router)
response = await token_counter(
request=TokenCountRequest(
model="special-alias",
messages=[{"role": "user", "content": "hello"}],
)
)
print("response: ", response)
assert (
response.tokenizer_type == "huggingface_tokenizer"
) # SHOULD use the hugging face tokenizer
assert response.model_used == "wolfram/miquliz-120b-v2.0"
@pytest.mark.asyncio
async def test_token_counting_model_not_in_model_list():
"""
Test Token counter - when a model is not in model_list
-> should use the default OpenAI tokenizer
"""
llm_router = Router(
model_list=[
{
"model_name": "gpt-4",
"litellm_params": {
"model": "gpt-4",
},
}
]
)
setattr(litellm.proxy.proxy_server, "llm_router", llm_router)
response = await token_counter(
request=TokenCountRequest(
model="special-alias",
messages=[{"role": "user", "content": "hello"}],
)
)
print("response: ", response)
assert (
response.tokenizer_type == "openai_tokenizer"
) # SHOULD use the OpenAI tokenizer
assert response.model_used == "special-alias"
@pytest.mark.asyncio
async def test_gpt_token_counting():
"""
Test Token counter
-> should work for gpt-4
"""
llm_router = Router(
model_list=[
{
"model_name": "gpt-4",
"litellm_params": {
"model": "gpt-4",
},
}
]
)
setattr(litellm.proxy.proxy_server, "llm_router", llm_router)
response = await token_counter(
request=TokenCountRequest(
model="gpt-4",
messages=[{"role": "user", "content": "hello"}],
)
)
print("response: ", response)
assert (
response.tokenizer_type == "openai_tokenizer"
) # SHOULD use the OpenAI tokenizer
assert response.request_model == "gpt-4"

View file

@ -0,0 +1,64 @@
#### What this tests ####
# This tests calling router with fallback models
import sys, os, time
import traceback, asyncio
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm import Router
from litellm.integrations.custom_logger import CustomLogger
import openai, httpx
@pytest.mark.asyncio
async def test_cooldown_badrequest_error():
"""
Test 1. It SHOULD NOT cooldown a deployment on a BadRequestError
"""
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
}
],
debug_level="DEBUG",
set_verbose=True,
cooldown_time=300,
num_retries=0,
allowed_fails=0,
)
# Act & Assert
try:
response = await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "gm"}],
bad_param=200,
)
except:
pass
await asyncio.sleep(3) # wait for deployment to get cooled-down
response = await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "gm"}],
mock_response="hello",
)
assert response is not None
print(response)

View file

@ -82,7 +82,7 @@ def test_async_fallbacks(caplog):
# Define the expected log messages
# - error request, falling back notice, success notice
expected_logs = [
"litellm.acompletion(model=gpt-3.5-turbo)\x1b[31m Exception OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: bad-key. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}} \nModel: gpt-3.5-turbo\nAPI Base: https://api.openai.com\nMessages: [{'content': 'Hello, how are you?', 'role': 'user'}]\nmodel_group: gpt-3.5-turbo\n\ndeployment: gpt-3.5-turbo\n\x1b[0m",
"litellm.acompletion(model=gpt-3.5-turbo)\x1b[31m Exception OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: bad-key. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}\x1b[0m",
"Falling back to model_group = azure/gpt-3.5-turbo",
"litellm.acompletion(model=azure/chatgpt-v-2)\x1b[32m 200 OK\x1b[0m",
"Successful fallback b/w models.",

View file

@ -950,7 +950,63 @@ def test_vertex_ai_stream():
# test_completion_vertexai_stream_bad_key()
# def test_completion_replicate_stream():
@pytest.mark.parametrize("sync_mode", [False, True])
@pytest.mark.asyncio
async def test_completion_replicate_llama3_streaming(sync_mode):
litellm.set_verbose = True
model_name = "replicate/meta/meta-llama-3-8b-instruct"
try:
if sync_mode:
final_chunk: Optional[litellm.ModelResponse] = None
response: litellm.CustomStreamWrapper = completion( # type: ignore
model=model_name,
messages=messages,
max_tokens=10, # type: ignore
stream=True,
)
complete_response = ""
# Add any assertions here to check the response
has_finish_reason = False
for idx, chunk in enumerate(response):
final_chunk = chunk
chunk, finished = streaming_format_tests(idx, chunk)
if finished:
has_finish_reason = True
break
complete_response += chunk
if has_finish_reason == False:
raise Exception("finish reason not set")
if complete_response.strip() == "":
raise Exception("Empty response received")
else:
response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore
model=model_name,
messages=messages,
max_tokens=100, # type: ignore
stream=True,
)
complete_response = ""
# Add any assertions here to check the response
has_finish_reason = False
idx = 0
final_chunk: Optional[litellm.ModelResponse] = None
async for chunk in response:
final_chunk = chunk
chunk, finished = streaming_format_tests(idx, chunk)
if finished:
has_finish_reason = True
break
complete_response += chunk
idx += 1
if has_finish_reason == False:
raise Exception("finish reason not set")
if complete_response.strip() == "":
raise Exception("Empty response received")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# TEMP Commented out - replicate throwing an auth error
# try:
# litellm.set_verbose = True
@ -984,15 +1040,28 @@ def test_vertex_ai_stream():
# pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize("sync_mode", [True])
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.parametrize(
"model",
[
# "bedrock/cohere.command-r-plus-v1:0",
# "anthropic.claude-3-sonnet-20240229-v1:0",
# "anthropic.claude-instant-v1",
# "bedrock/ai21.j2-mid",
# "mistral.mistral-7b-instruct-v0:2",
# "bedrock/amazon.titan-tg1-large",
# "meta.llama3-8b-instruct-v1:0",
"cohere.command-text-v14"
],
)
@pytest.mark.asyncio
async def test_bedrock_cohere_command_r_streaming(sync_mode):
async def test_bedrock_httpx_streaming(sync_mode, model):
try:
litellm.set_verbose = True
if sync_mode:
final_chunk: Optional[litellm.ModelResponse] = None
response: litellm.CustomStreamWrapper = completion( # type: ignore
model="bedrock/cohere.command-r-plus-v1:0",
model=model,
messages=messages,
max_tokens=10, # type: ignore
stream=True,
@ -1013,7 +1082,7 @@ async def test_bedrock_cohere_command_r_streaming(sync_mode):
raise Exception("Empty response received")
else:
response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore
model="bedrock/cohere.command-r-plus-v1:0",
model=model,
messages=messages,
max_tokens=100, # type: ignore
stream=True,

View file

@ -174,7 +174,6 @@ def test_load_test_token_counter(model):
"""
import tiktoken
enc = tiktoken.get_encoding("cl100k_base")
messages = [{"role": "user", "content": text}] * 10
start_time = time.time()
@ -186,4 +185,4 @@ def test_load_test_token_counter(model):
total_time = end_time - start_time
print("model={}, total test time={}".format(model, total_time))
assert total_time < 2, f"Total encoding time > 1.5s, {total_time}"
assert total_time < 10, f"Total encoding time > 10s, {total_time}"

View file

@ -1,27 +1,10 @@
from typing import List, Optional, Union, Iterable, cast
from typing import List, Optional, Union, Iterable
from pydantic import ConfigDict, BaseModel, validator, VERSION
from pydantic import BaseModel, validator
from typing_extensions import Literal, Required, TypedDict
# Function to get Pydantic version
def is_pydantic_v2() -> int:
return int(VERSION.split(".")[0])
def get_model_config() -> ConfigDict:
# Version-specific configuration
if is_pydantic_v2() >= 2:
model_config = ConfigDict(extra="allow", protected_namespaces=()) # type: ignore
else:
from pydantic import Extra
model_config = ConfigDict(extra=Extra.allow) # type: ignore
return model_config
class ChatCompletionSystemMessageParam(TypedDict, total=False):
content: Required[str]
"""The contents of the system message."""
@ -208,4 +191,6 @@ class CompletionRequest(BaseModel):
api_key: Optional[str] = None
model_list: Optional[List[str]] = None
model_config = get_model_config()
class Config:
extra = "allow"
protected_namespaces = ()

View file

@ -1,23 +1,6 @@
from typing import List, Optional, Union
from pydantic import ConfigDict, BaseModel, validator, VERSION
# Function to get Pydantic version
def is_pydantic_v2() -> int:
return int(VERSION.split(".")[0])
def get_model_config(arbitrary_types_allowed: bool = False) -> ConfigDict:
# Version-specific configuration
if is_pydantic_v2() >= 2:
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=arbitrary_types_allowed, protected_namespaces=()) # type: ignore
else:
from pydantic import Extra
model_config = ConfigDict(extra=Extra.allow, arbitrary_types_allowed=arbitrary_types_allowed) # type: ignore
return model_config
from pydantic import BaseModel, validator
class EmbeddingRequest(BaseModel):
@ -34,4 +17,7 @@ class EmbeddingRequest(BaseModel):
litellm_call_id: Optional[str] = None
litellm_logging_obj: Optional[dict] = None
logger_fn: Optional[str] = None
model_config = get_model_config()
class Config:
# allow kwargs
extra = "allow"

View file

@ -1,3 +0,0 @@
__all__ = ["openai"]
from . import openai

View file

@ -0,0 +1,53 @@
from typing import TypedDict, Any, Union, Optional, List, Literal, Dict
import json
from typing_extensions import (
Self,
Protocol,
TypeGuard,
override,
get_origin,
runtime_checkable,
Required,
)
class Field(TypedDict):
key: str
value: Dict[str, Any]
class FunctionCallArgs(TypedDict):
fields: Field
class FunctionResponse(TypedDict):
name: str
response: FunctionCallArgs
class FunctionCall(TypedDict):
name: str
args: FunctionCallArgs
class FileDataType(TypedDict):
mime_type: str
file_uri: str # the cloud storage uri of storing this file
class BlobType(TypedDict):
mime_type: Required[str]
data: Required[bytes]
class PartType(TypedDict, total=False):
text: str
inline_data: BlobType
file_data: FileDataType
function_call: FunctionCall
function_response: FunctionResponse
class ContentType(TypedDict, total=False):
role: Literal["user", "model"]
parts: Required[List[PartType]]

View file

@ -1,42 +1,19 @@
from typing import List, Optional, Union, Dict, Tuple, Literal, TypedDict
import httpx
from pydantic import (
ConfigDict,
BaseModel,
validator,
Field,
__version__ as pydantic_version,
VERSION,
)
from pydantic import BaseModel, validator, Field
from .completion import CompletionRequest
from .embedding import EmbeddingRequest
import uuid, enum
# Function to get Pydantic version
def is_pydantic_v2() -> int:
return int(VERSION.split(".")[0])
def get_model_config(arbitrary_types_allowed: bool = False) -> ConfigDict:
# Version-specific configuration
if is_pydantic_v2() >= 2:
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=arbitrary_types_allowed, protected_namespaces=()) # type: ignore
else:
from pydantic import Extra
model_config = ConfigDict(extra=Extra.allow, arbitrary_types_allowed=arbitrary_types_allowed) # type: ignore
return model_config
class ModelConfig(BaseModel):
model_name: str
litellm_params: Union[CompletionRequest, EmbeddingRequest]
tpm: int
rpm: int
model_config = get_model_config()
class Config:
protected_namespaces = ()
class RouterConfig(BaseModel):
@ -67,7 +44,8 @@ class RouterConfig(BaseModel):
"latency-based-routing",
] = "simple-shuffle"
model_config = get_model_config()
class Config:
protected_namespaces = ()
class UpdateRouterConfig(BaseModel):
@ -87,7 +65,8 @@ class UpdateRouterConfig(BaseModel):
fallbacks: Optional[List[dict]] = None
context_window_fallbacks: Optional[List[dict]] = None
model_config = get_model_config()
class Config:
protected_namespaces = ()
class ModelInfo(BaseModel):
@ -97,6 +76,9 @@ class ModelInfo(BaseModel):
db_model: bool = (
False # used for proxy - to separate models which are stored in the db vs. config.
)
base_model: Optional[str] = (
None # specify if the base model is azure/gpt-3.5-turbo etc for accurate cost tracking
)
def __init__(self, id: Optional[Union[str, int]] = None, **params):
if id is None:
@ -105,7 +87,8 @@ class ModelInfo(BaseModel):
id = str(id)
super().__init__(id=id, **params)
model_config = get_model_config()
class Config:
extra = "allow"
def __contains__(self, key):
# Define custom behavior for the 'in' operator
@ -200,14 +183,8 @@ class GenericLiteLLMParams(BaseModel):
max_retries = int(max_retries) # cast to int
super().__init__(max_retries=max_retries, **args, **params)
model_config = get_model_config(arbitrary_types_allowed=True)
if pydantic_version.startswith("1"):
# pydantic v2 warns about using a Config class.
# But without this, pydantic v1 will raise an error:
# RuntimeError: no validator found for <class 'openai.Timeout'>,
# see `arbitrary_types_allowed` in Config
# Putting arbitrary_types_allowed = True in the ConfigDict doesn't work in pydantic v1.
class Config:
extra = "allow"
arbitrary_types_allowed = True
def __contains__(self, key):
@ -267,15 +244,8 @@ class LiteLLM_Params(GenericLiteLLMParams):
max_retries = int(max_retries) # cast to int
super().__init__(max_retries=max_retries, **args, **params)
model_config = get_model_config(arbitrary_types_allowed=True)
if pydantic_version.startswith("1"):
# pydantic v2 warns about using a Config class.
# But without this, pydantic v1 will raise an error:
# RuntimeError: no validator found for <class 'openai.Timeout'>,
# see `arbitrary_types_allowed` in Config
# Putting arbitrary_types_allowed = True in the ConfigDict doesn't work in pydantic v1.
class Config:
extra = "allow"
arbitrary_types_allowed = True
def __contains__(self, key):
@ -306,7 +276,8 @@ class updateDeployment(BaseModel):
litellm_params: Optional[updateLiteLLMParams] = None
model_info: Optional[ModelInfo] = None
model_config = get_model_config()
class Config:
protected_namespaces = ()
class LiteLLMParamsTypedDict(TypedDict, total=False):
@ -380,7 +351,9 @@ class Deployment(BaseModel):
# if using pydantic v1
return self.dict(**kwargs)
model_config = get_model_config()
class Config:
extra = "allow"
protected_namespaces = ()
def __contains__(self, key):
# Define custom behavior for the 'in' operator

6
litellm/types/utils.py Normal file
View file

@ -0,0 +1,6 @@
from typing import List, Optional, Union, Dict, Tuple, Literal, TypedDict
class CostPerToken(TypedDict):
input_cost_per_token: float
output_cost_per_token: float

File diff suppressed because it is too large Load diff

View file

@ -234,6 +234,24 @@
"litellm_provider": "openai",
"mode": "chat"
},
"ft:davinci-002": {
"max_tokens": 16384,
"max_input_tokens": 16384,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000002,
"output_cost_per_token": 0.000002,
"litellm_provider": "text-completion-openai",
"mode": "completion"
},
"ft:babbage-002": {
"max_tokens": 16384,
"max_input_tokens": 16384,
"max_output_tokens": 4096,
"input_cost_per_token": 0.0000004,
"output_cost_per_token": 0.0000004,
"litellm_provider": "text-completion-openai",
"mode": "completion"
},
"text-embedding-3-large": {
"max_tokens": 8191,
"max_input_tokens": 8191,
@ -1385,6 +1403,24 @@
"mode": "completion",
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
"gemini/gemini-1.5-flash-latest": {
"max_tokens": 8192,
"max_input_tokens": 1000000,
"max_output_tokens": 8192,
"max_images_per_prompt": 3000,
"max_videos_per_prompt": 10,
"max_video_length": 1,
"max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1,
"max_pdf_size_mb": 30,
"input_cost_per_token": 0,
"output_cost_per_token": 0,
"litellm_provider": "vertex_ai-language-models",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
"gemini/gemini-pro": {
"max_tokens": 8192,
"max_input_tokens": 32760,
@ -1744,6 +1780,30 @@
"litellm_provider": "openrouter",
"mode": "chat"
},
"openrouter/openai/gpt-4o": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000005,
"output_cost_per_token": 0.000015,
"litellm_provider": "openrouter",
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_vision": true
},
"openrouter/openai/gpt-4o-2024-05-13": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000005,
"output_cost_per_token": 0.000015,
"litellm_provider": "openrouter",
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_vision": true
},
"openrouter/openai/gpt-4-vision-preview": {
"max_tokens": 130000,
"input_cost_per_token": 0.00001,
@ -2943,6 +3003,24 @@
"litellm_provider": "ollama",
"mode": "completion"
},
"ollama/llama3": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"max_output_tokens": 8192,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "ollama",
"mode": "chat"
},
"ollama/llama3:70b": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"max_output_tokens": 8192,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "ollama",
"mode": "chat"
},
"ollama/mistral": {
"max_tokens": 8192,
"max_input_tokens": 8192,
@ -2952,6 +3030,42 @@
"litellm_provider": "ollama",
"mode": "completion"
},
"ollama/mistral-7B-Instruct-v0.1": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"max_output_tokens": 8192,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "ollama",
"mode": "chat"
},
"ollama/mistral-7B-Instruct-v0.2": {
"max_tokens": 32768,
"max_input_tokens": 32768,
"max_output_tokens": 32768,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "ollama",
"mode": "chat"
},
"ollama/mixtral-8x7B-Instruct-v0.1": {
"max_tokens": 32768,
"max_input_tokens": 32768,
"max_output_tokens": 32768,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "ollama",
"mode": "chat"
},
"ollama/mixtral-8x22B-Instruct-v0.1": {
"max_tokens": 65536,
"max_input_tokens": 65536,
"max_output_tokens": 65536,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "ollama",
"mode": "chat"
},
"ollama/codellama": {
"max_tokens": 4096,
"max_input_tokens": 4096,

View file

@ -89,6 +89,7 @@ model_list:
litellm_params:
model: text-completion-openai/gpt-3.5-turbo-instruct
litellm_settings:
# set_verbose: True # Uncomment this if you want to see verbose logs; not recommended in production
drop_params: True
# max_budget: 100
# budget_duration: 30d

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "litellm"
version = "1.37.10"
version = "1.37.19"
description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"]
license = "MIT"
@ -79,7 +79,7 @@ requires = ["poetry-core", "wheel"]
build-backend = "poetry.core.masonry.api"
[tool.commitizen]
version = "1.37.10"
version = "1.37.19"
version_files = [
"pyproject.toml:^version"
]

View file

@ -1,6 +1,6 @@
# LITELLM PROXY DEPENDENCIES #
anyio==4.2.0 # openai + http req.
openai==1.14.3 # openai req.
openai==1.27.0 # openai req.
fastapi==0.111.0 # server dep
backoff==2.2.1 # server dep
pyyaml==6.0.0 # server dep

View file

@ -129,7 +129,7 @@ async def test_check_num_callbacks():
set(all_litellm_callbacks_1) - set(all_litellm_callbacks_2),
)
assert num_callbacks_1 == num_callbacks_2
assert abs(num_callbacks_1 - num_callbacks_2) <= 4
await asyncio.sleep(30)
@ -142,7 +142,7 @@ async def test_check_num_callbacks():
set(all_litellm_callbacks_3) - set(all_litellm_callbacks_2),
)
assert num_callbacks_1 == num_callbacks_2 == num_callbacks_3
assert abs(num_callbacks_3 - num_callbacks_2) <= 4
@pytest.mark.asyncio
@ -183,7 +183,7 @@ async def test_check_num_callbacks_on_lowest_latency():
set(all_litellm_callbacks_2) - set(all_litellm_callbacks_1),
)
assert num_callbacks_1 == num_callbacks_2
assert abs(num_callbacks_1 - num_callbacks_2) <= 4
await asyncio.sleep(30)
@ -196,7 +196,7 @@ async def test_check_num_callbacks_on_lowest_latency():
set(all_litellm_callbacks_3) - set(all_litellm_callbacks_2),
)
assert num_callbacks_1 == num_callbacks_2 == num_callbacks_3
assert abs(num_callbacks_2 - num_callbacks_3) <= 4
assert num_alerts_1 == num_alerts_2 == num_alerts_3

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -1 +1 @@
<!DOCTYPE html><html id="__next_error__"><head><meta charSet="utf-8"/><meta name="viewport" content="width=device-width, initial-scale=1"/><link rel="preload" as="script" fetchPriority="low" href="/ui/_next/static/chunks/webpack-de9c0fadf6a94b3b.js" crossorigin=""/><script src="/ui/_next/static/chunks/fd9d1056-f960ab1e6d32b002.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/69-04708d7d4a17c1ee.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/main-app-9b4fb13a7db53edf.js" async="" crossorigin=""></script><title>LiteLLM Dashboard</title><meta name="description" content="LiteLLM Proxy Admin UI"/><link rel="icon" href="/ui/favicon.ico" type="image/x-icon" sizes="16x16"/><meta name="next-size-adjust"/><script src="/ui/_next/static/chunks/polyfills-c67a75d1b6f99dc8.js" crossorigin="" noModule=""></script></head><body><script src="/ui/_next/static/chunks/webpack-de9c0fadf6a94b3b.js" crossorigin="" async=""></script><script>(self.__next_f=self.__next_f||[]).push([0]);self.__next_f.push([2,null])</script><script>self.__next_f.push([1,"1:HL[\"/ui/_next/static/media/c9a5bc6a7c948fb0-s.p.woff2\",\"font\",{\"crossOrigin\":\"\",\"type\":\"font/woff2\"}]\n2:HL[\"/ui/_next/static/css/f04e46b02318b660.css\",\"style\",{\"crossOrigin\":\"\"}]\n0:\"$L3\"\n"])</script><script>self.__next_f.push([1,"4:I[47690,[],\"\"]\n6:I[77831,[],\"\"]\n7:I[7926,[\"936\",\"static/chunks/2f6dbc85-052c4579f80d66ae.js\",\"884\",\"static/chunks/884-7576ee407a2ecbe6.js\",\"931\",\"static/chunks/app/page-6a39771cacf75ea6.js\"],\"\"]\n8:I[5613,[],\"\"]\n9:I[31778,[],\"\"]\nb:I[48955,[],\"\"]\nc:[]\n"])</script><script>self.__next_f.push([1,"3:[[[\"$\",\"link\",\"0\",{\"rel\":\"stylesheet\",\"href\":\"/ui/_next/static/css/f04e46b02318b660.css\",\"precedence\":\"next\",\"crossOrigin\":\"\"}]],[\"$\",\"$L4\",null,{\"buildId\":\"obp5wqVSVDMiDTC414cR8\",\"assetPrefix\":\"/ui\",\"initialCanonicalUrl\":\"/\",\"initialTree\":[\"\",{\"children\":[\"__PAGE__\",{}]},\"$undefined\",\"$undefined\",true],\"initialSeedData\":[\"\",{\"children\":[\"__PAGE__\",{},[\"$L5\",[\"$\",\"$L6\",null,{\"propsForComponent\":{\"params\":{}},\"Component\":\"$7\",\"isStaticGeneration\":true}],null]]},[null,[\"$\",\"html\",null,{\"lang\":\"en\",\"children\":[\"$\",\"body\",null,{\"className\":\"__className_c23dc8\",\"children\":[\"$\",\"$L8\",null,{\"parallelRouterKey\":\"children\",\"segmentPath\":[\"children\"],\"loading\":\"$undefined\",\"loadingStyles\":\"$undefined\",\"loadingScripts\":\"$undefined\",\"hasLoading\":false,\"error\":\"$undefined\",\"errorStyles\":\"$undefined\",\"errorScripts\":\"$undefined\",\"template\":[\"$\",\"$L9\",null,{}],\"templateStyles\":\"$undefined\",\"templateScripts\":\"$undefined\",\"notFound\":[[\"$\",\"title\",null,{\"children\":\"404: This page could not be found.\"}],[\"$\",\"div\",null,{\"style\":{\"fontFamily\":\"system-ui,\\\"Segoe UI\\\",Roboto,Helvetica,Arial,sans-serif,\\\"Apple Color Emoji\\\",\\\"Segoe UI Emoji\\\"\",\"height\":\"100vh\",\"textAlign\":\"center\",\"display\":\"flex\",\"flexDirection\":\"column\",\"alignItems\":\"center\",\"justifyContent\":\"center\"},\"children\":[\"$\",\"div\",null,{\"children\":[[\"$\",\"style\",null,{\"dangerouslySetInnerHTML\":{\"__html\":\"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}\"}}],[\"$\",\"h1\",null,{\"className\":\"next-error-h1\",\"style\":{\"display\":\"inline-block\",\"margin\":\"0 20px 0 0\",\"padding\":\"0 23px 0 0\",\"fontSize\":24,\"fontWeight\":500,\"verticalAlign\":\"top\",\"lineHeight\":\"49px\"},\"children\":\"404\"}],[\"$\",\"div\",null,{\"style\":{\"display\":\"inline-block\"},\"children\":[\"$\",\"h2\",null,{\"style\":{\"fontSize\":14,\"fontWeight\":400,\"lineHeight\":\"49px\",\"margin\":0},\"children\":\"This page could not be found.\"}]}]]}]}]],\"notFoundStyles\":[],\"styles\":null}]}]}],null]],\"initialHead\":[false,\"$La\"],\"globalErrorComponent\":\"$b\",\"missingSlots\":\"$Wc\"}]]\n"])</script><script>self.__next_f.push([1,"a:[[\"$\",\"meta\",\"0\",{\"name\":\"viewport\",\"content\":\"width=device-width, initial-scale=1\"}],[\"$\",\"meta\",\"1\",{\"charSet\":\"utf-8\"}],[\"$\",\"title\",\"2\",{\"children\":\"LiteLLM Dashboard\"}],[\"$\",\"meta\",\"3\",{\"name\":\"description\",\"content\":\"LiteLLM Proxy Admin UI\"}],[\"$\",\"link\",\"4\",{\"rel\":\"icon\",\"href\":\"/ui/favicon.ico\",\"type\":\"image/x-icon\",\"sizes\":\"16x16\"}],[\"$\",\"meta\",\"5\",{\"name\":\"next-size-adjust\"}]]\n5:null\n"])</script><script>self.__next_f.push([1,""])</script></body></html>
<!DOCTYPE html><html id="__next_error__"><head><meta charSet="utf-8"/><meta name="viewport" content="width=device-width, initial-scale=1"/><link rel="preload" as="script" fetchPriority="low" href="/ui/_next/static/chunks/webpack-de9c0fadf6a94b3b.js" crossorigin=""/><script src="/ui/_next/static/chunks/fd9d1056-f960ab1e6d32b002.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/69-04708d7d4a17c1ee.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/main-app-9b4fb13a7db53edf.js" async="" crossorigin=""></script><title>LiteLLM Dashboard</title><meta name="description" content="LiteLLM Proxy Admin UI"/><link rel="icon" href="/ui/favicon.ico" type="image/x-icon" sizes="16x16"/><meta name="next-size-adjust"/><script src="/ui/_next/static/chunks/polyfills-c67a75d1b6f99dc8.js" crossorigin="" noModule=""></script></head><body><script src="/ui/_next/static/chunks/webpack-de9c0fadf6a94b3b.js" crossorigin="" async=""></script><script>(self.__next_f=self.__next_f||[]).push([0]);self.__next_f.push([2,null])</script><script>self.__next_f.push([1,"1:HL[\"/ui/_next/static/media/c9a5bc6a7c948fb0-s.p.woff2\",\"font\",{\"crossOrigin\":\"\",\"type\":\"font/woff2\"}]\n2:HL[\"/ui/_next/static/css/f04e46b02318b660.css\",\"style\",{\"crossOrigin\":\"\"}]\n0:\"$L3\"\n"])</script><script>self.__next_f.push([1,"4:I[47690,[],\"\"]\n6:I[77831,[],\"\"]\n7:I[4858,[\"936\",\"static/chunks/2f6dbc85-052c4579f80d66ae.js\",\"884\",\"static/chunks/884-7576ee407a2ecbe6.js\",\"931\",\"static/chunks/app/page-f20fdea77aed85ba.js\"],\"\"]\n8:I[5613,[],\"\"]\n9:I[31778,[],\"\"]\nb:I[48955,[],\"\"]\nc:[]\n"])</script><script>self.__next_f.push([1,"3:[[[\"$\",\"link\",\"0\",{\"rel\":\"stylesheet\",\"href\":\"/ui/_next/static/css/f04e46b02318b660.css\",\"precedence\":\"next\",\"crossOrigin\":\"\"}]],[\"$\",\"$L4\",null,{\"buildId\":\"l-0LDfSCdaUCAbcLIx_QC\",\"assetPrefix\":\"/ui\",\"initialCanonicalUrl\":\"/\",\"initialTree\":[\"\",{\"children\":[\"__PAGE__\",{}]},\"$undefined\",\"$undefined\",true],\"initialSeedData\":[\"\",{\"children\":[\"__PAGE__\",{},[\"$L5\",[\"$\",\"$L6\",null,{\"propsForComponent\":{\"params\":{}},\"Component\":\"$7\",\"isStaticGeneration\":true}],null]]},[null,[\"$\",\"html\",null,{\"lang\":\"en\",\"children\":[\"$\",\"body\",null,{\"className\":\"__className_c23dc8\",\"children\":[\"$\",\"$L8\",null,{\"parallelRouterKey\":\"children\",\"segmentPath\":[\"children\"],\"loading\":\"$undefined\",\"loadingStyles\":\"$undefined\",\"loadingScripts\":\"$undefined\",\"hasLoading\":false,\"error\":\"$undefined\",\"errorStyles\":\"$undefined\",\"errorScripts\":\"$undefined\",\"template\":[\"$\",\"$L9\",null,{}],\"templateStyles\":\"$undefined\",\"templateScripts\":\"$undefined\",\"notFound\":[[\"$\",\"title\",null,{\"children\":\"404: This page could not be found.\"}],[\"$\",\"div\",null,{\"style\":{\"fontFamily\":\"system-ui,\\\"Segoe UI\\\",Roboto,Helvetica,Arial,sans-serif,\\\"Apple Color Emoji\\\",\\\"Segoe UI Emoji\\\"\",\"height\":\"100vh\",\"textAlign\":\"center\",\"display\":\"flex\",\"flexDirection\":\"column\",\"alignItems\":\"center\",\"justifyContent\":\"center\"},\"children\":[\"$\",\"div\",null,{\"children\":[[\"$\",\"style\",null,{\"dangerouslySetInnerHTML\":{\"__html\":\"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}\"}}],[\"$\",\"h1\",null,{\"className\":\"next-error-h1\",\"style\":{\"display\":\"inline-block\",\"margin\":\"0 20px 0 0\",\"padding\":\"0 23px 0 0\",\"fontSize\":24,\"fontWeight\":500,\"verticalAlign\":\"top\",\"lineHeight\":\"49px\"},\"children\":\"404\"}],[\"$\",\"div\",null,{\"style\":{\"display\":\"inline-block\"},\"children\":[\"$\",\"h2\",null,{\"style\":{\"fontSize\":14,\"fontWeight\":400,\"lineHeight\":\"49px\",\"margin\":0},\"children\":\"This page could not be found.\"}]}]]}]}]],\"notFoundStyles\":[],\"styles\":null}]}]}],null]],\"initialHead\":[false,\"$La\"],\"globalErrorComponent\":\"$b\",\"missingSlots\":\"$Wc\"}]]\n"])</script><script>self.__next_f.push([1,"a:[[\"$\",\"meta\",\"0\",{\"name\":\"viewport\",\"content\":\"width=device-width, initial-scale=1\"}],[\"$\",\"meta\",\"1\",{\"charSet\":\"utf-8\"}],[\"$\",\"title\",\"2\",{\"children\":\"LiteLLM Dashboard\"}],[\"$\",\"meta\",\"3\",{\"name\":\"description\",\"content\":\"LiteLLM Proxy Admin UI\"}],[\"$\",\"link\",\"4\",{\"rel\":\"icon\",\"href\":\"/ui/favicon.ico\",\"type\":\"image/x-icon\",\"sizes\":\"16x16\"}],[\"$\",\"meta\",\"5\",{\"name\":\"next-size-adjust\"}]]\n5:null\n"])</script><script>self.__next_f.push([1,""])</script></body></html>

View file

@ -1,7 +1,7 @@
2:I[77831,[],""]
3:I[7926,["936","static/chunks/2f6dbc85-052c4579f80d66ae.js","884","static/chunks/884-7576ee407a2ecbe6.js","931","static/chunks/app/page-6a39771cacf75ea6.js"],""]
3:I[4858,["936","static/chunks/2f6dbc85-052c4579f80d66ae.js","884","static/chunks/884-7576ee407a2ecbe6.js","931","static/chunks/app/page-f20fdea77aed85ba.js"],""]
4:I[5613,[],""]
5:I[31778,[],""]
0:["obp5wqVSVDMiDTC414cR8",[[["",{"children":["__PAGE__",{}]},"$undefined","$undefined",true],["",{"children":["__PAGE__",{},["$L1",["$","$L2",null,{"propsForComponent":{"params":{}},"Component":"$3","isStaticGeneration":true}],null]]},[null,["$","html",null,{"lang":"en","children":["$","body",null,{"className":"__className_c23dc8","children":["$","$L4",null,{"parallelRouterKey":"children","segmentPath":["children"],"loading":"$undefined","loadingStyles":"$undefined","loadingScripts":"$undefined","hasLoading":false,"error":"$undefined","errorStyles":"$undefined","errorScripts":"$undefined","template":["$","$L5",null,{}],"templateStyles":"$undefined","templateScripts":"$undefined","notFound":[["$","title",null,{"children":"404: This page could not be found."}],["$","div",null,{"style":{"fontFamily":"system-ui,\"Segoe UI\",Roboto,Helvetica,Arial,sans-serif,\"Apple Color Emoji\",\"Segoe UI Emoji\"","height":"100vh","textAlign":"center","display":"flex","flexDirection":"column","alignItems":"center","justifyContent":"center"},"children":["$","div",null,{"children":[["$","style",null,{"dangerouslySetInnerHTML":{"__html":"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}"}}],["$","h1",null,{"className":"next-error-h1","style":{"display":"inline-block","margin":"0 20px 0 0","padding":"0 23px 0 0","fontSize":24,"fontWeight":500,"verticalAlign":"top","lineHeight":"49px"},"children":"404"}],["$","div",null,{"style":{"display":"inline-block"},"children":["$","h2",null,{"style":{"fontSize":14,"fontWeight":400,"lineHeight":"49px","margin":0},"children":"This page could not be found."}]}]]}]}]],"notFoundStyles":[],"styles":null}]}]}],null]],[[["$","link","0",{"rel":"stylesheet","href":"/ui/_next/static/css/f04e46b02318b660.css","precedence":"next","crossOrigin":""}]],"$L6"]]]]
0:["l-0LDfSCdaUCAbcLIx_QC",[[["",{"children":["__PAGE__",{}]},"$undefined","$undefined",true],["",{"children":["__PAGE__",{},["$L1",["$","$L2",null,{"propsForComponent":{"params":{}},"Component":"$3","isStaticGeneration":true}],null]]},[null,["$","html",null,{"lang":"en","children":["$","body",null,{"className":"__className_c23dc8","children":["$","$L4",null,{"parallelRouterKey":"children","segmentPath":["children"],"loading":"$undefined","loadingStyles":"$undefined","loadingScripts":"$undefined","hasLoading":false,"error":"$undefined","errorStyles":"$undefined","errorScripts":"$undefined","template":["$","$L5",null,{}],"templateStyles":"$undefined","templateScripts":"$undefined","notFound":[["$","title",null,{"children":"404: This page could not be found."}],["$","div",null,{"style":{"fontFamily":"system-ui,\"Segoe UI\",Roboto,Helvetica,Arial,sans-serif,\"Apple Color Emoji\",\"Segoe UI Emoji\"","height":"100vh","textAlign":"center","display":"flex","flexDirection":"column","alignItems":"center","justifyContent":"center"},"children":["$","div",null,{"children":[["$","style",null,{"dangerouslySetInnerHTML":{"__html":"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}"}}],["$","h1",null,{"className":"next-error-h1","style":{"display":"inline-block","margin":"0 20px 0 0","padding":"0 23px 0 0","fontSize":24,"fontWeight":500,"verticalAlign":"top","lineHeight":"49px"},"children":"404"}],["$","div",null,{"style":{"display":"inline-block"},"children":["$","h2",null,{"style":{"fontSize":14,"fontWeight":400,"lineHeight":"49px","margin":0},"children":"This page could not be found."}]}]]}]}]],"notFoundStyles":[],"styles":null}]}]}],null]],[[["$","link","0",{"rel":"stylesheet","href":"/ui/_next/static/css/f04e46b02318b660.css","precedence":"next","crossOrigin":""}]],"$L6"]]]]
6:[["$","meta","0",{"name":"viewport","content":"width=device-width, initial-scale=1"}],["$","meta","1",{"charSet":"utf-8"}],["$","title","2",{"children":"LiteLLM Dashboard"}],["$","meta","3",{"name":"description","content":"LiteLLM Proxy Admin UI"}],["$","link","4",{"rel":"icon","href":"/ui/favicon.ico","type":"image/x-icon","sizes":"16x16"}],["$","meta","5",{"name":"next-size-adjust"}]]
1:null

View file

@ -23,12 +23,44 @@ import {
AccordionHeader,
AccordionList,
} from "@tremor/react";
import { TabPanel, TabPanels, TabGroup, TabList, Tab, Icon } from "@tremor/react";
import { getCallbacksCall, setCallbacksCall, serviceHealthCheck } from "./networking";
import { Modal, Form, Input, Select, Button as Button2, message } from "antd";
import { InformationCircleIcon, PencilAltIcon, PencilIcon, StatusOnlineIcon, TrashIcon, RefreshIcon } from "@heroicons/react/outline";
import {
TabPanel,
TabPanels,
TabGroup,
TabList,
Tab,
Icon,
} from "@tremor/react";
import {
getCallbacksCall,
setCallbacksCall,
getGeneralSettingsCall,
serviceHealthCheck,
updateConfigFieldSetting,
deleteConfigFieldSetting,
} from "./networking";
import {
Modal,
Form,
Input,
Select,
Button as Button2,
message,
InputNumber,
} from "antd";
import {
InformationCircleIcon,
PencilAltIcon,
PencilIcon,
StatusOnlineIcon,
TrashIcon,
RefreshIcon,
CheckCircleIcon,
XCircleIcon,
QuestionMarkCircleIcon,
} from "@heroicons/react/outline";
import StaticGenerationSearchParamsBailoutProvider from "next/dist/client/components/static-generation-searchparams-bailout-provider";
import AddFallbacks from "./add_fallbacks"
import AddFallbacks from "./add_fallbacks";
import openai from "openai";
import Paragraph from "antd/es/skeleton/Paragraph";
@ -36,7 +68,7 @@ interface GeneralSettingsPageProps {
accessToken: string | null;
userRole: string | null;
userID: string | null;
modelData: any
modelData: any;
}
async function testFallbackModelResponse(
@ -65,24 +97,39 @@ async function testFallbackModelResponse(
},
],
// @ts-ignore
mock_testing_fallbacks: true
mock_testing_fallbacks: true,
});
message.success(
<span>
Test model=<strong>{selectedModel}</strong>, received model=<strong>{response.model}</strong>.
See <a href="#" onClick={() => window.open('https://docs.litellm.ai/docs/proxy/reliability', '_blank')} style={{ textDecoration: 'underline', color: 'blue' }}>curl</a>
Test model=<strong>{selectedModel}</strong>, received model=
<strong>{response.model}</strong>. See{" "}
<a
href="#"
onClick={() =>
window.open(
"https://docs.litellm.ai/docs/proxy/reliability",
"_blank"
)
}
style={{ textDecoration: "underline", color: "blue" }}
>
curl
</a>
</span>
);
} catch (error) {
message.error(`Error occurred while generating model response. Please try again. Error: ${error}`, 20);
message.error(
`Error occurred while generating model response. Please try again. Error: ${error}`,
20
);
}
}
interface AccordionHeroProps {
selectedStrategy: string | null;
strategyArgs: routingStrategyArgs;
paramExplanation: { [key: string]: string }
paramExplanation: { [key: string]: string };
}
interface routingStrategyArgs {
@ -90,17 +137,30 @@ interface routingStrategyArgs {
lowest_latency_buffer?: number;
}
const defaultLowestLatencyArgs: routingStrategyArgs = {
"ttl": 3600,
"lowest_latency_buffer": 0
interface generalSettingsItem {
field_name: string;
field_type: string;
field_value: any;
field_description: string;
stored_in_db: boolean | null;
}
export const AccordionHero: React.FC<AccordionHeroProps> = ({ selectedStrategy, strategyArgs, paramExplanation }) => (
const defaultLowestLatencyArgs: routingStrategyArgs = {
ttl: 3600,
lowest_latency_buffer: 0,
};
export const AccordionHero: React.FC<AccordionHeroProps> = ({
selectedStrategy,
strategyArgs,
paramExplanation,
}) => (
<Accordion>
<AccordionHeader className="text-sm font-medium text-tremor-content-strong dark:text-dark-tremor-content-strong">Routing Strategy Specific Args</AccordionHeader>
<AccordionHeader className="text-sm font-medium text-tremor-content-strong dark:text-dark-tremor-content-strong">
Routing Strategy Specific Args
</AccordionHeader>
<AccordionBody>
{
selectedStrategy == "latency-based-routing" ?
{selectedStrategy == "latency-based-routing" ? (
<Card>
<Table>
<TableHead>
@ -114,13 +174,24 @@ export const AccordionHero: React.FC<AccordionHeroProps> = ({ selectedStrategy,
<TableRow key={param}>
<TableCell>
<Text>{param}</Text>
<p style={{fontSize: '0.65rem', color: '#808080', fontStyle: 'italic'}} className="mt-1">{paramExplanation[param]}</p>
<p
style={{
fontSize: "0.65rem",
color: "#808080",
fontStyle: "italic",
}}
className="mt-1"
>
{paramExplanation[param]}
</p>
</TableCell>
<TableCell>
<TextInput
name={param}
defaultValue={
typeof value === 'object' ? JSON.stringify(value, null, 2) : value.toString()
typeof value === "object"
? JSON.stringify(value, null, 2)
: value.toString()
}
/>
</TableCell>
@ -129,8 +200,9 @@ export const AccordionHero: React.FC<AccordionHeroProps> = ({ selectedStrategy,
</TableBody>
</Table>
</Card>
: <Text>No specific settings</Text>
}
) : (
<Text>No specific settings</Text>
)}
</AccordionBody>
</Accordion>
);
@ -139,26 +211,38 @@ const GeneralSettings: React.FC<GeneralSettingsPageProps> = ({
accessToken,
userRole,
userID,
modelData
modelData,
}) => {
const [routerSettings, setRouterSettings] = useState<{ [key: string]: any }>({});
const [routerSettings, setRouterSettings] = useState<{ [key: string]: any }>(
{}
);
const [generalSettingsDict, setGeneralSettingsDict] = useState<{
[key: string]: any;
}>({});
const [generalSettings, setGeneralSettings] = useState<generalSettingsItem[]>(
[]
);
const [isModalVisible, setIsModalVisible] = useState(false);
const [form] = Form.useForm();
const [selectedCallback, setSelectedCallback] = useState<string | null>(null);
const [selectedStrategy, setSelectedStrategy] = useState<string | null>(null)
const [strategySettings, setStrategySettings] = useState<routingStrategyArgs | null>(null);
const [selectedStrategy, setSelectedStrategy] = useState<string | null>(null);
const [strategySettings, setStrategySettings] =
useState<routingStrategyArgs | null>(null);
let paramExplanation: { [key: string]: string } = {
"routing_strategy_args": "(dict) Arguments to pass to the routing strategy",
"routing_strategy": "(string) Routing strategy to use",
"allowed_fails": "(int) Number of times a deployment can fail before being added to cooldown",
"cooldown_time": "(int) time in seconds to cooldown a deployment after failure",
"num_retries": "(int) Number of retries for failed requests. Defaults to 0.",
"timeout": "(float) Timeout for requests. Defaults to None.",
"retry_after": "(int) Minimum time to wait before retrying a failed request",
"ttl": "(int) Sliding window to look back over when calculating the average latency of a deployment. Default - 1 hour (in seconds).",
"lowest_latency_buffer": "(float) Shuffle between deployments within this % of the lowest latency. Default - 0 (i.e. always pick lowest latency)."
}
routing_strategy_args: "(dict) Arguments to pass to the routing strategy",
routing_strategy: "(string) Routing strategy to use",
allowed_fails:
"(int) Number of times a deployment can fail before being added to cooldown",
cooldown_time:
"(int) time in seconds to cooldown a deployment after failure",
num_retries: "(int) Number of retries for failed requests. Defaults to 0.",
timeout: "(float) Timeout for requests. Defaults to None.",
retry_after: "(int) Minimum time to wait before retrying a failed request",
ttl: "(int) Sliding window to look back over when calculating the average latency of a deployment. Default - 1 hour (in seconds).",
lowest_latency_buffer:
"(float) Shuffle between deployments within this % of the lowest latency. Default - 0 (i.e. always pick lowest latency).",
};
useEffect(() => {
if (!accessToken || !userRole || !userID) {
@ -169,6 +253,10 @@ const GeneralSettings: React.FC<GeneralSettingsPageProps> = ({
let router_settings = data.router_settings;
setRouterSettings(router_settings);
});
getGeneralSettingsCall(accessToken).then((data) => {
let general_settings = data;
setGeneralSettings(general_settings);
});
}, [accessToken, userRole, userID]);
const handleAddCallback = () => {
@ -190,8 +278,8 @@ const GeneralSettings: React.FC<GeneralSettingsPageProps> = ({
return;
}
console.log(`received key: ${key}`)
console.log(`routerSettings['fallbacks']: ${routerSettings['fallbacks']}`)
console.log(`received key: ${key}`);
console.log(`routerSettings['fallbacks']: ${routerSettings["fallbacks"]}`);
routerSettings["fallbacks"].map((dict: { [key: string]: any }) => {
// Check if the dictionary has the specified key and delete it if present
@ -202,19 +290,74 @@ const GeneralSettings: React.FC<GeneralSettingsPageProps> = ({
});
const payload = {
router_settings: routerSettings
router_settings: routerSettings,
};
try {
await setCallbacksCall(accessToken, payload);
setRouterSettings({ ...routerSettings });
setSelectedStrategy(routerSettings["routing_strategy"])
setSelectedStrategy(routerSettings["routing_strategy"]);
message.success("Router settings updated successfully");
} catch (error) {
message.error("Failed to update router settings: " + error, 20);
}
};
const handleInputChange = (fieldName: string, newValue: any) => {
// Update the value in the state
const updatedSettings = generalSettings.map((setting) =>
setting.field_name === fieldName
? { ...setting, field_value: newValue }
: setting
);
setGeneralSettings(updatedSettings);
};
const handleUpdateField = (fieldName: string, idx: number) => {
if (!accessToken) {
return;
}
let fieldValue = generalSettings[idx].field_value;
if (fieldValue == null || fieldValue == undefined) {
return;
}
try {
updateConfigFieldSetting(accessToken, fieldName, fieldValue);
// update value in state
const updatedSettings = generalSettings.map((setting) =>
setting.field_name === fieldName
? { ...setting, stored_in_db: true }
: setting
);
setGeneralSettings(updatedSettings);
} catch (error) {
// do something
}
};
const handleResetField = (fieldName: string, idx: number) => {
if (!accessToken) {
return;
}
try {
deleteConfigFieldSetting(accessToken, fieldName);
// update value in state
const updatedSettings = generalSettings.map((setting) =>
setting.field_name === fieldName
? { ...setting, stored_in_db: null, field_value: null }
: setting
);
setGeneralSettings(updatedSettings);
} catch (error) {
// do something
}
};
const handleSaveChanges = (router_settings: any) => {
if (!accessToken) {
return;
@ -223,39 +366,55 @@ const GeneralSettings: React.FC<GeneralSettingsPageProps> = ({
console.log("router_settings", router_settings);
const updatedVariables = Object.fromEntries(
Object.entries(router_settings).map(([key, value]) => {
if (key !== 'routing_strategy_args' && key !== "routing_strategy") {
return [key, (document.querySelector(`input[name="${key}"]`) as HTMLInputElement)?.value || value];
}
else if (key == "routing_strategy") {
return [key, selectedStrategy]
}
else if (key == "routing_strategy_args" && selectedStrategy == "latency-based-routing") {
let setRoutingStrategyArgs: routingStrategyArgs = {}
Object.entries(router_settings)
.map(([key, value]) => {
if (key !== "routing_strategy_args" && key !== "routing_strategy") {
return [
key,
(
document.querySelector(
`input[name="${key}"]`
) as HTMLInputElement
)?.value || value,
];
} else if (key == "routing_strategy") {
return [key, selectedStrategy];
} else if (
key == "routing_strategy_args" &&
selectedStrategy == "latency-based-routing"
) {
let setRoutingStrategyArgs: routingStrategyArgs = {};
const lowestLatencyBufferElement = document.querySelector(`input[name="lowest_latency_buffer"]`) as HTMLInputElement;
const ttlElement = document.querySelector(`input[name="ttl"]`) as HTMLInputElement;
const lowestLatencyBufferElement = document.querySelector(
`input[name="lowest_latency_buffer"]`
) as HTMLInputElement;
const ttlElement = document.querySelector(
`input[name="ttl"]`
) as HTMLInputElement;
if (lowestLatencyBufferElement?.value) {
setRoutingStrategyArgs["lowest_latency_buffer"] = Number(lowestLatencyBufferElement.value)
setRoutingStrategyArgs["lowest_latency_buffer"] = Number(
lowestLatencyBufferElement.value
);
}
if (ttlElement?.value) {
setRoutingStrategyArgs["ttl"] = Number(ttlElement.value)
setRoutingStrategyArgs["ttl"] = Number(ttlElement.value);
}
console.log(`setRoutingStrategyArgs: ${setRoutingStrategyArgs}`)
return [
"routing_strategy_args", setRoutingStrategyArgs
]
console.log(`setRoutingStrategyArgs: ${setRoutingStrategyArgs}`);
return ["routing_strategy_args", setRoutingStrategyArgs];
}
return null;
}).filter(entry => entry !== null && entry !== undefined) as Iterable<[string, unknown]>
})
.filter((entry) => entry !== null && entry !== undefined) as Iterable<
[string, unknown]
>
);
console.log("updatedVariables", updatedVariables);
const payload = {
router_settings: updatedVariables
router_settings: updatedVariables,
};
try {
@ -267,19 +426,17 @@ const GeneralSettings: React.FC<GeneralSettingsPageProps> = ({
message.success("router settings updated successfully");
};
if (!accessToken) {
return null;
}
return (
<div className="w-full mx-4">
<TabGroup className="gap-2 p-8 h-[75vh] w-full mt-2">
<TabList variant="line" defaultValue="1">
<Tab value="1">General Settings</Tab>
<Tab value="1">Loadbalancing</Tab>
<Tab value="2">Fallbacks</Tab>
<Tab value="3">General</Tab>
</TabList>
<TabPanels>
<TabPanel>
@ -294,27 +451,55 @@ const GeneralSettings: React.FC<GeneralSettingsPageProps> = ({
</TableRow>
</TableHead>
<TableBody>
{Object.entries(routerSettings).filter(([param, value]) => param != "fallbacks" && param != "context_window_fallbacks" && param != "routing_strategy_args").map(([param, value]) => (
{Object.entries(routerSettings)
.filter(
([param, value]) =>
param != "fallbacks" &&
param != "context_window_fallbacks" &&
param != "routing_strategy_args"
)
.map(([param, value]) => (
<TableRow key={param}>
<TableCell>
<Text>{param}</Text>
<p style={{fontSize: '0.65rem', color: '#808080', fontStyle: 'italic'}} className="mt-1">{paramExplanation[param]}</p>
<p
style={{
fontSize: "0.65rem",
color: "#808080",
fontStyle: "italic",
}}
className="mt-1"
>
{paramExplanation[param]}
</p>
</TableCell>
<TableCell>
{
param == "routing_strategy" ?
<Select2 defaultValue={value} className="w-full max-w-md" onValueChange={setSelectedStrategy}>
<SelectItem value="usage-based-routing">usage-based-routing</SelectItem>
<SelectItem value="latency-based-routing">latency-based-routing</SelectItem>
<SelectItem value="simple-shuffle">simple-shuffle</SelectItem>
</Select2> :
{param == "routing_strategy" ? (
<Select2
defaultValue={value}
className="w-full max-w-md"
onValueChange={setSelectedStrategy}
>
<SelectItem value="usage-based-routing">
usage-based-routing
</SelectItem>
<SelectItem value="latency-based-routing">
latency-based-routing
</SelectItem>
<SelectItem value="simple-shuffle">
simple-shuffle
</SelectItem>
</Select2>
) : (
<TextInput
name={param}
defaultValue={
typeof value === 'object' ? JSON.stringify(value, null, 2) : value.toString()
typeof value === "object"
? JSON.stringify(value, null, 2)
: value.toString()
}
/>
}
)}
</TableCell>
</TableRow>
))}
@ -323,15 +508,21 @@ const GeneralSettings: React.FC<GeneralSettingsPageProps> = ({
<AccordionHero
selectedStrategy={selectedStrategy}
strategyArgs={
routerSettings && routerSettings['routing_strategy_args'] && Object.keys(routerSettings['routing_strategy_args']).length > 0
? routerSettings['routing_strategy_args']
routerSettings &&
routerSettings["routing_strategy_args"] &&
Object.keys(routerSettings["routing_strategy_args"])
.length > 0
? routerSettings["routing_strategy_args"]
: defaultLowestLatencyArgs // default value when keys length is 0
}
paramExplanation={paramExplanation}
/>
</Card>
<Col>
<Button className="mt-2" onClick={() => handleSaveChanges(routerSettings)}>
<Button
className="mt-2"
onClick={() => handleSaveChanges(routerSettings)}
>
Save Changes
</Button>
</Col>
@ -347,15 +538,21 @@ const GeneralSettings: React.FC<GeneralSettingsPageProps> = ({
</TableHead>
<TableBody>
{
routerSettings["fallbacks"] &&
routerSettings["fallbacks"].map((item: Object, index: number) =>
{routerSettings["fallbacks"] &&
routerSettings["fallbacks"].map(
(item: Object, index: number) =>
Object.entries(item).map(([key, value]) => (
<TableRow key={index.toString() + key}>
<TableCell>{key}</TableCell>
<TableCell>{Array.isArray(value) ? value.join(', ') : value}</TableCell>
<TableCell>
<Button onClick={() => testFallbackModelResponse(key, accessToken)}>
{Array.isArray(value) ? value.join(", ") : value}
</TableCell>
<TableCell>
<Button
onClick={() =>
testFallbackModelResponse(key, accessToken)
}
>
Test Fallback
</Button>
</TableCell>
@ -368,11 +565,96 @@ const GeneralSettings: React.FC<GeneralSettingsPageProps> = ({
</TableCell>
</TableRow>
))
)
}
)}
</TableBody>
</Table>
<AddFallbacks models={modelData?.data ? modelData.data.map((data: any) => data.model_name) : []} accessToken={accessToken} routerSettings={routerSettings} setRouterSettings={setRouterSettings}/>
<AddFallbacks
models={
modelData?.data
? modelData.data.map((data: any) => data.model_name)
: []
}
accessToken={accessToken}
routerSettings={routerSettings}
setRouterSettings={setRouterSettings}
/>
</TabPanel>
<TabPanel>
<Card>
<Table>
<TableHead>
<TableRow>
<TableHeaderCell>Setting</TableHeaderCell>
<TableHeaderCell>Value</TableHeaderCell>
<TableHeaderCell>Status</TableHeaderCell>
<TableHeaderCell>Action</TableHeaderCell>
</TableRow>
</TableHead>
<TableBody>
{generalSettings.map((value, index) => (
<TableRow key={index}>
<TableCell>
<Text>{value.field_name}</Text>
<p
style={{
fontSize: "0.65rem",
color: "#808080",
fontStyle: "italic",
}}
className="mt-1"
>
{value.field_description}
</p>
</TableCell>
<TableCell>
{value.field_type == "Integer" ? (
<InputNumber
step={1}
value={value.field_value}
onChange={(newValue) =>
handleInputChange(value.field_name, newValue)
} // Handle value change
/>
) : null}
</TableCell>
<TableCell>
{value.stored_in_db == true ? (
<Badge icon={CheckCircleIcon} className="text-white">
In DB
</Badge>
) : value.stored_in_db == false ? (
<Badge className="text-gray bg-white outline">
In Config
</Badge>
) : (
<Badge className="text-gray bg-white outline">
Not Set
</Badge>
)}
</TableCell>
<TableCell>
<Button
onClick={() =>
handleUpdateField(value.field_name, index)
}
>
Update
</Button>
<Icon
icon={TrashIcon}
color="red"
onClick={() =>
handleResetField(value.field_name, index)
}
>
Reset
</Icon>
</TableCell>
</TableRow>
))}
</TableBody>
</Table>
</Card>
</TabPanel>
</TabPanels>
</TabGroup>

View file

@ -121,6 +121,7 @@ const handleSubmit = async (formValues: Record<string, any>, accessToken: string
// Iterate through the key-value pairs in formValues
litellmParamsObj["model"] = litellm_model
let modelName: string = "";
console.log("formValues add deployment:", formValues);
for (const [key, value] of Object.entries(formValues)) {
if (value === '') {
continue;
@ -628,6 +629,7 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
let input_cost = "Undefined";
let output_cost = "Undefined";
let max_tokens = "Undefined";
let max_input_tokens = "Undefined";
let cleanedLitellmParams = {};
const getProviderFromModel = (model: string) => {
@ -664,6 +666,7 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
input_cost = model_info?.input_cost_per_token;
output_cost = model_info?.output_cost_per_token;
max_tokens = model_info?.max_tokens;
max_input_tokens = model_info?.max_input_tokens;
}
if (curr_model?.litellm_params) {
@ -677,7 +680,19 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
modelData.data[i].provider = provider;
modelData.data[i].input_cost = input_cost;
modelData.data[i].output_cost = output_cost;
// Convert Cost in terms of Cost per 1M tokens
if (modelData.data[i].input_cost) {
modelData.data[i].input_cost = (Number(modelData.data[i].input_cost) * 1000000).toFixed(2);
}
if (modelData.data[i].output_cost) {
modelData.data[i].output_cost = (Number(modelData.data[i].output_cost) * 1000000).toFixed(2);
}
modelData.data[i].max_tokens = max_tokens;
modelData.data[i].max_input_tokens = max_input_tokens;
modelData.data[i].api_base = curr_model?.litellm_params?.api_base;
modelData.data[i].cleanedLitellmParams = cleanedLitellmParams;
@ -893,8 +908,9 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
<Text>Filter by Public Model Name</Text>
<Select
className="mb-4 mt-2 ml-2 w-50"
defaultValue="all"
defaultValue={selectedModelGroup? selectedModelGroup : availableModelGroups[0]}
onValueChange={(value) => setSelectedModelGroup(value === "all" ? "all" : value)}
value={selectedModelGroup ? selectedModelGroup : availableModelGroups[0]}
>
<SelectItem
value={"all"}
@ -913,85 +929,76 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
</Select>
</div>
<Card>
<Table className="mt-5">
<Table className="mt-5" style={{ maxWidth: '1500px', width: '100%' }}>
<TableHead>
<TableRow>
<TableHeaderCell>Public Model Name </TableHeaderCell>
<TableHeaderCell>
Provider
</TableHeaderCell>
{
userRole === "Admin" && (
<TableHeaderCell>
API Base
</TableHeaderCell>
)
}
<TableHeaderCell>
Extra litellm Params
</TableHeaderCell>
<TableHeaderCell>Input Price per token ($)</TableHeaderCell>
<TableHeaderCell>Output Price per token ($)</TableHeaderCell>
<TableHeaderCell>Max Tokens</TableHeaderCell>
<TableHeaderCell>Status</TableHeaderCell>
<TableHeaderCell style={{ maxWidth: '150px', whiteSpace: 'normal', wordBreak: 'break-word' }}>Public Model Name</TableHeaderCell>
<TableHeaderCell style={{ maxWidth: '100px', whiteSpace: 'normal', wordBreak: 'break-word' }}>Provider</TableHeaderCell>
{userRole === "Admin" && (
<TableHeaderCell style={{ maxWidth: '150px', whiteSpace: 'normal', wordBreak: 'break-word' }}>API Base</TableHeaderCell>
)}
<TableHeaderCell style={{ maxWidth: '200px', whiteSpace: 'normal', wordBreak: 'break-word' }}>Extra litellm Params</TableHeaderCell>
<TableHeaderCell style={{ maxWidth: '85px', whiteSpace: 'normal', wordBreak: 'break-word' }}>Input Price <p style={{ fontSize: '10px', color: 'gray' }}>/1M Tokens ($)</p></TableHeaderCell>
<TableHeaderCell style={{ maxWidth: '85px', whiteSpace: 'normal', wordBreak: 'break-word' }}>Output Price <p style={{ fontSize: '10px', color: 'gray' }}>/1M Tokens ($)</p></TableHeaderCell>
<TableHeaderCell style={{ maxWidth: '120px', whiteSpace: 'normal', wordBreak: 'break-word' }}>Max Tokens</TableHeaderCell>
<TableHeaderCell style={{ maxWidth: '50px', whiteSpace: 'normal', wordBreak: 'break-word' }}>Status</TableHeaderCell>
</TableRow>
</TableHead>
<TableBody>
{modelData.data
.filter((model: any) =>
selectedModelGroup === "all" || model.model_name === selectedModelGroup || selectedModelGroup === null || selectedModelGroup === undefined || selectedModelGroup === ""
selectedModelGroup === "all" ||
model.model_name === selectedModelGroup ||
selectedModelGroup === null ||
selectedModelGroup === undefined ||
selectedModelGroup === ""
)
.map((model: any, index: number) => (
<TableRow key={index}>
<TableCell>
<TableCell style={{ maxWidth: '150px', whiteSpace: 'normal', wordBreak: 'break-word' }}>
<Text>{model.model_name}</Text>
</TableCell>
<TableCell>{model.provider}</TableCell>
{
userRole === "Admin" && (
<TableCell>{model.api_base}</TableCell>
)
}
<TableCell>
<TableCell style={{ maxWidth: '100px', whiteSpace: 'normal', wordBreak: 'break-word' }}>{model.provider}</TableCell>
{userRole === "Admin" && (
<TableCell style={{ maxWidth: '150px', whiteSpace: 'normal', wordBreak: 'break-word' }}>{model.api_base}</TableCell>
)}
<TableCell style={{ maxWidth: '200px', whiteSpace: 'normal', wordBreak: 'break-word' }}>
<Accordion>
<AccordionHeader>
<Text>Litellm params</Text>
</AccordionHeader>
<AccordionBody>
<pre>
{JSON.stringify(model.cleanedLitellmParams, null, 2)}
</pre>
<pre>{JSON.stringify(model.cleanedLitellmParams, null, 2)}</pre>
</AccordionBody>
</Accordion>
</TableCell>
<TableCell>{model.input_cost || model.litellm_params.input_cost_per_token || null}</TableCell>
<TableCell>{model.output_cost || model.litellm_params.output_cost_per_token || null}</TableCell>
<TableCell>{model.max_tokens}</TableCell>
<TableCell>
{
model.model_info.db_model ? <Badge icon={CheckCircleIcon} className="text-white">DB Model</Badge> : <Badge icon={XCircleIcon} className="text-black">Config Model</Badge>
}
<TableCell style={{ maxWidth: '80px', whiteSpace: 'normal', wordBreak: 'break-word' }}>{model.input_cost || model.litellm_params.input_cost_per_token || null}</TableCell>
<TableCell style={{ maxWidth: '80px', whiteSpace: 'normal', wordBreak: 'break-word' }}>{model.output_cost || model.litellm_params.output_cost_per_token || null}</TableCell>
<TableCell style={{ maxWidth: '120px', whiteSpace: 'normal', wordBreak: 'break-word' }}>
<p style={{ fontSize: '10px' }}>
Max Tokens: {model.max_tokens} <br></br>
Max Input Tokens: {model.max_input_tokens}
</p>
</TableCell>
<TableCell>
<Icon
icon={PencilAltIcon}
size="sm"
onClick={() => handleEditClick(model)}
/>
<TableCell style={{ maxWidth: '100px', whiteSpace: 'normal', wordBreak: 'break-word' }}>
{model.model_info.db_model ? (
<Badge icon={CheckCircleIcon} size="xs" className="text-white">
<p style={{ fontSize: '10px' }}>DB Model</p>
</Badge>
) : (
<Badge icon={XCircleIcon} size="xs" className="text-black">
<p style={{ fontSize: '10px' }}>Config Model</p>
</Badge>
)}
</TableCell>
<TableCell style={{ maxWidth: '100px', whiteSpace: 'normal', wordBreak: 'break-word' }}>
<Icon icon={PencilAltIcon} size="sm" onClick={() => handleEditClick(model)} />
<DeleteModelButton modelID={model.model_info.id} accessToken={accessToken} />
</TableCell>
</TableRow>
))}
</TableBody>
</Table>
</Card>
</Grid>
@ -1116,13 +1123,22 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
</Form.Item>
}
{
selectedProvider == Providers.Azure && <Form.Item
selectedProvider == Providers.Azure &&
<div>
<Form.Item
label="Base Model"
name="base_model"
className="mb-0"
>
<TextInput placeholder="azure/gpt-3.5-turbo"/>
<Text>The actual model your azure deployment uses. Used for accurate cost tracking. Select name from <Link href="https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json" target="_blank">here</Link></Text>
</Form.Item>
<Row>
<Col span={10}></Col>
<Col span={10}><Text className="mb-2">The actual model your azure deployment uses. Used for accurate cost tracking. Select name from <Link href="https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json" target="_blank">here</Link></Text></Col>
</Row>
</div>
}
{
selectedProvider == Providers.Bedrock && <Form.Item

Some files were not shown because too many files have changed in this diff Show more