mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
Merge branch 'BerriAI:main' into main
This commit is contained in:
commit
b89b3d8c44
102 changed files with 8852 additions and 6557 deletions
|
@ -42,7 +42,7 @@ jobs:
|
|||
pip install lunary==0.2.5
|
||||
pip install "langfuse==2.27.1"
|
||||
pip install numpydoc
|
||||
pip install traceloop-sdk==0.0.69
|
||||
pip install traceloop-sdk==0.18.2
|
||||
pip install openai
|
||||
pip install prisma
|
||||
pip install "httpx==0.24.1"
|
||||
|
|
|
@ -34,7 +34,7 @@ LiteLLM manages:
|
|||
[**Jump to OpenAI Proxy Docs**](https://github.com/BerriAI/litellm?tab=readme-ov-file#openai-proxy---docs) <br>
|
||||
[**Jump to Supported LLM Providers**](https://github.com/BerriAI/litellm?tab=readme-ov-file#supported-providers-docs)
|
||||
|
||||
🚨 **Stable Release:** Use docker images with: `main-stable` tag. These run through 12 hr load tests (1k req./min).
|
||||
🚨 **Stable Release:** Use docker images with the `-stable` tag. These have undergone 12 hour load tests, before being published.
|
||||
|
||||
Support for more providers. Missing a provider or LLM Platform, raise a [feature request](https://github.com/BerriAI/litellm/issues/new?assignees=&labels=enhancement&projects=&template=feature_request.yml&title=%5BFeature%5D%3A+).
|
||||
|
||||
|
|
|
@ -151,3 +151,19 @@ response = image_generation(
|
|||
)
|
||||
print(f"response: {response}")
|
||||
```
|
||||
|
||||
## VertexAI - Image Generation Models
|
||||
|
||||
### Usage
|
||||
|
||||
Use this for image generation models on VertexAI
|
||||
|
||||
```python
|
||||
response = litellm.image_generation(
|
||||
prompt="An olympic size swimming pool",
|
||||
model="vertex_ai/imagegeneration@006",
|
||||
vertex_ai_project="adroit-crow-413218",
|
||||
vertex_ai_location="us-central1",
|
||||
)
|
||||
print(f"response: {response}")
|
||||
```
|
173
docs/my-website/docs/observability/lago.md
Normal file
173
docs/my-website/docs/observability/lago.md
Normal file
|
@ -0,0 +1,173 @@
|
|||
import Image from '@theme/IdealImage';
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Lago - Usage Based Billing
|
||||
|
||||
[Lago](https://www.getlago.com/) offers a self-hosted and cloud, metering and usage-based billing solution.
|
||||
|
||||
<Image img={require('../../img/lago.jpeg')} />
|
||||
|
||||
## Quick Start
|
||||
Use just 1 lines of code, to instantly log your responses **across all providers** with Lago
|
||||
|
||||
Get your Lago [API Key](https://docs.getlago.com/guide/self-hosted/docker#find-your-api-key)
|
||||
|
||||
```python
|
||||
litellm.callbacks = ["lago"] # logs cost + usage of successful calls to lago
|
||||
```
|
||||
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
# pip install lago
|
||||
import litellm
|
||||
import os
|
||||
|
||||
os.environ["LAGO_API_BASE"] = "" # http://0.0.0.0:3000
|
||||
os.environ["LAGO_API_KEY"] = ""
|
||||
os.environ["LAGO_API_EVENT_CODE"] = "" # The billable metric's code - https://docs.getlago.com/guide/events/ingesting-usage#define-a-billable-metric
|
||||
|
||||
# LLM API Keys
|
||||
os.environ['OPENAI_API_KEY']=""
|
||||
|
||||
# set lago as a callback, litellm will send the data to lago
|
||||
litellm.success_callback = ["lago"]
|
||||
|
||||
# openai call
|
||||
response = litellm.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[
|
||||
{"role": "user", "content": "Hi 👋 - i'm openai"}
|
||||
],
|
||||
user="your_customer_id" # 👈 SET YOUR CUSTOMER ID HERE
|
||||
)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
1. Add to Config.yaml
|
||||
```yaml
|
||||
model_list:
|
||||
- litellm_params:
|
||||
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
|
||||
api_key: my-fake-key
|
||||
model: openai/my-fake-model
|
||||
model_name: fake-openai-endpoint
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["lago"] # 👈 KEY CHANGE
|
||||
```
|
||||
|
||||
2. Start Proxy
|
||||
|
||||
```
|
||||
litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="curl" label="Curl">
|
||||
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data ' {
|
||||
"model": "fake-openai-endpoint",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what llm are you"
|
||||
}
|
||||
],
|
||||
"user": "your-customer-id" # 👈 SET YOUR CUSTOMER ID
|
||||
}
|
||||
'
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="openai_python" label="OpenAI Python SDK">
|
||||
|
||||
```python
|
||||
import openai
|
||||
client = openai.OpenAI(
|
||||
api_key="anything",
|
||||
base_url="http://0.0.0.0:4000"
|
||||
)
|
||||
|
||||
# request sent to model set on litellm proxy, `litellm --model`
|
||||
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "this is a test request, write a short poem"
|
||||
}
|
||||
], user="my_customer_id") # 👈 whatever your customer id is
|
||||
|
||||
print(response)
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="langchain" label="Langchain">
|
||||
|
||||
```python
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
import os
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "anything"
|
||||
|
||||
chat = ChatOpenAI(
|
||||
openai_api_base="http://0.0.0.0:4000",
|
||||
model = "gpt-3.5-turbo",
|
||||
temperature=0.1,
|
||||
extra_body={
|
||||
"user": "my_customer_id" # 👈 whatever your customer id is
|
||||
}
|
||||
)
|
||||
|
||||
messages = [
|
||||
SystemMessage(
|
||||
content="You are a helpful assistant that im using to make a test request to."
|
||||
),
|
||||
HumanMessage(
|
||||
content="test from litellm. tell me why it's amazing in 1 sentence"
|
||||
),
|
||||
]
|
||||
response = chat(messages)
|
||||
|
||||
print(response)
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
<Image img={require('../../img/lago_2.png')} />
|
||||
|
||||
## Advanced - Lagos Logging object
|
||||
|
||||
This is what LiteLLM will log to Lagos
|
||||
|
||||
```
|
||||
{
|
||||
"event": {
|
||||
"transaction_id": "<generated_unique_id>",
|
||||
"external_customer_id": <litellm_end_user_id>, # passed via `user` param in /chat/completion call - https://platform.openai.com/docs/api-reference/chat/create
|
||||
"code": os.getenv("LAGO_API_EVENT_CODE"),
|
||||
"properties": {
|
||||
"input_tokens": <number>,
|
||||
"output_tokens": <number>,
|
||||
"model": <string>,
|
||||
"response_cost": <number>, # 👈 LITELLM CALCULATED RESPONSE COST - https://github.com/BerriAI/litellm/blob/d43f75150a65f91f60dc2c0c9462ce3ffc713c1f/litellm/utils.py#L1473
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
|
@ -71,6 +71,23 @@ response = litellm.completion(
|
|||
)
|
||||
print(response)
|
||||
```
|
||||
|
||||
### Make LiteLLM Proxy use Custom `LANGSMITH_BASE_URL`
|
||||
|
||||
If you're using a custom LangSmith instance, you can set the
|
||||
`LANGSMITH_BASE_URL` environment variable to point to your instance.
|
||||
For example, you can make LiteLLM Proxy log to a local LangSmith instance with
|
||||
this config:
|
||||
|
||||
```yaml
|
||||
litellm_settings:
|
||||
success_callback: ["langsmith"]
|
||||
|
||||
environment_variables:
|
||||
LANGSMITH_BASE_URL: "http://localhost:1984"
|
||||
LANGSMITH_PROJECT: "litellm-proxy"
|
||||
```
|
||||
|
||||
## Support & Talk to Founders
|
||||
|
||||
- [Schedule Demo 👋](https://calendly.com/d/4mp-gd3-k5k/berriai-1-1-onboarding-litellm-hosted-version)
|
||||
|
|
|
@ -20,7 +20,7 @@ Use just 2 lines of code, to instantly log your responses **across all providers
|
|||
Get your OpenMeter API Key from https://openmeter.cloud/meters
|
||||
|
||||
```python
|
||||
litellm.success_callback = ["openmeter"] # logs cost + usage of successful calls to openmeter
|
||||
litellm.callbacks = ["openmeter"] # logs cost + usage of successful calls to openmeter
|
||||
```
|
||||
|
||||
|
||||
|
@ -28,7 +28,7 @@ litellm.success_callback = ["openmeter"] # logs cost + usage of successful calls
|
|||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
# pip install langfuse
|
||||
# pip install openmeter
|
||||
import litellm
|
||||
import os
|
||||
|
||||
|
@ -39,8 +39,8 @@ os.environ["OPENMETER_API_KEY"] = ""
|
|||
# LLM API Keys
|
||||
os.environ['OPENAI_API_KEY']=""
|
||||
|
||||
# set langfuse as a callback, litellm will send the data to langfuse
|
||||
litellm.success_callback = ["openmeter"]
|
||||
# set openmeter as a callback, litellm will send the data to openmeter
|
||||
litellm.callbacks = ["openmeter"]
|
||||
|
||||
# openai call
|
||||
response = litellm.completion(
|
||||
|
@ -64,7 +64,7 @@ model_list:
|
|||
model_name: fake-openai-endpoint
|
||||
|
||||
litellm_settings:
|
||||
success_callback: ["openmeter"] # 👈 KEY CHANGE
|
||||
callbacks: ["openmeter"] # 👈 KEY CHANGE
|
||||
```
|
||||
|
||||
2. Start Proxy
|
||||
|
|
|
@ -223,6 +223,32 @@ assert isinstance(
|
|||
|
||||
```
|
||||
|
||||
### Setting `anthropic-beta` Header in Requests
|
||||
|
||||
Pass the the `extra_headers` param to litellm, All headers will be forwarded to Anthropic API
|
||||
|
||||
```python
|
||||
response = completion(
|
||||
model="anthropic/claude-3-opus-20240229",
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
)
|
||||
```
|
||||
|
||||
### Forcing Anthropic Tool Use
|
||||
|
||||
If you want Claude to use a specific tool to answer the user’s question
|
||||
|
||||
You can do this by specifying the tool in the `tool_choice` field like so:
|
||||
```python
|
||||
response = completion(
|
||||
model="anthropic/claude-3-opus-20240229",
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice={"type": "tool", "name": "get_weather"},
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
### Parallel Function Calling
|
||||
|
||||
|
|
|
@ -102,12 +102,18 @@ Ollama supported models: https://github.com/ollama/ollama
|
|||
| Model Name | Function Call |
|
||||
|----------------------|-----------------------------------------------------------------------------------
|
||||
| Mistral | `completion(model='ollama/mistral', messages, api_base="http://localhost:11434", stream=True)` |
|
||||
| Mistral-7B-Instruct-v0.1 | `completion(model='ollama/mistral-7B-Instruct-v0.1', messages, api_base="http://localhost:11434", stream=False)` |
|
||||
| Mistral-7B-Instruct-v0.2 | `completion(model='ollama/mistral-7B-Instruct-v0.2', messages, api_base="http://localhost:11434", stream=False)` |
|
||||
| Mixtral-8x7B-Instruct-v0.1 | `completion(model='ollama/mistral-8x7B-Instruct-v0.1', messages, api_base="http://localhost:11434", stream=False)` |
|
||||
| Mixtral-8x22B-Instruct-v0.1 | `completion(model='ollama/mixtral-8x22B-Instruct-v0.1', messages, api_base="http://localhost:11434", stream=False)` |
|
||||
| Llama2 7B | `completion(model='ollama/llama2', messages, api_base="http://localhost:11434", stream=True)` |
|
||||
| Llama2 13B | `completion(model='ollama/llama2:13b', messages, api_base="http://localhost:11434", stream=True)` |
|
||||
| Llama2 70B | `completion(model='ollama/llama2:70b', messages, api_base="http://localhost:11434", stream=True)` |
|
||||
| Llama2 Uncensored | `completion(model='ollama/llama2-uncensored', messages, api_base="http://localhost:11434", stream=True)` |
|
||||
| Code Llama | `completion(model='ollama/codellama', messages, api_base="http://localhost:11434", stream=True)` |
|
||||
| Llama2 Uncensored | `completion(model='ollama/llama2-uncensored', messages, api_base="http://localhost:11434", stream=True)` |
|
||||
|Meta LLaMa3 8B | `completion(model='ollama/llama3', messages, api_base="http://localhost:11434", stream=False)` |
|
||||
| Meta LLaMa3 70B | `completion(model='ollama/llama3:70b', messages, api_base="http://localhost:11434", stream=False)` |
|
||||
| Orca Mini | `completion(model='ollama/orca-mini', messages, api_base="http://localhost:11434", stream=True)` |
|
||||
| Vicuna | `completion(model='ollama/vicuna', messages, api_base="http://localhost:11434", stream=True)` |
|
||||
| Nous-Hermes | `completion(model='ollama/nous-hermes', messages, api_base="http://localhost:11434", stream=True)` |
|
||||
|
|
|
@ -188,6 +188,7 @@ These also support the `OPENAI_API_BASE` environment variable, which can be used
|
|||
## OpenAI Vision Models
|
||||
| Model Name | Function Call |
|
||||
|-----------------------|-----------------------------------------------------------------|
|
||||
| gpt-4o | `response = completion(model="gpt-4o", messages=messages)` |
|
||||
| gpt-4-turbo | `response = completion(model="gpt-4-turbo", messages=messages)` |
|
||||
| gpt-4-vision-preview | `response = completion(model="gpt-4-vision-preview", messages=messages)` |
|
||||
|
||||
|
|
|
@ -508,6 +508,31 @@ All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a02
|
|||
| text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` |
|
||||
| text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` |
|
||||
|
||||
## Image Generation Models
|
||||
|
||||
Usage
|
||||
|
||||
```python
|
||||
response = await litellm.aimage_generation(
|
||||
prompt="An olympic size swimming pool",
|
||||
model="vertex_ai/imagegeneration@006",
|
||||
vertex_ai_project="adroit-crow-413218",
|
||||
vertex_ai_location="us-central1",
|
||||
)
|
||||
```
|
||||
|
||||
**Generating multiple images**
|
||||
|
||||
Use the `n` parameter to pass how many images you want generated
|
||||
```python
|
||||
response = await litellm.aimage_generation(
|
||||
prompt="An olympic size swimming pool",
|
||||
model="vertex_ai/imagegeneration@006",
|
||||
vertex_ai_project="adroit-crow-413218",
|
||||
vertex_ai_location="us-central1",
|
||||
n=1,
|
||||
)
|
||||
```
|
||||
|
||||
## Extra
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# 🚨 Alerting
|
||||
# 🚨 Alerting / Webhooks
|
||||
|
||||
Get alerts for:
|
||||
|
||||
|
@ -11,7 +11,7 @@ Get alerts for:
|
|||
- Daily Reports:
|
||||
- **LLM** Top 5 slowest deployments
|
||||
- **LLM** Top 5 deployments with most failed requests
|
||||
- **Spend** Weekly & Monthly spend per Team, Tag
|
||||
- **Spend** Weekly & Monthly spend per Team, Tag
|
||||
|
||||
|
||||
## Quick Start
|
||||
|
@ -61,10 +61,38 @@ curl -X GET 'http://localhost:4000/health/services?service=slack' \
|
|||
-H 'Authorization: Bearer sk-1234'
|
||||
```
|
||||
|
||||
## Advanced - Opting into specific alert types
|
||||
|
||||
## Extras
|
||||
Set `alert_types` if you want to Opt into only specific alert types
|
||||
|
||||
### Using Discord Webhooks
|
||||
```shell
|
||||
general_settings:
|
||||
alerting: ["slack"]
|
||||
alert_types: ["spend_reports"]
|
||||
```
|
||||
|
||||
All Possible Alert Types
|
||||
|
||||
```python
|
||||
alert_types:
|
||||
Optional[
|
||||
List[
|
||||
Literal[
|
||||
"llm_exceptions",
|
||||
"llm_too_slow",
|
||||
"llm_requests_hanging",
|
||||
"budget_alerts",
|
||||
"db_exceptions",
|
||||
"daily_reports",
|
||||
"spend_reports",
|
||||
"cooldown_deployment",
|
||||
"new_model_added",
|
||||
]
|
||||
]
|
||||
```
|
||||
|
||||
|
||||
## Advanced - Using Discord Webhooks
|
||||
|
||||
Discord provides a slack compatible webhook url that you can use for alerting
|
||||
|
||||
|
@ -96,3 +124,80 @@ environment_variables:
|
|||
```
|
||||
|
||||
That's it ! You're ready to go !
|
||||
|
||||
## Advanced - [BETA] Webhooks for Budget Alerts
|
||||
|
||||
**Note**: This is a beta feature, so the spec might change.
|
||||
|
||||
Set a webhook to get notified for budget alerts.
|
||||
|
||||
1. Setup config.yaml
|
||||
|
||||
Add url to your environment, for testing you can use a link from [here](https://webhook.site/)
|
||||
|
||||
```bash
|
||||
export WEBHOOK_URL="https://webhook.site/6ab090e8-c55f-4a23-b075-3209f5c57906"
|
||||
```
|
||||
|
||||
Add 'webhook' to config.yaml
|
||||
```yaml
|
||||
general_settings:
|
||||
alerting: ["webhook"] # 👈 KEY CHANGE
|
||||
```
|
||||
|
||||
2. Start proxy
|
||||
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
|
||||
# RUNNING on http://0.0.0.0:4000
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
```bash
|
||||
curl -X GET --location 'http://0.0.0.0:4000/health/services?service=webhook' \
|
||||
--header 'Authorization: Bearer sk-1234'
|
||||
```
|
||||
|
||||
**Expected Response**
|
||||
|
||||
```bash
|
||||
{
|
||||
"spend": 1, # the spend for the 'event_group'
|
||||
"max_budget": 0, # the 'max_budget' set for the 'event_group'
|
||||
"token": "88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b",
|
||||
"user_id": "default_user_id",
|
||||
"team_id": null,
|
||||
"user_email": null,
|
||||
"key_alias": null,
|
||||
"projected_exceeded_data": null,
|
||||
"projected_spend": null,
|
||||
"event": "budget_crossed", # Literal["budget_crossed", "threshold_crossed", "projected_limit_exceeded"]
|
||||
"event_group": "user",
|
||||
"event_message": "User Budget: Budget Crossed"
|
||||
}
|
||||
```
|
||||
|
||||
**API Spec for Webhook Event**
|
||||
|
||||
- `spend` *float*: The current spend amount for the 'event_group'.
|
||||
- `max_budget` *float*: The maximum allowed budget for the 'event_group'.
|
||||
- `token` *str*: A hashed value of the key, used for authentication or identification purposes.
|
||||
- `user_id` *str or null*: The ID of the user associated with the event (optional).
|
||||
- `team_id` *str or null*: The ID of the team associated with the event (optional).
|
||||
- `user_email` *str or null*: The email of the user associated with the event (optional).
|
||||
- `key_alias` *str or null*: An alias for the key associated with the event (optional).
|
||||
- `projected_exceeded_date` *str or null*: The date when the budget is projected to be exceeded, returned when 'soft_budget' is set for key (optional).
|
||||
- `projected_spend` *float or null*: The projected spend amount, returned when 'soft_budget' is set for key (optional).
|
||||
- `event` *Literal["budget_crossed", "threshold_crossed", "projected_limit_exceeded"]*: The type of event that triggered the webhook. Possible values are:
|
||||
* "budget_crossed": Indicates that the spend has exceeded the max budget.
|
||||
* "threshold_crossed": Indicates that spend has crossed a threshold (currently sent when 85% and 95% of budget is reached).
|
||||
* "projected_limit_exceeded": For "key" only - Indicates that the projected spend is expected to exceed the soft budget threshold.
|
||||
- `event_group` *Literal["user", "key", "team", "proxy"]*: The group associated with the event. Possible values are:
|
||||
* "user": The event is related to a specific user.
|
||||
* "key": The event is related to a specific key.
|
||||
* "team": The event is related to a team.
|
||||
* "proxy": The event is related to a proxy.
|
||||
|
||||
- `event_message` *str*: A human-readable description of the event.
|
319
docs/my-website/docs/proxy/billing.md
Normal file
319
docs/my-website/docs/proxy/billing.md
Normal file
|
@ -0,0 +1,319 @@
|
|||
import Image from '@theme/IdealImage';
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# 💵 Billing
|
||||
|
||||
Bill internal teams, external customers for their usage
|
||||
|
||||
**🚨 Requirements**
|
||||
- [Setup Lago](https://docs.getlago.com/guide/self-hosted/docker#run-the-app), for usage-based billing. We recommend following [their Stripe tutorial](https://docs.getlago.com/templates/per-transaction/stripe#step-1-create-billable-metrics-for-transaction)
|
||||
|
||||
Steps:
|
||||
- Connect the proxy to Lago
|
||||
- Set the id you want to bill for (customers, internal users, teams)
|
||||
- Start!
|
||||
|
||||
## Quick Start
|
||||
|
||||
Bill internal teams for their usage
|
||||
|
||||
### 1. Connect proxy to Lago
|
||||
|
||||
Set 'lago' as a callback on your proxy config.yaml
|
||||
|
||||
```yaml
|
||||
model_name:
|
||||
- model_name: fake-openai-endpoint
|
||||
litellm_params:
|
||||
model: openai/fake
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["lago"] # 👈 KEY CHANGE
|
||||
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
```
|
||||
|
||||
Add your Lago keys to the environment
|
||||
|
||||
```bash
|
||||
export LAGO_API_BASE="http://localhost:3000" # self-host - https://docs.getlago.com/guide/self-hosted/docker#run-the-app
|
||||
export LAGO_API_KEY="3e29d607-de54-49aa-a019-ecf585729070" # Get key - https://docs.getlago.com/guide/self-hosted/docker#find-your-api-key
|
||||
export LAGO_API_EVENT_CODE="openai_tokens" # name of lago billing code
|
||||
export LAGO_API_CHARGE_BY="team_id" # 👈 Charges 'team_id' attached to proxy key
|
||||
```
|
||||
|
||||
Start proxy
|
||||
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
### 2. Create Key for Internal Team
|
||||
|
||||
```bash
|
||||
curl 'http://0.0.0.0:4000/key/generate' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data-raw '{"team_id": "my-unique-id"}' # 👈 Internal Team's ID
|
||||
```
|
||||
|
||||
Response Object:
|
||||
|
||||
```bash
|
||||
{
|
||||
"key": "sk-tXL0wt5-lOOVK9sfY2UacA",
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
### 3. Start billing!
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="curl" label="Curl">
|
||||
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Authorization: Bearer sk-tXL0wt5-lOOVK9sfY2UacA' \ # 👈 Team's Key
|
||||
--data ' {
|
||||
"model": "fake-openai-endpoint",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what llm are you"
|
||||
}
|
||||
],
|
||||
}
|
||||
'
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="openai_python" label="OpenAI Python SDK">
|
||||
|
||||
```python
|
||||
import openai
|
||||
client = openai.OpenAI(
|
||||
api_key="sk-tXL0wt5-lOOVK9sfY2UacA", # 👈 Team's Key
|
||||
base_url="http://0.0.0.0:4000"
|
||||
)
|
||||
|
||||
# request sent to model set on litellm proxy, `litellm --model`
|
||||
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "this is a test request, write a short poem"
|
||||
}
|
||||
])
|
||||
|
||||
print(response)
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="langchain" label="Langchain">
|
||||
|
||||
```python
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
import os
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "sk-tXL0wt5-lOOVK9sfY2UacA" # 👈 Team's Key
|
||||
|
||||
chat = ChatOpenAI(
|
||||
openai_api_base="http://0.0.0.0:4000",
|
||||
model = "gpt-3.5-turbo",
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
messages = [
|
||||
SystemMessage(
|
||||
content="You are a helpful assistant that im using to make a test request to."
|
||||
),
|
||||
HumanMessage(
|
||||
content="test from litellm. tell me why it's amazing in 1 sentence"
|
||||
),
|
||||
]
|
||||
response = chat(messages)
|
||||
|
||||
print(response)
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
**See Results on Lago**
|
||||
|
||||
|
||||
<Image img={require('../../img/lago_2.png')} style={{ width: '500px', height: 'auto' }} />
|
||||
|
||||
## Advanced - Lago Logging object
|
||||
|
||||
This is what LiteLLM will log to Lagos
|
||||
|
||||
```
|
||||
{
|
||||
"event": {
|
||||
"transaction_id": "<generated_unique_id>",
|
||||
"external_customer_id": <selected_id>, # either 'end_user_id', 'user_id', or 'team_id'. Default 'end_user_id'.
|
||||
"code": os.getenv("LAGO_API_EVENT_CODE"),
|
||||
"properties": {
|
||||
"input_tokens": <number>,
|
||||
"output_tokens": <number>,
|
||||
"model": <string>,
|
||||
"response_cost": <number>, # 👈 LITELLM CALCULATED RESPONSE COST - https://github.com/BerriAI/litellm/blob/d43f75150a65f91f60dc2c0c9462ce3ffc713c1f/litellm/utils.py#L1473
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Advanced - Bill Customers, Internal Users
|
||||
|
||||
For:
|
||||
- Customers (id passed via 'user' param in /chat/completion call) = 'end_user_id'
|
||||
- Internal Users (id set when [creating keys](https://docs.litellm.ai/docs/proxy/virtual_keys#advanced---spend-tracking)) = 'user_id'
|
||||
- Teams (id set when [creating keys](https://docs.litellm.ai/docs/proxy/virtual_keys#advanced---spend-tracking)) = 'team_id'
|
||||
|
||||
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="customers" label="Customer Billing">
|
||||
|
||||
1. Set 'LAGO_API_CHARGE_BY' to 'end_user_id'
|
||||
|
||||
```bash
|
||||
export LAGO_API_CHARGE_BY="end_user_id"
|
||||
```
|
||||
|
||||
2. Test it!
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="curl" label="Curl">
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data ' {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what llm are you"
|
||||
}
|
||||
],
|
||||
"user": "my_customer_id" # 👈 whatever your customer id is
|
||||
}
|
||||
'
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="openai_sdk" label="OpenAI Python SDK">
|
||||
|
||||
```python
|
||||
import openai
|
||||
client = openai.OpenAI(
|
||||
api_key="anything",
|
||||
base_url="http://0.0.0.0:4000"
|
||||
)
|
||||
|
||||
# request sent to model set on litellm proxy, `litellm --model`
|
||||
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "this is a test request, write a short poem"
|
||||
}
|
||||
], user="my_customer_id") # 👈 whatever your customer id is
|
||||
|
||||
print(response)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="langchain" label="Langchain">
|
||||
|
||||
```python
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
import os
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "anything"
|
||||
|
||||
chat = ChatOpenAI(
|
||||
openai_api_base="http://0.0.0.0:4000",
|
||||
model = "gpt-3.5-turbo",
|
||||
temperature=0.1,
|
||||
extra_body={
|
||||
"user": "my_customer_id" # 👈 whatever your customer id is
|
||||
}
|
||||
)
|
||||
|
||||
messages = [
|
||||
SystemMessage(
|
||||
content="You are a helpful assistant that im using to make a test request to."
|
||||
),
|
||||
HumanMessage(
|
||||
content="test from litellm. tell me why it's amazing in 1 sentence"
|
||||
),
|
||||
]
|
||||
response = chat(messages)
|
||||
|
||||
print(response)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="users" label="Internal User Billing">
|
||||
|
||||
1. Set 'LAGO_API_CHARGE_BY' to 'user_id'
|
||||
|
||||
```bash
|
||||
export LAGO_API_CHARGE_BY="user_id"
|
||||
```
|
||||
|
||||
2. Create a key for that user
|
||||
|
||||
```bash
|
||||
curl 'http://0.0.0.0:4000/key/generate' \
|
||||
--header 'Authorization: Bearer <your-master-key>' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data-raw '{"user_id": "my-unique-id"}' # 👈 Internal User's id
|
||||
```
|
||||
|
||||
Response Object:
|
||||
|
||||
```bash
|
||||
{
|
||||
"key": "sk-tXL0wt5-lOOVK9sfY2UacA",
|
||||
}
|
||||
```
|
||||
|
||||
3. Make API Calls with that Key
|
||||
|
||||
```python
|
||||
import openai
|
||||
client = openai.OpenAI(
|
||||
api_key="sk-tXL0wt5-lOOVK9sfY2UacA", # 👈 Generated key
|
||||
base_url="http://0.0.0.0:4000"
|
||||
)
|
||||
|
||||
# request sent to model set on litellm proxy, `litellm --model`
|
||||
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "this is a test request, write a short poem"
|
||||
}
|
||||
])
|
||||
|
||||
print(response)
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
|
@ -25,26 +25,45 @@ class MyCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observabilit
|
|||
def __init__(self):
|
||||
pass
|
||||
|
||||
#### ASYNC ####
|
||||
|
||||
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||
pass
|
||||
|
||||
async def async_log_pre_api_call(self, model, messages, kwargs):
|
||||
pass
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
pass
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
pass
|
||||
|
||||
#### CALL HOOKS - proxy only ####
|
||||
|
||||
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: Literal["completion", "embeddings"]):
|
||||
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: Literal[
|
||||
"completion",
|
||||
"text_completion",
|
||||
"embeddings",
|
||||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
]) -> Optional[dict, str, Exception]:
|
||||
data["model"] = "my-new-model"
|
||||
return data
|
||||
|
||||
async def async_post_call_failure_hook(
|
||||
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
|
||||
):
|
||||
pass
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response,
|
||||
):
|
||||
pass
|
||||
|
||||
async def async_moderation_hook( # call made in parallel to llm api call
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
call_type: Literal["completion", "embeddings", "image_generation"],
|
||||
):
|
||||
pass
|
||||
|
||||
async def async_post_call_streaming_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response: str,
|
||||
):
|
||||
pass
|
||||
proxy_handler_instance = MyCustomHandler()
|
||||
```
|
||||
|
||||
|
@ -191,3 +210,99 @@ general_settings:
|
|||
**Result**
|
||||
|
||||
<Image img={require('../../img/end_user_enforcement.png')}/>
|
||||
|
||||
## Advanced - Return rejected message as response
|
||||
|
||||
For chat completions and text completion calls, you can return a rejected message as a user response.
|
||||
|
||||
Do this by returning a string. LiteLLM takes care of returning the response in the correct format depending on the endpoint and if it's streaming/non-streaming.
|
||||
|
||||
For non-chat/text completion endpoints, this response is returned as a 400 status code exception.
|
||||
|
||||
|
||||
### 1. Create Custom Handler
|
||||
|
||||
```python
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
import litellm
|
||||
from litellm.utils import get_formatted_prompt
|
||||
|
||||
# This file includes the custom callbacks for LiteLLM Proxy
|
||||
# Once defined, these can be passed in proxy_config.yaml
|
||||
class MyCustomHandler(CustomLogger):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
#### CALL HOOKS - proxy only ####
|
||||
|
||||
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: Literal[
|
||||
"completion",
|
||||
"text_completion",
|
||||
"embeddings",
|
||||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
]) -> Optional[dict, str, Exception]:
|
||||
formatted_prompt = get_formatted_prompt(data=data, call_type=call_type)
|
||||
|
||||
if "Hello world" in formatted_prompt:
|
||||
return "This is an invalid response"
|
||||
|
||||
return data
|
||||
|
||||
proxy_handler_instance = MyCustomHandler()
|
||||
```
|
||||
|
||||
### 2. Update config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
|
||||
litellm_settings:
|
||||
callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
|
||||
```
|
||||
|
||||
|
||||
### 3. Test it!
|
||||
|
||||
```shell
|
||||
$ litellm /path/to/config.yaml
|
||||
```
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--data ' {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello world"
|
||||
}
|
||||
],
|
||||
}'
|
||||
```
|
||||
|
||||
**Expected Response**
|
||||
|
||||
```
|
||||
{
|
||||
"id": "chatcmpl-d00bbede-2d90-4618-bf7b-11a1c23cf360",
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"message": {
|
||||
"content": "This is an invalid response.", # 👈 REJECTED RESPONSE
|
||||
"role": "assistant"
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1716234198,
|
||||
"model": null,
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": null,
|
||||
"usage": {}
|
||||
}
|
||||
```
|
|
@ -5,6 +5,8 @@
|
|||
- debug (prints info logs)
|
||||
- detailed debug (prints debug logs)
|
||||
|
||||
The proxy also supports json logs. [See here](#json-logs)
|
||||
|
||||
## `debug`
|
||||
|
||||
**via cli**
|
||||
|
@ -32,3 +34,19 @@ $ litellm --detailed_debug
|
|||
```python
|
||||
os.environ["LITELLM_LOG"] = "DEBUG"
|
||||
```
|
||||
|
||||
## JSON LOGS
|
||||
|
||||
Set `JSON_LOGS="True"` in your env:
|
||||
|
||||
```bash
|
||||
export JSON_LOGS="True"
|
||||
```
|
||||
|
||||
Start proxy
|
||||
|
||||
```bash
|
||||
$ litellm
|
||||
```
|
||||
|
||||
The proxy will now all logs in json format.
|
|
@ -1,7 +1,8 @@
|
|||
import Image from '@theme/IdealImage';
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# ✨ Enterprise Features - Content Mod, SSO
|
||||
# ✨ Enterprise Features - Content Mod, SSO, Custom Swagger
|
||||
|
||||
Features here are behind a commercial license in our `/enterprise` folder. [**See Code**](https://github.com/BerriAI/litellm/tree/main/enterprise)
|
||||
|
||||
|
@ -20,6 +21,7 @@ Features:
|
|||
- ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors)
|
||||
- ✅ Don't log/store specific requests to Langfuse, Sentry, etc. (eg confidential LLM requests)
|
||||
- ✅ Tracking Spend for Custom Tags
|
||||
- ✅ Custom Branding + Routes on Swagger Docs
|
||||
|
||||
|
||||
|
||||
|
@ -527,3 +529,38 @@ curl -X GET "http://0.0.0.0:4000/spend/tags" \
|
|||
<!-- ## Tracking Spend per Key
|
||||
|
||||
## Tracking Spend per User -->
|
||||
|
||||
## Swagger Docs - Custom Routes + Branding
|
||||
|
||||
:::info
|
||||
|
||||
Requires a LiteLLM Enterprise key to use. Request one [here](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
|
||||
|
||||
:::
|
||||
|
||||
Set LiteLLM Key in your environment
|
||||
|
||||
```bash
|
||||
LITELLM_LICENSE=""
|
||||
```
|
||||
|
||||
### Customize Title + Description
|
||||
|
||||
In your environment, set:
|
||||
|
||||
```bash
|
||||
DOCS_TITLE="TotalGPT"
|
||||
DOCS_DESCRIPTION="Sample Company Description"
|
||||
```
|
||||
|
||||
### Customize Routes
|
||||
|
||||
Hide admin routes from users.
|
||||
|
||||
In your environment, set:
|
||||
|
||||
```bash
|
||||
DOCS_FILTERED="True" # only shows openai routes to user
|
||||
```
|
||||
|
||||
<Image img={require('../../img/custom_swagger.png')} style={{ width: '900px', height: 'auto' }} />
|
BIN
docs/my-website/img/custom_swagger.png
Normal file
BIN
docs/my-website/img/custom_swagger.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 223 KiB |
BIN
docs/my-website/img/lago.jpeg
Normal file
BIN
docs/my-website/img/lago.jpeg
Normal file
Binary file not shown.
After Width: | Height: | Size: 344 KiB |
BIN
docs/my-website/img/lago_2.png
Normal file
BIN
docs/my-website/img/lago_2.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 219 KiB |
|
@ -41,6 +41,7 @@ const sidebars = {
|
|||
"proxy/reliability",
|
||||
"proxy/cost_tracking",
|
||||
"proxy/users",
|
||||
"proxy/billing",
|
||||
"proxy/user_keys",
|
||||
"proxy/enterprise",
|
||||
"proxy/virtual_keys",
|
||||
|
@ -175,6 +176,7 @@ const sidebars = {
|
|||
"observability/custom_callback",
|
||||
"observability/langfuse_integration",
|
||||
"observability/sentry",
|
||||
"observability/lago",
|
||||
"observability/openmeter",
|
||||
"observability/promptlayer_integration",
|
||||
"observability/wandb_integration",
|
||||
|
|
|
@ -27,8 +27,8 @@ input_callback: List[Union[str, Callable]] = []
|
|||
success_callback: List[Union[str, Callable]] = []
|
||||
failure_callback: List[Union[str, Callable]] = []
|
||||
service_callback: List[Union[str, Callable]] = []
|
||||
callbacks: List[Callable] = []
|
||||
_custom_logger_compatible_callbacks: list = ["openmeter"]
|
||||
_custom_logger_compatible_callbacks_literal = Literal["lago", "openmeter"]
|
||||
callbacks: List[Union[Callable, _custom_logger_compatible_callbacks_literal]] = []
|
||||
_langfuse_default_tags: Optional[
|
||||
List[
|
||||
Literal[
|
||||
|
@ -724,6 +724,9 @@ from .utils import (
|
|||
get_supported_openai_params,
|
||||
get_api_base,
|
||||
get_first_chars_messages,
|
||||
ModelResponse,
|
||||
ImageResponse,
|
||||
ImageObject,
|
||||
)
|
||||
from .llms.huggingface_restapi import HuggingfaceConfig
|
||||
from .llms.anthropic import AnthropicConfig
|
||||
|
|
|
@ -1,19 +1,33 @@
|
|||
import logging
|
||||
import logging, os, json
|
||||
from logging import Formatter
|
||||
|
||||
set_verbose = False
|
||||
json_logs = False
|
||||
json_logs = bool(os.getenv("JSON_LOGS", False))
|
||||
# Create a handler for the logger (you may need to adapt this based on your needs)
|
||||
handler = logging.StreamHandler()
|
||||
handler.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
class JsonFormatter(Formatter):
|
||||
def __init__(self):
|
||||
super(JsonFormatter, self).__init__()
|
||||
|
||||
def format(self, record):
|
||||
json_record = {}
|
||||
json_record["message"] = record.getMessage()
|
||||
return json.dumps(json_record)
|
||||
|
||||
|
||||
# Create a formatter and set it for the handler
|
||||
formatter = logging.Formatter(
|
||||
"\033[92m%(asctime)s - %(name)s:%(levelname)s\033[0m: %(filename)s:%(lineno)s - %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
if json_logs:
|
||||
handler.setFormatter(JsonFormatter())
|
||||
else:
|
||||
formatter = logging.Formatter(
|
||||
"\033[92m%(asctime)s - %(name)s:%(levelname)s\033[0m: %(filename)s:%(lineno)s - %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
|
||||
|
||||
handler.setFormatter(formatter)
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
verbose_proxy_logger = logging.getLogger("LiteLLM Proxy")
|
||||
verbose_router_logger = logging.getLogger("LiteLLM Router")
|
||||
|
|
|
@ -15,11 +15,19 @@ from typing import Optional
|
|||
|
||||
|
||||
class AuthenticationError(openai.AuthenticationError): # type: ignore
|
||||
def __init__(self, message, llm_provider, model, response: httpx.Response):
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
llm_provider,
|
||||
model,
|
||||
response: httpx.Response,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
):
|
||||
self.status_code = 401
|
||||
self.message = message
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
super().__init__(
|
||||
self.message, response=response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
@ -27,11 +35,19 @@ class AuthenticationError(openai.AuthenticationError): # type: ignore
|
|||
|
||||
# raise when invalid models passed, example gpt-8
|
||||
class NotFoundError(openai.NotFoundError): # type: ignore
|
||||
def __init__(self, message, model, llm_provider, response: httpx.Response):
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
model,
|
||||
llm_provider,
|
||||
response: httpx.Response,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
):
|
||||
self.status_code = 404
|
||||
self.message = message
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
super().__init__(
|
||||
self.message, response=response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
@ -39,12 +55,18 @@ class NotFoundError(openai.NotFoundError): # type: ignore
|
|||
|
||||
class BadRequestError(openai.BadRequestError): # type: ignore
|
||||
def __init__(
|
||||
self, message, model, llm_provider, response: Optional[httpx.Response] = None
|
||||
self,
|
||||
message,
|
||||
model,
|
||||
llm_provider,
|
||||
response: Optional[httpx.Response] = None,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
):
|
||||
self.status_code = 400
|
||||
self.message = message
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
response = response or httpx.Response(
|
||||
status_code=self.status_code,
|
||||
request=httpx.Request(
|
||||
|
@ -57,18 +79,28 @@ class BadRequestError(openai.BadRequestError): # type: ignore
|
|||
|
||||
|
||||
class UnprocessableEntityError(openai.UnprocessableEntityError): # type: ignore
|
||||
def __init__(self, message, model, llm_provider, response: httpx.Response):
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
model,
|
||||
llm_provider,
|
||||
response: httpx.Response,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
):
|
||||
self.status_code = 422
|
||||
self.message = message
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
super().__init__(
|
||||
self.message, response=response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class Timeout(openai.APITimeoutError): # type: ignore
|
||||
def __init__(self, message, model, llm_provider):
|
||||
def __init__(
|
||||
self, message, model, llm_provider, litellm_debug_info: Optional[str] = None
|
||||
):
|
||||
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||
super().__init__(
|
||||
request=request
|
||||
|
@ -77,6 +109,7 @@ class Timeout(openai.APITimeoutError): # type: ignore
|
|||
self.message = message
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
|
||||
# custom function to convert to str
|
||||
def __str__(self):
|
||||
|
@ -84,22 +117,38 @@ class Timeout(openai.APITimeoutError): # type: ignore
|
|||
|
||||
|
||||
class PermissionDeniedError(openai.PermissionDeniedError): # type:ignore
|
||||
def __init__(self, message, llm_provider, model, response: httpx.Response):
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
llm_provider,
|
||||
model,
|
||||
response: httpx.Response,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
):
|
||||
self.status_code = 403
|
||||
self.message = message
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
super().__init__(
|
||||
self.message, response=response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class RateLimitError(openai.RateLimitError): # type: ignore
|
||||
def __init__(self, message, llm_provider, model, response: httpx.Response):
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
llm_provider,
|
||||
model,
|
||||
response: httpx.Response,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
):
|
||||
self.status_code = 429
|
||||
self.message = message
|
||||
self.llm_provider = llm_provider
|
||||
self.modle = model
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
super().__init__(
|
||||
self.message, response=response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
@ -107,11 +156,45 @@ class RateLimitError(openai.RateLimitError): # type: ignore
|
|||
|
||||
# sub class of rate limit error - meant to give more granularity for error handling context window exceeded errors
|
||||
class ContextWindowExceededError(BadRequestError): # type: ignore
|
||||
def __init__(self, message, model, llm_provider, response: httpx.Response):
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
model,
|
||||
llm_provider,
|
||||
response: httpx.Response,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
):
|
||||
self.status_code = 400
|
||||
self.message = message
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
super().__init__(
|
||||
message=self.message,
|
||||
model=self.model, # type: ignore
|
||||
llm_provider=self.llm_provider, # type: ignore
|
||||
response=response,
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
# sub class of bad request error - meant to help us catch guardrails-related errors on proxy.
|
||||
class RejectedRequestError(BadRequestError): # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
model,
|
||||
llm_provider,
|
||||
request_data: dict,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
):
|
||||
self.status_code = 400
|
||||
self.message = message
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
self.request_data = request_data
|
||||
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||
response = httpx.Response(status_code=500, request=request)
|
||||
super().__init__(
|
||||
message=self.message,
|
||||
model=self.model, # type: ignore
|
||||
|
@ -122,11 +205,19 @@ class ContextWindowExceededError(BadRequestError): # type: ignore
|
|||
|
||||
class ContentPolicyViolationError(BadRequestError): # type: ignore
|
||||
# Error code: 400 - {'error': {'code': 'content_policy_violation', 'message': 'Your request was rejected as a result of our safety system. Image descriptions generated from your prompt may contain text that is not allowed by our safety system. If you believe this was done in error, your request may succeed if retried, or by adjusting your prompt.', 'param': None, 'type': 'invalid_request_error'}}
|
||||
def __init__(self, message, model, llm_provider, response: httpx.Response):
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
model,
|
||||
llm_provider,
|
||||
response: httpx.Response,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
):
|
||||
self.status_code = 400
|
||||
self.message = message
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
super().__init__(
|
||||
message=self.message,
|
||||
model=self.model, # type: ignore
|
||||
|
@ -136,11 +227,19 @@ class ContentPolicyViolationError(BadRequestError): # type: ignore
|
|||
|
||||
|
||||
class ServiceUnavailableError(openai.APIStatusError): # type: ignore
|
||||
def __init__(self, message, llm_provider, model, response: httpx.Response):
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
llm_provider,
|
||||
model,
|
||||
response: httpx.Response,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
):
|
||||
self.status_code = 503
|
||||
self.message = message
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
super().__init__(
|
||||
self.message, response=response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
@ -149,33 +248,51 @@ class ServiceUnavailableError(openai.APIStatusError): # type: ignore
|
|||
# raise this when the API returns an invalid response object - https://github.com/openai/openai-python/blob/1be14ee34a0f8e42d3f9aa5451aa4cb161f1781f/openai/api_requestor.py#L401
|
||||
class APIError(openai.APIError): # type: ignore
|
||||
def __init__(
|
||||
self, status_code, message, llm_provider, model, request: httpx.Request
|
||||
self,
|
||||
status_code,
|
||||
message,
|
||||
llm_provider,
|
||||
model,
|
||||
request: httpx.Request,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
super().__init__(self.message, request=request, body=None) # type: ignore
|
||||
|
||||
|
||||
# raised if an invalid request (not get, delete, put, post) is made
|
||||
class APIConnectionError(openai.APIConnectionError): # type: ignore
|
||||
def __init__(self, message, llm_provider, model, request: httpx.Request):
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
llm_provider,
|
||||
model,
|
||||
request: httpx.Request,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
):
|
||||
self.message = message
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
self.status_code = 500
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
super().__init__(message=self.message, request=request)
|
||||
|
||||
|
||||
# raised if an invalid request (not get, delete, put, post) is made
|
||||
class APIResponseValidationError(openai.APIResponseValidationError): # type: ignore
|
||||
def __init__(self, message, llm_provider, model):
|
||||
def __init__(
|
||||
self, message, llm_provider, model, litellm_debug_info: Optional[str] = None
|
||||
):
|
||||
self.message = message
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||
response = httpx.Response(status_code=500, request=request)
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
super().__init__(response=response, body=None, message=message)
|
||||
|
||||
|
||||
|
|
|
@ -4,7 +4,6 @@ import dotenv, os
|
|||
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.caching import DualCache
|
||||
|
||||
from typing import Literal, Union, Optional
|
||||
import traceback
|
||||
|
||||
|
@ -64,8 +63,17 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
|||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: Literal["completion", "embeddings", "image_generation"],
|
||||
):
|
||||
call_type: Literal[
|
||||
"completion",
|
||||
"text_completion",
|
||||
"embeddings",
|
||||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
],
|
||||
) -> Optional[
|
||||
Union[Exception, str, dict]
|
||||
]: # raise exception if invalid, return a str for the user to receive - if rejected, or return a modified dictionary for passing into litellm
|
||||
pass
|
||||
|
||||
async def async_post_call_failure_hook(
|
||||
|
|
179
litellm/integrations/lago.py
Normal file
179
litellm/integrations/lago.py
Normal file
|
@ -0,0 +1,179 @@
|
|||
# What is this?
|
||||
## On Success events log cost to Lago - https://github.com/BerriAI/litellm/issues/3639
|
||||
|
||||
import dotenv, os, json
|
||||
import litellm
|
||||
import traceback, httpx
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
import uuid
|
||||
from typing import Optional, Literal
|
||||
|
||||
|
||||
def get_utc_datetime():
|
||||
import datetime as dt
|
||||
from datetime import datetime
|
||||
|
||||
if hasattr(dt, "UTC"):
|
||||
return datetime.now(dt.UTC) # type: ignore
|
||||
else:
|
||||
return datetime.utcnow() # type: ignore
|
||||
|
||||
|
||||
class LagoLogger(CustomLogger):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.validate_environment()
|
||||
self.async_http_handler = AsyncHTTPHandler()
|
||||
self.sync_http_handler = HTTPHandler()
|
||||
|
||||
def validate_environment(self):
|
||||
"""
|
||||
Expects
|
||||
LAGO_API_BASE,
|
||||
LAGO_API_KEY,
|
||||
LAGO_API_EVENT_CODE,
|
||||
|
||||
Optional:
|
||||
LAGO_API_CHARGE_BY
|
||||
|
||||
in the environment
|
||||
"""
|
||||
missing_keys = []
|
||||
if os.getenv("LAGO_API_KEY", None) is None:
|
||||
missing_keys.append("LAGO_API_KEY")
|
||||
|
||||
if os.getenv("LAGO_API_BASE", None) is None:
|
||||
missing_keys.append("LAGO_API_BASE")
|
||||
|
||||
if os.getenv("LAGO_API_EVENT_CODE", None) is None:
|
||||
missing_keys.append("LAGO_API_EVENT_CODE")
|
||||
|
||||
if len(missing_keys) > 0:
|
||||
raise Exception("Missing keys={} in environment.".format(missing_keys))
|
||||
|
||||
def _common_logic(self, kwargs: dict, response_obj) -> dict:
|
||||
call_id = response_obj.get("id", kwargs.get("litellm_call_id"))
|
||||
dt = get_utc_datetime().isoformat()
|
||||
cost = kwargs.get("response_cost", None)
|
||||
model = kwargs.get("model")
|
||||
usage = {}
|
||||
|
||||
if (
|
||||
isinstance(response_obj, litellm.ModelResponse)
|
||||
or isinstance(response_obj, litellm.EmbeddingResponse)
|
||||
) and hasattr(response_obj, "usage"):
|
||||
usage = {
|
||||
"prompt_tokens": response_obj["usage"].get("prompt_tokens", 0),
|
||||
"completion_tokens": response_obj["usage"].get("completion_tokens", 0),
|
||||
"total_tokens": response_obj["usage"].get("total_tokens"),
|
||||
}
|
||||
|
||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
||||
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
||||
user_id = litellm_params["metadata"].get("user_api_key_user_id", None)
|
||||
team_id = litellm_params["metadata"].get("user_api_key_team_id", None)
|
||||
org_id = litellm_params["metadata"].get("user_api_key_org_id", None)
|
||||
|
||||
charge_by: Literal["end_user_id", "team_id", "user_id"] = "end_user_id"
|
||||
external_customer_id: Optional[str] = None
|
||||
|
||||
if os.getenv("LAGO_API_CHARGE_BY", None) is not None and isinstance(
|
||||
os.environ["LAGO_API_CHARGE_BY"], str
|
||||
):
|
||||
if os.environ["LAGO_API_CHARGE_BY"] in [
|
||||
"end_user_id",
|
||||
"user_id",
|
||||
"team_id",
|
||||
]:
|
||||
charge_by = os.environ["LAGO_API_CHARGE_BY"] # type: ignore
|
||||
else:
|
||||
raise Exception("invalid LAGO_API_CHARGE_BY set")
|
||||
|
||||
if charge_by == "end_user_id":
|
||||
external_customer_id = end_user_id
|
||||
elif charge_by == "team_id":
|
||||
external_customer_id = team_id
|
||||
elif charge_by == "user_id":
|
||||
external_customer_id = user_id
|
||||
|
||||
if external_customer_id is None:
|
||||
raise Exception("External Customer ID is not set")
|
||||
|
||||
return {
|
||||
"event": {
|
||||
"transaction_id": str(uuid.uuid4()),
|
||||
"external_customer_id": external_customer_id,
|
||||
"code": os.getenv("LAGO_API_EVENT_CODE"),
|
||||
"properties": {"model": model, "response_cost": cost, **usage},
|
||||
}
|
||||
}
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
_url = os.getenv("LAGO_API_BASE")
|
||||
assert _url is not None and isinstance(
|
||||
_url, str
|
||||
), "LAGO_API_BASE missing or not set correctly. LAGO_API_BASE={}".format(_url)
|
||||
if _url.endswith("/"):
|
||||
_url += "api/v1/events"
|
||||
else:
|
||||
_url += "/api/v1/events"
|
||||
|
||||
api_key = os.getenv("LAGO_API_KEY")
|
||||
|
||||
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
|
||||
_headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer {}".format(api_key),
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.sync_http_handler.post(
|
||||
url=_url,
|
||||
data=json.dumps(_data),
|
||||
headers=_headers,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
except Exception as e:
|
||||
if hasattr(response, "text"):
|
||||
litellm.print_verbose(f"\nError Message: {response.text}")
|
||||
raise e
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
_url = os.getenv("LAGO_API_BASE")
|
||||
assert _url is not None and isinstance(
|
||||
_url, str
|
||||
), "LAGO_API_BASE missing or not set correctly. LAGO_API_BASE={}".format(
|
||||
_url
|
||||
)
|
||||
if _url.endswith("/"):
|
||||
_url += "api/v1/events"
|
||||
else:
|
||||
_url += "/api/v1/events"
|
||||
|
||||
api_key = os.getenv("LAGO_API_KEY")
|
||||
|
||||
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
|
||||
_headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer {}".format(api_key),
|
||||
}
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
response: Optional[httpx.Response] = None
|
||||
try:
|
||||
response = await self.async_http_handler.post(
|
||||
url=_url,
|
||||
data=json.dumps(_data),
|
||||
headers=_headers,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
except Exception as e:
|
||||
if response is not None and hasattr(response, "text"):
|
||||
litellm.print_verbose(f"\nError Message: {response.text}")
|
||||
raise e
|
|
@ -93,6 +93,7 @@ class LangFuseLogger:
|
|||
)
|
||||
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
litellm_call_id = kwargs.get("litellm_call_id", None)
|
||||
metadata = (
|
||||
litellm_params.get("metadata", {}) or {}
|
||||
) # if litellm_params['metadata'] == None
|
||||
|
@ -161,6 +162,7 @@ class LangFuseLogger:
|
|||
response_obj,
|
||||
level,
|
||||
print_verbose,
|
||||
litellm_call_id,
|
||||
)
|
||||
elif response_obj is not None:
|
||||
self._log_langfuse_v1(
|
||||
|
@ -255,6 +257,7 @@ class LangFuseLogger:
|
|||
response_obj,
|
||||
level,
|
||||
print_verbose,
|
||||
litellm_call_id,
|
||||
) -> tuple:
|
||||
import langfuse
|
||||
|
||||
|
@ -318,7 +321,7 @@ class LangFuseLogger:
|
|||
|
||||
session_id = clean_metadata.pop("session_id", None)
|
||||
trace_name = clean_metadata.pop("trace_name", None)
|
||||
trace_id = clean_metadata.pop("trace_id", None)
|
||||
trace_id = clean_metadata.pop("trace_id", litellm_call_id)
|
||||
existing_trace_id = clean_metadata.pop("existing_trace_id", None)
|
||||
update_trace_keys = clean_metadata.pop("update_trace_keys", [])
|
||||
debug = clean_metadata.pop("debug_langfuse", None)
|
||||
|
@ -351,9 +354,13 @@ class LangFuseLogger:
|
|||
|
||||
# Special keys that are found in the function arguments and not the metadata
|
||||
if "input" in update_trace_keys:
|
||||
trace_params["input"] = input if not mask_input else "redacted-by-litellm"
|
||||
trace_params["input"] = (
|
||||
input if not mask_input else "redacted-by-litellm"
|
||||
)
|
||||
if "output" in update_trace_keys:
|
||||
trace_params["output"] = output if not mask_output else "redacted-by-litellm"
|
||||
trace_params["output"] = (
|
||||
output if not mask_output else "redacted-by-litellm"
|
||||
)
|
||||
else: # don't overwrite an existing trace
|
||||
trace_params = {
|
||||
"id": trace_id,
|
||||
|
@ -375,7 +382,9 @@ class LangFuseLogger:
|
|||
if level == "ERROR":
|
||||
trace_params["status_message"] = output
|
||||
else:
|
||||
trace_params["output"] = output if not mask_output else "redacted-by-litellm"
|
||||
trace_params["output"] = (
|
||||
output if not mask_output else "redacted-by-litellm"
|
||||
)
|
||||
|
||||
if debug == True or (isinstance(debug, str) and debug.lower() == "true"):
|
||||
if "metadata" in trace_params:
|
||||
|
|
|
@ -44,6 +44,8 @@ class LangsmithLogger:
|
|||
print_verbose(
|
||||
f"Langsmith Logging - project_name: {project_name}, run_name {run_name}"
|
||||
)
|
||||
langsmith_base_url = os.getenv("LANGSMITH_BASE_URL", "https://api.smith.langchain.com")
|
||||
|
||||
try:
|
||||
print_verbose(
|
||||
f"Langsmith Logging - Enters logging function for model {kwargs}"
|
||||
|
@ -86,8 +88,12 @@ class LangsmithLogger:
|
|||
"end_time": end_time,
|
||||
}
|
||||
|
||||
url = f"{langsmith_base_url}/runs"
|
||||
print_verbose(
|
||||
f"Langsmith Logging - About to send data to {url} ..."
|
||||
)
|
||||
response = requests.post(
|
||||
"https://api.smith.langchain.com/runs",
|
||||
url=url,
|
||||
json=data,
|
||||
headers={"x-api-key": self.langsmith_api_key},
|
||||
)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#### What this does ####
|
||||
# Class for sending Slack Alerts #
|
||||
import dotenv, os
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy._types import UserAPIKeyAuth, CallInfo
|
||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||
import litellm, threading
|
||||
from typing import List, Literal, Any, Union, Optional, Dict
|
||||
|
@ -36,6 +36,13 @@ class SlackAlertingArgs(LiteLLMBase):
|
|||
os.getenv("SLACK_DAILY_REPORT_FREQUENCY", default_daily_report_frequency)
|
||||
)
|
||||
report_check_interval: int = 5 * 60 # 5 minutes
|
||||
budget_alert_ttl: int = 24 * 60 * 60 # 24 hours
|
||||
|
||||
|
||||
class WebhookEvent(CallInfo):
|
||||
event: Literal["budget_crossed", "threshold_crossed", "projected_limit_exceeded"]
|
||||
event_group: Literal["user", "key", "team", "proxy"]
|
||||
event_message: str # human-readable description of event
|
||||
|
||||
|
||||
class DeploymentMetrics(LiteLLMBase):
|
||||
|
@ -87,6 +94,9 @@ class SlackAlerting(CustomLogger):
|
|||
"budget_alerts",
|
||||
"db_exceptions",
|
||||
"daily_reports",
|
||||
"spend_reports",
|
||||
"cooldown_deployment",
|
||||
"new_model_added",
|
||||
]
|
||||
] = [
|
||||
"llm_exceptions",
|
||||
|
@ -95,6 +105,9 @@ class SlackAlerting(CustomLogger):
|
|||
"budget_alerts",
|
||||
"db_exceptions",
|
||||
"daily_reports",
|
||||
"spend_reports",
|
||||
"cooldown_deployment",
|
||||
"new_model_added",
|
||||
],
|
||||
alert_to_webhook_url: Optional[
|
||||
Dict
|
||||
|
@ -158,13 +171,28 @@ class SlackAlerting(CustomLogger):
|
|||
) -> Optional[str]:
|
||||
"""
|
||||
Returns langfuse trace url
|
||||
|
||||
- check:
|
||||
-> existing_trace_id
|
||||
-> trace_id
|
||||
-> litellm_call_id
|
||||
"""
|
||||
# do nothing for now
|
||||
if (
|
||||
request_data is not None
|
||||
and request_data.get("metadata", {}).get("trace_id", None) is not None
|
||||
):
|
||||
trace_id = request_data["metadata"]["trace_id"]
|
||||
if request_data is not None:
|
||||
trace_id = None
|
||||
if (
|
||||
request_data.get("metadata", {}).get("existing_trace_id", None)
|
||||
is not None
|
||||
):
|
||||
trace_id = request_data["metadata"]["existing_trace_id"]
|
||||
elif request_data.get("metadata", {}).get("trace_id", None) is not None:
|
||||
trace_id = request_data["metadata"]["trace_id"]
|
||||
elif request_data.get("litellm_logging_obj", None) is not None and hasattr(
|
||||
request_data["litellm_logging_obj"], "model_call_details"
|
||||
):
|
||||
trace_id = request_data["litellm_logging_obj"].model_call_details[
|
||||
"litellm_call_id"
|
||||
]
|
||||
if litellm.utils.langFuseLogger is not None:
|
||||
base_url = litellm.utils.langFuseLogger.Langfuse.base_url
|
||||
return f"{base_url}/trace/{trace_id}"
|
||||
|
@ -549,127 +577,131 @@ class SlackAlerting(CustomLogger):
|
|||
alert_type="llm_requests_hanging",
|
||||
)
|
||||
|
||||
async def failed_tracking_alert(self, error_message: str):
|
||||
"""Raise alert when tracking failed for specific model"""
|
||||
_cache: DualCache = self.internal_usage_cache
|
||||
message = "Failed Tracking Cost for" + error_message
|
||||
_cache_key = "budget_alerts:failed_tracking:{}".format(message)
|
||||
result = await _cache.async_get_cache(key=_cache_key)
|
||||
if result is None:
|
||||
await self.send_alert(
|
||||
message=message, level="High", alert_type="budget_alerts"
|
||||
)
|
||||
await _cache.async_set_cache(
|
||||
key=_cache_key,
|
||||
value="SENT",
|
||||
ttl=self.alerting_args.budget_alert_ttl,
|
||||
)
|
||||
|
||||
async def budget_alerts(
|
||||
self,
|
||||
type: Literal[
|
||||
"token_budget",
|
||||
"user_budget",
|
||||
"user_and_proxy_budget",
|
||||
"failed_budgets",
|
||||
"failed_tracking",
|
||||
"team_budget",
|
||||
"proxy_budget",
|
||||
"projected_limit_exceeded",
|
||||
],
|
||||
user_max_budget: float,
|
||||
user_current_spend: float,
|
||||
user_info=None,
|
||||
error_message="",
|
||||
user_info: CallInfo,
|
||||
):
|
||||
## PREVENTITIVE ALERTING ## - https://github.com/BerriAI/litellm/issues/2727
|
||||
# - Alert once within 24hr period
|
||||
# - Cache this information
|
||||
# - Don't re-alert, if alert already sent
|
||||
_cache: DualCache = self.internal_usage_cache
|
||||
|
||||
if self.alerting is None or self.alert_types is None:
|
||||
# do nothing if alerting is not switched on
|
||||
return
|
||||
if "budget_alerts" not in self.alert_types:
|
||||
return
|
||||
_id: str = "default_id" # used for caching
|
||||
if type == "user_and_proxy_budget":
|
||||
user_info = dict(user_info)
|
||||
user_id = user_info["user_id"]
|
||||
_id = user_id
|
||||
max_budget = user_info["max_budget"]
|
||||
spend = user_info["spend"]
|
||||
user_email = user_info["user_email"]
|
||||
user_info = f"""\nUser ID: {user_id}\nMax Budget: ${max_budget}\nSpend: ${spend}\nUser Email: {user_email}"""
|
||||
user_info_json = user_info.model_dump(exclude_none=True)
|
||||
for k, v in user_info_json.items():
|
||||
user_info_str = "\n{}: {}\n".format(k, v)
|
||||
|
||||
event: Optional[
|
||||
Literal["budget_crossed", "threshold_crossed", "projected_limit_exceeded"]
|
||||
] = None
|
||||
event_group: Optional[Literal["user", "team", "key", "proxy"]] = None
|
||||
event_message: str = ""
|
||||
webhook_event: Optional[WebhookEvent] = None
|
||||
if type == "proxy_budget":
|
||||
event_group = "proxy"
|
||||
event_message += "Proxy Budget: "
|
||||
elif type == "user_budget":
|
||||
event_group = "user"
|
||||
event_message += "User Budget: "
|
||||
_id = user_info.user_id or _id
|
||||
elif type == "team_budget":
|
||||
event_group = "team"
|
||||
event_message += "Team Budget: "
|
||||
_id = user_info.team_id or _id
|
||||
elif type == "token_budget":
|
||||
token_info = dict(user_info)
|
||||
token = token_info["token"]
|
||||
_id = token
|
||||
spend = token_info["spend"]
|
||||
max_budget = token_info["max_budget"]
|
||||
user_id = token_info["user_id"]
|
||||
user_info = f"""\nToken: {token}\nSpend: ${spend}\nMax Budget: ${max_budget}\nUser ID: {user_id}"""
|
||||
elif type == "failed_tracking":
|
||||
user_id = str(user_info)
|
||||
_id = user_id
|
||||
user_info = f"\nUser ID: {user_id}\n Error {error_message}"
|
||||
message = "Failed Tracking Cost for" + user_info
|
||||
await self.send_alert(
|
||||
message=message, level="High", alert_type="budget_alerts"
|
||||
)
|
||||
return
|
||||
elif type == "projected_limit_exceeded" and user_info is not None:
|
||||
"""
|
||||
Input variables:
|
||||
user_info = {
|
||||
"key_alias": key_alias,
|
||||
"projected_spend": projected_spend,
|
||||
"projected_exceeded_date": projected_exceeded_date,
|
||||
}
|
||||
user_max_budget=soft_limit,
|
||||
user_current_spend=new_spend
|
||||
"""
|
||||
message = f"""\n🚨 `ProjectedLimitExceededError` 💸\n\n`Key Alias:` {user_info["key_alias"]} \n`Expected Day of Error`: {user_info["projected_exceeded_date"]} \n`Current Spend`: {user_current_spend} \n`Projected Spend at end of month`: {user_info["projected_spend"]} \n`Soft Limit`: {user_max_budget}"""
|
||||
await self.send_alert(
|
||||
message=message, level="High", alert_type="budget_alerts"
|
||||
)
|
||||
return
|
||||
else:
|
||||
user_info = str(user_info)
|
||||
event_group = "key"
|
||||
event_message += "Key Budget: "
|
||||
_id = user_info.token
|
||||
elif type == "projected_limit_exceeded":
|
||||
event_group = "key"
|
||||
event_message += "Key Budget: Projected Limit Exceeded"
|
||||
event = "projected_limit_exceeded"
|
||||
_id = user_info.token
|
||||
|
||||
# percent of max_budget left to spend
|
||||
if user_max_budget > 0:
|
||||
percent_left = (user_max_budget - user_current_spend) / user_max_budget
|
||||
if user_info.max_budget > 0:
|
||||
percent_left = (
|
||||
user_info.max_budget - user_info.spend
|
||||
) / user_info.max_budget
|
||||
else:
|
||||
percent_left = 0
|
||||
verbose_proxy_logger.debug(
|
||||
f"Budget Alerts: Percent left: {percent_left} for {user_info}"
|
||||
)
|
||||
|
||||
## PREVENTITIVE ALERTING ## - https://github.com/BerriAI/litellm/issues/2727
|
||||
# - Alert once within 28d period
|
||||
# - Cache this information
|
||||
# - Don't re-alert, if alert already sent
|
||||
_cache: DualCache = self.internal_usage_cache
|
||||
|
||||
# check if crossed budget
|
||||
if user_current_spend >= user_max_budget:
|
||||
verbose_proxy_logger.debug("Budget Crossed for %s", user_info)
|
||||
message = "Budget Crossed for" + user_info
|
||||
result = await _cache.async_get_cache(key=message)
|
||||
if result is None:
|
||||
await self.send_alert(
|
||||
message=message, level="High", alert_type="budget_alerts"
|
||||
)
|
||||
await _cache.async_set_cache(key=message, value="SENT", ttl=2419200)
|
||||
return
|
||||
if user_info.spend >= user_info.max_budget:
|
||||
event = "budget_crossed"
|
||||
event_message += "Budget Crossed"
|
||||
elif percent_left <= 0.05:
|
||||
event = "threshold_crossed"
|
||||
event_message += "5% Threshold Crossed"
|
||||
elif percent_left <= 0.15:
|
||||
event = "threshold_crossed"
|
||||
event_message += "15% Threshold Crossed"
|
||||
|
||||
# check if 5% of max budget is left
|
||||
if percent_left <= 0.05:
|
||||
message = "5% budget left for" + user_info
|
||||
cache_key = "alerting:{}".format(_id)
|
||||
result = await _cache.async_get_cache(key=cache_key)
|
||||
if event is not None and event_group is not None:
|
||||
_cache_key = "budget_alerts:{}:{}".format(event, _id)
|
||||
result = await _cache.async_get_cache(key=_cache_key)
|
||||
if result is None:
|
||||
webhook_event = WebhookEvent(
|
||||
event=event,
|
||||
event_group=event_group,
|
||||
event_message=event_message,
|
||||
**user_info_json,
|
||||
)
|
||||
await self.send_alert(
|
||||
message=message, level="Medium", alert_type="budget_alerts"
|
||||
message=event_message + "\n\n" + user_info_str,
|
||||
level="High",
|
||||
alert_type="budget_alerts",
|
||||
user_info=webhook_event,
|
||||
)
|
||||
await _cache.async_set_cache(
|
||||
key=_cache_key,
|
||||
value="SENT",
|
||||
ttl=self.alerting_args.budget_alert_ttl,
|
||||
)
|
||||
|
||||
await _cache.async_set_cache(key=cache_key, value="SENT", ttl=2419200)
|
||||
|
||||
return
|
||||
|
||||
# check if 15% of max budget is left
|
||||
if percent_left <= 0.15:
|
||||
message = "15% budget left for" + user_info
|
||||
result = await _cache.async_get_cache(key=message)
|
||||
if result is None:
|
||||
await self.send_alert(
|
||||
message=message, level="Low", alert_type="budget_alerts"
|
||||
)
|
||||
await _cache.async_set_cache(key=message, value="SENT", ttl=2419200)
|
||||
return
|
||||
|
||||
return
|
||||
|
||||
async def model_added_alert(self, model_name: str, litellm_model_name: str):
|
||||
model_info = litellm.model_cost.get(litellm_model_name, {})
|
||||
async def model_added_alert(
|
||||
self, model_name: str, litellm_model_name: str, passed_model_info: Any
|
||||
):
|
||||
base_model_from_user = getattr(passed_model_info, "base_model", None)
|
||||
model_info = {}
|
||||
base_model = ""
|
||||
if base_model_from_user is not None:
|
||||
model_info = litellm.model_cost.get(base_model_from_user, {})
|
||||
base_model = f"Base Model: `{base_model_from_user}`\n"
|
||||
else:
|
||||
model_info = litellm.model_cost.get(litellm_model_name, {})
|
||||
model_info_str = ""
|
||||
for k, v in model_info.items():
|
||||
if k == "input_cost_per_token" or k == "output_cost_per_token":
|
||||
|
@ -681,6 +713,7 @@ class SlackAlerting(CustomLogger):
|
|||
message = f"""
|
||||
*🚅 New Model Added*
|
||||
Model Name: `{model_name}`
|
||||
{base_model}
|
||||
|
||||
Usage OpenAI Python SDK:
|
||||
```
|
||||
|
@ -715,6 +748,34 @@ Model Info:
|
|||
async def model_removed_alert(self, model_name: str):
|
||||
pass
|
||||
|
||||
async def send_webhook_alert(self, webhook_event: WebhookEvent) -> bool:
|
||||
"""
|
||||
Sends structured alert to webhook, if set.
|
||||
|
||||
Currently only implemented for budget alerts
|
||||
|
||||
Returns -> True if sent, False if not.
|
||||
"""
|
||||
|
||||
webhook_url = os.getenv("WEBHOOK_URL", None)
|
||||
if webhook_url is None:
|
||||
raise Exception("Missing webhook_url from environment")
|
||||
|
||||
payload = webhook_event.model_dump_json()
|
||||
headers = {"Content-type": "application/json"}
|
||||
|
||||
response = await self.async_http_handler.post(
|
||||
url=webhook_url,
|
||||
headers=headers,
|
||||
data=payload,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
print("Error sending webhook alert. Error=", response.text) # noqa
|
||||
|
||||
return False
|
||||
|
||||
async def send_alert(
|
||||
self,
|
||||
message: str,
|
||||
|
@ -726,9 +787,11 @@ Model Info:
|
|||
"budget_alerts",
|
||||
"db_exceptions",
|
||||
"daily_reports",
|
||||
"spend_reports",
|
||||
"new_model_added",
|
||||
"cooldown_deployment",
|
||||
],
|
||||
user_info: Optional[WebhookEvent] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
@ -748,6 +811,19 @@ Model Info:
|
|||
if self.alerting is None:
|
||||
return
|
||||
|
||||
if (
|
||||
"webhook" in self.alerting
|
||||
and alert_type == "budget_alerts"
|
||||
and user_info is not None
|
||||
):
|
||||
await self.send_webhook_alert(webhook_event=user_info)
|
||||
|
||||
if "slack" not in self.alerting:
|
||||
return
|
||||
|
||||
if alert_type not in self.alert_types:
|
||||
return
|
||||
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
|
@ -795,27 +871,37 @@ Model Info:
|
|||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""Log deployment latency"""
|
||||
if "daily_reports" in self.alert_types:
|
||||
model_id = (
|
||||
kwargs.get("litellm_params", {}).get("model_info", {}).get("id", "")
|
||||
)
|
||||
response_s: timedelta = end_time - start_time
|
||||
|
||||
final_value = response_s
|
||||
total_tokens = 0
|
||||
|
||||
if isinstance(response_obj, litellm.ModelResponse):
|
||||
completion_tokens = response_obj.usage.completion_tokens
|
||||
final_value = float(response_s.total_seconds() / completion_tokens)
|
||||
|
||||
await self.async_update_daily_reports(
|
||||
DeploymentMetrics(
|
||||
id=model_id,
|
||||
failed_request=False,
|
||||
latency_per_output_token=final_value,
|
||||
updated_at=litellm.utils.get_utc_datetime(),
|
||||
try:
|
||||
if "daily_reports" in self.alert_types:
|
||||
model_id = (
|
||||
kwargs.get("litellm_params", {}).get("model_info", {}).get("id", "")
|
||||
)
|
||||
response_s: timedelta = end_time - start_time
|
||||
|
||||
final_value = response_s
|
||||
total_tokens = 0
|
||||
|
||||
if isinstance(response_obj, litellm.ModelResponse):
|
||||
completion_tokens = response_obj.usage.completion_tokens
|
||||
if completion_tokens is not None and completion_tokens > 0:
|
||||
final_value = float(
|
||||
response_s.total_seconds() / completion_tokens
|
||||
)
|
||||
|
||||
await self.async_update_daily_reports(
|
||||
DeploymentMetrics(
|
||||
id=model_id,
|
||||
failed_request=False,
|
||||
latency_per_output_token=final_value,
|
||||
updated_at=litellm.utils.get_utc_datetime(),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"[Non-Blocking Error] Slack Alerting: Got error in logging LLM deployment latency: ",
|
||||
e,
|
||||
)
|
||||
pass
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""Log failure + deployment latency"""
|
||||
|
@ -942,7 +1028,7 @@ Model Info:
|
|||
await self.send_alert(
|
||||
message=_weekly_spend_message,
|
||||
level="Low",
|
||||
alert_type="daily_reports",
|
||||
alert_type="spend_reports",
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error("Error sending weekly spend report", e)
|
||||
|
@ -993,7 +1079,7 @@ Model Info:
|
|||
await self.send_alert(
|
||||
message=_spend_message,
|
||||
level="Low",
|
||||
alert_type="daily_reports",
|
||||
alert_type="spend_reports",
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error("Error sending weekly spend report", e)
|
||||
|
|
|
@ -93,6 +93,7 @@ class AnthropicConfig:
|
|||
"max_tokens",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"extra_headers",
|
||||
]
|
||||
|
||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||
|
@ -504,7 +505,9 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
## Handle Tool Calling
|
||||
if "tools" in optional_params:
|
||||
_is_function_call = True
|
||||
headers["anthropic-beta"] = "tools-2024-04-04"
|
||||
if "anthropic-beta" not in headers:
|
||||
# default to v1 of "anthropic-beta"
|
||||
headers["anthropic-beta"] = "tools-2024-05-16"
|
||||
|
||||
anthropic_tools = []
|
||||
for tool in optional_params["tools"]:
|
||||
|
|
|
@ -21,7 +21,7 @@ class BaseLLM:
|
|||
messages: list,
|
||||
print_verbose,
|
||||
encoding,
|
||||
) -> litellm.utils.ModelResponse:
|
||||
) -> Union[litellm.utils.ModelResponse, litellm.utils.CustomStreamWrapper]:
|
||||
"""
|
||||
Helper function to process the response across sync + async completion calls
|
||||
"""
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# What is this?
|
||||
## Initial implementation of calling bedrock via httpx client (allows for async calls).
|
||||
## V0 - just covers cohere command-r support
|
||||
## V1 - covers cohere + anthropic claude-3 support
|
||||
|
||||
import os, types
|
||||
import json
|
||||
|
@ -29,12 +29,20 @@ from litellm.utils import (
|
|||
get_secret,
|
||||
Logging,
|
||||
)
|
||||
import litellm
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt, cohere_message_pt
|
||||
import litellm, uuid
|
||||
from .prompt_templates.factory import (
|
||||
prompt_factory,
|
||||
custom_prompt,
|
||||
cohere_message_pt,
|
||||
construct_tool_use_system_prompt,
|
||||
extract_between_tags,
|
||||
parse_xml_params,
|
||||
contains_tag,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from .base import BaseLLM
|
||||
import httpx # type: ignore
|
||||
from .bedrock import BedrockError, convert_messages_to_prompt
|
||||
from .bedrock import BedrockError, convert_messages_to_prompt, ModelResponseIterator
|
||||
from litellm.types.llms.bedrock import *
|
||||
|
||||
|
||||
|
@ -280,7 +288,8 @@ class BedrockLLM(BaseLLM):
|
|||
messages: List,
|
||||
print_verbose,
|
||||
encoding,
|
||||
) -> ModelResponse:
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
provider = model.split(".")[0]
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
|
@ -297,26 +306,210 @@ class BedrockLLM(BaseLLM):
|
|||
raise BedrockError(message=response.text, status_code=422)
|
||||
|
||||
try:
|
||||
model_response.choices[0].message.content = completion_response["text"] # type: ignore
|
||||
if provider == "cohere":
|
||||
if "text" in completion_response:
|
||||
outputText = completion_response["text"] # type: ignore
|
||||
elif "generations" in completion_response:
|
||||
outputText = completion_response["generations"][0]["text"]
|
||||
model_response["finish_reason"] = map_finish_reason(
|
||||
completion_response["generations"][0]["finish_reason"]
|
||||
)
|
||||
elif provider == "anthropic":
|
||||
if model.startswith("anthropic.claude-3"):
|
||||
json_schemas: dict = {}
|
||||
_is_function_call = False
|
||||
## Handle Tool Calling
|
||||
if "tools" in optional_params:
|
||||
_is_function_call = True
|
||||
for tool in optional_params["tools"]:
|
||||
json_schemas[tool["function"]["name"]] = tool[
|
||||
"function"
|
||||
].get("parameters", None)
|
||||
outputText = completion_response.get("content")[0].get("text", None)
|
||||
if outputText is not None and contains_tag(
|
||||
"invoke", outputText
|
||||
): # OUTPUT PARSE FUNCTION CALL
|
||||
function_name = extract_between_tags("tool_name", outputText)[0]
|
||||
function_arguments_str = extract_between_tags(
|
||||
"invoke", outputText
|
||||
)[0].strip()
|
||||
function_arguments_str = (
|
||||
f"<invoke>{function_arguments_str}</invoke>"
|
||||
)
|
||||
function_arguments = parse_xml_params(
|
||||
function_arguments_str,
|
||||
json_schema=json_schemas.get(
|
||||
function_name, None
|
||||
), # check if we have a json schema for this function name)
|
||||
)
|
||||
_message = litellm.Message(
|
||||
tool_calls=[
|
||||
{
|
||||
"id": f"call_{uuid.uuid4()}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(function_arguments),
|
||||
},
|
||||
}
|
||||
],
|
||||
content=None,
|
||||
)
|
||||
model_response.choices[0].message = _message # type: ignore
|
||||
model_response._hidden_params["original_response"] = (
|
||||
outputText # allow user to access raw anthropic tool calling response
|
||||
)
|
||||
if (
|
||||
_is_function_call == True
|
||||
and stream is not None
|
||||
and stream == True
|
||||
):
|
||||
print_verbose(
|
||||
f"INSIDE BEDROCK STREAMING TOOL CALLING CONDITION BLOCK"
|
||||
)
|
||||
# return an iterator
|
||||
streaming_model_response = ModelResponse(stream=True)
|
||||
streaming_model_response.choices[0].finish_reason = getattr(
|
||||
model_response.choices[0], "finish_reason", "stop"
|
||||
)
|
||||
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
|
||||
streaming_choice = litellm.utils.StreamingChoices()
|
||||
streaming_choice.index = model_response.choices[0].index
|
||||
_tool_calls = []
|
||||
print_verbose(
|
||||
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
|
||||
)
|
||||
print_verbose(
|
||||
f"type of streaming_choice: {type(streaming_choice)}"
|
||||
)
|
||||
if isinstance(model_response.choices[0], litellm.Choices):
|
||||
if getattr(
|
||||
model_response.choices[0].message, "tool_calls", None
|
||||
) is not None and isinstance(
|
||||
model_response.choices[0].message.tool_calls, list
|
||||
):
|
||||
for tool_call in model_response.choices[
|
||||
0
|
||||
].message.tool_calls:
|
||||
_tool_call = {**tool_call.dict(), "index": 0}
|
||||
_tool_calls.append(_tool_call)
|
||||
delta_obj = litellm.utils.Delta(
|
||||
content=getattr(
|
||||
model_response.choices[0].message, "content", None
|
||||
),
|
||||
role=model_response.choices[0].message.role,
|
||||
tool_calls=_tool_calls,
|
||||
)
|
||||
streaming_choice.delta = delta_obj
|
||||
streaming_model_response.choices = [streaming_choice]
|
||||
completion_stream = ModelResponseIterator(
|
||||
model_response=streaming_model_response
|
||||
)
|
||||
print_verbose(
|
||||
f"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
|
||||
)
|
||||
return litellm.CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="cached_response",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
model_response["finish_reason"] = map_finish_reason(
|
||||
completion_response.get("stop_reason", "")
|
||||
)
|
||||
_usage = litellm.Usage(
|
||||
prompt_tokens=completion_response["usage"]["input_tokens"],
|
||||
completion_tokens=completion_response["usage"]["output_tokens"],
|
||||
total_tokens=completion_response["usage"]["input_tokens"]
|
||||
+ completion_response["usage"]["output_tokens"],
|
||||
)
|
||||
setattr(model_response, "usage", _usage)
|
||||
else:
|
||||
outputText = completion_response["completion"]
|
||||
|
||||
model_response["finish_reason"] = completion_response["stop_reason"]
|
||||
elif provider == "ai21":
|
||||
outputText = (
|
||||
completion_response.get("completions")[0].get("data").get("text")
|
||||
)
|
||||
elif provider == "meta":
|
||||
outputText = completion_response["generation"]
|
||||
elif provider == "mistral":
|
||||
outputText = completion_response["outputs"][0]["text"]
|
||||
model_response["finish_reason"] = completion_response["outputs"][0][
|
||||
"stop_reason"
|
||||
]
|
||||
else: # amazon titan
|
||||
outputText = completion_response.get("results")[0].get("outputText")
|
||||
except Exception as e:
|
||||
raise BedrockError(message=response.text, status_code=422)
|
||||
raise BedrockError(
|
||||
message="Error processing={}, Received error={}".format(
|
||||
response.text, str(e)
|
||||
),
|
||||
status_code=422,
|
||||
)
|
||||
|
||||
try:
|
||||
if (
|
||||
len(outputText) > 0
|
||||
and hasattr(model_response.choices[0], "message")
|
||||
and getattr(model_response.choices[0].message, "tool_calls", None)
|
||||
is None
|
||||
):
|
||||
model_response["choices"][0]["message"]["content"] = outputText
|
||||
elif (
|
||||
hasattr(model_response.choices[0], "message")
|
||||
and getattr(model_response.choices[0].message, "tool_calls", None)
|
||||
is not None
|
||||
):
|
||||
pass
|
||||
else:
|
||||
raise Exception()
|
||||
except:
|
||||
raise BedrockError(
|
||||
message=json.dumps(outputText), status_code=response.status_code
|
||||
)
|
||||
|
||||
if stream and provider == "ai21":
|
||||
streaming_model_response = ModelResponse(stream=True)
|
||||
streaming_model_response.choices[0].finish_reason = model_response.choices[ # type: ignore
|
||||
0
|
||||
].finish_reason
|
||||
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
|
||||
streaming_choice = litellm.utils.StreamingChoices()
|
||||
streaming_choice.index = model_response.choices[0].index
|
||||
delta_obj = litellm.utils.Delta(
|
||||
content=getattr(model_response.choices[0].message, "content", None),
|
||||
role=model_response.choices[0].message.role,
|
||||
)
|
||||
streaming_choice.delta = delta_obj
|
||||
streaming_model_response.choices = [streaming_choice]
|
||||
mri = ModelResponseIterator(model_response=streaming_model_response)
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=mri,
|
||||
model=model,
|
||||
custom_llm_provider="cached_response",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
## CALCULATING USAGE - bedrock returns usage in the headers
|
||||
prompt_tokens = int(
|
||||
response.headers.get(
|
||||
"x-amzn-bedrock-input-token-count",
|
||||
len(encoding.encode("".join(m.get("content", "") for m in messages))),
|
||||
)
|
||||
bedrock_input_tokens = response.headers.get(
|
||||
"x-amzn-bedrock-input-token-count", None
|
||||
)
|
||||
bedrock_output_tokens = response.headers.get(
|
||||
"x-amzn-bedrock-output-token-count", None
|
||||
)
|
||||
|
||||
prompt_tokens = int(
|
||||
bedrock_input_tokens or litellm.token_counter(messages=messages)
|
||||
)
|
||||
|
||||
completion_tokens = int(
|
||||
response.headers.get(
|
||||
"x-amzn-bedrock-output-token-count",
|
||||
len(
|
||||
encoding.encode(
|
||||
model_response.choices[0].message.content, # type: ignore
|
||||
disallowed_special=(),
|
||||
)
|
||||
),
|
||||
bedrock_output_tokens
|
||||
or litellm.token_counter(
|
||||
text=model_response.choices[0].message.content, # type: ignore
|
||||
count_response_tokens=True,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -359,6 +552,7 @@ class BedrockLLM(BaseLLM):
|
|||
|
||||
## SETUP ##
|
||||
stream = optional_params.pop("stream", None)
|
||||
provider = model.split(".")[0]
|
||||
|
||||
## CREDENTIALS ##
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||
|
@ -414,19 +608,18 @@ class BedrockLLM(BaseLLM):
|
|||
else:
|
||||
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
||||
|
||||
if stream is not None and stream == True:
|
||||
if (stream is not None and stream == True) and provider != "ai21":
|
||||
endpoint_url = f"{endpoint_url}/model/{model}/invoke-with-response-stream"
|
||||
else:
|
||||
endpoint_url = f"{endpoint_url}/model/{model}/invoke"
|
||||
|
||||
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||
|
||||
provider = model.split(".")[0]
|
||||
prompt, chat_history = self.convert_messages_to_prompt(
|
||||
model, messages, provider, custom_prompt_dict
|
||||
)
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
|
||||
json_schemas: dict = {}
|
||||
if provider == "cohere":
|
||||
if model.startswith("cohere.command-r"):
|
||||
## LOAD CONFIG
|
||||
|
@ -453,8 +646,114 @@ class BedrockLLM(BaseLLM):
|
|||
True # cohere requires stream = True in inference params
|
||||
)
|
||||
data = json.dumps({"prompt": prompt, **inference_params})
|
||||
elif provider == "anthropic":
|
||||
if model.startswith("anthropic.claude-3"):
|
||||
# Separate system prompt from rest of message
|
||||
system_prompt_idx: list[int] = []
|
||||
system_messages: list[str] = []
|
||||
for idx, message in enumerate(messages):
|
||||
if message["role"] == "system":
|
||||
system_messages.append(message["content"])
|
||||
system_prompt_idx.append(idx)
|
||||
if len(system_prompt_idx) > 0:
|
||||
inference_params["system"] = "\n".join(system_messages)
|
||||
messages = [
|
||||
i for j, i in enumerate(messages) if j not in system_prompt_idx
|
||||
]
|
||||
# Format rest of message according to anthropic guidelines
|
||||
messages = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="anthropic_xml"
|
||||
) # type: ignore
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonAnthropicClaude3Config.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
## Handle Tool Calling
|
||||
if "tools" in inference_params:
|
||||
_is_function_call = True
|
||||
for tool in inference_params["tools"]:
|
||||
json_schemas[tool["function"]["name"]] = tool["function"].get(
|
||||
"parameters", None
|
||||
)
|
||||
tool_calling_system_prompt = construct_tool_use_system_prompt(
|
||||
tools=inference_params["tools"]
|
||||
)
|
||||
inference_params["system"] = (
|
||||
inference_params.get("system", "\n")
|
||||
+ tool_calling_system_prompt
|
||||
) # add the anthropic tool calling prompt to the system prompt
|
||||
inference_params.pop("tools")
|
||||
data = json.dumps({"messages": messages, **inference_params})
|
||||
else:
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonAnthropicConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
data = json.dumps({"prompt": prompt, **inference_params})
|
||||
elif provider == "ai21":
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonAI21Config.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
|
||||
data = json.dumps({"prompt": prompt, **inference_params})
|
||||
elif provider == "mistral":
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonMistralConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
|
||||
data = json.dumps({"prompt": prompt, **inference_params})
|
||||
elif provider == "amazon": # amazon titan
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonTitanConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
|
||||
data = json.dumps(
|
||||
{
|
||||
"inputText": prompt,
|
||||
"textGenerationConfig": inference_params,
|
||||
}
|
||||
)
|
||||
elif provider == "meta":
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonLlamaConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
data = json.dumps({"prompt": prompt, **inference_params})
|
||||
else:
|
||||
raise Exception("UNSUPPORTED PROVIDER")
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": inference_params,
|
||||
},
|
||||
)
|
||||
raise Exception(
|
||||
"Bedrock HTTPX: Unsupported provider={}, model={}".format(
|
||||
provider, model
|
||||
)
|
||||
)
|
||||
|
||||
## COMPLETION CALL
|
||||
|
||||
|
@ -482,7 +781,7 @@ class BedrockLLM(BaseLLM):
|
|||
if acompletion:
|
||||
if isinstance(client, HTTPHandler):
|
||||
client = None
|
||||
if stream:
|
||||
if stream == True and provider != "ai21":
|
||||
return self.async_streaming(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -511,7 +810,7 @@ class BedrockLLM(BaseLLM):
|
|||
encoding=encoding,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=False,
|
||||
stream=stream, # type: ignore
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=prepped.headers,
|
||||
|
@ -528,7 +827,7 @@ class BedrockLLM(BaseLLM):
|
|||
self.client = HTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
self.client = client
|
||||
if stream is not None and stream == True:
|
||||
if (stream is not None and stream == True) and provider != "ai21":
|
||||
response = self.client.post(
|
||||
url=prepped.url,
|
||||
headers=prepped.headers, # type: ignore
|
||||
|
@ -541,7 +840,7 @@ class BedrockLLM(BaseLLM):
|
|||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
decoder = AWSEventStreamDecoder()
|
||||
decoder = AWSEventStreamDecoder(model=model)
|
||||
|
||||
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
|
||||
streaming_response = CustomStreamWrapper(
|
||||
|
@ -550,15 +849,24 @@ class BedrockLLM(BaseLLM):
|
|||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response=streaming_response,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
return streaming_response
|
||||
|
||||
response = self.client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
|
||||
|
||||
try:
|
||||
response = self.client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=response.text)
|
||||
except httpx.TimeoutException as e:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
return self.process_response(
|
||||
model=model,
|
||||
|
@ -591,7 +899,7 @@ class BedrockLLM(BaseLLM):
|
|||
logger_fn=None,
|
||||
headers={},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> ModelResponse:
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
if client is None:
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
|
@ -602,12 +910,20 @@ class BedrockLLM(BaseLLM):
|
|||
else:
|
||||
self.client = client # type: ignore
|
||||
|
||||
response = await self.client.post(api_base, headers=headers, data=data) # type: ignore
|
||||
try:
|
||||
response = await self.client.post(api_base, headers=headers, data=data) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=response.text)
|
||||
except httpx.TimeoutException as e:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
return self.process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
stream=stream,
|
||||
stream=stream if isinstance(stream, bool) else False,
|
||||
logging_obj=logging_obj,
|
||||
api_key="",
|
||||
data=data,
|
||||
|
@ -650,7 +966,7 @@ class BedrockLLM(BaseLLM):
|
|||
if response.status_code != 200:
|
||||
raise BedrockError(status_code=response.status_code, message=response.text)
|
||||
|
||||
decoder = AWSEventStreamDecoder()
|
||||
decoder = AWSEventStreamDecoder(model=model)
|
||||
|
||||
completion_stream = decoder.aiter_bytes(response.aiter_bytes(chunk_size=1024))
|
||||
streaming_response = CustomStreamWrapper(
|
||||
|
@ -659,6 +975,15 @@ class BedrockLLM(BaseLLM):
|
|||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response=streaming_response,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return streaming_response
|
||||
|
||||
def embedding(self, *args, **kwargs):
|
||||
|
@ -676,11 +1001,70 @@ def get_response_stream_shape():
|
|||
|
||||
|
||||
class AWSEventStreamDecoder:
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, model: str) -> None:
|
||||
from botocore.parsers import EventStreamJSONParser
|
||||
|
||||
self.model = model
|
||||
self.parser = EventStreamJSONParser()
|
||||
|
||||
def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
|
||||
text = ""
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
if "outputText" in chunk_data:
|
||||
text = chunk_data["outputText"]
|
||||
# ai21 mapping
|
||||
if "ai21" in self.model: # fake ai21 streaming
|
||||
text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore
|
||||
is_finished = True
|
||||
finish_reason = "stop"
|
||||
######## bedrock.anthropic mappings ###############
|
||||
elif "completion" in chunk_data: # not claude-3
|
||||
text = chunk_data["completion"] # bedrock.anthropic
|
||||
stop_reason = chunk_data.get("stop_reason", None)
|
||||
if stop_reason != None:
|
||||
is_finished = True
|
||||
finish_reason = stop_reason
|
||||
elif "delta" in chunk_data:
|
||||
if chunk_data["delta"].get("text", None) is not None:
|
||||
text = chunk_data["delta"]["text"]
|
||||
stop_reason = chunk_data["delta"].get("stop_reason", None)
|
||||
if stop_reason != None:
|
||||
is_finished = True
|
||||
finish_reason = stop_reason
|
||||
######## bedrock.mistral mappings ###############
|
||||
elif "outputs" in chunk_data:
|
||||
if (
|
||||
len(chunk_data["outputs"]) == 1
|
||||
and chunk_data["outputs"][0].get("text", None) is not None
|
||||
):
|
||||
text = chunk_data["outputs"][0]["text"]
|
||||
stop_reason = chunk_data.get("stop_reason", None)
|
||||
if stop_reason != None:
|
||||
is_finished = True
|
||||
finish_reason = stop_reason
|
||||
######## bedrock.cohere mappings ###############
|
||||
# meta mapping
|
||||
elif "generation" in chunk_data:
|
||||
text = chunk_data["generation"] # bedrock.meta
|
||||
# cohere mapping
|
||||
elif "text" in chunk_data:
|
||||
text = chunk_data["text"] # bedrock.cohere
|
||||
# cohere mapping for finish reason
|
||||
elif "finish_reason" in chunk_data:
|
||||
finish_reason = chunk_data["finish_reason"]
|
||||
is_finished = True
|
||||
elif chunk_data.get("completionReason", None):
|
||||
is_finished = True
|
||||
finish_reason = chunk_data["completionReason"]
|
||||
return GenericStreamingChunk(
|
||||
**{
|
||||
"text": text,
|
||||
"is_finished": is_finished,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
)
|
||||
|
||||
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]:
|
||||
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
|
||||
from botocore.eventstream import EventStreamBuffer
|
||||
|
@ -693,12 +1077,7 @@ class AWSEventStreamDecoder:
|
|||
if message:
|
||||
# sse_event = ServerSentEvent(data=message, event="completion")
|
||||
_data = json.loads(message)
|
||||
streaming_chunk: GenericStreamingChunk = GenericStreamingChunk(
|
||||
text=_data.get("text", ""),
|
||||
is_finished=_data.get("is_finished", False),
|
||||
finish_reason=_data.get("finish_reason", ""),
|
||||
)
|
||||
yield streaming_chunk
|
||||
yield self._chunk_parser(chunk_data=_data)
|
||||
|
||||
async def aiter_bytes(
|
||||
self, iterator: AsyncIterator[bytes]
|
||||
|
@ -713,12 +1092,7 @@ class AWSEventStreamDecoder:
|
|||
message = self._parse_message_from_event(event)
|
||||
if message:
|
||||
_data = json.loads(message)
|
||||
streaming_chunk: GenericStreamingChunk = GenericStreamingChunk(
|
||||
text=_data.get("text", ""),
|
||||
is_finished=_data.get("is_finished", False),
|
||||
finish_reason=_data.get("finish_reason", ""),
|
||||
)
|
||||
yield streaming_chunk
|
||||
yield self._chunk_parser(chunk_data=_data)
|
||||
|
||||
def _parse_message_from_event(self, event) -> Optional[str]:
|
||||
response_dict = event.to_response_dict()
|
||||
|
|
|
@ -260,7 +260,7 @@ def completion(
|
|||
message_obj = Message(content=item.content.parts[0].text)
|
||||
else:
|
||||
message_obj = Message(content=None)
|
||||
choice_obj = Choices(index=idx + 1, message=message_obj)
|
||||
choice_obj = Choices(index=idx, message=message_obj)
|
||||
choices_list.append(choice_obj)
|
||||
model_response["choices"] = choices_list
|
||||
except Exception as e:
|
||||
|
@ -352,7 +352,7 @@ async def async_completion(
|
|||
message_obj = Message(content=item.content.parts[0].text)
|
||||
else:
|
||||
message_obj = Message(content=None)
|
||||
choice_obj = Choices(index=idx + 1, message=message_obj)
|
||||
choice_obj = Choices(index=idx, message=message_obj)
|
||||
choices_list.append(choice_obj)
|
||||
model_response["choices"] = choices_list
|
||||
except Exception as e:
|
||||
|
|
|
@ -96,7 +96,7 @@ class MistralConfig:
|
|||
safe_prompt: Optional[bool] = None,
|
||||
response_format: Optional[dict] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
@ -211,7 +211,7 @@ class OpenAIConfig:
|
|||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
@ -234,6 +234,47 @@ class OpenAIConfig:
|
|||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
base_params = [
|
||||
"frequency_penalty",
|
||||
"logit_bias",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"max_tokens",
|
||||
"n",
|
||||
"presence_penalty",
|
||||
"seed",
|
||||
"stop",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"user",
|
||||
"function_call",
|
||||
"functions",
|
||||
"max_retries",
|
||||
"extra_headers",
|
||||
] # works across all models
|
||||
|
||||
model_specific_params = []
|
||||
if (
|
||||
model != "gpt-3.5-turbo-16k" and model != "gpt-4"
|
||||
): # gpt-4 does not support 'response_format'
|
||||
model_specific_params.append("response_format")
|
||||
|
||||
return base_params + model_specific_params
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict, model: str
|
||||
) -> dict:
|
||||
supported_openai_params = self.get_supported_openai_params(model)
|
||||
for param, value in non_default_params.items():
|
||||
if param in supported_openai_params:
|
||||
optional_params[param] = value
|
||||
return optional_params
|
||||
|
||||
|
||||
class OpenAITextCompletionConfig:
|
||||
"""
|
||||
|
@ -294,7 +335,7 @@ class OpenAITextCompletionConfig:
|
|||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
|
|
@ -12,6 +12,7 @@ from typing import (
|
|||
Sequence,
|
||||
)
|
||||
import litellm
|
||||
import litellm.types
|
||||
from litellm.types.completion import (
|
||||
ChatCompletionUserMessageParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
|
@ -20,9 +21,12 @@ from litellm.types.completion import (
|
|||
ChatCompletionMessageToolCallParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
)
|
||||
import litellm.types.llms
|
||||
from litellm.types.llms.anthropic import *
|
||||
import uuid
|
||||
|
||||
import litellm.types.llms.vertex_ai
|
||||
|
||||
|
||||
def default_pt(messages):
|
||||
return " ".join(message["content"] for message in messages)
|
||||
|
@ -841,6 +845,175 @@ def anthropic_messages_pt_xml(messages: list):
|
|||
# ------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def infer_protocol_value(
|
||||
value: Any,
|
||||
) -> Literal[
|
||||
"string_value",
|
||||
"number_value",
|
||||
"bool_value",
|
||||
"struct_value",
|
||||
"list_value",
|
||||
"null_value",
|
||||
"unknown",
|
||||
]:
|
||||
if value is None:
|
||||
return "null_value"
|
||||
if isinstance(value, int) or isinstance(value, float):
|
||||
return "number_value"
|
||||
if isinstance(value, str):
|
||||
return "string_value"
|
||||
if isinstance(value, bool):
|
||||
return "bool_value"
|
||||
if isinstance(value, dict):
|
||||
return "struct_value"
|
||||
if isinstance(value, list):
|
||||
return "list_value"
|
||||
|
||||
return "unknown"
|
||||
|
||||
|
||||
def convert_to_gemini_tool_call_invoke(
|
||||
tool_calls: list,
|
||||
) -> List[litellm.types.llms.vertex_ai.PartType]:
|
||||
"""
|
||||
OpenAI tool invokes:
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_abc123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"arguments": "{\n\"location\": \"Boston, MA\"\n}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"""
|
||||
"""
|
||||
Gemini tool call invokes: - https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/function-calling#submit-api-output
|
||||
content {
|
||||
role: "model"
|
||||
parts [
|
||||
{
|
||||
function_call {
|
||||
name: "get_current_weather"
|
||||
args {
|
||||
fields {
|
||||
key: "unit"
|
||||
value {
|
||||
string_value: "fahrenheit"
|
||||
}
|
||||
}
|
||||
fields {
|
||||
key: "predicted_temperature"
|
||||
value {
|
||||
number_value: 45
|
||||
}
|
||||
}
|
||||
fields {
|
||||
key: "location"
|
||||
value {
|
||||
string_value: "Boston, MA"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
function_call {
|
||||
name: "get_current_weather"
|
||||
args {
|
||||
fields {
|
||||
key: "location"
|
||||
value {
|
||||
string_value: "San Francisco"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
"""
|
||||
- json.load the arguments
|
||||
- iterate through arguments -> create a FunctionCallArgs for each field
|
||||
"""
|
||||
try:
|
||||
_parts_list: List[litellm.types.llms.vertex_ai.PartType] = []
|
||||
for tool in tool_calls:
|
||||
if "function" in tool:
|
||||
name = tool["function"].get("name", "")
|
||||
arguments = tool["function"].get("arguments", "")
|
||||
arguments_dict = json.loads(arguments)
|
||||
for k, v in arguments_dict.items():
|
||||
inferred_protocol_value = infer_protocol_value(value=v)
|
||||
_field = litellm.types.llms.vertex_ai.Field(
|
||||
key=k, value={inferred_protocol_value: v}
|
||||
)
|
||||
_fields = litellm.types.llms.vertex_ai.FunctionCallArgs(
|
||||
fields=_field
|
||||
)
|
||||
function_call = litellm.types.llms.vertex_ai.FunctionCall(
|
||||
name=name,
|
||||
args=_fields,
|
||||
)
|
||||
_parts_list.append(
|
||||
litellm.types.llms.vertex_ai.PartType(function_call=function_call)
|
||||
)
|
||||
return _parts_list
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
"Unable to convert openai tool calls={} to gemini tool calls. Received error={}".format(
|
||||
tool_calls, str(e)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def convert_to_gemini_tool_call_result(
|
||||
message: dict,
|
||||
) -> litellm.types.llms.vertex_ai.PartType:
|
||||
"""
|
||||
OpenAI message with a tool result looks like:
|
||||
{
|
||||
"tool_call_id": "tool_1",
|
||||
"role": "tool",
|
||||
"name": "get_current_weather",
|
||||
"content": "function result goes here",
|
||||
},
|
||||
|
||||
OpenAI message with a function call result looks like:
|
||||
{
|
||||
"role": "function",
|
||||
"name": "get_current_weather",
|
||||
"content": "function result goes here",
|
||||
}
|
||||
"""
|
||||
content = message.get("content", "")
|
||||
name = message.get("name", "")
|
||||
|
||||
# We can't determine from openai message format whether it's a successful or
|
||||
# error call result so default to the successful result template
|
||||
inferred_content_value = infer_protocol_value(value=content)
|
||||
|
||||
_field = litellm.types.llms.vertex_ai.Field(
|
||||
key="content", value={inferred_content_value: content}
|
||||
)
|
||||
|
||||
_function_call_args = litellm.types.llms.vertex_ai.FunctionCallArgs(fields=_field)
|
||||
|
||||
_function_response = litellm.types.llms.vertex_ai.FunctionResponse(
|
||||
name=name, response=_function_call_args
|
||||
)
|
||||
|
||||
_part = litellm.types.llms.vertex_ai.PartType(function_response=_function_response)
|
||||
|
||||
return _part
|
||||
|
||||
|
||||
def convert_to_anthropic_tool_result(message: dict) -> dict:
|
||||
"""
|
||||
OpenAI message with a tool result looks like:
|
||||
|
@ -1328,6 +1501,7 @@ def _gemini_vision_convert_messages(messages: list):
|
|||
# Case 1: Image from URL
|
||||
image = _load_image_from_url(img)
|
||||
processed_images.append(image)
|
||||
|
||||
else:
|
||||
try:
|
||||
from PIL import Image
|
||||
|
@ -1335,8 +1509,22 @@ def _gemini_vision_convert_messages(messages: list):
|
|||
raise Exception(
|
||||
"gemini image conversion failed please run `pip install Pillow`"
|
||||
)
|
||||
# Case 2: Image filepath (e.g. temp.jpeg) given
|
||||
image = Image.open(img)
|
||||
|
||||
if "base64" in img:
|
||||
# Case 2: Base64 image data
|
||||
import base64
|
||||
import io
|
||||
# Extract the base64 image data
|
||||
base64_data = img.split("base64,")[1]
|
||||
|
||||
# Decode the base64 image data
|
||||
image_data = base64.b64decode(base64_data)
|
||||
|
||||
# Load the image from the decoded data
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
else:
|
||||
# Case 3: Image filepath (e.g. temp.jpeg) given
|
||||
image = Image.open(img)
|
||||
processed_images.append(image)
|
||||
content = [prompt] + processed_images
|
||||
return content
|
||||
|
|
|
@ -2,11 +2,12 @@ import os, types
|
|||
import json
|
||||
import requests # type: ignore
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
from litellm.utils import ModelResponse, Usage
|
||||
import litellm
|
||||
from typing import Callable, Optional, Union, Tuple, Any
|
||||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
|
||||
import litellm, asyncio
|
||||
import httpx # type: ignore
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
|
||||
|
||||
class ReplicateError(Exception):
|
||||
|
@ -145,6 +146,65 @@ def start_prediction(
|
|||
)
|
||||
|
||||
|
||||
async def async_start_prediction(
|
||||
version_id,
|
||||
input_data,
|
||||
api_token,
|
||||
api_base,
|
||||
logging_obj,
|
||||
print_verbose,
|
||||
http_handler: AsyncHTTPHandler,
|
||||
) -> str:
|
||||
base_url = api_base
|
||||
if "deployments" in version_id:
|
||||
print_verbose("\nLiteLLM: Request to custom replicate deployment")
|
||||
version_id = version_id.replace("deployments/", "")
|
||||
base_url = f"https://api.replicate.com/v1/deployments/{version_id}"
|
||||
print_verbose(f"Deployment base URL: {base_url}\n")
|
||||
else: # assume it's a model
|
||||
base_url = f"https://api.replicate.com/v1/models/{version_id}"
|
||||
headers = {
|
||||
"Authorization": f"Token {api_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
initial_prediction_data = {
|
||||
"input": input_data,
|
||||
}
|
||||
|
||||
if ":" in version_id and len(version_id) > 64:
|
||||
model_parts = version_id.split(":")
|
||||
if (
|
||||
len(model_parts) > 1 and len(model_parts[1]) == 64
|
||||
): ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"
|
||||
initial_prediction_data["version"] = model_parts[1]
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input_data["prompt"],
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": initial_prediction_data,
|
||||
"headers": headers,
|
||||
"api_base": base_url,
|
||||
},
|
||||
)
|
||||
|
||||
response = await http_handler.post(
|
||||
url="{}/predictions".format(base_url),
|
||||
data=json.dumps(initial_prediction_data),
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if response.status_code == 201:
|
||||
response_data = response.json()
|
||||
return response_data.get("urls", {}).get("get")
|
||||
else:
|
||||
raise ReplicateError(
|
||||
response.status_code, f"Failed to start prediction {response.text}"
|
||||
)
|
||||
|
||||
|
||||
# Function to handle prediction response (non-streaming)
|
||||
def handle_prediction_response(prediction_url, api_token, print_verbose):
|
||||
output_string = ""
|
||||
|
@ -178,6 +238,40 @@ def handle_prediction_response(prediction_url, api_token, print_verbose):
|
|||
return output_string, logs
|
||||
|
||||
|
||||
async def async_handle_prediction_response(
|
||||
prediction_url, api_token, print_verbose, http_handler: AsyncHTTPHandler
|
||||
) -> Tuple[str, Any]:
|
||||
output_string = ""
|
||||
headers = {
|
||||
"Authorization": f"Token {api_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
status = ""
|
||||
logs = ""
|
||||
while True and (status not in ["succeeded", "failed", "canceled"]):
|
||||
print_verbose(f"replicate: polling endpoint: {prediction_url}")
|
||||
await asyncio.sleep(0.5)
|
||||
response = await http_handler.get(prediction_url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
if "output" in response_data:
|
||||
output_string = "".join(response_data["output"])
|
||||
print_verbose(f"Non-streamed output:{output_string}")
|
||||
status = response_data.get("status", None)
|
||||
logs = response_data.get("logs", "")
|
||||
if status == "failed":
|
||||
replicate_error = response_data.get("error", "")
|
||||
raise ReplicateError(
|
||||
status_code=400,
|
||||
message=f"Error: {replicate_error}, \nReplicate logs:{logs}",
|
||||
)
|
||||
else:
|
||||
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
|
||||
print_verbose("Replicate: Failed to fetch prediction status and output.")
|
||||
return output_string, logs
|
||||
|
||||
|
||||
# Function to handle prediction response (streaming)
|
||||
def handle_prediction_response_streaming(prediction_url, api_token, print_verbose):
|
||||
previous_output = ""
|
||||
|
@ -214,6 +308,45 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos
|
|||
)
|
||||
|
||||
|
||||
# Function to handle prediction response (streaming)
|
||||
async def async_handle_prediction_response_streaming(
|
||||
prediction_url, api_token, print_verbose
|
||||
):
|
||||
http_handler = AsyncHTTPHandler(concurrent_limit=1)
|
||||
previous_output = ""
|
||||
output_string = ""
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Token {api_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
status = ""
|
||||
while True and (status not in ["succeeded", "failed", "canceled"]):
|
||||
await asyncio.sleep(0.5) # prevent being rate limited by replicate
|
||||
print_verbose(f"replicate: polling endpoint: {prediction_url}")
|
||||
response = await http_handler.get(prediction_url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
status = response_data["status"]
|
||||
if "output" in response_data:
|
||||
output_string = "".join(response_data["output"])
|
||||
new_output = output_string[len(previous_output) :]
|
||||
print_verbose(f"New chunk: {new_output}")
|
||||
yield {"output": new_output, "status": status}
|
||||
previous_output = output_string
|
||||
status = response_data["status"]
|
||||
if status == "failed":
|
||||
replicate_error = response_data.get("error", "")
|
||||
raise ReplicateError(
|
||||
status_code=400, message=f"Error: {replicate_error}"
|
||||
)
|
||||
else:
|
||||
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
|
||||
print_verbose(
|
||||
f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}"
|
||||
)
|
||||
|
||||
|
||||
# Function to extract version ID from model string
|
||||
def model_to_version_id(model):
|
||||
if ":" in model:
|
||||
|
@ -222,6 +355,39 @@ def model_to_version_id(model):
|
|||
return model
|
||||
|
||||
|
||||
def process_response(
|
||||
model_response: ModelResponse,
|
||||
result: str,
|
||||
model: str,
|
||||
encoding: Any,
|
||||
prompt: str,
|
||||
) -> ModelResponse:
|
||||
if len(result) == 0: # edge case, where result from replicate is empty
|
||||
result = " "
|
||||
|
||||
## Building RESPONSE OBJECT
|
||||
if len(result) > 1:
|
||||
model_response["choices"][0]["message"]["content"] = result
|
||||
|
||||
# Calculate usage
|
||||
prompt_tokens = len(encoding.encode(prompt, disallowed_special=()))
|
||||
completion_tokens = len(
|
||||
encoding.encode(
|
||||
model_response["choices"][0]["message"].get("content", ""),
|
||||
disallowed_special=(),
|
||||
)
|
||||
)
|
||||
model_response["model"] = "replicate/" + model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
|
||||
return model_response
|
||||
|
||||
|
||||
# Main function for prediction completion
|
||||
def completion(
|
||||
model: str,
|
||||
|
@ -229,14 +395,15 @@ def completion(
|
|||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
optional_params: dict,
|
||||
logging_obj,
|
||||
api_key,
|
||||
encoding,
|
||||
custom_prompt_dict={},
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
):
|
||||
acompletion=None,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
# Start a prediction and get the prediction URL
|
||||
version_id = model_to_version_id(model)
|
||||
## Load Config
|
||||
|
@ -274,6 +441,12 @@ def completion(
|
|||
else:
|
||||
prompt = prompt_factory(model=model, messages=messages)
|
||||
|
||||
if prompt is None or not isinstance(prompt, str):
|
||||
raise ReplicateError(
|
||||
status_code=400,
|
||||
message="LiteLLM Error - prompt is not a string - {}".format(prompt),
|
||||
)
|
||||
|
||||
# If system prompt is supported, and a system prompt is provided, use it
|
||||
if system_prompt is not None:
|
||||
input_data = {
|
||||
|
@ -285,6 +458,20 @@ def completion(
|
|||
else:
|
||||
input_data = {"prompt": prompt, **optional_params}
|
||||
|
||||
if acompletion is not None and acompletion == True:
|
||||
return async_completion(
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
encoding=encoding,
|
||||
optional_params=optional_params,
|
||||
version_id=version_id,
|
||||
input_data=input_data,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
logging_obj=logging_obj,
|
||||
print_verbose=print_verbose,
|
||||
) # type: ignore
|
||||
## COMPLETION CALL
|
||||
## Replicate Compeltion calls have 2 steps
|
||||
## Step1: Start Prediction: gets a prediction url
|
||||
|
@ -293,6 +480,7 @@ def completion(
|
|||
model_response["created"] = int(
|
||||
time.time()
|
||||
) # for pricing this must remain right before calling api
|
||||
|
||||
prediction_url = start_prediction(
|
||||
version_id,
|
||||
input_data,
|
||||
|
@ -306,9 +494,10 @@ def completion(
|
|||
# Handle the prediction response (streaming or non-streaming)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
print_verbose("streaming request")
|
||||
return handle_prediction_response_streaming(
|
||||
_response = handle_prediction_response_streaming(
|
||||
prediction_url, api_key, print_verbose
|
||||
)
|
||||
return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore
|
||||
else:
|
||||
result, logs = handle_prediction_response(
|
||||
prediction_url, api_key, print_verbose
|
||||
|
@ -328,29 +517,56 @@ def completion(
|
|||
|
||||
print_verbose(f"raw model_response: {result}")
|
||||
|
||||
if len(result) == 0: # edge case, where result from replicate is empty
|
||||
result = " "
|
||||
|
||||
## Building RESPONSE OBJECT
|
||||
if len(result) > 1:
|
||||
model_response["choices"][0]["message"]["content"] = result
|
||||
|
||||
# Calculate usage
|
||||
prompt_tokens = len(encoding.encode(prompt, disallowed_special=()))
|
||||
completion_tokens = len(
|
||||
encoding.encode(
|
||||
model_response["choices"][0]["message"].get("content", ""),
|
||||
disallowed_special=(),
|
||||
)
|
||||
return process_response(
|
||||
model_response=model_response,
|
||||
result=result,
|
||||
model=model,
|
||||
encoding=encoding,
|
||||
prompt=prompt,
|
||||
)
|
||||
model_response["model"] = "replicate/" + model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
|
||||
|
||||
async def async_completion(
|
||||
model_response: ModelResponse,
|
||||
model: str,
|
||||
prompt: str,
|
||||
encoding,
|
||||
optional_params: dict,
|
||||
version_id,
|
||||
input_data,
|
||||
api_key,
|
||||
api_base,
|
||||
logging_obj,
|
||||
print_verbose,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
http_handler = AsyncHTTPHandler(concurrent_limit=1)
|
||||
prediction_url = await async_start_prediction(
|
||||
version_id,
|
||||
input_data,
|
||||
api_key,
|
||||
api_base,
|
||||
logging_obj=logging_obj,
|
||||
print_verbose=print_verbose,
|
||||
http_handler=http_handler,
|
||||
)
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
_response = async_handle_prediction_response_streaming(
|
||||
prediction_url, api_key, print_verbose
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore
|
||||
|
||||
result, logs = await async_handle_prediction_response(
|
||||
prediction_url, api_key, print_verbose, http_handler=http_handler
|
||||
)
|
||||
|
||||
return process_response(
|
||||
model_response=model_response,
|
||||
result=result,
|
||||
model=model,
|
||||
encoding=encoding,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
|
||||
# # Example usage:
|
||||
|
|
|
@ -3,10 +3,15 @@ import json
|
|||
from enum import Enum
|
||||
import requests # type: ignore
|
||||
import time
|
||||
from typing import Callable, Optional, Union, List
|
||||
from typing import Callable, Optional, Union, List, Literal
|
||||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
|
||||
import litellm, uuid
|
||||
import httpx, inspect # type: ignore
|
||||
from litellm.types.llms.vertex_ai import *
|
||||
from litellm.llms.prompt_templates.factory import (
|
||||
convert_to_gemini_tool_call_result,
|
||||
convert_to_gemini_tool_call_invoke,
|
||||
)
|
||||
|
||||
|
||||
class VertexAIError(Exception):
|
||||
|
@ -283,6 +288,125 @@ def _load_image_from_url(image_url: str):
|
|||
return Image.from_bytes(data=image_bytes)
|
||||
|
||||
|
||||
def _convert_gemini_role(role: str) -> Literal["user", "model"]:
|
||||
if role == "user":
|
||||
return "user"
|
||||
else:
|
||||
return "model"
|
||||
|
||||
|
||||
def _process_gemini_image(image_url: str) -> PartType:
|
||||
try:
|
||||
if "gs://" in image_url:
|
||||
# Case 1: Images with Cloud Storage URIs
|
||||
# The supported MIME types for images include image/png and image/jpeg.
|
||||
part_mime = "image/png" if "png" in image_url else "image/jpeg"
|
||||
_file_data = FileDataType(mime_type=part_mime, file_uri=image_url)
|
||||
return PartType(file_data=_file_data)
|
||||
elif "https:/" in image_url:
|
||||
# Case 2: Images with direct links
|
||||
image = _load_image_from_url(image_url)
|
||||
_blob = BlobType(data=image.data, mime_type=image._mime_type)
|
||||
return PartType(inline_data=_blob)
|
||||
elif ".mp4" in image_url and "gs://" in image_url:
|
||||
# Case 3: Videos with Cloud Storage URIs
|
||||
part_mime = "video/mp4"
|
||||
_file_data = FileDataType(mime_type=part_mime, file_uri=image_url)
|
||||
return PartType(file_data=_file_data)
|
||||
elif "base64" in image_url:
|
||||
# Case 4: Images with base64 encoding
|
||||
import base64, re
|
||||
|
||||
# base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image>
|
||||
image_metadata, img_without_base_64 = image_url.split(",")
|
||||
|
||||
# read mime_type from img_without_base_64=data:image/jpeg;base64
|
||||
# Extract MIME type using regular expression
|
||||
mime_type_match = re.match(r"data:(.*?);base64", image_metadata)
|
||||
|
||||
if mime_type_match:
|
||||
mime_type = mime_type_match.group(1)
|
||||
else:
|
||||
mime_type = "image/jpeg"
|
||||
decoded_img = base64.b64decode(img_without_base_64)
|
||||
_blob = BlobType(data=decoded_img, mime_type=mime_type)
|
||||
return PartType(inline_data=_blob)
|
||||
raise Exception("Invalid image received - {}".format(image_url))
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
|
||||
"""
|
||||
Converts given messages from OpenAI format to Gemini format
|
||||
|
||||
- Parts must be iterable
|
||||
- Roles must alternate b/w 'user' and 'model' (same as anthropic -> merge consecutive roles)
|
||||
- Please ensure that function response turn comes immediately after a function call turn
|
||||
"""
|
||||
user_message_types = {"user", "system"}
|
||||
contents: List[ContentType] = []
|
||||
|
||||
msg_i = 0
|
||||
while msg_i < len(messages):
|
||||
user_content: List[PartType] = []
|
||||
init_msg_i = msg_i
|
||||
## MERGE CONSECUTIVE USER CONTENT ##
|
||||
while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types:
|
||||
if isinstance(messages[msg_i]["content"], list):
|
||||
_parts: List[PartType] = []
|
||||
for element in messages[msg_i]["content"]:
|
||||
if isinstance(element, dict):
|
||||
if element["type"] == "text":
|
||||
_part = PartType(text=element["text"])
|
||||
_parts.append(_part)
|
||||
elif element["type"] == "image_url":
|
||||
image_url = element["image_url"]["url"]
|
||||
_part = _process_gemini_image(image_url=image_url)
|
||||
_parts.append(_part) # type: ignore
|
||||
user_content.extend(_parts)
|
||||
else:
|
||||
_part = PartType(text=messages[msg_i]["content"])
|
||||
user_content.append(_part)
|
||||
|
||||
msg_i += 1
|
||||
|
||||
if user_content:
|
||||
contents.append(ContentType(role="user", parts=user_content))
|
||||
assistant_content = []
|
||||
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
|
||||
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
|
||||
assistant_text = (
|
||||
messages[msg_i].get("content") or ""
|
||||
) # either string or none
|
||||
if assistant_text:
|
||||
assistant_content.append(PartType(text=assistant_text))
|
||||
if messages[msg_i].get(
|
||||
"tool_calls", []
|
||||
): # support assistant tool invoke convertion
|
||||
assistant_content.extend(
|
||||
convert_to_gemini_tool_call_invoke(messages[msg_i]["tool_calls"])
|
||||
)
|
||||
msg_i += 1
|
||||
|
||||
if assistant_content:
|
||||
contents.append(ContentType(role="model", parts=assistant_content))
|
||||
|
||||
## APPEND TOOL CALL MESSAGES ##
|
||||
if msg_i < len(messages) and messages[msg_i]["role"] == "tool":
|
||||
_part = convert_to_gemini_tool_call_result(messages[msg_i])
|
||||
contents.append(ContentType(parts=[_part])) # type: ignore
|
||||
msg_i += 1
|
||||
if msg_i == init_msg_i: # prevent infinite loops
|
||||
raise Exception(
|
||||
"Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format(
|
||||
messages[msg_i]
|
||||
)
|
||||
)
|
||||
|
||||
return contents
|
||||
|
||||
|
||||
def _gemini_vision_convert_messages(messages: list):
|
||||
"""
|
||||
Converts given messages for GPT-4 Vision to Gemini format.
|
||||
|
@ -396,10 +520,10 @@ def completion(
|
|||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
vertex_credentials=None,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
acompletion: bool = False,
|
||||
|
@ -556,6 +680,7 @@ def completion(
|
|||
"model_response": model_response,
|
||||
"encoding": encoding,
|
||||
"messages": messages,
|
||||
"request_str": request_str,
|
||||
"print_verbose": print_verbose,
|
||||
"client_options": client_options,
|
||||
"instances": instances,
|
||||
|
@ -574,11 +699,9 @@ def completion(
|
|||
print_verbose("\nMaking VertexAI Gemini Pro / Pro Vision Call")
|
||||
print_verbose(f"\nProcessing input messages = {messages}")
|
||||
tools = optional_params.pop("tools", None)
|
||||
prompt, images = _gemini_vision_convert_messages(messages=messages)
|
||||
content = [prompt] + images
|
||||
content = _gemini_convert_messages_with_history(messages=messages)
|
||||
stream = optional_params.pop("stream", False)
|
||||
if stream == True:
|
||||
|
||||
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
|
@ -589,7 +712,7 @@ def completion(
|
|||
},
|
||||
)
|
||||
|
||||
model_response = llm_model.generate_content(
|
||||
_model_response = llm_model.generate_content(
|
||||
contents=content,
|
||||
generation_config=optional_params,
|
||||
safety_settings=safety_settings,
|
||||
|
@ -597,7 +720,7 @@ def completion(
|
|||
tools=tools,
|
||||
)
|
||||
|
||||
return model_response
|
||||
return _model_response
|
||||
|
||||
request_str += f"response = llm_model.generate_content({content})\n"
|
||||
## LOGGING
|
||||
|
@ -850,12 +973,12 @@ async def async_completion(
|
|||
mode: str,
|
||||
prompt: str,
|
||||
model: str,
|
||||
messages: list,
|
||||
model_response: ModelResponse,
|
||||
logging_obj=None,
|
||||
request_str=None,
|
||||
request_str: str,
|
||||
print_verbose: Callable,
|
||||
logging_obj,
|
||||
encoding=None,
|
||||
messages=None,
|
||||
print_verbose=None,
|
||||
client_options=None,
|
||||
instances=None,
|
||||
vertex_project=None,
|
||||
|
@ -875,8 +998,7 @@ async def async_completion(
|
|||
tools = optional_params.pop("tools", None)
|
||||
stream = optional_params.pop("stream", False)
|
||||
|
||||
prompt, images = _gemini_vision_convert_messages(messages=messages)
|
||||
content = [prompt] + images
|
||||
content = _gemini_convert_messages_with_history(messages=messages)
|
||||
|
||||
request_str += f"response = llm_model.generate_content({content})\n"
|
||||
## LOGGING
|
||||
|
@ -1076,11 +1198,11 @@ async def async_streaming(
|
|||
prompt: str,
|
||||
model: str,
|
||||
model_response: ModelResponse,
|
||||
logging_obj=None,
|
||||
request_str=None,
|
||||
messages: list,
|
||||
print_verbose: Callable,
|
||||
logging_obj,
|
||||
request_str: str,
|
||||
encoding=None,
|
||||
messages=None,
|
||||
print_verbose=None,
|
||||
client_options=None,
|
||||
instances=None,
|
||||
vertex_project=None,
|
||||
|
@ -1097,8 +1219,8 @@ async def async_streaming(
|
|||
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
|
||||
print_verbose(f"\nProcessing input messages = {messages}")
|
||||
|
||||
prompt, images = _gemini_vision_convert_messages(messages=messages)
|
||||
content = [prompt] + images
|
||||
content = _gemini_convert_messages_with_history(messages=messages)
|
||||
|
||||
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
|
|
224
litellm/llms/vertex_httpx.py
Normal file
224
litellm/llms/vertex_httpx.py
Normal file
|
@ -0,0 +1,224 @@
|
|||
import os, types
|
||||
import json
|
||||
from enum import Enum
|
||||
import requests # type: ignore
|
||||
import time
|
||||
from typing import Callable, Optional, Union, List, Any, Tuple
|
||||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
|
||||
import litellm, uuid
|
||||
import httpx, inspect # type: ignore
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from .base import BaseLLM
|
||||
|
||||
|
||||
class VertexAIError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(
|
||||
method="POST", url=" https://cloud.google.com/vertex-ai/"
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class VertexLLM(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.access_token: Optional[str] = None
|
||||
self.refresh_token: Optional[str] = None
|
||||
self._credentials: Optional[Any] = None
|
||||
self.project_id: Optional[str] = None
|
||||
self.async_handler: Optional[AsyncHTTPHandler] = None
|
||||
|
||||
def load_auth(self) -> Tuple[Any, str]:
|
||||
from google.auth.transport.requests import Request # type: ignore[import-untyped]
|
||||
from google.auth.credentials import Credentials # type: ignore[import-untyped]
|
||||
import google.auth as google_auth
|
||||
|
||||
credentials, project_id = google_auth.default(
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
|
||||
credentials.refresh(Request())
|
||||
|
||||
if not project_id:
|
||||
raise ValueError("Could not resolve project_id")
|
||||
|
||||
if not isinstance(project_id, str):
|
||||
raise TypeError(
|
||||
f"Expected project_id to be a str but got {type(project_id)}"
|
||||
)
|
||||
|
||||
return credentials, project_id
|
||||
|
||||
def refresh_auth(self, credentials: Any) -> None:
|
||||
from google.auth.transport.requests import Request # type: ignore[import-untyped]
|
||||
|
||||
credentials.refresh(Request())
|
||||
|
||||
def _prepare_request(self, request: httpx.Request) -> None:
|
||||
access_token = self._ensure_access_token()
|
||||
|
||||
if request.headers.get("Authorization"):
|
||||
# already authenticated, nothing for us to do
|
||||
return
|
||||
|
||||
request.headers["Authorization"] = f"Bearer {access_token}"
|
||||
|
||||
def _ensure_access_token(self) -> str:
|
||||
if self.access_token is not None:
|
||||
return self.access_token
|
||||
|
||||
if not self._credentials:
|
||||
self._credentials, project_id = self.load_auth()
|
||||
if not self.project_id:
|
||||
self.project_id = project_id
|
||||
else:
|
||||
self.refresh_auth(self._credentials)
|
||||
|
||||
if not self._credentials.token:
|
||||
raise RuntimeError("Could not resolve API token from the environment")
|
||||
|
||||
assert isinstance(self._credentials.token, str)
|
||||
return self._credentials.token
|
||||
|
||||
def image_generation(
|
||||
self,
|
||||
prompt: str,
|
||||
vertex_project: str,
|
||||
vertex_location: str,
|
||||
model: Optional[
|
||||
str
|
||||
] = "imagegeneration", # vertex ai uses imagegeneration as the default model
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[int] = None,
|
||||
logging_obj=None,
|
||||
model_response=None,
|
||||
aimg_generation=False,
|
||||
):
|
||||
if aimg_generation == True:
|
||||
response = self.aimage_generation(
|
||||
prompt=prompt,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
model=model,
|
||||
client=client,
|
||||
optional_params=optional_params,
|
||||
timeout=timeout,
|
||||
logging_obj=logging_obj,
|
||||
model_response=model_response,
|
||||
)
|
||||
return response
|
||||
|
||||
async def aimage_generation(
|
||||
self,
|
||||
prompt: str,
|
||||
vertex_project: str,
|
||||
vertex_location: str,
|
||||
model_response: litellm.ImageResponse,
|
||||
model: Optional[
|
||||
str
|
||||
] = "imagegeneration", # vertex ai uses imagegeneration as the default model
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[int] = None,
|
||||
logging_obj=None,
|
||||
):
|
||||
response = None
|
||||
if client is None:
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
_httpx_timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = _httpx_timeout
|
||||
else:
|
||||
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
|
||||
self.async_handler = AsyncHTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
self.async_handler = client # type: ignore
|
||||
|
||||
# make POST request to
|
||||
# https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
|
||||
|
||||
"""
|
||||
Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
|
||||
curl -X POST \
|
||||
-H "Authorization: Bearer $(gcloud auth print-access-token)" \
|
||||
-H "Content-Type: application/json; charset=utf-8" \
|
||||
-d {
|
||||
"instances": [
|
||||
{
|
||||
"prompt": "a cat"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"sampleCount": 1
|
||||
}
|
||||
} \
|
||||
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
|
||||
"""
|
||||
auth_header = self._ensure_access_token()
|
||||
optional_params = optional_params or {
|
||||
"sampleCount": 1
|
||||
} # default optional params
|
||||
|
||||
request_data = {
|
||||
"instances": [{"prompt": prompt}],
|
||||
"parameters": optional_params,
|
||||
}
|
||||
|
||||
request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
|
||||
response = await self.async_handler.post(
|
||||
url=url,
|
||||
headers={
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Authorization": f"Bearer {auth_header}",
|
||||
},
|
||||
data=json.dumps(request_data),
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
"""
|
||||
Vertex AI Image generation response example:
|
||||
{
|
||||
"predictions": [
|
||||
{
|
||||
"bytesBase64Encoded": "BASE64_IMG_BYTES",
|
||||
"mimeType": "image/png"
|
||||
},
|
||||
{
|
||||
"mimeType": "image/png",
|
||||
"bytesBase64Encoded": "BASE64_IMG_BYTES"
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
_json_response = response.json()
|
||||
_predictions = _json_response["predictions"]
|
||||
|
||||
_response_data: List[litellm.ImageObject] = []
|
||||
for _prediction in _predictions:
|
||||
_bytes_base64_encoded = _prediction["bytesBase64Encoded"]
|
||||
image_object = litellm.ImageObject(b64_json=_bytes_base64_encoded)
|
||||
_response_data.append(image_object)
|
||||
|
||||
model_response.data = _response_data
|
||||
|
||||
return model_response
|
108
litellm/main.py
108
litellm/main.py
|
@ -79,6 +79,7 @@ from .llms.anthropic_text import AnthropicTextCompletion
|
|||
from .llms.huggingface_restapi import Huggingface
|
||||
from .llms.predibase import PredibaseChatCompletion
|
||||
from .llms.bedrock_httpx import BedrockLLM
|
||||
from .llms.vertex_httpx import VertexLLM
|
||||
from .llms.triton import TritonChatCompletion
|
||||
from .llms.prompt_templates.factory import (
|
||||
prompt_factory,
|
||||
|
@ -118,6 +119,7 @@ huggingface = Huggingface()
|
|||
predibase_chat_completions = PredibaseChatCompletion()
|
||||
triton_chat_completions = TritonChatCompletion()
|
||||
bedrock_chat_completion = BedrockLLM()
|
||||
vertex_chat_completion = VertexLLM()
|
||||
####### COMPLETION ENDPOINTS ################
|
||||
|
||||
|
||||
|
@ -320,12 +322,13 @@ async def acompletion(
|
|||
or custom_llm_provider == "huggingface"
|
||||
or custom_llm_provider == "ollama"
|
||||
or custom_llm_provider == "ollama_chat"
|
||||
or custom_llm_provider == "replicate"
|
||||
or custom_llm_provider == "vertex_ai"
|
||||
or custom_llm_provider == "gemini"
|
||||
or custom_llm_provider == "sagemaker"
|
||||
or custom_llm_provider == "anthropic"
|
||||
or custom_llm_provider == "predibase"
|
||||
or (custom_llm_provider == "bedrock" and "cohere" in model)
|
||||
or custom_llm_provider == "bedrock"
|
||||
or custom_llm_provider in litellm.openai_compatible_providers
|
||||
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
|
@ -367,6 +370,8 @@ async def acompletion(
|
|||
async def _async_streaming(response, model, custom_llm_provider, args):
|
||||
try:
|
||||
print_verbose(f"received response in _async_streaming: {response}")
|
||||
if asyncio.iscoroutine(response):
|
||||
response = await response
|
||||
async for line in response:
|
||||
print_verbose(f"line in async streaming: {line}")
|
||||
yield line
|
||||
|
@ -552,7 +557,7 @@ def completion(
|
|||
model_info = kwargs.get("model_info", None)
|
||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||
fallbacks = kwargs.get("fallbacks", None)
|
||||
headers = kwargs.get("headers", None)
|
||||
headers = kwargs.get("headers", None) or extra_headers
|
||||
num_retries = kwargs.get("num_retries", None) ## deprecated
|
||||
max_retries = kwargs.get("max_retries", None)
|
||||
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
|
||||
|
@ -674,20 +679,6 @@ def completion(
|
|||
k: v for k, v in kwargs.items() if k not in default_params
|
||||
} # model-specific params - pass them straight to the model/provider
|
||||
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
|
||||
try:
|
||||
if base_url is not None:
|
||||
api_base = base_url
|
||||
|
@ -727,6 +718,16 @@ def completion(
|
|||
"aws_region_name", None
|
||||
) # support region-based pricing for bedrock
|
||||
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
if isinstance(timeout, httpx.Timeout) and not supports_httpx_timeout(
|
||||
custom_llm_provider
|
||||
):
|
||||
timeout = timeout.read or 600 # default 10 min timeout
|
||||
elif not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
|
||||
### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ###
|
||||
if input_cost_per_token is not None and output_cost_per_token is not None:
|
||||
litellm.register_model(
|
||||
|
@ -1192,7 +1193,7 @@ def completion(
|
|||
|
||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||
|
||||
model_response = replicate.completion(
|
||||
model_response = replicate.completion( # type: ignore
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
|
@ -1205,12 +1206,10 @@ def completion(
|
|||
api_key=replicate_key,
|
||||
logging_obj=logging,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
acompletion=acompletion,
|
||||
)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
# don't try to access stream object,
|
||||
model_response = CustomStreamWrapper(model_response, model, logging_obj=logging, custom_llm_provider="replicate") # type: ignore
|
||||
|
||||
if optional_params.get("stream", False) or acompletion == True:
|
||||
if optional_params.get("stream", False) == True:
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
|
@ -1984,23 +1983,9 @@ def completion(
|
|||
# boto3 reads keys from .env
|
||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||
|
||||
if "cohere" in model:
|
||||
response = bedrock_chat_completion.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
extra_headers=extra_headers,
|
||||
timeout=timeout,
|
||||
acompletion=acompletion,
|
||||
)
|
||||
else:
|
||||
if (
|
||||
"aws_bedrock_client" in optional_params
|
||||
): # use old bedrock flow for aws_bedrock_client users.
|
||||
response = bedrock.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -2036,7 +2021,22 @@ def completion(
|
|||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging,
|
||||
)
|
||||
|
||||
else:
|
||||
response = bedrock_chat_completion.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
extra_headers=extra_headers,
|
||||
timeout=timeout,
|
||||
acompletion=acompletion,
|
||||
)
|
||||
if optional_params.get("stream", False):
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
|
@ -3856,6 +3856,36 @@ def image_generation(
|
|||
model_response=model_response,
|
||||
aimg_generation=aimg_generation,
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
vertex_ai_project = (
|
||||
optional_params.pop("vertex_project", None)
|
||||
or optional_params.pop("vertex_ai_project", None)
|
||||
or litellm.vertex_project
|
||||
or get_secret("VERTEXAI_PROJECT")
|
||||
)
|
||||
vertex_ai_location = (
|
||||
optional_params.pop("vertex_location", None)
|
||||
or optional_params.pop("vertex_ai_location", None)
|
||||
or litellm.vertex_location
|
||||
or get_secret("VERTEXAI_LOCATION")
|
||||
)
|
||||
vertex_credentials = (
|
||||
optional_params.pop("vertex_credentials", None)
|
||||
or optional_params.pop("vertex_ai_credentials", None)
|
||||
or get_secret("VERTEXAI_CREDENTIALS")
|
||||
)
|
||||
model_response = vertex_chat_completion.image_generation(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
timeout=timeout,
|
||||
logging_obj=litellm_logging_obj,
|
||||
optional_params=optional_params,
|
||||
model_response=model_response,
|
||||
vertex_project=vertex_ai_project,
|
||||
vertex_location=vertex_ai_location,
|
||||
aimg_generation=aimg_generation,
|
||||
)
|
||||
|
||||
return model_response
|
||||
except Exception as e:
|
||||
## Map to OpenAI Exception
|
||||
|
|
|
@ -234,6 +234,24 @@
|
|||
"litellm_provider": "openai",
|
||||
"mode": "chat"
|
||||
},
|
||||
"ft:davinci-002": {
|
||||
"max_tokens": 16384,
|
||||
"max_input_tokens": 16384,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.000002,
|
||||
"output_cost_per_token": 0.000002,
|
||||
"litellm_provider": "text-completion-openai",
|
||||
"mode": "completion"
|
||||
},
|
||||
"ft:babbage-002": {
|
||||
"max_tokens": 16384,
|
||||
"max_input_tokens": 16384,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.0000004,
|
||||
"output_cost_per_token": 0.0000004,
|
||||
"litellm_provider": "text-completion-openai",
|
||||
"mode": "completion"
|
||||
},
|
||||
"text-embedding-3-large": {
|
||||
"max_tokens": 8191,
|
||||
"max_input_tokens": 8191,
|
||||
|
@ -1385,6 +1403,24 @@
|
|||
"mode": "completion",
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||
},
|
||||
"gemini/gemini-1.5-flash-latest": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 1000000,
|
||||
"max_output_tokens": 8192,
|
||||
"max_images_per_prompt": 3000,
|
||||
"max_videos_per_prompt": 10,
|
||||
"max_video_length": 1,
|
||||
"max_audio_length_hours": 8.4,
|
||||
"max_audio_per_prompt": 1,
|
||||
"max_pdf_size_mb": 30,
|
||||
"input_cost_per_token": 0,
|
||||
"output_cost_per_token": 0,
|
||||
"litellm_provider": "vertex_ai-language-models",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||
},
|
||||
"gemini/gemini-pro": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 32760,
|
||||
|
@ -1744,6 +1780,30 @@
|
|||
"litellm_provider": "openrouter",
|
||||
"mode": "chat"
|
||||
},
|
||||
"openrouter/openai/gpt-4o": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.000005,
|
||||
"output_cost_per_token": 0.000015,
|
||||
"litellm_provider": "openrouter",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"openrouter/openai/gpt-4o-2024-05-13": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.000005,
|
||||
"output_cost_per_token": 0.000015,
|
||||
"litellm_provider": "openrouter",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"openrouter/openai/gpt-4-vision-preview": {
|
||||
"max_tokens": 130000,
|
||||
"input_cost_per_token": 0.00001,
|
||||
|
@ -2943,6 +3003,24 @@
|
|||
"litellm_provider": "ollama",
|
||||
"mode": "completion"
|
||||
},
|
||||
"ollama/llama3": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 8192,
|
||||
"max_output_tokens": 8192,
|
||||
"input_cost_per_token": 0.0,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "ollama",
|
||||
"mode": "chat"
|
||||
},
|
||||
"ollama/llama3:70b": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 8192,
|
||||
"max_output_tokens": 8192,
|
||||
"input_cost_per_token": 0.0,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "ollama",
|
||||
"mode": "chat"
|
||||
},
|
||||
"ollama/mistral": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 8192,
|
||||
|
@ -2952,6 +3030,42 @@
|
|||
"litellm_provider": "ollama",
|
||||
"mode": "completion"
|
||||
},
|
||||
"ollama/mistral-7B-Instruct-v0.1": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 8192,
|
||||
"max_output_tokens": 8192,
|
||||
"input_cost_per_token": 0.0,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "ollama",
|
||||
"mode": "chat"
|
||||
},
|
||||
"ollama/mistral-7B-Instruct-v0.2": {
|
||||
"max_tokens": 32768,
|
||||
"max_input_tokens": 32768,
|
||||
"max_output_tokens": 32768,
|
||||
"input_cost_per_token": 0.0,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "ollama",
|
||||
"mode": "chat"
|
||||
},
|
||||
"ollama/mixtral-8x7B-Instruct-v0.1": {
|
||||
"max_tokens": 32768,
|
||||
"max_input_tokens": 32768,
|
||||
"max_output_tokens": 32768,
|
||||
"input_cost_per_token": 0.0,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "ollama",
|
||||
"mode": "chat"
|
||||
},
|
||||
"ollama/mixtral-8x22B-Instruct-v0.1": {
|
||||
"max_tokens": 65536,
|
||||
"max_input_tokens": 65536,
|
||||
"max_output_tokens": 65536,
|
||||
"input_cost_per_token": 0.0,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "ollama",
|
||||
"mode": "chat"
|
||||
},
|
||||
"ollama/codellama": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 4096,
|
||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -1 +1 @@
|
|||
<!DOCTYPE html><html id="__next_error__"><head><meta charSet="utf-8"/><meta name="viewport" content="width=device-width, initial-scale=1"/><link rel="preload" as="script" fetchPriority="low" href="/ui/_next/static/chunks/webpack-de9c0fadf6a94b3b.js" crossorigin=""/><script src="/ui/_next/static/chunks/fd9d1056-f960ab1e6d32b002.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/69-04708d7d4a17c1ee.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/main-app-9b4fb13a7db53edf.js" async="" crossorigin=""></script><title>LiteLLM Dashboard</title><meta name="description" content="LiteLLM Proxy Admin UI"/><link rel="icon" href="/ui/favicon.ico" type="image/x-icon" sizes="16x16"/><meta name="next-size-adjust"/><script src="/ui/_next/static/chunks/polyfills-c67a75d1b6f99dc8.js" crossorigin="" noModule=""></script></head><body><script src="/ui/_next/static/chunks/webpack-de9c0fadf6a94b3b.js" crossorigin="" async=""></script><script>(self.__next_f=self.__next_f||[]).push([0]);self.__next_f.push([2,null])</script><script>self.__next_f.push([1,"1:HL[\"/ui/_next/static/media/c9a5bc6a7c948fb0-s.p.woff2\",\"font\",{\"crossOrigin\":\"\",\"type\":\"font/woff2\"}]\n2:HL[\"/ui/_next/static/css/f04e46b02318b660.css\",\"style\",{\"crossOrigin\":\"\"}]\n0:\"$L3\"\n"])</script><script>self.__next_f.push([1,"4:I[47690,[],\"\"]\n6:I[77831,[],\"\"]\n7:I[7926,[\"936\",\"static/chunks/2f6dbc85-052c4579f80d66ae.js\",\"884\",\"static/chunks/884-7576ee407a2ecbe6.js\",\"931\",\"static/chunks/app/page-6a39771cacf75ea6.js\"],\"\"]\n8:I[5613,[],\"\"]\n9:I[31778,[],\"\"]\nb:I[48955,[],\"\"]\nc:[]\n"])</script><script>self.__next_f.push([1,"3:[[[\"$\",\"link\",\"0\",{\"rel\":\"stylesheet\",\"href\":\"/ui/_next/static/css/f04e46b02318b660.css\",\"precedence\":\"next\",\"crossOrigin\":\"\"}]],[\"$\",\"$L4\",null,{\"buildId\":\"obp5wqVSVDMiDTC414cR8\",\"assetPrefix\":\"/ui\",\"initialCanonicalUrl\":\"/\",\"initialTree\":[\"\",{\"children\":[\"__PAGE__\",{}]},\"$undefined\",\"$undefined\",true],\"initialSeedData\":[\"\",{\"children\":[\"__PAGE__\",{},[\"$L5\",[\"$\",\"$L6\",null,{\"propsForComponent\":{\"params\":{}},\"Component\":\"$7\",\"isStaticGeneration\":true}],null]]},[null,[\"$\",\"html\",null,{\"lang\":\"en\",\"children\":[\"$\",\"body\",null,{\"className\":\"__className_c23dc8\",\"children\":[\"$\",\"$L8\",null,{\"parallelRouterKey\":\"children\",\"segmentPath\":[\"children\"],\"loading\":\"$undefined\",\"loadingStyles\":\"$undefined\",\"loadingScripts\":\"$undefined\",\"hasLoading\":false,\"error\":\"$undefined\",\"errorStyles\":\"$undefined\",\"errorScripts\":\"$undefined\",\"template\":[\"$\",\"$L9\",null,{}],\"templateStyles\":\"$undefined\",\"templateScripts\":\"$undefined\",\"notFound\":[[\"$\",\"title\",null,{\"children\":\"404: This page could not be found.\"}],[\"$\",\"div\",null,{\"style\":{\"fontFamily\":\"system-ui,\\\"Segoe UI\\\",Roboto,Helvetica,Arial,sans-serif,\\\"Apple Color Emoji\\\",\\\"Segoe UI Emoji\\\"\",\"height\":\"100vh\",\"textAlign\":\"center\",\"display\":\"flex\",\"flexDirection\":\"column\",\"alignItems\":\"center\",\"justifyContent\":\"center\"},\"children\":[\"$\",\"div\",null,{\"children\":[[\"$\",\"style\",null,{\"dangerouslySetInnerHTML\":{\"__html\":\"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}\"}}],[\"$\",\"h1\",null,{\"className\":\"next-error-h1\",\"style\":{\"display\":\"inline-block\",\"margin\":\"0 20px 0 0\",\"padding\":\"0 23px 0 0\",\"fontSize\":24,\"fontWeight\":500,\"verticalAlign\":\"top\",\"lineHeight\":\"49px\"},\"children\":\"404\"}],[\"$\",\"div\",null,{\"style\":{\"display\":\"inline-block\"},\"children\":[\"$\",\"h2\",null,{\"style\":{\"fontSize\":14,\"fontWeight\":400,\"lineHeight\":\"49px\",\"margin\":0},\"children\":\"This page could not be found.\"}]}]]}]}]],\"notFoundStyles\":[],\"styles\":null}]}]}],null]],\"initialHead\":[false,\"$La\"],\"globalErrorComponent\":\"$b\",\"missingSlots\":\"$Wc\"}]]\n"])</script><script>self.__next_f.push([1,"a:[[\"$\",\"meta\",\"0\",{\"name\":\"viewport\",\"content\":\"width=device-width, initial-scale=1\"}],[\"$\",\"meta\",\"1\",{\"charSet\":\"utf-8\"}],[\"$\",\"title\",\"2\",{\"children\":\"LiteLLM Dashboard\"}],[\"$\",\"meta\",\"3\",{\"name\":\"description\",\"content\":\"LiteLLM Proxy Admin UI\"}],[\"$\",\"link\",\"4\",{\"rel\":\"icon\",\"href\":\"/ui/favicon.ico\",\"type\":\"image/x-icon\",\"sizes\":\"16x16\"}],[\"$\",\"meta\",\"5\",{\"name\":\"next-size-adjust\"}]]\n5:null\n"])</script><script>self.__next_f.push([1,""])</script></body></html>
|
||||
<!DOCTYPE html><html id="__next_error__"><head><meta charSet="utf-8"/><meta name="viewport" content="width=device-width, initial-scale=1"/><link rel="preload" as="script" fetchPriority="low" href="/ui/_next/static/chunks/webpack-de9c0fadf6a94b3b.js" crossorigin=""/><script src="/ui/_next/static/chunks/fd9d1056-f960ab1e6d32b002.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/69-04708d7d4a17c1ee.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/main-app-9b4fb13a7db53edf.js" async="" crossorigin=""></script><title>LiteLLM Dashboard</title><meta name="description" content="LiteLLM Proxy Admin UI"/><link rel="icon" href="/ui/favicon.ico" type="image/x-icon" sizes="16x16"/><meta name="next-size-adjust"/><script src="/ui/_next/static/chunks/polyfills-c67a75d1b6f99dc8.js" crossorigin="" noModule=""></script></head><body><script src="/ui/_next/static/chunks/webpack-de9c0fadf6a94b3b.js" crossorigin="" async=""></script><script>(self.__next_f=self.__next_f||[]).push([0]);self.__next_f.push([2,null])</script><script>self.__next_f.push([1,"1:HL[\"/ui/_next/static/media/c9a5bc6a7c948fb0-s.p.woff2\",\"font\",{\"crossOrigin\":\"\",\"type\":\"font/woff2\"}]\n2:HL[\"/ui/_next/static/css/f04e46b02318b660.css\",\"style\",{\"crossOrigin\":\"\"}]\n0:\"$L3\"\n"])</script><script>self.__next_f.push([1,"4:I[47690,[],\"\"]\n6:I[77831,[],\"\"]\n7:I[4858,[\"936\",\"static/chunks/2f6dbc85-052c4579f80d66ae.js\",\"884\",\"static/chunks/884-7576ee407a2ecbe6.js\",\"931\",\"static/chunks/app/page-f20fdea77aed85ba.js\"],\"\"]\n8:I[5613,[],\"\"]\n9:I[31778,[],\"\"]\nb:I[48955,[],\"\"]\nc:[]\n"])</script><script>self.__next_f.push([1,"3:[[[\"$\",\"link\",\"0\",{\"rel\":\"stylesheet\",\"href\":\"/ui/_next/static/css/f04e46b02318b660.css\",\"precedence\":\"next\",\"crossOrigin\":\"\"}]],[\"$\",\"$L4\",null,{\"buildId\":\"l-0LDfSCdaUCAbcLIx_QC\",\"assetPrefix\":\"/ui\",\"initialCanonicalUrl\":\"/\",\"initialTree\":[\"\",{\"children\":[\"__PAGE__\",{}]},\"$undefined\",\"$undefined\",true],\"initialSeedData\":[\"\",{\"children\":[\"__PAGE__\",{},[\"$L5\",[\"$\",\"$L6\",null,{\"propsForComponent\":{\"params\":{}},\"Component\":\"$7\",\"isStaticGeneration\":true}],null]]},[null,[\"$\",\"html\",null,{\"lang\":\"en\",\"children\":[\"$\",\"body\",null,{\"className\":\"__className_c23dc8\",\"children\":[\"$\",\"$L8\",null,{\"parallelRouterKey\":\"children\",\"segmentPath\":[\"children\"],\"loading\":\"$undefined\",\"loadingStyles\":\"$undefined\",\"loadingScripts\":\"$undefined\",\"hasLoading\":false,\"error\":\"$undefined\",\"errorStyles\":\"$undefined\",\"errorScripts\":\"$undefined\",\"template\":[\"$\",\"$L9\",null,{}],\"templateStyles\":\"$undefined\",\"templateScripts\":\"$undefined\",\"notFound\":[[\"$\",\"title\",null,{\"children\":\"404: This page could not be found.\"}],[\"$\",\"div\",null,{\"style\":{\"fontFamily\":\"system-ui,\\\"Segoe UI\\\",Roboto,Helvetica,Arial,sans-serif,\\\"Apple Color Emoji\\\",\\\"Segoe UI Emoji\\\"\",\"height\":\"100vh\",\"textAlign\":\"center\",\"display\":\"flex\",\"flexDirection\":\"column\",\"alignItems\":\"center\",\"justifyContent\":\"center\"},\"children\":[\"$\",\"div\",null,{\"children\":[[\"$\",\"style\",null,{\"dangerouslySetInnerHTML\":{\"__html\":\"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}\"}}],[\"$\",\"h1\",null,{\"className\":\"next-error-h1\",\"style\":{\"display\":\"inline-block\",\"margin\":\"0 20px 0 0\",\"padding\":\"0 23px 0 0\",\"fontSize\":24,\"fontWeight\":500,\"verticalAlign\":\"top\",\"lineHeight\":\"49px\"},\"children\":\"404\"}],[\"$\",\"div\",null,{\"style\":{\"display\":\"inline-block\"},\"children\":[\"$\",\"h2\",null,{\"style\":{\"fontSize\":14,\"fontWeight\":400,\"lineHeight\":\"49px\",\"margin\":0},\"children\":\"This page could not be found.\"}]}]]}]}]],\"notFoundStyles\":[],\"styles\":null}]}]}],null]],\"initialHead\":[false,\"$La\"],\"globalErrorComponent\":\"$b\",\"missingSlots\":\"$Wc\"}]]\n"])</script><script>self.__next_f.push([1,"a:[[\"$\",\"meta\",\"0\",{\"name\":\"viewport\",\"content\":\"width=device-width, initial-scale=1\"}],[\"$\",\"meta\",\"1\",{\"charSet\":\"utf-8\"}],[\"$\",\"title\",\"2\",{\"children\":\"LiteLLM Dashboard\"}],[\"$\",\"meta\",\"3\",{\"name\":\"description\",\"content\":\"LiteLLM Proxy Admin UI\"}],[\"$\",\"link\",\"4\",{\"rel\":\"icon\",\"href\":\"/ui/favicon.ico\",\"type\":\"image/x-icon\",\"sizes\":\"16x16\"}],[\"$\",\"meta\",\"5\",{\"name\":\"next-size-adjust\"}]]\n5:null\n"])</script><script>self.__next_f.push([1,""])</script></body></html>
|
|
@ -1,7 +1,7 @@
|
|||
2:I[77831,[],""]
|
||||
3:I[7926,["936","static/chunks/2f6dbc85-052c4579f80d66ae.js","884","static/chunks/884-7576ee407a2ecbe6.js","931","static/chunks/app/page-6a39771cacf75ea6.js"],""]
|
||||
3:I[4858,["936","static/chunks/2f6dbc85-052c4579f80d66ae.js","884","static/chunks/884-7576ee407a2ecbe6.js","931","static/chunks/app/page-f20fdea77aed85ba.js"],""]
|
||||
4:I[5613,[],""]
|
||||
5:I[31778,[],""]
|
||||
0:["obp5wqVSVDMiDTC414cR8",[[["",{"children":["__PAGE__",{}]},"$undefined","$undefined",true],["",{"children":["__PAGE__",{},["$L1",["$","$L2",null,{"propsForComponent":{"params":{}},"Component":"$3","isStaticGeneration":true}],null]]},[null,["$","html",null,{"lang":"en","children":["$","body",null,{"className":"__className_c23dc8","children":["$","$L4",null,{"parallelRouterKey":"children","segmentPath":["children"],"loading":"$undefined","loadingStyles":"$undefined","loadingScripts":"$undefined","hasLoading":false,"error":"$undefined","errorStyles":"$undefined","errorScripts":"$undefined","template":["$","$L5",null,{}],"templateStyles":"$undefined","templateScripts":"$undefined","notFound":[["$","title",null,{"children":"404: This page could not be found."}],["$","div",null,{"style":{"fontFamily":"system-ui,\"Segoe UI\",Roboto,Helvetica,Arial,sans-serif,\"Apple Color Emoji\",\"Segoe UI Emoji\"","height":"100vh","textAlign":"center","display":"flex","flexDirection":"column","alignItems":"center","justifyContent":"center"},"children":["$","div",null,{"children":[["$","style",null,{"dangerouslySetInnerHTML":{"__html":"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}"}}],["$","h1",null,{"className":"next-error-h1","style":{"display":"inline-block","margin":"0 20px 0 0","padding":"0 23px 0 0","fontSize":24,"fontWeight":500,"verticalAlign":"top","lineHeight":"49px"},"children":"404"}],["$","div",null,{"style":{"display":"inline-block"},"children":["$","h2",null,{"style":{"fontSize":14,"fontWeight":400,"lineHeight":"49px","margin":0},"children":"This page could not be found."}]}]]}]}]],"notFoundStyles":[],"styles":null}]}]}],null]],[[["$","link","0",{"rel":"stylesheet","href":"/ui/_next/static/css/f04e46b02318b660.css","precedence":"next","crossOrigin":""}]],"$L6"]]]]
|
||||
0:["l-0LDfSCdaUCAbcLIx_QC",[[["",{"children":["__PAGE__",{}]},"$undefined","$undefined",true],["",{"children":["__PAGE__",{},["$L1",["$","$L2",null,{"propsForComponent":{"params":{}},"Component":"$3","isStaticGeneration":true}],null]]},[null,["$","html",null,{"lang":"en","children":["$","body",null,{"className":"__className_c23dc8","children":["$","$L4",null,{"parallelRouterKey":"children","segmentPath":["children"],"loading":"$undefined","loadingStyles":"$undefined","loadingScripts":"$undefined","hasLoading":false,"error":"$undefined","errorStyles":"$undefined","errorScripts":"$undefined","template":["$","$L5",null,{}],"templateStyles":"$undefined","templateScripts":"$undefined","notFound":[["$","title",null,{"children":"404: This page could not be found."}],["$","div",null,{"style":{"fontFamily":"system-ui,\"Segoe UI\",Roboto,Helvetica,Arial,sans-serif,\"Apple Color Emoji\",\"Segoe UI Emoji\"","height":"100vh","textAlign":"center","display":"flex","flexDirection":"column","alignItems":"center","justifyContent":"center"},"children":["$","div",null,{"children":[["$","style",null,{"dangerouslySetInnerHTML":{"__html":"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}"}}],["$","h1",null,{"className":"next-error-h1","style":{"display":"inline-block","margin":"0 20px 0 0","padding":"0 23px 0 0","fontSize":24,"fontWeight":500,"verticalAlign":"top","lineHeight":"49px"},"children":"404"}],["$","div",null,{"style":{"display":"inline-block"},"children":["$","h2",null,{"style":{"fontSize":14,"fontWeight":400,"lineHeight":"49px","margin":0},"children":"This page could not be found."}]}]]}]}]],"notFoundStyles":[],"styles":null}]}]}],null]],[[["$","link","0",{"rel":"stylesheet","href":"/ui/_next/static/css/f04e46b02318b660.css","precedence":"next","crossOrigin":""}]],"$L6"]]]]
|
||||
6:[["$","meta","0",{"name":"viewport","content":"width=device-width, initial-scale=1"}],["$","meta","1",{"charSet":"utf-8"}],["$","title","2",{"children":"LiteLLM Dashboard"}],["$","meta","3",{"name":"description","content":"LiteLLM Proxy Admin UI"}],["$","link","4",{"rel":"icon","href":"/ui/favicon.ico","type":"image/x-icon","sizes":"16x16"}],["$","meta","5",{"name":"next-size-adjust"}]]
|
||||
1:null
|
||||
|
|
20
litellm/proxy/_logging.py
Normal file
20
litellm/proxy/_logging.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
import json
|
||||
import logging
|
||||
from logging import Formatter
|
||||
|
||||
|
||||
class JsonFormatter(Formatter):
|
||||
def __init__(self):
|
||||
super(JsonFormatter, self).__init__()
|
||||
|
||||
def format(self, record):
|
||||
json_record = {}
|
||||
json_record["message"] = record.getMessage()
|
||||
return json.dumps(json_record)
|
||||
|
||||
|
||||
logger = logging.root
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(JsonFormatter())
|
||||
logger.handlers = [handler]
|
||||
logger.setLevel(logging.DEBUG)
|
|
@ -1,45 +1,20 @@
|
|||
model_list:
|
||||
- litellm_params:
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_version: 2023-07-01-preview
|
||||
model: azure/azure-embedding-model
|
||||
model_info:
|
||||
base_model: text-embedding-ada-002
|
||||
mode: embedding
|
||||
model_name: text-embedding-ada-002
|
||||
- model_name: gpt-3.5-turbo-012
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
api_base: http://0.0.0.0:8080
|
||||
api_key: ""
|
||||
- model_name: gpt-3.5-turbo-0125-preview
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
input_cost_per_token: 0.0
|
||||
output_cost_per_token: 0.0
|
||||
- model_name: bert-classifier
|
||||
litellm_params:
|
||||
model: huggingface/text-classification/shahrukhx01/question-vs-statement-classifier
|
||||
api_key: os.environ/HUGGINGFACE_API_KEY
|
||||
- model_name: gpt-3.5-turbo-fake-model
|
||||
litellm_params:
|
||||
model: openai/my-fake-model
|
||||
api_base: http://0.0.0.0:8080
|
||||
api_key: ""
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/gpt-35-turbo
|
||||
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
|
||||
api_key: os.environ/AZURE_EUROPE_API_KEY
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
||||
api_version: "2023-05-15"
|
||||
api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault
|
||||
|
||||
router_settings:
|
||||
redis_host: redis
|
||||
# redis_password: <your redis password>
|
||||
redis_port: 6379
|
||||
enable_pre_call_checks: true
|
||||
|
||||
litellm_settings:
|
||||
set_verbose: True
|
||||
fallbacks: [{"gpt-3.5-turbo-012": ["gpt-3.5-turbo-0125-preview"]}]
|
||||
# service_callback: ["prometheus_system"]
|
||||
# success_callback: ["prometheus"]
|
||||
# failure_callback: ["prometheus"]
|
||||
|
||||
general_settings:
|
||||
enable_jwt_auth: True
|
||||
disable_reset_budget: True
|
||||
proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds)
|
||||
routing_strategy: simple-shuffle # Literal["simple-shuffle", "least-busy", "usage-based-routing","latency-based-routing"], default="simple-shuffle"
|
||||
alerting: ["slack"]
|
||||
|
|
|
@ -1,37 +1,11 @@
|
|||
from pydantic import ConfigDict, BaseModel, Field, root_validator, Json, VERSION
|
||||
from pydantic import BaseModel, Extra, Field, root_validator, Json, validator
|
||||
from dataclasses import fields
|
||||
import enum
|
||||
from typing import Optional, List, Union, Dict, Literal, Any
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
import json
|
||||
import uuid, json, sys, os
|
||||
from litellm.types.router import UpdateRouterConfig
|
||||
|
||||
try:
|
||||
from pydantic import model_validator # type: ignore
|
||||
except ImportError:
|
||||
from pydantic import root_validator # pydantic v1
|
||||
|
||||
def model_validator(mode): # type: ignore
|
||||
pre = mode == "before"
|
||||
return root_validator(pre=pre)
|
||||
|
||||
|
||||
# Function to get Pydantic version
|
||||
def is_pydantic_v2() -> int:
|
||||
return int(VERSION.split(".")[0])
|
||||
|
||||
|
||||
def get_model_config(arbitrary_types_allowed: bool = False) -> ConfigDict:
|
||||
# Version-specific configuration
|
||||
if is_pydantic_v2() >= 2:
|
||||
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=arbitrary_types_allowed, protected_namespaces=()) # type: ignore
|
||||
else:
|
||||
from pydantic import Extra
|
||||
|
||||
model_config = ConfigDict(extra=Extra.allow, arbitrary_types_allowed=arbitrary_types_allowed) # type: ignore
|
||||
|
||||
return model_config
|
||||
|
||||
|
||||
def hash_token(token: str):
|
||||
import hashlib
|
||||
|
@ -61,7 +35,8 @@ class LiteLLMBase(BaseModel):
|
|||
# if using pydantic v1
|
||||
return self.__fields_set__
|
||||
|
||||
model_config = get_model_config()
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
|
||||
class LiteLLM_UpperboundKeyGenerateParams(LiteLLMBase):
|
||||
|
@ -77,8 +52,18 @@ class LiteLLM_UpperboundKeyGenerateParams(LiteLLMBase):
|
|||
|
||||
|
||||
class LiteLLMRoutes(enum.Enum):
|
||||
openai_route_names: List = [
|
||||
"chat_completion",
|
||||
"completion",
|
||||
"embeddings",
|
||||
"image_generation",
|
||||
"audio_transcriptions",
|
||||
"moderations",
|
||||
"model_list", # OpenAI /v1/models route
|
||||
]
|
||||
openai_routes: List = [
|
||||
# chat completions
|
||||
"/engines/{model}/chat/completions",
|
||||
"/openai/deployments/{model}/chat/completions",
|
||||
"/chat/completions",
|
||||
"/v1/chat/completions",
|
||||
|
@ -102,11 +87,8 @@ class LiteLLMRoutes(enum.Enum):
|
|||
# models
|
||||
"/models",
|
||||
"/v1/models",
|
||||
]
|
||||
|
||||
# NOTE: ROUTES ONLY FOR MASTER KEY - only the Master Key should be able to Reset Spend
|
||||
master_key_only_routes: List = [
|
||||
"/global/spend/reset",
|
||||
# token counter
|
||||
"/utils/token_counter",
|
||||
]
|
||||
|
||||
info_routes: List = [
|
||||
|
@ -119,6 +101,11 @@ class LiteLLMRoutes(enum.Enum):
|
|||
"/v2/key/info",
|
||||
]
|
||||
|
||||
# NOTE: ROUTES ONLY FOR MASTER KEY - only the Master Key should be able to Reset Spend
|
||||
master_key_only_routes: List = [
|
||||
"/global/spend/reset",
|
||||
]
|
||||
|
||||
sso_only_routes: List = [
|
||||
"/key/generate",
|
||||
"/key/update",
|
||||
|
@ -227,13 +214,19 @@ class LiteLLM_JWTAuth(LiteLLMBase):
|
|||
"global_spend_tracking_routes",
|
||||
"info_routes",
|
||||
]
|
||||
team_jwt_scope: str = "litellm_team"
|
||||
team_id_jwt_field: str = "client_id"
|
||||
team_id_jwt_field: Optional[str] = None
|
||||
team_allowed_routes: List[
|
||||
Literal["openai_routes", "info_routes", "management_routes"]
|
||||
] = ["openai_routes", "info_routes"]
|
||||
team_id_default: Optional[str] = Field(
|
||||
default=None,
|
||||
description="If no team_id given, default permissions/spend-tracking to this team.s",
|
||||
)
|
||||
org_id_jwt_field: Optional[str] = None
|
||||
user_id_jwt_field: Optional[str] = None
|
||||
user_id_upsert: bool = Field(
|
||||
default=False, description="If user doesn't exist, upsert them into the db."
|
||||
)
|
||||
end_user_id_jwt_field: Optional[str] = None
|
||||
public_key_ttl: float = 600
|
||||
|
||||
|
@ -258,8 +251,12 @@ class LiteLLMPromptInjectionParams(LiteLLMBase):
|
|||
llm_api_name: Optional[str] = None
|
||||
llm_api_system_prompt: Optional[str] = None
|
||||
llm_api_fail_call_string: Optional[str] = None
|
||||
reject_as_response: Optional[bool] = Field(
|
||||
default=False,
|
||||
description="Return rejected request error message as a string to the user. Default behaviour is to raise an exception.",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@root_validator(pre=True)
|
||||
def check_llm_api_params(cls, values):
|
||||
llm_api_check = values.get("llm_api_check")
|
||||
if llm_api_check is True:
|
||||
|
@ -317,7 +314,8 @@ class ProxyChatCompletionRequest(LiteLLMBase):
|
|||
deployment_id: Optional[str] = None
|
||||
request_timeout: Optional[int] = None
|
||||
|
||||
model_config = get_model_config()
|
||||
class Config:
|
||||
extra = "allow" # allow params not defined here, these fall in litellm.completion(**kwargs)
|
||||
|
||||
|
||||
class ModelInfoDelete(LiteLLMBase):
|
||||
|
@ -344,9 +342,11 @@ class ModelInfo(LiteLLMBase):
|
|||
]
|
||||
]
|
||||
|
||||
model_config = get_model_config()
|
||||
class Config:
|
||||
extra = Extra.allow # Allow extra fields
|
||||
protected_namespaces = ()
|
||||
|
||||
@model_validator(mode="before")
|
||||
@root_validator(pre=True)
|
||||
def set_model_info(cls, values):
|
||||
if values.get("id") is None:
|
||||
values.update({"id": str(uuid.uuid4())})
|
||||
|
@ -372,9 +372,10 @@ class ModelParams(LiteLLMBase):
|
|||
litellm_params: dict
|
||||
model_info: ModelInfo
|
||||
|
||||
model_config = get_model_config()
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
@model_validator(mode="before")
|
||||
@root_validator(pre=True)
|
||||
def set_model_info(cls, values):
|
||||
if values.get("model_info") is None:
|
||||
values.update({"model_info": ModelInfo()})
|
||||
|
@ -410,7 +411,8 @@ class GenerateKeyRequest(GenerateRequestBase):
|
|||
{}
|
||||
) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {}
|
||||
|
||||
model_config = get_model_config()
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
|
||||
class GenerateKeyResponse(GenerateKeyRequest):
|
||||
|
@ -420,7 +422,7 @@ class GenerateKeyResponse(GenerateKeyRequest):
|
|||
user_id: Optional[str] = None
|
||||
token_id: Optional[str] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@root_validator(pre=True)
|
||||
def set_model_info(cls, values):
|
||||
if values.get("token") is not None:
|
||||
values.update({"key": values.get("token")})
|
||||
|
@ -460,7 +462,8 @@ class LiteLLM_ModelTable(LiteLLMBase):
|
|||
created_by: str
|
||||
updated_by: str
|
||||
|
||||
model_config = get_model_config()
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
|
||||
class NewUserRequest(GenerateKeyRequest):
|
||||
|
@ -488,7 +491,7 @@ class UpdateUserRequest(GenerateRequestBase):
|
|||
user_role: Optional[str] = None
|
||||
max_budget: Optional[float] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@root_validator(pre=True)
|
||||
def check_user_info(cls, values):
|
||||
if values.get("user_id") is None and values.get("user_email") is None:
|
||||
raise ValueError("Either user id or user email must be provided")
|
||||
|
@ -508,7 +511,7 @@ class NewEndUserRequest(LiteLLMBase):
|
|||
None # if no equivalent model in allowed region - default all requests to this model
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@root_validator(pre=True)
|
||||
def check_user_info(cls, values):
|
||||
if values.get("max_budget") is not None and values.get("budget_id") is not None:
|
||||
raise ValueError("Set either 'max_budget' or 'budget_id', not both.")
|
||||
|
@ -521,7 +524,7 @@ class Member(LiteLLMBase):
|
|||
user_id: Optional[str] = None
|
||||
user_email: Optional[str] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@root_validator(pre=True)
|
||||
def check_user_info(cls, values):
|
||||
if values.get("user_id") is None and values.get("user_email") is None:
|
||||
raise ValueError("Either user id or user email must be provided")
|
||||
|
@ -546,7 +549,8 @@ class TeamBase(LiteLLMBase):
|
|||
class NewTeamRequest(TeamBase):
|
||||
model_aliases: Optional[dict] = None
|
||||
|
||||
model_config = get_model_config()
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
|
||||
class GlobalEndUsersSpend(LiteLLMBase):
|
||||
|
@ -565,7 +569,7 @@ class TeamMemberDeleteRequest(LiteLLMBase):
|
|||
user_id: Optional[str] = None
|
||||
user_email: Optional[str] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@root_validator(pre=True)
|
||||
def check_user_info(cls, values):
|
||||
if values.get("user_id") is None and values.get("user_email") is None:
|
||||
raise ValueError("Either user id or user email must be provided")
|
||||
|
@ -599,9 +603,10 @@ class LiteLLM_TeamTable(TeamBase):
|
|||
budget_reset_at: Optional[datetime] = None
|
||||
model_id: Optional[int] = None
|
||||
|
||||
model_config = get_model_config()
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
@model_validator(mode="before")
|
||||
@root_validator(pre=True)
|
||||
def set_model_info(cls, values):
|
||||
dict_fields = [
|
||||
"metadata",
|
||||
|
@ -637,7 +642,8 @@ class LiteLLM_BudgetTable(LiteLLMBase):
|
|||
model_max_budget: Optional[dict] = None
|
||||
budget_duration: Optional[str] = None
|
||||
|
||||
model_config = get_model_config()
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
|
||||
class NewOrganizationRequest(LiteLLM_BudgetTable):
|
||||
|
@ -687,7 +693,8 @@ class KeyManagementSettings(LiteLLMBase):
|
|||
class TeamDefaultSettings(LiteLLMBase):
|
||||
team_id: str
|
||||
|
||||
model_config = get_model_config()
|
||||
class Config:
|
||||
extra = "allow" # allow params not defined here, these fall in litellm.completion(**kwargs)
|
||||
|
||||
|
||||
class DynamoDBArgs(LiteLLMBase):
|
||||
|
@ -711,6 +718,25 @@ class DynamoDBArgs(LiteLLMBase):
|
|||
assume_role_aws_session_name: Optional[str] = None
|
||||
|
||||
|
||||
class ConfigFieldUpdate(LiteLLMBase):
|
||||
field_name: str
|
||||
field_value: Any
|
||||
config_type: Literal["general_settings"]
|
||||
|
||||
|
||||
class ConfigFieldDelete(LiteLLMBase):
|
||||
config_type: Literal["general_settings"]
|
||||
field_name: str
|
||||
|
||||
|
||||
class ConfigList(LiteLLMBase):
|
||||
field_name: str
|
||||
field_type: str
|
||||
field_description: str
|
||||
field_value: Any
|
||||
stored_in_db: Optional[bool]
|
||||
|
||||
|
||||
class ConfigGeneralSettings(LiteLLMBase):
|
||||
"""
|
||||
Documents all the fields supported by `general_settings` in config.yaml
|
||||
|
@ -758,7 +784,11 @@ class ConfigGeneralSettings(LiteLLMBase):
|
|||
description="override user_api_key_auth with your own auth script - https://docs.litellm.ai/docs/proxy/virtual_keys#custom-auth",
|
||||
)
|
||||
max_parallel_requests: Optional[int] = Field(
|
||||
None, description="maximum parallel requests for each api key"
|
||||
None,
|
||||
description="maximum parallel requests for each api key",
|
||||
)
|
||||
global_max_parallel_requests: Optional[int] = Field(
|
||||
None, description="global max parallel requests to allow for a proxy instance."
|
||||
)
|
||||
infer_model_from_keys: Optional[bool] = Field(
|
||||
None,
|
||||
|
@ -828,7 +858,8 @@ class ConfigYAML(LiteLLMBase):
|
|||
description="litellm router object settings. See router.py __init__ for all, example router.num_retries=5, router.timeout=5, router.max_retries=5, router.retry_after=5",
|
||||
)
|
||||
|
||||
model_config = get_model_config()
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
|
||||
class LiteLLM_VerificationToken(LiteLLMBase):
|
||||
|
@ -862,7 +893,8 @@ class LiteLLM_VerificationToken(LiteLLMBase):
|
|||
user_id_rate_limits: Optional[dict] = None
|
||||
team_id_rate_limits: Optional[dict] = None
|
||||
|
||||
model_config = get_model_config()
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
|
||||
class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
|
||||
|
@ -892,7 +924,7 @@ class UserAPIKeyAuth(
|
|||
user_role: Optional[Literal["proxy_admin", "app_owner", "app_user"]] = None
|
||||
allowed_model_region: Optional[Literal["eu"]] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@root_validator(pre=True)
|
||||
def check_api_key(cls, values):
|
||||
if values.get("api_key") is not None:
|
||||
values.update({"token": hash_token(values.get("api_key"))})
|
||||
|
@ -919,7 +951,7 @@ class LiteLLM_UserTable(LiteLLMBase):
|
|||
tpm_limit: Optional[int] = None
|
||||
rpm_limit: Optional[int] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@root_validator(pre=True)
|
||||
def set_model_info(cls, values):
|
||||
if values.get("spend") is None:
|
||||
values.update({"spend": 0.0})
|
||||
|
@ -927,7 +959,8 @@ class LiteLLM_UserTable(LiteLLMBase):
|
|||
values.update({"models": []})
|
||||
return values
|
||||
|
||||
model_config = get_model_config()
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
|
||||
class LiteLLM_EndUserTable(LiteLLMBase):
|
||||
|
@ -939,13 +972,14 @@ class LiteLLM_EndUserTable(LiteLLMBase):
|
|||
default_model: Optional[str] = None
|
||||
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@root_validator(pre=True)
|
||||
def set_model_info(cls, values):
|
||||
if values.get("spend") is None:
|
||||
values.update({"spend": 0.0})
|
||||
return values
|
||||
|
||||
model_config = get_model_config()
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
|
||||
class LiteLLM_SpendLogs(LiteLLMBase):
|
||||
|
@ -983,3 +1017,30 @@ class LiteLLM_ErrorLogs(LiteLLMBase):
|
|||
|
||||
class LiteLLM_SpendLogs_ResponseObject(LiteLLMBase):
|
||||
response: Optional[List[Union[LiteLLM_SpendLogs, Any]]] = None
|
||||
|
||||
|
||||
class TokenCountRequest(LiteLLMBase):
|
||||
model: str
|
||||
prompt: Optional[str] = None
|
||||
messages: Optional[List[dict]] = None
|
||||
|
||||
|
||||
class TokenCountResponse(LiteLLMBase):
|
||||
total_tokens: int
|
||||
request_model: str
|
||||
model_used: str
|
||||
tokenizer_type: str
|
||||
|
||||
|
||||
class CallInfo(LiteLLMBase):
|
||||
"""Used for slack budget alerting"""
|
||||
|
||||
spend: float
|
||||
max_budget: float
|
||||
token: str = Field(description="Hashed value of that key")
|
||||
user_id: Optional[str] = None
|
||||
team_id: Optional[str] = None
|
||||
user_email: Optional[str] = None
|
||||
key_alias: Optional[str] = None
|
||||
projected_exceeded_date: Optional[str] = None
|
||||
projected_spend: Optional[float] = None
|
||||
|
|
|
@ -26,7 +26,7 @@ all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes
|
|||
|
||||
def common_checks(
|
||||
request_body: dict,
|
||||
team_object: LiteLLM_TeamTable,
|
||||
team_object: Optional[LiteLLM_TeamTable],
|
||||
user_object: Optional[LiteLLM_UserTable],
|
||||
end_user_object: Optional[LiteLLM_EndUserTable],
|
||||
global_proxy_spend: Optional[float],
|
||||
|
@ -45,13 +45,14 @@ def common_checks(
|
|||
6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
|
||||
"""
|
||||
_model = request_body.get("model", None)
|
||||
if team_object.blocked == True:
|
||||
if team_object is not None and team_object.blocked == True:
|
||||
raise Exception(
|
||||
f"Team={team_object.team_id} is blocked. Update via `/team/unblock` if your admin."
|
||||
)
|
||||
# 2. If user can call model
|
||||
if (
|
||||
_model is not None
|
||||
and team_object is not None
|
||||
and len(team_object.models) > 0
|
||||
and _model not in team_object.models
|
||||
):
|
||||
|
@ -65,7 +66,8 @@ def common_checks(
|
|||
)
|
||||
# 3. If team is in budget
|
||||
if (
|
||||
team_object.max_budget is not None
|
||||
team_object is not None
|
||||
and team_object.max_budget is not None
|
||||
and team_object.spend is not None
|
||||
and team_object.spend > team_object.max_budget
|
||||
):
|
||||
|
@ -239,6 +241,7 @@ async def get_user_object(
|
|||
user_id: str,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: DualCache,
|
||||
user_id_upsert: bool,
|
||||
) -> Optional[LiteLLM_UserTable]:
|
||||
"""
|
||||
- Check if user id in proxy User Table
|
||||
|
@ -252,7 +255,7 @@ async def get_user_object(
|
|||
return None
|
||||
|
||||
# check if in cache
|
||||
cached_user_obj = user_api_key_cache.async_get_cache(key=user_id)
|
||||
cached_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
|
||||
if cached_user_obj is not None:
|
||||
if isinstance(cached_user_obj, dict):
|
||||
return LiteLLM_UserTable(**cached_user_obj)
|
||||
|
@ -260,16 +263,27 @@ async def get_user_object(
|
|||
return cached_user_obj
|
||||
# else, check db
|
||||
try:
|
||||
|
||||
response = await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": user_id}
|
||||
)
|
||||
|
||||
if response is None:
|
||||
raise Exception
|
||||
if user_id_upsert:
|
||||
response = await prisma_client.db.litellm_usertable.create(
|
||||
data={"user_id": user_id}
|
||||
)
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
return LiteLLM_UserTable(**response.dict())
|
||||
except Exception as e: # if end-user not in db
|
||||
raise Exception(
|
||||
_response = LiteLLM_UserTable(**dict(response))
|
||||
|
||||
# save the user object to cache
|
||||
await user_api_key_cache.async_set_cache(key=user_id, value=_response)
|
||||
|
||||
return _response
|
||||
except Exception as e: # if user not in db
|
||||
raise ValueError(
|
||||
f"User doesn't exist in db. 'user_id'={user_id}. Create user via `/user/new` call."
|
||||
)
|
||||
|
||||
|
@ -290,7 +304,7 @@ async def get_team_object(
|
|||
)
|
||||
|
||||
# check if in cache
|
||||
cached_team_obj = user_api_key_cache.async_get_cache(key=team_id)
|
||||
cached_team_obj = await user_api_key_cache.async_get_cache(key=team_id)
|
||||
if cached_team_obj is not None:
|
||||
if isinstance(cached_team_obj, dict):
|
||||
return LiteLLM_TeamTable(**cached_team_obj)
|
||||
|
@ -305,7 +319,11 @@ async def get_team_object(
|
|||
if response is None:
|
||||
raise Exception
|
||||
|
||||
return LiteLLM_TeamTable(**response.dict())
|
||||
_response = LiteLLM_TeamTable(**response.dict())
|
||||
# save the team object to cache
|
||||
await user_api_key_cache.async_set_cache(key=response.team_id, value=_response)
|
||||
|
||||
return _response
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."
|
||||
|
|
|
@ -55,12 +55,9 @@ class JWTHandler:
|
|||
return True
|
||||
return False
|
||||
|
||||
def is_team(self, scopes: list) -> bool:
|
||||
if self.litellm_jwtauth.team_jwt_scope in scopes:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_end_user_id(self, token: dict, default_value: Optional[str]) -> str:
|
||||
def get_end_user_id(
|
||||
self, token: dict, default_value: Optional[str]
|
||||
) -> Optional[str]:
|
||||
try:
|
||||
if self.litellm_jwtauth.end_user_id_jwt_field is not None:
|
||||
user_id = token[self.litellm_jwtauth.end_user_id_jwt_field]
|
||||
|
@ -70,13 +67,36 @@ class JWTHandler:
|
|||
user_id = default_value
|
||||
return user_id
|
||||
|
||||
def is_required_team_id(self) -> bool:
|
||||
"""
|
||||
Returns:
|
||||
- True: if 'team_id_jwt_field' is set
|
||||
- False: if not
|
||||
"""
|
||||
if self.litellm_jwtauth.team_id_jwt_field is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
||||
try:
|
||||
team_id = token[self.litellm_jwtauth.team_id_jwt_field]
|
||||
if self.litellm_jwtauth.team_id_jwt_field is not None:
|
||||
team_id = token[self.litellm_jwtauth.team_id_jwt_field]
|
||||
elif self.litellm_jwtauth.team_id_default is not None:
|
||||
team_id = self.litellm_jwtauth.team_id_default
|
||||
else:
|
||||
team_id = None
|
||||
except KeyError:
|
||||
team_id = default_value
|
||||
return team_id
|
||||
|
||||
def is_upsert_user_id(self) -> bool:
|
||||
"""
|
||||
Returns:
|
||||
- True: if 'user_id_upsert' is set
|
||||
- False: if not
|
||||
"""
|
||||
return self.litellm_jwtauth.user_id_upsert
|
||||
|
||||
def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
||||
try:
|
||||
if self.litellm_jwtauth.user_id_jwt_field is not None:
|
||||
|
@ -207,12 +227,14 @@ class JWTHandler:
|
|||
raise Exception(f"Validation fails: {str(e)}")
|
||||
elif public_key is not None and isinstance(public_key, str):
|
||||
try:
|
||||
cert = x509.load_pem_x509_certificate(public_key.encode(), default_backend())
|
||||
cert = x509.load_pem_x509_certificate(
|
||||
public_key.encode(), default_backend()
|
||||
)
|
||||
|
||||
# Extract public key
|
||||
key = cert.public_key().public_bytes(
|
||||
serialization.Encoding.PEM,
|
||||
serialization.PublicFormat.SubjectPublicKeyInfo
|
||||
serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
# decode the token using the public key
|
||||
|
@ -221,7 +243,7 @@ class JWTHandler:
|
|||
key,
|
||||
algorithms=algorithms,
|
||||
audience=audience,
|
||||
options=decode_options
|
||||
options=decode_options,
|
||||
)
|
||||
return payload
|
||||
|
||||
|
|
42
litellm/proxy/auth/litellm_license.py
Normal file
42
litellm/proxy/auth/litellm_license.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
# What is this?
|
||||
## If litellm license in env, checks if it's valid
|
||||
import os
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
|
||||
class LicenseCheck:
|
||||
"""
|
||||
- Check if license in env
|
||||
- Returns if license is valid
|
||||
"""
|
||||
|
||||
base_url = "https://license.litellm.ai"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.license_str = os.getenv("LITELLM_LICENSE", None)
|
||||
self.http_handler = HTTPHandler()
|
||||
|
||||
def _verify(self, license_str: str) -> bool:
|
||||
url = "{}/verify_license/{}".format(self.base_url, license_str)
|
||||
|
||||
try: # don't impact user, if call fails
|
||||
response = self.http_handler.get(url=url)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
response_json = response.json()
|
||||
|
||||
premium = response_json["verify"]
|
||||
|
||||
assert isinstance(premium, bool)
|
||||
|
||||
return premium
|
||||
except Exception as e:
|
||||
return False
|
||||
|
||||
def is_premium(self) -> bool:
|
||||
if self.license_str is None:
|
||||
return False
|
||||
elif self._verify(license_str=self.license_str):
|
||||
return True
|
||||
return False
|
|
@ -79,6 +79,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
max_parallel_requests = user_api_key_dict.max_parallel_requests
|
||||
if max_parallel_requests is None:
|
||||
max_parallel_requests = sys.maxsize
|
||||
global_max_parallel_requests = data.get("metadata", {}).get(
|
||||
"global_max_parallel_requests", None
|
||||
)
|
||||
tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize)
|
||||
if tpm_limit is None:
|
||||
tpm_limit = sys.maxsize
|
||||
|
@ -91,6 +94,24 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
# Setup values
|
||||
# ------------
|
||||
|
||||
if global_max_parallel_requests is not None:
|
||||
# get value from cache
|
||||
_key = "global_max_parallel_requests"
|
||||
current_global_requests = await cache.async_get_cache(
|
||||
key=_key, local_only=True
|
||||
)
|
||||
# check if below limit
|
||||
if current_global_requests is None:
|
||||
current_global_requests = 1
|
||||
# if above -> raise error
|
||||
if current_global_requests >= global_max_parallel_requests:
|
||||
raise HTTPException(
|
||||
status_code=429, detail="Max parallel request limit reached."
|
||||
)
|
||||
# if below -> increment
|
||||
else:
|
||||
await cache.async_increment_cache(key=_key, value=1, local_only=True)
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
current_minute = datetime.now().strftime("%M")
|
||||
|
@ -207,6 +228,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.print_verbose(f"INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
|
||||
global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get(
|
||||
"global_max_parallel_requests", None
|
||||
)
|
||||
user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"]
|
||||
user_api_key_user_id = kwargs["litellm_params"]["metadata"].get(
|
||||
"user_api_key_user_id", None
|
||||
|
@ -222,6 +246,14 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
# Setup values
|
||||
# ------------
|
||||
|
||||
if global_max_parallel_requests is not None:
|
||||
# get value from cache
|
||||
_key = "global_max_parallel_requests"
|
||||
# decrement
|
||||
await self.user_api_key_cache.async_increment_cache(
|
||||
key=_key, value=-1, local_only=True
|
||||
)
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
current_minute = datetime.now().strftime("%M")
|
||||
|
@ -336,6 +368,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.print_verbose(f"Inside Max Parallel Request Failure Hook")
|
||||
global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get(
|
||||
"global_max_parallel_requests", None
|
||||
)
|
||||
user_api_key = (
|
||||
kwargs["litellm_params"].get("metadata", {}).get("user_api_key", None)
|
||||
)
|
||||
|
@ -347,17 +382,26 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
|||
return
|
||||
|
||||
## decrement call count if call failed
|
||||
if (
|
||||
hasattr(kwargs["exception"], "status_code")
|
||||
and kwargs["exception"].status_code == 429
|
||||
and "Max parallel request limit reached" in str(kwargs["exception"])
|
||||
):
|
||||
if "Max parallel request limit reached" in str(kwargs["exception"]):
|
||||
pass # ignore failed calls due to max limit being reached
|
||||
else:
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
|
||||
if global_max_parallel_requests is not None:
|
||||
# get value from cache
|
||||
_key = "global_max_parallel_requests"
|
||||
current_global_requests = (
|
||||
await self.user_api_key_cache.async_get_cache(
|
||||
key=_key, local_only=True
|
||||
)
|
||||
)
|
||||
# decrement
|
||||
await self.user_api_key_cache.async_increment_cache(
|
||||
key=_key, value=-1, local_only=True
|
||||
)
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
current_minute = datetime.now().strftime("%M")
|
||||
|
|
|
@ -146,6 +146,7 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
|
|||
try:
|
||||
assert call_type in [
|
||||
"completion",
|
||||
"text_completion",
|
||||
"embeddings",
|
||||
"image_generation",
|
||||
"moderation",
|
||||
|
@ -192,6 +193,15 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
|
|||
return data
|
||||
|
||||
except HTTPException as e:
|
||||
|
||||
if (
|
||||
e.status_code == 400
|
||||
and isinstance(e.detail, dict)
|
||||
and "error" in e.detail
|
||||
and self.prompt_injection_params is not None
|
||||
and self.prompt_injection_params.reject_as_response
|
||||
):
|
||||
return e.detail["error"]
|
||||
raise e
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
|
|
@ -17,6 +17,7 @@ if litellm_mode == "DEV":
|
|||
from importlib import resources
|
||||
import shutil
|
||||
|
||||
|
||||
telemetry = None
|
||||
|
||||
|
||||
|
@ -505,6 +506,7 @@ def run_server(
|
|||
port = random.randint(1024, 49152)
|
||||
|
||||
from litellm.proxy.proxy_server import app
|
||||
import litellm
|
||||
|
||||
if run_gunicorn == False:
|
||||
if ssl_certfile_path is not None and ssl_keyfile_path is not None:
|
||||
|
@ -519,7 +521,15 @@ def run_server(
|
|||
ssl_certfile=ssl_certfile_path,
|
||||
) # run uvicorn
|
||||
else:
|
||||
uvicorn.run(app, host=host, port=port) # run uvicorn
|
||||
print(f"litellm.json_logs: {litellm.json_logs}")
|
||||
if litellm.json_logs:
|
||||
from litellm.proxy._logging import logger
|
||||
|
||||
uvicorn.run(
|
||||
app, host=host, port=port, log_config=None
|
||||
) # run uvicorn w/ json
|
||||
else:
|
||||
uvicorn.run(app, host=host, port=port) # run uvicorn
|
||||
elif run_gunicorn == True:
|
||||
import gunicorn.app.base
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -11,6 +11,7 @@ from litellm.proxy._types import (
|
|||
LiteLLM_EndUserTable,
|
||||
LiteLLM_TeamTable,
|
||||
Member,
|
||||
CallInfo,
|
||||
)
|
||||
from litellm.caching import DualCache, RedisCache
|
||||
from litellm.router import Deployment, ModelInfo, LiteLLM_Params
|
||||
|
@ -18,8 +19,18 @@ from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
|||
from litellm.proxy.hooks.parallel_request_limiter import (
|
||||
_PROXY_MaxParallelRequestsHandler,
|
||||
)
|
||||
from litellm.exceptions import RejectedRequestError
|
||||
from litellm._service_logger import ServiceLogging, ServiceTypes
|
||||
from litellm import ModelResponse, EmbeddingResponse, ImageResponse
|
||||
from litellm import (
|
||||
ModelResponse,
|
||||
EmbeddingResponse,
|
||||
ImageResponse,
|
||||
TranscriptionResponse,
|
||||
TextCompletionResponse,
|
||||
CustomStreamWrapper,
|
||||
TextCompletionStreamWrapper,
|
||||
)
|
||||
from litellm.utils import ModelResponseIterator
|
||||
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
|
||||
from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter
|
||||
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
|
||||
|
@ -32,6 +43,7 @@ from email.mime.text import MIMEText
|
|||
from email.mime.multipart import MIMEMultipart
|
||||
from datetime import datetime, timedelta
|
||||
from litellm.integrations.slack_alerting import SlackAlerting
|
||||
from typing_extensions import overload
|
||||
|
||||
|
||||
def print_verbose(print_statement):
|
||||
|
@ -74,6 +86,9 @@ class ProxyLogging:
|
|||
"budget_alerts",
|
||||
"db_exceptions",
|
||||
"daily_reports",
|
||||
"spend_reports",
|
||||
"cooldown_deployment",
|
||||
"new_model_added",
|
||||
]
|
||||
] = [
|
||||
"llm_exceptions",
|
||||
|
@ -82,6 +97,9 @@ class ProxyLogging:
|
|||
"budget_alerts",
|
||||
"db_exceptions",
|
||||
"daily_reports",
|
||||
"spend_reports",
|
||||
"cooldown_deployment",
|
||||
"new_model_added",
|
||||
]
|
||||
self.slack_alerting_instance = SlackAlerting(
|
||||
alerting_threshold=self.alerting_threshold,
|
||||
|
@ -104,6 +122,9 @@ class ProxyLogging:
|
|||
"budget_alerts",
|
||||
"db_exceptions",
|
||||
"daily_reports",
|
||||
"spend_reports",
|
||||
"cooldown_deployment",
|
||||
"new_model_added",
|
||||
]
|
||||
]
|
||||
] = None,
|
||||
|
@ -122,7 +143,13 @@ class ProxyLogging:
|
|||
alerting_args=alerting_args,
|
||||
)
|
||||
|
||||
if "daily_reports" in self.alert_types:
|
||||
if (
|
||||
self.alerting is not None
|
||||
and "slack" in self.alerting
|
||||
and "daily_reports" in self.alert_types
|
||||
):
|
||||
# NOTE: ENSURE we only add callbacks when alerting is on
|
||||
# We should NOT add callbacks when alerting is off
|
||||
litellm.callbacks.append(self.slack_alerting_instance) # type: ignore
|
||||
|
||||
if redis_cache is not None:
|
||||
|
@ -140,6 +167,8 @@ class ProxyLogging:
|
|||
self.slack_alerting_instance.response_taking_too_long_callback
|
||||
)
|
||||
for callback in litellm.callbacks:
|
||||
if isinstance(callback, str):
|
||||
callback = litellm.utils._init_custom_logger_compatible_class(callback)
|
||||
if callback not in litellm.input_callback:
|
||||
litellm.input_callback.append(callback)
|
||||
if callback not in litellm.success_callback:
|
||||
|
@ -165,18 +194,20 @@ class ProxyLogging:
|
|||
)
|
||||
litellm.utils.set_callbacks(callback_list=callback_list)
|
||||
|
||||
# The actual implementation of the function
|
||||
async def pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
data: dict,
|
||||
call_type: Literal[
|
||||
"completion",
|
||||
"text_completion",
|
||||
"embeddings",
|
||||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
],
|
||||
):
|
||||
) -> dict:
|
||||
"""
|
||||
Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body.
|
||||
|
||||
|
@ -203,8 +234,25 @@ class ProxyLogging:
|
|||
call_type=call_type,
|
||||
)
|
||||
if response is not None:
|
||||
data = response
|
||||
|
||||
if isinstance(response, Exception):
|
||||
raise response
|
||||
elif isinstance(response, dict):
|
||||
data = response
|
||||
elif isinstance(response, str):
|
||||
if (
|
||||
call_type == "completion"
|
||||
or call_type == "text_completion"
|
||||
):
|
||||
raise RejectedRequestError(
|
||||
message=response,
|
||||
model=data.get("model", ""),
|
||||
llm_provider="",
|
||||
request_data=data,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400, detail={"error": response}
|
||||
)
|
||||
print_verbose(f"final data being sent to {call_type} call: {data}")
|
||||
return data
|
||||
except Exception as e:
|
||||
|
@ -252,8 +300,8 @@ class ProxyLogging:
|
|||
"""
|
||||
Runs the CustomLogger's async_moderation_hook()
|
||||
"""
|
||||
new_data = copy.deepcopy(data)
|
||||
for callback in litellm.callbacks:
|
||||
new_data = copy.deepcopy(data)
|
||||
try:
|
||||
if isinstance(callback, CustomLogger):
|
||||
await callback.async_moderation_hook(
|
||||
|
@ -265,30 +313,30 @@ class ProxyLogging:
|
|||
raise e
|
||||
return data
|
||||
|
||||
async def failed_tracking_alert(self, error_message: str):
|
||||
if self.alerting is None:
|
||||
return
|
||||
await self.slack_alerting_instance.failed_tracking_alert(
|
||||
error_message=error_message
|
||||
)
|
||||
|
||||
async def budget_alerts(
|
||||
self,
|
||||
type: Literal[
|
||||
"token_budget",
|
||||
"user_budget",
|
||||
"user_and_proxy_budget",
|
||||
"failed_budgets",
|
||||
"failed_tracking",
|
||||
"team_budget",
|
||||
"proxy_budget",
|
||||
"projected_limit_exceeded",
|
||||
],
|
||||
user_max_budget: float,
|
||||
user_current_spend: float,
|
||||
user_info=None,
|
||||
error_message="",
|
||||
user_info: CallInfo,
|
||||
):
|
||||
if self.alerting is None:
|
||||
# do nothing if alerting is not switched on
|
||||
return
|
||||
await self.slack_alerting_instance.budget_alerts(
|
||||
type=type,
|
||||
user_max_budget=user_max_budget,
|
||||
user_current_spend=user_current_spend,
|
||||
user_info=user_info,
|
||||
error_message=error_message,
|
||||
)
|
||||
|
||||
async def alerting_handler(
|
||||
|
@ -344,7 +392,11 @@ class ProxyLogging:
|
|||
for client in self.alerting:
|
||||
if client == "slack":
|
||||
await self.slack_alerting_instance.send_alert(
|
||||
message=message, level=level, alert_type=alert_type, **extra_kwargs
|
||||
message=message,
|
||||
level=level,
|
||||
alert_type=alert_type,
|
||||
user_info=None,
|
||||
**extra_kwargs,
|
||||
)
|
||||
elif client == "sentry":
|
||||
if litellm.utils.sentry_sdk_instance is not None:
|
||||
|
@ -418,9 +470,14 @@ class ProxyLogging:
|
|||
|
||||
Related issue - https://github.com/BerriAI/litellm/issues/3395
|
||||
"""
|
||||
litellm_debug_info = getattr(original_exception, "litellm_debug_info", None)
|
||||
exception_str = str(original_exception)
|
||||
if litellm_debug_info is not None:
|
||||
exception_str += litellm_debug_info
|
||||
|
||||
asyncio.create_task(
|
||||
self.alerting_handler(
|
||||
message=f"LLM API call failed: {str(original_exception)}",
|
||||
message=f"LLM API call failed: `{exception_str}`",
|
||||
level="High",
|
||||
alert_type="llm_exceptions",
|
||||
request_data=request_data,
|
||||
|
@ -1787,7 +1844,9 @@ def hash_token(token: str):
|
|||
return hashed_token
|
||||
|
||||
|
||||
def get_logging_payload(kwargs, response_obj, start_time, end_time):
|
||||
def get_logging_payload(
|
||||
kwargs, response_obj, start_time, end_time, end_user_id: Optional[str]
|
||||
):
|
||||
from litellm.proxy._types import LiteLLM_SpendLogs
|
||||
from pydantic import Json
|
||||
import uuid
|
||||
|
@ -1865,7 +1924,7 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time):
|
|||
"prompt_tokens": usage.get("prompt_tokens", 0),
|
||||
"completion_tokens": usage.get("completion_tokens", 0),
|
||||
"request_tags": metadata.get("tags", []),
|
||||
"end_user": kwargs.get("user", ""),
|
||||
"end_user": end_user_id or "",
|
||||
"api_base": litellm_params.get("api_base", ""),
|
||||
}
|
||||
|
||||
|
@ -2028,6 +2087,11 @@ async def update_spend(
|
|||
raise e
|
||||
|
||||
### UPDATE END-USER TABLE ###
|
||||
verbose_proxy_logger.debug(
|
||||
"End-User Spend transactions: {}".format(
|
||||
len(prisma_client.end_user_list_transactons.keys())
|
||||
)
|
||||
)
|
||||
if len(prisma_client.end_user_list_transactons.keys()) > 0:
|
||||
for i in range(n_retry_times + 1):
|
||||
start_time = time.time()
|
||||
|
@ -2043,13 +2107,18 @@ async def update_spend(
|
|||
max_end_user_budget = None
|
||||
if litellm.max_end_user_budget is not None:
|
||||
max_end_user_budget = litellm.max_end_user_budget
|
||||
new_user_obj = LiteLLM_EndUserTable(
|
||||
user_id=end_user_id, spend=response_cost, blocked=False
|
||||
)
|
||||
batcher.litellm_endusertable.update_many(
|
||||
batcher.litellm_endusertable.upsert(
|
||||
where={"user_id": end_user_id},
|
||||
data={"spend": {"increment": response_cost}},
|
||||
data={
|
||||
"create": {
|
||||
"user_id": end_user_id,
|
||||
"spend": response_cost,
|
||||
"blocked": False,
|
||||
},
|
||||
"update": {"spend": {"increment": response_cost}},
|
||||
},
|
||||
)
|
||||
|
||||
prisma_client.end_user_list_transactons = (
|
||||
{}
|
||||
) # Clear the remaining transactions after processing all batches in the loop.
|
||||
|
|
|
@ -262,13 +262,22 @@ class Router:
|
|||
|
||||
self.retry_after = retry_after
|
||||
self.routing_strategy = routing_strategy
|
||||
self.fallbacks = fallbacks or litellm.fallbacks
|
||||
|
||||
## SETTING FALLBACKS ##
|
||||
### validate if it's set + in correct format
|
||||
_fallbacks = fallbacks or litellm.fallbacks
|
||||
|
||||
self.validate_fallbacks(fallback_param=_fallbacks)
|
||||
### set fallbacks
|
||||
self.fallbacks = _fallbacks
|
||||
|
||||
if default_fallbacks is not None or litellm.default_fallbacks is not None:
|
||||
_fallbacks = default_fallbacks or litellm.default_fallbacks
|
||||
if self.fallbacks is not None:
|
||||
self.fallbacks.append({"*": _fallbacks})
|
||||
else:
|
||||
self.fallbacks = [{"*": _fallbacks}]
|
||||
|
||||
self.context_window_fallbacks = (
|
||||
context_window_fallbacks or litellm.context_window_fallbacks
|
||||
)
|
||||
|
@ -336,6 +345,21 @@ class Router:
|
|||
if self.alerting_config is not None:
|
||||
self._initialize_alerting()
|
||||
|
||||
def validate_fallbacks(self, fallback_param: Optional[List]):
|
||||
if fallback_param is None:
|
||||
return
|
||||
if len(fallback_param) > 0: # if set
|
||||
## for dictionary in list, check if only 1 key in dict
|
||||
for _dict in fallback_param:
|
||||
assert isinstance(_dict, dict), "Item={}, not a dictionary".format(
|
||||
_dict
|
||||
)
|
||||
assert (
|
||||
len(_dict.keys()) == 1
|
||||
), "Only 1 key allows in dictionary. You set={} for dict={}".format(
|
||||
len(_dict.keys()), _dict
|
||||
)
|
||||
|
||||
def routing_strategy_init(self, routing_strategy: str, routing_strategy_args: dict):
|
||||
if routing_strategy == "least-busy":
|
||||
self.leastbusy_logger = LeastBusyLoggingHandler(
|
||||
|
@ -638,6 +662,10 @@ class Router:
|
|||
async def abatch_completion(
|
||||
self, models: List[str], messages: List[Dict[str, str]], **kwargs
|
||||
):
|
||||
"""
|
||||
Async Batch Completion - Batch Process 1 request to multiple model_group on litellm.Router
|
||||
Use this for sending the same request to N models
|
||||
"""
|
||||
|
||||
async def _async_completion_no_exceptions(
|
||||
model: str, messages: List[Dict[str, str]], **kwargs
|
||||
|
@ -662,6 +690,51 @@ class Router:
|
|||
response = await asyncio.gather(*_tasks)
|
||||
return response
|
||||
|
||||
async def abatch_completion_one_model_multiple_requests(
|
||||
self, model: str, messages: List[List[Dict[str, str]]], **kwargs
|
||||
):
|
||||
"""
|
||||
Async Batch Completion - Batch Process multiple Messages to one model_group on litellm.Router
|
||||
|
||||
Use this for sending multiple requests to 1 model
|
||||
|
||||
Args:
|
||||
model (List[str]): model group
|
||||
messages (List[List[Dict[str, str]]]): list of messages. Each element in the list is one request
|
||||
**kwargs: additional kwargs
|
||||
Usage:
|
||||
response = await self.abatch_completion_one_model_multiple_requests(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[
|
||||
[{"role": "user", "content": "hello"}, {"role": "user", "content": "tell me something funny"}],
|
||||
[{"role": "user", "content": "hello good mornign"}],
|
||||
]
|
||||
)
|
||||
"""
|
||||
|
||||
async def _async_completion_no_exceptions(
|
||||
model: str, messages: List[Dict[str, str]], **kwargs
|
||||
):
|
||||
"""
|
||||
Wrapper around self.async_completion that catches exceptions and returns them as a result
|
||||
"""
|
||||
try:
|
||||
return await self.acompletion(model=model, messages=messages, **kwargs)
|
||||
except Exception as e:
|
||||
return e
|
||||
|
||||
_tasks = []
|
||||
for message_request in messages:
|
||||
# add each task but if the task fails
|
||||
_tasks.append(
|
||||
_async_completion_no_exceptions(
|
||||
model=model, messages=message_request, **kwargs
|
||||
)
|
||||
)
|
||||
|
||||
response = await asyncio.gather(*_tasks)
|
||||
return response
|
||||
|
||||
def image_generation(self, prompt: str, model: str, **kwargs):
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
|
@ -1899,10 +1972,28 @@ class Router:
|
|||
metadata = kwargs.get("litellm_params", {}).get("metadata", None)
|
||||
_model_info = kwargs.get("litellm_params", {}).get("model_info", {})
|
||||
|
||||
exception_response = getattr(exception, "response", {})
|
||||
exception_headers = getattr(exception_response, "headers", None)
|
||||
_time_to_cooldown = self.cooldown_time
|
||||
|
||||
if exception_headers is not None:
|
||||
|
||||
_time_to_cooldown = (
|
||||
litellm.utils._get_retry_after_from_exception_header(
|
||||
response_headers=exception_headers
|
||||
)
|
||||
)
|
||||
|
||||
if _time_to_cooldown < 0:
|
||||
# if the response headers did not read it -> set to default cooldown time
|
||||
_time_to_cooldown = self.cooldown_time
|
||||
|
||||
if isinstance(_model_info, dict):
|
||||
deployment_id = _model_info.get("id", None)
|
||||
self._set_cooldown_deployments(
|
||||
exception_status=exception_status, deployment=deployment_id
|
||||
exception_status=exception_status,
|
||||
deployment=deployment_id,
|
||||
time_to_cooldown=_time_to_cooldown,
|
||||
) # setting deployment_id in cooldown deployments
|
||||
if custom_llm_provider:
|
||||
model_name = f"{custom_llm_provider}/{model_name}"
|
||||
|
@ -1962,8 +2053,50 @@ class Router:
|
|||
key=rpm_key, value=request_count, local_only=True
|
||||
) # don't change existing ttl
|
||||
|
||||
def _is_cooldown_required(self, exception_status: Union[str, int]):
|
||||
"""
|
||||
A function to determine if a cooldown is required based on the exception status.
|
||||
|
||||
Parameters:
|
||||
exception_status (Union[str, int]): The status of the exception.
|
||||
|
||||
Returns:
|
||||
bool: True if a cooldown is required, False otherwise.
|
||||
"""
|
||||
try:
|
||||
|
||||
if isinstance(exception_status, str):
|
||||
exception_status = int(exception_status)
|
||||
|
||||
if exception_status >= 400 and exception_status < 500:
|
||||
if exception_status == 429:
|
||||
# Cool down 429 Rate Limit Errors
|
||||
return True
|
||||
|
||||
elif exception_status == 401:
|
||||
# Cool down 401 Auth Errors
|
||||
return True
|
||||
|
||||
elif exception_status == 408:
|
||||
return True
|
||||
|
||||
else:
|
||||
# Do NOT cool down all other 4XX Errors
|
||||
return False
|
||||
|
||||
else:
|
||||
# should cool down for all other errors
|
||||
return True
|
||||
|
||||
except:
|
||||
# Catch all - if any exceptions default to cooling down
|
||||
return True
|
||||
|
||||
def _set_cooldown_deployments(
|
||||
self, exception_status: Union[str, int], deployment: Optional[str] = None
|
||||
self,
|
||||
exception_status: Union[str, int],
|
||||
deployment: Optional[str] = None,
|
||||
time_to_cooldown: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
Add a model to the list of models being cooled down for that minute, if it exceeds the allowed fails / minute
|
||||
|
@ -1975,6 +2108,9 @@ class Router:
|
|||
if deployment is None:
|
||||
return
|
||||
|
||||
if self._is_cooldown_required(exception_status=exception_status) == False:
|
||||
return
|
||||
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
# get current fails for deployment
|
||||
|
@ -1987,6 +2123,8 @@ class Router:
|
|||
f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {self.allowed_fails}"
|
||||
)
|
||||
cooldown_time = self.cooldown_time or 1
|
||||
if time_to_cooldown is not None:
|
||||
cooldown_time = time_to_cooldown
|
||||
|
||||
if isinstance(exception_status, str):
|
||||
try:
|
||||
|
@ -2024,7 +2162,9 @@ class Router:
|
|||
)
|
||||
|
||||
self.send_deployment_cooldown_alert(
|
||||
deployment_id=deployment, exception_status=exception_status
|
||||
deployment_id=deployment,
|
||||
exception_status=exception_status,
|
||||
cooldown_time=cooldown_time,
|
||||
)
|
||||
else:
|
||||
self.failed_calls.set_cache(
|
||||
|
@ -2309,7 +2449,7 @@ class Router:
|
|||
organization = litellm.get_secret(organization_env_name)
|
||||
litellm_params["organization"] = organization
|
||||
|
||||
if "azure" in model_name and isinstance(api_key, str):
|
||||
if "azure" in model_name:
|
||||
if api_base is None or not isinstance(api_base, str):
|
||||
raise ValueError(
|
||||
f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}"
|
||||
|
@ -3185,7 +3325,7 @@ class Router:
|
|||
|
||||
if _rate_limit_error == True: # allow generic fallback logic to take place
|
||||
raise ValueError(
|
||||
f"{RouterErrors.no_deployments_available.value}, passed model={model}"
|
||||
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. Try again in {self.cooldown_time} seconds."
|
||||
)
|
||||
elif _context_window_error == True:
|
||||
raise litellm.ContextWindowExceededError(
|
||||
|
@ -3257,7 +3397,9 @@ class Router:
|
|||
litellm.print_verbose(f"initial list of deployments: {healthy_deployments}")
|
||||
|
||||
if len(healthy_deployments) == 0:
|
||||
raise ValueError(f"No healthy deployment available, passed model={model}. ")
|
||||
raise ValueError(
|
||||
f"No healthy deployment available, passed model={model}. Try again in {self.cooldown_time} seconds"
|
||||
)
|
||||
if litellm.model_alias_map and model in litellm.model_alias_map:
|
||||
model = litellm.model_alias_map[
|
||||
model
|
||||
|
@ -3347,7 +3489,7 @@ class Router:
|
|||
if _allowed_model_region is None:
|
||||
_allowed_model_region = "n/a"
|
||||
raise ValueError(
|
||||
f"{RouterErrors.no_deployments_available.value}, passed model={model}. Enable pre-call-checks={self.enable_pre_call_checks}, allowed_model_region={_allowed_model_region}"
|
||||
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. Enable pre-call-checks={self.enable_pre_call_checks}, allowed_model_region={_allowed_model_region}"
|
||||
)
|
||||
|
||||
if (
|
||||
|
@ -3415,7 +3557,7 @@ class Router:
|
|||
f"get_available_deployment for model: {model}, No deployment available"
|
||||
)
|
||||
raise ValueError(
|
||||
f"{RouterErrors.no_deployments_available.value}, passed model={model}"
|
||||
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}"
|
||||
)
|
||||
verbose_router_logger.info(
|
||||
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
|
||||
|
@ -3545,7 +3687,7 @@ class Router:
|
|||
f"get_available_deployment for model: {model}, No deployment available"
|
||||
)
|
||||
raise ValueError(
|
||||
f"{RouterErrors.no_deployments_available.value}, passed model={model}"
|
||||
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}"
|
||||
)
|
||||
verbose_router_logger.info(
|
||||
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
|
||||
|
@ -3683,7 +3825,10 @@ class Router:
|
|||
print("\033[94m\nInitialized Alerting for litellm.Router\033[0m\n") # noqa
|
||||
|
||||
def send_deployment_cooldown_alert(
|
||||
self, deployment_id: str, exception_status: Union[str, int]
|
||||
self,
|
||||
deployment_id: str,
|
||||
exception_status: Union[str, int],
|
||||
cooldown_time: float,
|
||||
):
|
||||
try:
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||
|
@ -3707,7 +3852,7 @@ class Router:
|
|||
)
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.slack_alerting_instance.send_alert(
|
||||
message=f"Router: Cooling down deployment: {_api_base}, for {self.cooldown_time} seconds. Got exception: {str(exception_status)}. Change 'cooldown_time' + 'allowed_fails' under 'Router Settings' on proxy UI, or via config - https://docs.litellm.ai/docs/proxy/reliability#fallbacks--retries--timeouts--cooldowns",
|
||||
message=f"Router: Cooling down Deployment:\nModel Name: `{_model_name}`\nAPI Base: `{_api_base}`\nCooldown Time: `{cooldown_time} seconds`\nException Status Code: `{str(exception_status)}`\n\nChange 'cooldown_time' + 'allowed_fails' under 'Router Settings' on proxy UI, or via config - https://docs.litellm.ai/docs/proxy/reliability#fallbacks--retries--timeouts--cooldowns",
|
||||
alert_type="cooldown_deployment",
|
||||
level="Low",
|
||||
)
|
||||
|
|
|
@ -27,7 +27,7 @@ class LiteLLMBase(BaseModel):
|
|||
|
||||
|
||||
class RoutingArgs(LiteLLMBase):
|
||||
ttl: int = 1 * 60 * 60 # 1 hour
|
||||
ttl: float = 1 * 60 * 60 # 1 hour
|
||||
lowest_latency_buffer: float = 0
|
||||
max_latency_list_size: int = 10
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -242,12 +242,24 @@ async def test_langfuse_masked_input_output(langfuse_client):
|
|||
response = await create_async_task(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "This is a test"}],
|
||||
metadata={"trace_id": _unique_trace_name, "mask_input": mask_value, "mask_output": mask_value},
|
||||
mock_response="This is a test response"
|
||||
metadata={
|
||||
"trace_id": _unique_trace_name,
|
||||
"mask_input": mask_value,
|
||||
"mask_output": mask_value,
|
||||
},
|
||||
mock_response="This is a test response",
|
||||
)
|
||||
print(response)
|
||||
expected_input = "redacted-by-litellm" if mask_value else {'messages': [{'content': 'This is a test', 'role': 'user'}]}
|
||||
expected_output = "redacted-by-litellm" if mask_value else {'content': 'This is a test response', 'role': 'assistant'}
|
||||
expected_input = (
|
||||
"redacted-by-litellm"
|
||||
if mask_value
|
||||
else {"messages": [{"content": "This is a test", "role": "user"}]}
|
||||
)
|
||||
expected_output = (
|
||||
"redacted-by-litellm"
|
||||
if mask_value
|
||||
else {"content": "This is a test response", "role": "assistant"}
|
||||
)
|
||||
langfuse_client.flush()
|
||||
await asyncio.sleep(2)
|
||||
|
||||
|
@ -262,6 +274,7 @@ async def test_langfuse_masked_input_output(langfuse_client):
|
|||
assert generations[0].input == expected_input
|
||||
assert generations[0].output == expected_output
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_langfuse_logging_metadata(langfuse_client):
|
||||
"""
|
||||
|
@ -523,7 +536,7 @@ def test_langfuse_logging_function_calling():
|
|||
# test_langfuse_logging_function_calling()
|
||||
|
||||
|
||||
def test_langfuse_existing_trace_id():
|
||||
def test_aaalangfuse_existing_trace_id():
|
||||
"""
|
||||
When existing trace id is passed, don't set trace params -> prevents overwriting the trace
|
||||
|
||||
|
@ -577,7 +590,7 @@ def test_langfuse_existing_trace_id():
|
|||
"verbose": False,
|
||||
"custom_llm_provider": "openai",
|
||||
"api_base": "https://api.openai.com/v1/",
|
||||
"litellm_call_id": "508113a1-c6f1-48ce-a3e1-01c6cce9330e",
|
||||
"litellm_call_id": None,
|
||||
"model_alias_map": {},
|
||||
"completion_call_id": None,
|
||||
"metadata": None,
|
||||
|
@ -593,7 +606,7 @@ def test_langfuse_existing_trace_id():
|
|||
"stream": False,
|
||||
"user": None,
|
||||
"call_type": "completion",
|
||||
"litellm_call_id": "508113a1-c6f1-48ce-a3e1-01c6cce9330e",
|
||||
"litellm_call_id": None,
|
||||
"completion_start_time": "2024-05-01 07:31:29.903685",
|
||||
"temperature": 0.1,
|
||||
"extra_body": {},
|
||||
|
@ -633,6 +646,8 @@ def test_langfuse_existing_trace_id():
|
|||
|
||||
trace_id = langfuse_response_object["trace_id"]
|
||||
|
||||
assert trace_id is not None
|
||||
|
||||
langfuse_client.flush()
|
||||
|
||||
time.sleep(2)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# What is this?
|
||||
## Tests slack alerting on proxy logging object
|
||||
|
||||
import sys, json
|
||||
import sys, json, uuid, random
|
||||
import os
|
||||
import io, asyncio
|
||||
from datetime import datetime, timedelta
|
||||
|
@ -22,6 +22,7 @@ import unittest.mock
|
|||
from unittest.mock import AsyncMock
|
||||
import pytest
|
||||
from litellm.router import AlertingConfig, Router
|
||||
from litellm.proxy._types import CallInfo
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -123,7 +124,9 @@ from datetime import datetime, timedelta
|
|||
|
||||
@pytest.fixture
|
||||
def slack_alerting():
|
||||
return SlackAlerting(alerting_threshold=1, internal_usage_cache=DualCache())
|
||||
return SlackAlerting(
|
||||
alerting_threshold=1, internal_usage_cache=DualCache(), alerting=["slack"]
|
||||
)
|
||||
|
||||
|
||||
# Test for hanging LLM responses
|
||||
|
@ -161,7 +164,10 @@ async def test_budget_alerts_crossed(slack_alerting):
|
|||
user_current_spend = 101
|
||||
with patch.object(slack_alerting, "send_alert", new=AsyncMock()) as mock_send_alert:
|
||||
await slack_alerting.budget_alerts(
|
||||
"user_budget", user_max_budget, user_current_spend
|
||||
"user_budget",
|
||||
user_info=CallInfo(
|
||||
token="", spend=user_current_spend, max_budget=user_max_budget
|
||||
),
|
||||
)
|
||||
mock_send_alert.assert_awaited_once()
|
||||
|
||||
|
@ -173,12 +179,18 @@ async def test_budget_alerts_crossed_again(slack_alerting):
|
|||
user_current_spend = 101
|
||||
with patch.object(slack_alerting, "send_alert", new=AsyncMock()) as mock_send_alert:
|
||||
await slack_alerting.budget_alerts(
|
||||
"user_budget", user_max_budget, user_current_spend
|
||||
"user_budget",
|
||||
user_info=CallInfo(
|
||||
token="", spend=user_current_spend, max_budget=user_max_budget
|
||||
),
|
||||
)
|
||||
mock_send_alert.assert_awaited_once()
|
||||
mock_send_alert.reset_mock()
|
||||
await slack_alerting.budget_alerts(
|
||||
"user_budget", user_max_budget, user_current_spend
|
||||
"user_budget",
|
||||
user_info=CallInfo(
|
||||
token="", spend=user_current_spend, max_budget=user_max_budget
|
||||
),
|
||||
)
|
||||
mock_send_alert.assert_not_awaited()
|
||||
|
||||
|
@ -365,21 +377,23 @@ async def test_send_llm_exception_to_slack():
|
|||
@pytest.mark.asyncio
|
||||
async def test_send_daily_reports_ignores_zero_values():
|
||||
router = MagicMock()
|
||||
router.get_model_ids.return_value = ['model1', 'model2', 'model3']
|
||||
router.get_model_ids.return_value = ["model1", "model2", "model3"]
|
||||
|
||||
slack_alerting = SlackAlerting(internal_usage_cache=MagicMock())
|
||||
# model1:failed=None, model2:failed=0, model3:failed=10, model1:latency=0; model2:latency=0; model3:latency=None
|
||||
slack_alerting.internal_usage_cache.async_batch_get_cache = AsyncMock(return_value=[None, 0, 10, 0, 0, None])
|
||||
slack_alerting.internal_usage_cache.async_batch_get_cache = AsyncMock(
|
||||
return_value=[None, 0, 10, 0, 0, None]
|
||||
)
|
||||
slack_alerting.internal_usage_cache.async_batch_set_cache = AsyncMock()
|
||||
|
||||
router.get_model_info.side_effect = lambda x: {"litellm_params": {"model": x}}
|
||||
|
||||
with patch.object(slack_alerting, 'send_alert', new=AsyncMock()) as mock_send_alert:
|
||||
with patch.object(slack_alerting, "send_alert", new=AsyncMock()) as mock_send_alert:
|
||||
result = await slack_alerting.send_daily_reports(router)
|
||||
|
||||
# Check that the send_alert method was called
|
||||
mock_send_alert.assert_called_once()
|
||||
message = mock_send_alert.call_args[1]['message']
|
||||
message = mock_send_alert.call_args[1]["message"]
|
||||
|
||||
# Ensure the message includes only the non-zero, non-None metrics
|
||||
assert "model3" in message
|
||||
|
@ -393,15 +407,91 @@ async def test_send_daily_reports_ignores_zero_values():
|
|||
@pytest.mark.asyncio
|
||||
async def test_send_daily_reports_all_zero_or_none():
|
||||
router = MagicMock()
|
||||
router.get_model_ids.return_value = ['model1', 'model2', 'model3']
|
||||
router.get_model_ids.return_value = ["model1", "model2", "model3"]
|
||||
|
||||
slack_alerting = SlackAlerting(internal_usage_cache=MagicMock())
|
||||
slack_alerting.internal_usage_cache.async_batch_get_cache = AsyncMock(return_value=[None, 0, None, 0, None, 0])
|
||||
slack_alerting.internal_usage_cache.async_batch_get_cache = AsyncMock(
|
||||
return_value=[None, 0, None, 0, None, 0]
|
||||
)
|
||||
|
||||
with patch.object(slack_alerting, 'send_alert', new=AsyncMock()) as mock_send_alert:
|
||||
with patch.object(slack_alerting, "send_alert", new=AsyncMock()) as mock_send_alert:
|
||||
result = await slack_alerting.send_daily_reports(router)
|
||||
|
||||
# Check that the send_alert method was not called
|
||||
mock_send_alert.assert_not_called()
|
||||
|
||||
assert result == False
|
||||
|
||||
|
||||
# test user budget crossed alert sent only once, even if user makes multiple calls
|
||||
@pytest.mark.parametrize(
|
||||
"alerting_type",
|
||||
[
|
||||
"token_budget",
|
||||
"user_budget",
|
||||
"team_budget",
|
||||
"proxy_budget",
|
||||
"projected_limit_exceeded",
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_token_budget_crossed_alerts(alerting_type):
|
||||
slack_alerting = SlackAlerting()
|
||||
|
||||
with patch.object(slack_alerting, "send_alert", new=AsyncMock()) as mock_send_alert:
|
||||
user_info = {
|
||||
"token": "50e55ca5bfbd0759697538e8d23c0cd5031f52d9e19e176d7233b20c7c4d3403",
|
||||
"spend": 86,
|
||||
"max_budget": 100,
|
||||
"user_id": "ishaan@berri.ai",
|
||||
"user_email": "ishaan@berri.ai",
|
||||
"key_alias": "my-test-key",
|
||||
"projected_exceeded_date": "10/20/2024",
|
||||
"projected_spend": 200,
|
||||
}
|
||||
|
||||
user_info = CallInfo(**user_info)
|
||||
|
||||
for _ in range(50):
|
||||
await slack_alerting.budget_alerts(
|
||||
type=alerting_type,
|
||||
user_info=user_info,
|
||||
)
|
||||
mock_send_alert.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"alerting_type",
|
||||
[
|
||||
"token_budget",
|
||||
"user_budget",
|
||||
"team_budget",
|
||||
"proxy_budget",
|
||||
"projected_limit_exceeded",
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_alerting(alerting_type):
|
||||
slack_alerting = SlackAlerting(alerting=["webhook"])
|
||||
|
||||
with patch.object(
|
||||
slack_alerting, "send_webhook_alert", new=AsyncMock()
|
||||
) as mock_send_alert:
|
||||
user_info = {
|
||||
"token": "50e55ca5bfbd0759697538e8d23c0cd5031f52d9e19e176d7233b20c7c4d3403",
|
||||
"spend": 1,
|
||||
"max_budget": 0,
|
||||
"user_id": "ishaan@berri.ai",
|
||||
"user_email": "ishaan@berri.ai",
|
||||
"key_alias": "my-test-key",
|
||||
"projected_exceeded_date": "10/20/2024",
|
||||
"projected_spend": 200,
|
||||
}
|
||||
|
||||
user_info = CallInfo(**user_info)
|
||||
for _ in range(50):
|
||||
await slack_alerting.budget_alerts(
|
||||
type=alerting_type,
|
||||
user_info=user_info,
|
||||
)
|
||||
mock_send_alert.assert_awaited_once()
|
||||
|
|
|
@ -16,6 +16,7 @@ from litellm.tests.test_streaming import streaming_format_tests
|
|||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history
|
||||
|
||||
litellm.num_retries = 3
|
||||
litellm.cache = None
|
||||
|
@ -98,7 +99,7 @@ def load_vertex_ai_credentials():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def get_response():
|
||||
async def test_get_response():
|
||||
load_vertex_ai_credentials()
|
||||
prompt = '\ndef count_nums(arr):\n """\n Write a function count_nums which takes an array of integers and returns\n the number of elements which has a sum of digits > 0.\n If a number is negative, then its first signed digit will be negative:\n e.g. -123 has signed digits -1, 2, and 3.\n >>> count_nums([]) == 0\n >>> count_nums([-1, 11, -11]) == 1\n >>> count_nums([1, 1, 2]) == 3\n """\n'
|
||||
try:
|
||||
|
@ -371,14 +372,13 @@ def test_vertex_ai_stream():
|
|||
"gemini-1.5-pro",
|
||||
"gemini-1.5-pro-preview-0215",
|
||||
]:
|
||||
# our account does not have access to this model
|
||||
# ouraccount does not have access to this model
|
||||
continue
|
||||
print("making request", model)
|
||||
response = completion(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "user", "content": "write 10 line code code for saying hi"}
|
||||
],
|
||||
messages=[{"role": "user", "content": "hello tell me a short story"}],
|
||||
max_tokens=15,
|
||||
stream=True,
|
||||
)
|
||||
completed_str = ""
|
||||
|
@ -389,7 +389,7 @@ def test_vertex_ai_stream():
|
|||
completed_str += content
|
||||
assert type(content) == str
|
||||
# pass
|
||||
assert len(completed_str) > 4
|
||||
assert len(completed_str) > 1
|
||||
except litellm.RateLimitError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
|
@ -595,30 +595,68 @@ def test_gemini_pro_vision_base64():
|
|||
async def test_gemini_pro_function_calling(sync_mode):
|
||||
try:
|
||||
load_vertex_ai_credentials()
|
||||
data = {
|
||||
"model": "vertex_ai/gemini-pro",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Call the submit_cities function with San Francisco and New York",
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "submit_cities",
|
||||
"description": "Submits a list of cities",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"cities": {"type": "array", "items": {"type": "string"}}
|
||||
},
|
||||
"required": ["cities"],
|
||||
litellm.set_verbose = True
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Your name is Litellm Bot, you are a helpful assistant",
|
||||
},
|
||||
# User asks for their name and weather in San Francisco
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, what is your name and can you tell me the weather?",
|
||||
},
|
||||
# Assistant replies with a tool call
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_123",
|
||||
"type": "function",
|
||||
"index": 0,
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location":"San Francisco, CA"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
# The result of the tool call is added to the history
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"name": "get_weather",
|
||||
"content": "27 degrees celsius and clear in San Francisco, CA",
|
||||
},
|
||||
# Now the assistant can reply with the result of the tool call.
|
||||
]
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
}
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
data = {
|
||||
"model": "vertex_ai/gemini-1.5-pro-preview-0514",
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
}
|
||||
if sync_mode:
|
||||
response = litellm.completion(**data)
|
||||
|
@ -638,7 +676,7 @@ async def test_gemini_pro_function_calling(sync_mode):
|
|||
# gemini_pro_function_calling()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [False, True])
|
||||
@pytest.mark.parametrize("sync_mode", [True])
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_pro_function_calling_streaming(sync_mode):
|
||||
load_vertex_ai_credentials()
|
||||
|
@ -713,7 +751,7 @@ async def test_gemini_pro_async_function_calling():
|
|||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"description": "Get the current weather in a given location.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -743,8 +781,9 @@ async def test_gemini_pro_async_function_calling():
|
|||
print(f"completion: {completion}")
|
||||
assert completion.choices[0].message.content is None
|
||||
assert len(completion.choices[0].message.tool_calls) == 1
|
||||
except litellm.APIError as e:
|
||||
pass
|
||||
|
||||
# except litellm.APIError as e:
|
||||
# pass
|
||||
except litellm.RateLimitError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
|
@ -894,3 +933,45 @@ async def test_vertexai_aembedding():
|
|||
# traceback.print_exc()
|
||||
# raise e
|
||||
# test_gemini_pro_vision_async()
|
||||
|
||||
|
||||
def test_prompt_factory():
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Your name is Litellm Bot, you are a helpful assistant",
|
||||
},
|
||||
# User asks for their name and weather in San Francisco
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, what is your name and can you tell me the weather?",
|
||||
},
|
||||
# Assistant replies with a tool call
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_123",
|
||||
"type": "function",
|
||||
"index": 0,
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location":"San Francisco, CA"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
# The result of the tool call is added to the history
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"name": "get_weather",
|
||||
"content": "27 degrees celsius and clear in San Francisco, CA",
|
||||
},
|
||||
# Now the assistant can reply with the result of the tool call.
|
||||
]
|
||||
|
||||
translated_messages = _gemini_convert_messages_with_history(messages=messages)
|
||||
|
||||
print(f"\n\ntranslated_messages: {translated_messages}\ntranslated_messages")
|
||||
|
|
|
@ -206,6 +206,7 @@ def test_completion_bedrock_claude_sts_client_auth():
|
|||
|
||||
# test_completion_bedrock_claude_sts_client_auth()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="We don't have Circle CI OIDC credentials as yet")
|
||||
def test_completion_bedrock_claude_sts_oidc_auth():
|
||||
print("\ncalling bedrock claude with oidc auth")
|
||||
|
@ -244,7 +245,7 @@ def test_bedrock_extra_headers():
|
|||
messages=messages,
|
||||
max_tokens=10,
|
||||
temperature=0.78,
|
||||
extra_headers={"x-key": "x_key_value"}
|
||||
extra_headers={"x-key": "x_key_value"},
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
assert len(response.choices) > 0
|
||||
|
@ -259,7 +260,7 @@ def test_bedrock_claude_3():
|
|||
try:
|
||||
litellm.set_verbose = True
|
||||
data = {
|
||||
"max_tokens": 2000,
|
||||
"max_tokens": 100,
|
||||
"stream": False,
|
||||
"temperature": 0.3,
|
||||
"messages": [
|
||||
|
@ -282,6 +283,7 @@ def test_bedrock_claude_3():
|
|||
}
|
||||
response: ModelResponse = completion(
|
||||
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
num_retries=3,
|
||||
# messages=messages,
|
||||
# max_tokens=10,
|
||||
# temperature=0.78,
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import sys, os
|
||||
import sys, os, json
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
@ -7,7 +7,7 @@ import os, io
|
|||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the, system path
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
import litellm
|
||||
from litellm import embedding, completion, completion_cost, Timeout
|
||||
|
@ -38,7 +38,7 @@ def reset_callbacks():
|
|||
@pytest.mark.skip(reason="Local test")
|
||||
def test_response_model_none():
|
||||
"""
|
||||
Addresses - https://github.com/BerriAI/litellm/issues/2972
|
||||
Addresses:https://github.com/BerriAI/litellm/issues/2972
|
||||
"""
|
||||
x = completion(
|
||||
model="mymodel",
|
||||
|
@ -278,7 +278,8 @@ def test_completion_claude_3_function_call():
|
|||
model="anthropic/claude-3-opus-20240229",
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
tool_choice={"type": "tool", "name": "get_weather"},
|
||||
extra_headers={"anthropic-beta": "tools-2024-05-16"},
|
||||
)
|
||||
# Add any assertions, here to check response args
|
||||
print(response)
|
||||
|
@ -1053,6 +1054,25 @@ def test_completion_azure_gpt4_vision():
|
|||
# test_completion_azure_gpt4_vision()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["gpt-3.5-turbo", "gpt-4", "gpt-4o"])
|
||||
def test_completion_openai_params(model):
|
||||
litellm.drop_params = True
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": """Generate JSON about Bill Gates: { "full_name": "", "title": "" }""",
|
||||
}
|
||||
]
|
||||
|
||||
response = completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
|
||||
print(f"response: {response}")
|
||||
|
||||
|
||||
def test_completion_fireworks_ai():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
|
@ -1161,28 +1181,28 @@ HF Tests we should pass
|
|||
# Test util to sort models to TGI, conv, None
|
||||
def test_get_hf_task_for_model():
|
||||
model = "glaiveai/glaive-coder-7b"
|
||||
model_type = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
|
||||
model_type, _ = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
|
||||
print(f"model:{model}, model type: {model_type}")
|
||||
assert model_type == "text-generation-inference"
|
||||
|
||||
model = "meta-llama/Llama-2-7b-hf"
|
||||
model_type = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
|
||||
model_type, _ = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
|
||||
print(f"model:{model}, model type: {model_type}")
|
||||
assert model_type == "text-generation-inference"
|
||||
|
||||
model = "facebook/blenderbot-400M-distill"
|
||||
model_type = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
|
||||
model_type, _ = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
|
||||
print(f"model:{model}, model type: {model_type}")
|
||||
assert model_type == "conversational"
|
||||
|
||||
model = "facebook/blenderbot-3B"
|
||||
model_type = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
|
||||
model_type, _ = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
|
||||
print(f"model:{model}, model type: {model_type}")
|
||||
assert model_type == "conversational"
|
||||
|
||||
# neither Conv or None
|
||||
model = "roneneldan/TinyStories-3M"
|
||||
model_type = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
|
||||
model_type, _ = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
|
||||
print(f"model:{model}, model type: {model_type}")
|
||||
assert model_type == "text-generation"
|
||||
|
||||
|
@ -2300,36 +2320,30 @@ def test_completion_azure_deployment_id():
|
|||
|
||||
# test_completion_azure_deployment_id()
|
||||
|
||||
# Only works for local endpoint
|
||||
# def test_completion_anthropic_openai_proxy():
|
||||
# try:
|
||||
# response = completion(
|
||||
# model="custom_openai/claude-2",
|
||||
# messages=messages,
|
||||
# api_base="http://0.0.0.0:8000"
|
||||
# )
|
||||
# # Add any assertions here to check the response
|
||||
# print(response)
|
||||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# test_completion_anthropic_openai_proxy()
|
||||
import asyncio
|
||||
|
||||
|
||||
def test_completion_replicate_llama3():
|
||||
@pytest.mark.parametrize("sync_mode", [False, True])
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_replicate_llama3(sync_mode):
|
||||
litellm.set_verbose = True
|
||||
model_name = "replicate/meta/meta-llama-3-8b-instruct"
|
||||
try:
|
||||
response = completion(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
)
|
||||
if sync_mode:
|
||||
response = completion(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
response = await litellm.acompletion(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
)
|
||||
print(f"ASYNC REPLICATE RESPONSE - {response}")
|
||||
print(response)
|
||||
# Add any assertions here to check the response
|
||||
response_str = response["choices"][0]["message"]["content"]
|
||||
print("RESPONSE STRING\n", response_str)
|
||||
if type(response_str) != str:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
assert isinstance(response, litellm.ModelResponse)
|
||||
response_format_tests(response=response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
@ -2670,14 +2684,29 @@ def response_format_tests(response: litellm.ModelResponse):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"bedrock/cohere.command-r-plus-v1:0",
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"anthropic.claude-instant-v1",
|
||||
"bedrock/ai21.j2-mid",
|
||||
"mistral.mistral-7b-instruct-v0:2",
|
||||
"bedrock/amazon.titan-tg1-large",
|
||||
"meta.llama3-8b-instruct-v1:0",
|
||||
"cohere.command-text-v14",
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_bedrock_command_r(sync_mode):
|
||||
async def test_completion_bedrock_httpx_models(sync_mode, model):
|
||||
litellm.set_verbose = True
|
||||
|
||||
if sync_mode:
|
||||
response = completion(
|
||||
model="bedrock/cohere.command-r-plus-v1:0",
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "Hey! how's it going?"}],
|
||||
temperature=0.2,
|
||||
max_tokens=200,
|
||||
)
|
||||
|
||||
assert isinstance(response, litellm.ModelResponse)
|
||||
|
@ -2685,8 +2714,10 @@ async def test_completion_bedrock_command_r(sync_mode):
|
|||
response_format_tests(response=response)
|
||||
else:
|
||||
response = await litellm.acompletion(
|
||||
model="bedrock/cohere.command-r-plus-v1:0",
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "Hey! how's it going?"}],
|
||||
temperature=0.2,
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
assert isinstance(response, litellm.ModelResponse)
|
||||
|
@ -2722,69 +2753,12 @@ def test_completion_bedrock_titan_null_response():
|
|||
pytest.fail(f"An error occurred - {str(e)}")
|
||||
|
||||
|
||||
def test_completion_bedrock_titan():
|
||||
try:
|
||||
response = completion(
|
||||
model="bedrock/amazon.titan-tg1-large",
|
||||
messages=messages,
|
||||
temperature=0.2,
|
||||
max_tokens=200,
|
||||
top_p=0.8,
|
||||
logger_fn=logger_fn,
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
except RateLimitError:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_completion_bedrock_titan()
|
||||
|
||||
|
||||
def test_completion_bedrock_claude():
|
||||
print("calling claude")
|
||||
try:
|
||||
response = completion(
|
||||
model="anthropic.claude-instant-v1",
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
temperature=0.1,
|
||||
logger_fn=logger_fn,
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
except RateLimitError:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_completion_bedrock_claude()
|
||||
|
||||
|
||||
def test_completion_bedrock_cohere():
|
||||
print("calling bedrock cohere")
|
||||
litellm.set_verbose = True
|
||||
try:
|
||||
response = completion(
|
||||
model="bedrock/cohere.command-text-v14",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
temperature=0.1,
|
||||
max_tokens=10,
|
||||
stream=True,
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
for chunk in response:
|
||||
print(chunk)
|
||||
except RateLimitError:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_completion_bedrock_cohere()
|
||||
|
||||
|
||||
|
@ -2807,23 +2781,6 @@ def test_completion_bedrock_cohere():
|
|||
# pytest.fail(f"Error occurred: {e}")
|
||||
# test_completion_bedrock_claude_stream()
|
||||
|
||||
# def test_completion_bedrock_ai21():
|
||||
# try:
|
||||
# litellm.set_verbose = False
|
||||
# response = completion(
|
||||
# model="bedrock/ai21.j2-mid",
|
||||
# messages=messages,
|
||||
# temperature=0.2,
|
||||
# top_p=0.2,
|
||||
# max_tokens=20
|
||||
# )
|
||||
# # Add any assertions here to check the response
|
||||
# print(response)
|
||||
# except RateLimitError:
|
||||
# pass
|
||||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
######## Test VLLM ########
|
||||
# def test_completion_vllm():
|
||||
|
@ -3096,7 +3053,6 @@ def test_mistral_anyscale_stream():
|
|||
print(chunk["choices"][0]["delta"].get("content", ""), end="")
|
||||
|
||||
|
||||
# test_mistral_anyscale_stream()
|
||||
# test_completion_anyscale_2()
|
||||
# def test_completion_with_fallbacks_multiple_keys():
|
||||
# print(f"backup key 1: {os.getenv('BACKUP_OPENAI_API_KEY_1')}")
|
||||
|
@ -3246,6 +3202,7 @@ def test_completion_gemini():
|
|||
response = completion(model=model_name, messages=messages)
|
||||
# Add any assertions,here to check the response
|
||||
print(response)
|
||||
assert response.choices[0]["index"] == 0
|
||||
except litellm.APIError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
|
|
|
@ -65,6 +65,42 @@ async def test_custom_pricing(sync_mode):
|
|||
assert new_handler.response_cost == 0
|
||||
|
||||
|
||||
def test_custom_pricing_as_completion_cost_param():
|
||||
from litellm import ModelResponse, Choices, Message
|
||||
from litellm.utils import Usage
|
||||
|
||||
resp = ModelResponse(
|
||||
id="chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac",
|
||||
choices=[
|
||||
Choices(
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
message=Message(
|
||||
content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a",
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
],
|
||||
created=1700775391,
|
||||
model="ft:gpt-3.5-turbo:my-org:custom_suffix:id",
|
||||
object="chat.completion",
|
||||
system_fingerprint=None,
|
||||
usage=Usage(prompt_tokens=21, completion_tokens=17, total_tokens=38),
|
||||
)
|
||||
|
||||
cost = litellm.completion_cost(
|
||||
completion_response=resp,
|
||||
custom_cost_per_token={
|
||||
"input_cost_per_token": 1000,
|
||||
"output_cost_per_token": 20,
|
||||
},
|
||||
)
|
||||
|
||||
expected_cost = 1000 * 21 + 17 * 20
|
||||
|
||||
assert round(cost, 5) == round(expected_cost, 5)
|
||||
|
||||
|
||||
def test_get_gpt3_tokens():
|
||||
max_tokens = get_max_tokens("gpt-3.5-turbo")
|
||||
print(max_tokens)
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
import sys, os
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import ConfigDict
|
||||
|
||||
load_dotenv()
|
||||
import os, io
|
||||
|
@ -14,36 +13,21 @@ sys.path.insert(
|
|||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the, system path
|
||||
import pytest, litellm
|
||||
from pydantic import BaseModel, VERSION
|
||||
from pydantic import BaseModel
|
||||
from litellm.proxy.proxy_server import ProxyConfig
|
||||
from litellm.proxy.utils import encrypt_value, ProxyLogging, DualCache
|
||||
from litellm.types.router import Deployment, LiteLLM_Params, ModelInfo
|
||||
from typing import Literal
|
||||
|
||||
|
||||
# Function to get Pydantic version
|
||||
def is_pydantic_v2() -> int:
|
||||
return int(VERSION.split(".")[0])
|
||||
|
||||
|
||||
def get_model_config(arbitrary_types_allowed: bool = False) -> ConfigDict:
|
||||
# Version-specific configuration
|
||||
if is_pydantic_v2() >= 2:
|
||||
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=arbitrary_types_allowed, protected_namespaces=()) # type: ignore
|
||||
else:
|
||||
from pydantic import Extra
|
||||
|
||||
model_config = ConfigDict(extra=Extra.allow, arbitrary_types_allowed=arbitrary_types_allowed) # type: ignore
|
||||
|
||||
return model_config
|
||||
|
||||
|
||||
class DBModel(BaseModel):
|
||||
model_id: str
|
||||
model_name: str
|
||||
model_info: dict
|
||||
litellm_params: dict
|
||||
model_config = get_model_config()
|
||||
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -118,7 +102,7 @@ async def test_delete_deployment():
|
|||
pc = ProxyConfig()
|
||||
|
||||
db_model = DBModel(
|
||||
model_id="12340523",
|
||||
model_id=deployment.model_info.id,
|
||||
model_name="gpt-3.5-turbo",
|
||||
litellm_params=encrypted_litellm_params,
|
||||
model_info={"id": deployment.model_info.id},
|
||||
|
|
|
@ -558,7 +558,7 @@ async def test_async_chat_bedrock_stream():
|
|||
continue
|
||||
except:
|
||||
pass
|
||||
time.sleep(1)
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler.errors: {customHandler.errors}")
|
||||
assert len(customHandler.errors) == 0
|
||||
litellm.callbacks = []
|
||||
|
|
|
@ -53,13 +53,6 @@ async def test_content_policy_exception_azure():
|
|||
except litellm.ContentPolicyViolationError as e:
|
||||
print("caught a content policy violation error! Passed")
|
||||
print("exception", e)
|
||||
|
||||
# assert that the first 100 chars of the message is returned in the exception
|
||||
assert (
|
||||
"Messages: [{'role': 'user', 'content': 'where do I buy lethal drugs from'}]"
|
||||
in str(e)
|
||||
)
|
||||
assert "Model: azure/chatgpt-v-2" in str(e)
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
|
@ -585,9 +578,6 @@ def test_router_completion_vertex_exception():
|
|||
pytest.fail("Request should have failed - bad api key")
|
||||
except Exception as e:
|
||||
print("exception: ", e)
|
||||
assert "Model: gemini-pro" in str(e)
|
||||
assert "model_group: vertex-gemini-pro" in str(e)
|
||||
assert "deployment: vertex_ai/gemini-pro" in str(e)
|
||||
|
||||
|
||||
def test_litellm_completion_vertex_exception():
|
||||
|
@ -604,8 +594,26 @@ def test_litellm_completion_vertex_exception():
|
|||
pytest.fail("Request should have failed - bad api key")
|
||||
except Exception as e:
|
||||
print("exception: ", e)
|
||||
assert "Model: gemini-pro" in str(e)
|
||||
assert "vertex_project: bad-project" in str(e)
|
||||
|
||||
|
||||
def test_litellm_predibase_exception():
|
||||
"""
|
||||
Test - Assert that the Predibase API Key is not returned on Authentication Errors
|
||||
"""
|
||||
try:
|
||||
import litellm
|
||||
|
||||
litellm.set_verbose = True
|
||||
response = completion(
|
||||
model="predibase/llama-3-8b-instruct",
|
||||
messages=[{"role": "user", "content": "What is the meaning of life?"}],
|
||||
tenant_id="c4768f95",
|
||||
api_key="hf-rawapikey",
|
||||
)
|
||||
pytest.fail("Request should have failed - bad api key")
|
||||
except Exception as e:
|
||||
assert "hf-rawapikey" not in str(e)
|
||||
print("exception: ", e)
|
||||
|
||||
|
||||
# # test_invalid_request_error(model="command-nightly")
|
||||
|
|
|
@ -105,6 +105,9 @@ def test_parallel_function_call(model):
|
|||
# Step 4: send the info for each function call and function response to the model
|
||||
for tool_call in tool_calls:
|
||||
function_name = tool_call.function.name
|
||||
if function_name not in available_functions:
|
||||
# the model called a function that does not exist in available_functions - don't try calling anything
|
||||
return
|
||||
function_to_call = available_functions[function_name]
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
function_response = function_to_call(
|
||||
|
@ -124,7 +127,6 @@ def test_parallel_function_call(model):
|
|||
model=model, messages=messages, temperature=0.2, seed=22
|
||||
) # get a new response from the model where it can see the function response
|
||||
print("second response\n", second_response)
|
||||
return second_response
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
|
|
@ -162,6 +162,39 @@ async def test_aimage_generation_bedrock_with_optional_params():
|
|||
print(f"response: {response}")
|
||||
except litellm.RateLimitError as e:
|
||||
pass
|
||||
except litellm.ContentPolicyViolationError:
|
||||
pass # Azure randomly raises these errors skip when they occur
|
||||
except Exception as e:
|
||||
if "Your task failed as a result of our safety system." in str(e):
|
||||
pass
|
||||
else:
|
||||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aimage_generation_vertex_ai():
|
||||
from test_amazing_vertex_completion import load_vertex_ai_credentials
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
load_vertex_ai_credentials()
|
||||
try:
|
||||
response = await litellm.aimage_generation(
|
||||
prompt="An olympic size swimming pool",
|
||||
model="vertex_ai/imagegeneration@006",
|
||||
vertex_ai_project="adroit-crow-413218",
|
||||
vertex_ai_location="us-central1",
|
||||
n=1,
|
||||
)
|
||||
assert response.data is not None
|
||||
assert len(response.data) > 0
|
||||
|
||||
for d in response.data:
|
||||
assert isinstance(d, litellm.ImageObject)
|
||||
print("data in response.data", d)
|
||||
assert d.b64_json is not None
|
||||
except litellm.RateLimitError as e:
|
||||
pass
|
||||
except litellm.ContentPolicyViolationError:
|
||||
pass # Azure randomly raises these errors - skip when they occur
|
||||
except Exception as e:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#### What this tests ####
|
||||
# Unit tests for JWT-Auth
|
||||
|
||||
import sys, os, asyncio, time, random
|
||||
import sys, os, asyncio, time, random, uuid
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
@ -24,6 +24,7 @@ public_key = {
|
|||
"alg": "RS256",
|
||||
}
|
||||
|
||||
|
||||
def test_load_config_with_custom_role_names():
|
||||
config = {
|
||||
"general_settings": {
|
||||
|
@ -77,7 +78,8 @@ async def test_token_single_public_key():
|
|||
== "qIgOQfEVrrErJC0E7gsHXi6rs_V0nyFY5qPFui2-tv0o4CwpwDzgfBtLO7o_wLiguq0lnu54sMT2eLNoRiiPuLvv6bg7Iy1H9yc5_4Jf5oYEOrqN5o9ZBOoYp1q68Pv0oNJYyZdGu5ZJfd7V4y953vB2XfEKgXCsAkhVhlvIUMiDNKWoMDWsyb2xela5tRURZ2mJAXcHfSC_sYdZxIA2YYrIHfoevq_vTlaz0qVSe_uOKjEpgOAS08UUrgda4CQL11nzICiIQzc6qmjIQt2cjzB2D_9zb4BYndzEtfl0kwAT0z_I85S3mkwTqHU-1BvKe_4MG4VG3dAAeffLPXJyXQ"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('audience', [None, "litellm-proxy"])
|
||||
|
||||
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_invalid_token(audience):
|
||||
"""
|
||||
|
@ -90,7 +92,7 @@ async def test_valid_invalid_token(audience):
|
|||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
|
||||
os.environ.pop('JWT_AUDIENCE', None)
|
||||
os.environ.pop("JWT_AUDIENCE", None)
|
||||
if audience:
|
||||
os.environ["JWT_AUDIENCE"] = audience
|
||||
|
||||
|
@ -138,7 +140,7 @@ async def test_valid_invalid_token(audience):
|
|||
"sub": "user123",
|
||||
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||
"scope": "litellm-proxy-admin",
|
||||
"aud": audience
|
||||
"aud": audience,
|
||||
}
|
||||
|
||||
# Generate the JWT token
|
||||
|
@ -166,7 +168,7 @@ async def test_valid_invalid_token(audience):
|
|||
"sub": "user123",
|
||||
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||
"scope": "litellm-NO-SCOPE",
|
||||
"aud": audience
|
||||
"aud": audience,
|
||||
}
|
||||
|
||||
# Generate the JWT token
|
||||
|
@ -183,6 +185,7 @@ async def test_valid_invalid_token(audience):
|
|||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prisma_client():
|
||||
import litellm
|
||||
|
@ -205,7 +208,7 @@ def prisma_client():
|
|||
return prisma_client
|
||||
|
||||
|
||||
@pytest.mark.parametrize('audience', [None, "litellm-proxy"])
|
||||
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_token_output(prisma_client, audience):
|
||||
import jwt, json
|
||||
|
@ -222,7 +225,7 @@ async def test_team_token_output(prisma_client, audience):
|
|||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
|
||||
os.environ.pop('JWT_AUDIENCE', None)
|
||||
os.environ.pop("JWT_AUDIENCE", None)
|
||||
if audience:
|
||||
os.environ["JWT_AUDIENCE"] = audience
|
||||
|
||||
|
@ -261,7 +264,7 @@ async def test_team_token_output(prisma_client, audience):
|
|||
|
||||
jwt_handler.user_api_key_cache = cache
|
||||
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth()
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_id_jwt_field="client_id")
|
||||
|
||||
# VALID TOKEN
|
||||
## GENERATE A TOKEN
|
||||
|
@ -274,7 +277,7 @@ async def test_team_token_output(prisma_client, audience):
|
|||
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||
"scope": "litellm_team",
|
||||
"client_id": team_id,
|
||||
"aud": audience
|
||||
"aud": audience,
|
||||
}
|
||||
|
||||
# Generate the JWT token
|
||||
|
@ -289,7 +292,7 @@ async def test_team_token_output(prisma_client, audience):
|
|||
"sub": "user123",
|
||||
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||
"scope": "litellm_proxy_admin",
|
||||
"aud": audience
|
||||
"aud": audience,
|
||||
}
|
||||
|
||||
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")
|
||||
|
@ -315,7 +318,13 @@ async def test_team_token_output(prisma_client, audience):
|
|||
|
||||
## 1. INITIAL TEAM CALL - should fail
|
||||
# use generated key to auth in
|
||||
setattr(litellm.proxy.proxy_server, "general_settings", {"enable_jwt_auth": True})
|
||||
setattr(
|
||||
litellm.proxy.proxy_server,
|
||||
"general_settings",
|
||||
{
|
||||
"enable_jwt_auth": True,
|
||||
},
|
||||
)
|
||||
setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler)
|
||||
try:
|
||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||
|
@ -358,9 +367,22 @@ async def test_team_token_output(prisma_client, audience):
|
|||
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('audience', [None, "litellm-proxy"])
|
||||
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
|
||||
@pytest.mark.parametrize(
|
||||
"team_id_set, default_team_id",
|
||||
[(True, False), (False, True)],
|
||||
)
|
||||
@pytest.mark.parametrize("user_id_upsert", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_token_output(prisma_client, audience):
|
||||
async def test_user_token_output(
|
||||
prisma_client, audience, team_id_set, default_team_id, user_id_upsert
|
||||
):
|
||||
import uuid
|
||||
|
||||
args = locals()
|
||||
print(f"received args - {args}")
|
||||
if default_team_id:
|
||||
default_team_id = "team_id_12344_{}".format(uuid.uuid4())
|
||||
"""
|
||||
- If user required, check if it exists
|
||||
- fail initial request (when user doesn't exist)
|
||||
|
@ -373,7 +395,12 @@ async def test_user_token_output(prisma_client, audience):
|
|||
from cryptography.hazmat.backends import default_backend
|
||||
from fastapi import Request
|
||||
from starlette.datastructures import URL
|
||||
from litellm.proxy.proxy_server import user_api_key_auth, new_team, new_user
|
||||
from litellm.proxy.proxy_server import (
|
||||
user_api_key_auth,
|
||||
new_team,
|
||||
new_user,
|
||||
user_info,
|
||||
)
|
||||
from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth, NewUserRequest
|
||||
import litellm
|
||||
import uuid
|
||||
|
@ -381,7 +408,7 @@ async def test_user_token_output(prisma_client, audience):
|
|||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
|
||||
os.environ.pop('JWT_AUDIENCE', None)
|
||||
os.environ.pop("JWT_AUDIENCE", None)
|
||||
if audience:
|
||||
os.environ["JWT_AUDIENCE"] = audience
|
||||
|
||||
|
@ -423,6 +450,11 @@ async def test_user_token_output(prisma_client, audience):
|
|||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth()
|
||||
|
||||
jwt_handler.litellm_jwtauth.user_id_jwt_field = "sub"
|
||||
jwt_handler.litellm_jwtauth.team_id_default = default_team_id
|
||||
jwt_handler.litellm_jwtauth.user_id_upsert = user_id_upsert
|
||||
|
||||
if team_id_set:
|
||||
jwt_handler.litellm_jwtauth.team_id_jwt_field = "client_id"
|
||||
|
||||
# VALID TOKEN
|
||||
## GENERATE A TOKEN
|
||||
|
@ -436,7 +468,7 @@ async def test_user_token_output(prisma_client, audience):
|
|||
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||
"scope": "litellm_team",
|
||||
"client_id": team_id,
|
||||
"aud": audience
|
||||
"aud": audience,
|
||||
}
|
||||
|
||||
# Generate the JWT token
|
||||
|
@ -451,7 +483,7 @@ async def test_user_token_output(prisma_client, audience):
|
|||
"sub": user_id,
|
||||
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||
"scope": "litellm_proxy_admin",
|
||||
"aud": audience
|
||||
"aud": audience,
|
||||
}
|
||||
|
||||
admin_token = jwt.encode(payload, private_key_str, algorithm="RS256")
|
||||
|
@ -503,6 +535,16 @@ async def test_user_token_output(prisma_client, audience):
|
|||
),
|
||||
user_api_key_dict=result,
|
||||
)
|
||||
if default_team_id:
|
||||
await new_team(
|
||||
data=NewTeamRequest(
|
||||
team_id=default_team_id,
|
||||
tpm_limit=100,
|
||||
rpm_limit=99,
|
||||
models=["gpt-3.5-turbo", "gpt-4"],
|
||||
),
|
||||
user_api_key_dict=result,
|
||||
)
|
||||
except Exception as e:
|
||||
pytest.fail(f"This should not fail - {str(e)}")
|
||||
|
||||
|
@ -513,23 +555,35 @@ async def test_user_token_output(prisma_client, audience):
|
|||
team_result: UserAPIKeyAuth = await user_api_key_auth(
|
||||
request=request, api_key=bearer_token
|
||||
)
|
||||
pytest.fail(f"User doesn't exist. this should fail")
|
||||
if user_id_upsert == False:
|
||||
pytest.fail(f"User doesn't exist. this should fail")
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
## 4. Create user
|
||||
try:
|
||||
bearer_token = "Bearer " + admin_token
|
||||
if user_id_upsert:
|
||||
## check if user already exists
|
||||
try:
|
||||
bearer_token = "Bearer " + admin_token
|
||||
|
||||
request._url = URL(url="/team/new")
|
||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||
await new_user(
|
||||
data=NewUserRequest(
|
||||
user_id=user_id,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
pytest.fail(f"This should not fail - {str(e)}")
|
||||
request._url = URL(url="/team/new")
|
||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||
await user_info(user_id=user_id)
|
||||
except Exception as e:
|
||||
pytest.fail(f"This should not fail - {str(e)}")
|
||||
else:
|
||||
try:
|
||||
bearer_token = "Bearer " + admin_token
|
||||
|
||||
request._url = URL(url="/team/new")
|
||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||
await new_user(
|
||||
data=NewUserRequest(
|
||||
user_id=user_id,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
pytest.fail(f"This should not fail - {str(e)}")
|
||||
|
||||
## 5. 3rd call w/ same team, same user -> call should succeed
|
||||
bearer_token = "Bearer " + token
|
||||
|
@ -543,7 +597,8 @@ async def test_user_token_output(prisma_client, audience):
|
|||
|
||||
## 6. ASSERT USER_API_KEY_AUTH format (used for tpm/rpm limiting in parallel_request_limiter.py AND cost tracking)
|
||||
|
||||
assert team_result.team_tpm_limit == 100
|
||||
assert team_result.team_rpm_limit == 99
|
||||
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
|
||||
if team_id_set or default_team_id is not None:
|
||||
assert team_result.team_tpm_limit == 100
|
||||
assert team_result.team_rpm_limit == 99
|
||||
assert team_result.team_models == ["gpt-3.5-turbo", "gpt-4"]
|
||||
assert team_result.user_id == user_id
|
||||
|
|
|
@ -23,6 +23,7 @@ import sys, os
|
|||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import Request
|
||||
from fastapi.routing import APIRoute
|
||||
from datetime import datetime
|
||||
|
||||
load_dotenv()
|
||||
|
@ -51,6 +52,13 @@ from litellm.proxy.proxy_server import (
|
|||
user_info,
|
||||
info_key_fn,
|
||||
new_team,
|
||||
chat_completion,
|
||||
completion,
|
||||
embeddings,
|
||||
image_generation,
|
||||
audio_transcriptions,
|
||||
moderations,
|
||||
model_list,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
@ -146,7 +154,38 @@ async def test_new_user_response(prisma_client):
|
|||
pytest.fail(f"Got exception {e}")
|
||||
|
||||
|
||||
def test_generate_and_call_with_valid_key(prisma_client):
|
||||
@pytest.mark.parametrize(
|
||||
"api_route", [
|
||||
# chat_completion
|
||||
APIRoute(path="/engines/{model}/chat/completions", endpoint=chat_completion),
|
||||
APIRoute(path="/openai/deployments/{model}/chat/completions", endpoint=chat_completion),
|
||||
APIRoute(path="/chat/completions", endpoint=chat_completion),
|
||||
APIRoute(path="/v1/chat/completions", endpoint=chat_completion),
|
||||
# completion
|
||||
APIRoute(path="/completions", endpoint=completion),
|
||||
APIRoute(path="/v1/completions", endpoint=completion),
|
||||
APIRoute(path="/engines/{model}/completions", endpoint=completion),
|
||||
APIRoute(path="/openai/deployments/{model}/completions", endpoint=completion),
|
||||
# embeddings
|
||||
APIRoute(path="/v1/embeddings", endpoint=embeddings),
|
||||
APIRoute(path="/embeddings", endpoint=embeddings),
|
||||
APIRoute(path="/openai/deployments/{model}/embeddings", endpoint=embeddings),
|
||||
# image generation
|
||||
APIRoute(path="/v1/images/generations", endpoint=image_generation),
|
||||
APIRoute(path="/images/generations", endpoint=image_generation),
|
||||
# audio transcriptions
|
||||
APIRoute(path="/v1/audio/transcriptions", endpoint=audio_transcriptions),
|
||||
APIRoute(path="/audio/transcriptions", endpoint=audio_transcriptions),
|
||||
# moderations
|
||||
APIRoute(path="/v1/moderations", endpoint=moderations),
|
||||
APIRoute(path="/moderations", endpoint=moderations),
|
||||
# model_list
|
||||
APIRoute(path= "/v1/models", endpoint=model_list),
|
||||
APIRoute(path= "/models", endpoint=model_list),
|
||||
],
|
||||
ids=lambda route: str(dict(route=route.endpoint.__name__, path=route.path)),
|
||||
)
|
||||
def test_generate_and_call_with_valid_key(prisma_client, api_route):
|
||||
# 1. Generate a Key, and use it to make a call
|
||||
|
||||
print("prisma client=", prisma_client)
|
||||
|
@ -181,8 +220,12 @@ def test_generate_and_call_with_valid_key(prisma_client):
|
|||
)
|
||||
print("token from prisma", value_from_prisma)
|
||||
|
||||
request = Request(scope={"type": "http"})
|
||||
request._url = URL(url="/chat/completions")
|
||||
request = Request({
|
||||
"type": "http",
|
||||
"route": api_route,
|
||||
"path": api_route.path,
|
||||
"headers": [("Authorization", bearer_token)]
|
||||
})
|
||||
|
||||
# use generated key to auth in
|
||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||
|
|
|
@ -705,7 +705,7 @@ async def test_lowest_latency_routing_first_pick():
|
|||
) # type: ignore
|
||||
|
||||
deployments = {}
|
||||
for _ in range(5):
|
||||
for _ in range(10):
|
||||
response = await router.acompletion(
|
||||
model="azure-model", messages=[{"role": "user", "content": "hello"}]
|
||||
)
|
||||
|
|
|
@ -28,6 +28,37 @@ from datetime import datetime
|
|||
## On Request failure
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_max_parallel_requests():
|
||||
"""
|
||||
Test if ParallelRequestHandler respects 'global_max_parallel_requests'
|
||||
|
||||
data["metadata"]["global_max_parallel_requests"]
|
||||
"""
|
||||
global_max_parallel_requests = 0
|
||||
_api_key = "sk-12345"
|
||||
_api_key = hash_token("sk-12345")
|
||||
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=100)
|
||||
local_cache = DualCache()
|
||||
parallel_request_handler = MaxParallelRequestsHandler()
|
||||
|
||||
for _ in range(3):
|
||||
try:
|
||||
await parallel_request_handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=local_cache,
|
||||
data={
|
||||
"metadata": {
|
||||
"global_max_parallel_requests": global_max_parallel_requests
|
||||
}
|
||||
},
|
||||
call_type="",
|
||||
)
|
||||
pytest.fail("Expected call to fail")
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pre_call_hook():
|
||||
"""
|
||||
|
|
138
litellm/tests/test_proxy_token_counter.py
Normal file
138
litellm/tests/test_proxy_token_counter.py
Normal file
|
@ -0,0 +1,138 @@
|
|||
# Test the following scenarios:
|
||||
# 1. Generate a Key, and use it to make a call
|
||||
|
||||
|
||||
import sys, os
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import Request
|
||||
from datetime import datetime
|
||||
|
||||
load_dotenv()
|
||||
import os, io, time
|
||||
|
||||
# this file is to test litellm/proxy
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest, logging, asyncio
|
||||
import litellm, asyncio
|
||||
from litellm.proxy.proxy_server import token_counter
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
verbose_proxy_logger.setLevel(level=logging.DEBUG)
|
||||
|
||||
from litellm.proxy._types import TokenCountRequest, TokenCountResponse
|
||||
|
||||
|
||||
from litellm import Router
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vLLM_token_counting():
|
||||
"""
|
||||
Test Token counter for vLLM models
|
||||
- User passes model="special-alias"
|
||||
- token_counter should infer that special_alias -> maps to wolfram/miquliz-120b-v2.0
|
||||
-> token counter should use hugging face tokenizer
|
||||
"""
|
||||
|
||||
llm_router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "special-alias",
|
||||
"litellm_params": {
|
||||
"model": "openai/wolfram/miquliz-120b-v2.0",
|
||||
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "llm_router", llm_router)
|
||||
|
||||
response = await token_counter(
|
||||
request=TokenCountRequest(
|
||||
model="special-alias",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
)
|
||||
)
|
||||
|
||||
print("response: ", response)
|
||||
|
||||
assert (
|
||||
response.tokenizer_type == "huggingface_tokenizer"
|
||||
) # SHOULD use the hugging face tokenizer
|
||||
assert response.model_used == "wolfram/miquliz-120b-v2.0"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_counting_model_not_in_model_list():
|
||||
"""
|
||||
Test Token counter - when a model is not in model_list
|
||||
-> should use the default OpenAI tokenizer
|
||||
"""
|
||||
|
||||
llm_router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-4",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4",
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "llm_router", llm_router)
|
||||
|
||||
response = await token_counter(
|
||||
request=TokenCountRequest(
|
||||
model="special-alias",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
)
|
||||
)
|
||||
|
||||
print("response: ", response)
|
||||
|
||||
assert (
|
||||
response.tokenizer_type == "openai_tokenizer"
|
||||
) # SHOULD use the OpenAI tokenizer
|
||||
assert response.model_used == "special-alias"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gpt_token_counting():
|
||||
"""
|
||||
Test Token counter
|
||||
-> should work for gpt-4
|
||||
"""
|
||||
|
||||
llm_router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-4",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4",
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "llm_router", llm_router)
|
||||
|
||||
response = await token_counter(
|
||||
request=TokenCountRequest(
|
||||
model="gpt-4",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
)
|
||||
)
|
||||
|
||||
print("response: ", response)
|
||||
|
||||
assert (
|
||||
response.tokenizer_type == "openai_tokenizer"
|
||||
) # SHOULD use the OpenAI tokenizer
|
||||
assert response.request_model == "gpt-4"
|
64
litellm/tests/test_router_cooldowns.py
Normal file
64
litellm/tests/test_router_cooldowns.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
#### What this tests ####
|
||||
# This tests calling router with fallback models
|
||||
|
||||
import sys, os, time
|
||||
import traceback, asyncio
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
import litellm
|
||||
from litellm import Router
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
import openai, httpx
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cooldown_badrequest_error():
|
||||
"""
|
||||
Test 1. It SHOULD NOT cooldown a deployment on a BadRequestError
|
||||
"""
|
||||
|
||||
router = litellm.Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
}
|
||||
],
|
||||
debug_level="DEBUG",
|
||||
set_verbose=True,
|
||||
cooldown_time=300,
|
||||
num_retries=0,
|
||||
allowed_fails=0,
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
try:
|
||||
|
||||
response = await router.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "gm"}],
|
||||
bad_param=200,
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
await asyncio.sleep(3) # wait for deployment to get cooled-down
|
||||
|
||||
response = await router.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "gm"}],
|
||||
mock_response="hello",
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
|
||||
print(response)
|
|
@ -82,7 +82,7 @@ def test_async_fallbacks(caplog):
|
|||
# Define the expected log messages
|
||||
# - error request, falling back notice, success notice
|
||||
expected_logs = [
|
||||
"litellm.acompletion(model=gpt-3.5-turbo)\x1b[31m Exception OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: bad-key. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}} \nModel: gpt-3.5-turbo\nAPI Base: https://api.openai.com\nMessages: [{'content': 'Hello, how are you?', 'role': 'user'}]\nmodel_group: gpt-3.5-turbo\n\ndeployment: gpt-3.5-turbo\n\x1b[0m",
|
||||
"litellm.acompletion(model=gpt-3.5-turbo)\x1b[31m Exception OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: bad-key. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}\x1b[0m",
|
||||
"Falling back to model_group = azure/gpt-3.5-turbo",
|
||||
"litellm.acompletion(model=azure/chatgpt-v-2)\x1b[32m 200 OK\x1b[0m",
|
||||
"Successful fallback b/w models.",
|
||||
|
|
|
@ -950,7 +950,63 @@ def test_vertex_ai_stream():
|
|||
|
||||
# test_completion_vertexai_stream_bad_key()
|
||||
|
||||
# def test_completion_replicate_stream():
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [False, True])
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_replicate_llama3_streaming(sync_mode):
|
||||
litellm.set_verbose = True
|
||||
model_name = "replicate/meta/meta-llama-3-8b-instruct"
|
||||
try:
|
||||
if sync_mode:
|
||||
final_chunk: Optional[litellm.ModelResponse] = None
|
||||
response: litellm.CustomStreamWrapper = completion( # type: ignore
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_tokens=10, # type: ignore
|
||||
stream=True,
|
||||
)
|
||||
complete_response = ""
|
||||
# Add any assertions here to check the response
|
||||
has_finish_reason = False
|
||||
for idx, chunk in enumerate(response):
|
||||
final_chunk = chunk
|
||||
chunk, finished = streaming_format_tests(idx, chunk)
|
||||
if finished:
|
||||
has_finish_reason = True
|
||||
break
|
||||
complete_response += chunk
|
||||
if has_finish_reason == False:
|
||||
raise Exception("finish reason not set")
|
||||
if complete_response.strip() == "":
|
||||
raise Exception("Empty response received")
|
||||
else:
|
||||
response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_tokens=100, # type: ignore
|
||||
stream=True,
|
||||
)
|
||||
complete_response = ""
|
||||
# Add any assertions here to check the response
|
||||
has_finish_reason = False
|
||||
idx = 0
|
||||
final_chunk: Optional[litellm.ModelResponse] = None
|
||||
async for chunk in response:
|
||||
final_chunk = chunk
|
||||
chunk, finished = streaming_format_tests(idx, chunk)
|
||||
if finished:
|
||||
has_finish_reason = True
|
||||
break
|
||||
complete_response += chunk
|
||||
idx += 1
|
||||
if has_finish_reason == False:
|
||||
raise Exception("finish reason not set")
|
||||
if complete_response.strip() == "":
|
||||
raise Exception("Empty response received")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# TEMP Commented out - replicate throwing an auth error
|
||||
# try:
|
||||
# litellm.set_verbose = True
|
||||
|
@ -984,15 +1040,28 @@ def test_vertex_ai_stream():
|
|||
# pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True])
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
# "bedrock/cohere.command-r-plus-v1:0",
|
||||
# "anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
# "anthropic.claude-instant-v1",
|
||||
# "bedrock/ai21.j2-mid",
|
||||
# "mistral.mistral-7b-instruct-v0:2",
|
||||
# "bedrock/amazon.titan-tg1-large",
|
||||
# "meta.llama3-8b-instruct-v1:0",
|
||||
"cohere.command-text-v14"
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_bedrock_cohere_command_r_streaming(sync_mode):
|
||||
async def test_bedrock_httpx_streaming(sync_mode, model):
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
if sync_mode:
|
||||
final_chunk: Optional[litellm.ModelResponse] = None
|
||||
response: litellm.CustomStreamWrapper = completion( # type: ignore
|
||||
model="bedrock/cohere.command-r-plus-v1:0",
|
||||
model=model,
|
||||
messages=messages,
|
||||
max_tokens=10, # type: ignore
|
||||
stream=True,
|
||||
|
@ -1013,7 +1082,7 @@ async def test_bedrock_cohere_command_r_streaming(sync_mode):
|
|||
raise Exception("Empty response received")
|
||||
else:
|
||||
response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore
|
||||
model="bedrock/cohere.command-r-plus-v1:0",
|
||||
model=model,
|
||||
messages=messages,
|
||||
max_tokens=100, # type: ignore
|
||||
stream=True,
|
||||
|
|
|
@ -174,7 +174,6 @@ def test_load_test_token_counter(model):
|
|||
"""
|
||||
import tiktoken
|
||||
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
messages = [{"role": "user", "content": text}] * 10
|
||||
|
||||
start_time = time.time()
|
||||
|
@ -186,4 +185,4 @@ def test_load_test_token_counter(model):
|
|||
|
||||
total_time = end_time - start_time
|
||||
print("model={}, total test time={}".format(model, total_time))
|
||||
assert total_time < 2, f"Total encoding time > 1.5s, {total_time}"
|
||||
assert total_time < 10, f"Total encoding time > 10s, {total_time}"
|
||||
|
|
|
@ -1,27 +1,10 @@
|
|||
from typing import List, Optional, Union, Iterable, cast
|
||||
from typing import List, Optional, Union, Iterable
|
||||
|
||||
from pydantic import ConfigDict, BaseModel, validator, VERSION
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
from typing_extensions import Literal, Required, TypedDict
|
||||
|
||||
|
||||
# Function to get Pydantic version
|
||||
def is_pydantic_v2() -> int:
|
||||
return int(VERSION.split(".")[0])
|
||||
|
||||
|
||||
def get_model_config() -> ConfigDict:
|
||||
# Version-specific configuration
|
||||
if is_pydantic_v2() >= 2:
|
||||
model_config = ConfigDict(extra="allow", protected_namespaces=()) # type: ignore
|
||||
else:
|
||||
from pydantic import Extra
|
||||
|
||||
model_config = ConfigDict(extra=Extra.allow) # type: ignore
|
||||
|
||||
return model_config
|
||||
|
||||
|
||||
class ChatCompletionSystemMessageParam(TypedDict, total=False):
|
||||
content: Required[str]
|
||||
"""The contents of the system message."""
|
||||
|
@ -208,4 +191,6 @@ class CompletionRequest(BaseModel):
|
|||
api_key: Optional[str] = None
|
||||
model_list: Optional[List[str]] = None
|
||||
|
||||
model_config = get_model_config()
|
||||
class Config:
|
||||
extra = "allow"
|
||||
protected_namespaces = ()
|
||||
|
|
|
@ -1,23 +1,6 @@
|
|||
from typing import List, Optional, Union
|
||||
|
||||
from pydantic import ConfigDict, BaseModel, validator, VERSION
|
||||
|
||||
|
||||
# Function to get Pydantic version
|
||||
def is_pydantic_v2() -> int:
|
||||
return int(VERSION.split(".")[0])
|
||||
|
||||
|
||||
def get_model_config(arbitrary_types_allowed: bool = False) -> ConfigDict:
|
||||
# Version-specific configuration
|
||||
if is_pydantic_v2() >= 2:
|
||||
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=arbitrary_types_allowed, protected_namespaces=()) # type: ignore
|
||||
else:
|
||||
from pydantic import Extra
|
||||
|
||||
model_config = ConfigDict(extra=Extra.allow, arbitrary_types_allowed=arbitrary_types_allowed) # type: ignore
|
||||
|
||||
return model_config
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
|
@ -34,4 +17,7 @@ class EmbeddingRequest(BaseModel):
|
|||
litellm_call_id: Optional[str] = None
|
||||
litellm_logging_obj: Optional[dict] = None
|
||||
logger_fn: Optional[str] = None
|
||||
model_config = get_model_config()
|
||||
|
||||
class Config:
|
||||
# allow kwargs
|
||||
extra = "allow"
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
__all__ = ["openai"]
|
||||
|
||||
from . import openai
|
53
litellm/types/llms/vertex_ai.py
Normal file
53
litellm/types/llms/vertex_ai.py
Normal file
|
@ -0,0 +1,53 @@
|
|||
from typing import TypedDict, Any, Union, Optional, List, Literal, Dict
|
||||
import json
|
||||
from typing_extensions import (
|
||||
Self,
|
||||
Protocol,
|
||||
TypeGuard,
|
||||
override,
|
||||
get_origin,
|
||||
runtime_checkable,
|
||||
Required,
|
||||
)
|
||||
|
||||
|
||||
class Field(TypedDict):
|
||||
key: str
|
||||
value: Dict[str, Any]
|
||||
|
||||
|
||||
class FunctionCallArgs(TypedDict):
|
||||
fields: Field
|
||||
|
||||
|
||||
class FunctionResponse(TypedDict):
|
||||
name: str
|
||||
response: FunctionCallArgs
|
||||
|
||||
|
||||
class FunctionCall(TypedDict):
|
||||
name: str
|
||||
args: FunctionCallArgs
|
||||
|
||||
|
||||
class FileDataType(TypedDict):
|
||||
mime_type: str
|
||||
file_uri: str # the cloud storage uri of storing this file
|
||||
|
||||
|
||||
class BlobType(TypedDict):
|
||||
mime_type: Required[str]
|
||||
data: Required[bytes]
|
||||
|
||||
|
||||
class PartType(TypedDict, total=False):
|
||||
text: str
|
||||
inline_data: BlobType
|
||||
file_data: FileDataType
|
||||
function_call: FunctionCall
|
||||
function_response: FunctionResponse
|
||||
|
||||
|
||||
class ContentType(TypedDict, total=False):
|
||||
role: Literal["user", "model"]
|
||||
parts: Required[List[PartType]]
|
|
@ -1,42 +1,19 @@
|
|||
from typing import List, Optional, Union, Dict, Tuple, Literal, TypedDict
|
||||
import httpx
|
||||
from pydantic import (
|
||||
ConfigDict,
|
||||
BaseModel,
|
||||
validator,
|
||||
Field,
|
||||
__version__ as pydantic_version,
|
||||
VERSION,
|
||||
)
|
||||
from pydantic import BaseModel, validator, Field
|
||||
from .completion import CompletionRequest
|
||||
from .embedding import EmbeddingRequest
|
||||
import uuid, enum
|
||||
|
||||
|
||||
# Function to get Pydantic version
|
||||
def is_pydantic_v2() -> int:
|
||||
return int(VERSION.split(".")[0])
|
||||
|
||||
|
||||
def get_model_config(arbitrary_types_allowed: bool = False) -> ConfigDict:
|
||||
# Version-specific configuration
|
||||
if is_pydantic_v2() >= 2:
|
||||
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=arbitrary_types_allowed, protected_namespaces=()) # type: ignore
|
||||
else:
|
||||
from pydantic import Extra
|
||||
|
||||
model_config = ConfigDict(extra=Extra.allow, arbitrary_types_allowed=arbitrary_types_allowed) # type: ignore
|
||||
|
||||
return model_config
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
model_name: str
|
||||
litellm_params: Union[CompletionRequest, EmbeddingRequest]
|
||||
tpm: int
|
||||
rpm: int
|
||||
|
||||
model_config = get_model_config()
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
|
||||
class RouterConfig(BaseModel):
|
||||
|
@ -67,7 +44,8 @@ class RouterConfig(BaseModel):
|
|||
"latency-based-routing",
|
||||
] = "simple-shuffle"
|
||||
|
||||
model_config = get_model_config()
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
|
||||
class UpdateRouterConfig(BaseModel):
|
||||
|
@ -87,7 +65,8 @@ class UpdateRouterConfig(BaseModel):
|
|||
fallbacks: Optional[List[dict]] = None
|
||||
context_window_fallbacks: Optional[List[dict]] = None
|
||||
|
||||
model_config = get_model_config()
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
|
@ -97,6 +76,9 @@ class ModelInfo(BaseModel):
|
|||
db_model: bool = (
|
||||
False # used for proxy - to separate models which are stored in the db vs. config.
|
||||
)
|
||||
base_model: Optional[str] = (
|
||||
None # specify if the base model is azure/gpt-3.5-turbo etc for accurate cost tracking
|
||||
)
|
||||
|
||||
def __init__(self, id: Optional[Union[str, int]] = None, **params):
|
||||
if id is None:
|
||||
|
@ -105,7 +87,8 @@ class ModelInfo(BaseModel):
|
|||
id = str(id)
|
||||
super().__init__(id=id, **params)
|
||||
|
||||
model_config = get_model_config()
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
def __contains__(self, key):
|
||||
# Define custom behavior for the 'in' operator
|
||||
|
@ -200,15 +183,9 @@ class GenericLiteLLMParams(BaseModel):
|
|||
max_retries = int(max_retries) # cast to int
|
||||
super().__init__(max_retries=max_retries, **args, **params)
|
||||
|
||||
model_config = get_model_config(arbitrary_types_allowed=True)
|
||||
if pydantic_version.startswith("1"):
|
||||
# pydantic v2 warns about using a Config class.
|
||||
# But without this, pydantic v1 will raise an error:
|
||||
# RuntimeError: no validator found for <class 'openai.Timeout'>,
|
||||
# see `arbitrary_types_allowed` in Config
|
||||
# Putting arbitrary_types_allowed = True in the ConfigDict doesn't work in pydantic v1.
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
class Config:
|
||||
extra = "allow"
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __contains__(self, key):
|
||||
# Define custom behavior for the 'in' operator
|
||||
|
@ -267,16 +244,9 @@ class LiteLLM_Params(GenericLiteLLMParams):
|
|||
max_retries = int(max_retries) # cast to int
|
||||
super().__init__(max_retries=max_retries, **args, **params)
|
||||
|
||||
model_config = get_model_config(arbitrary_types_allowed=True)
|
||||
|
||||
if pydantic_version.startswith("1"):
|
||||
# pydantic v2 warns about using a Config class.
|
||||
# But without this, pydantic v1 will raise an error:
|
||||
# RuntimeError: no validator found for <class 'openai.Timeout'>,
|
||||
# see `arbitrary_types_allowed` in Config
|
||||
# Putting arbitrary_types_allowed = True in the ConfigDict doesn't work in pydantic v1.
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
class Config:
|
||||
extra = "allow"
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __contains__(self, key):
|
||||
# Define custom behavior for the 'in' operator
|
||||
|
@ -306,7 +276,8 @@ class updateDeployment(BaseModel):
|
|||
litellm_params: Optional[updateLiteLLMParams] = None
|
||||
model_info: Optional[ModelInfo] = None
|
||||
|
||||
model_config = get_model_config()
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
|
||||
class LiteLLMParamsTypedDict(TypedDict, total=False):
|
||||
|
@ -380,7 +351,9 @@ class Deployment(BaseModel):
|
|||
# if using pydantic v1
|
||||
return self.dict(**kwargs)
|
||||
|
||||
model_config = get_model_config()
|
||||
class Config:
|
||||
extra = "allow"
|
||||
protected_namespaces = ()
|
||||
|
||||
def __contains__(self, key):
|
||||
# Define custom behavior for the 'in' operator
|
||||
|
|
6
litellm/types/utils.py
Normal file
6
litellm/types/utils.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from typing import List, Optional, Union, Dict, Tuple, Literal, TypedDict
|
||||
|
||||
|
||||
class CostPerToken(TypedDict):
|
||||
input_cost_per_token: float
|
||||
output_cost_per_token: float
|
622
litellm/utils.py
622
litellm/utils.py
File diff suppressed because it is too large
Load diff
|
@ -234,6 +234,24 @@
|
|||
"litellm_provider": "openai",
|
||||
"mode": "chat"
|
||||
},
|
||||
"ft:davinci-002": {
|
||||
"max_tokens": 16384,
|
||||
"max_input_tokens": 16384,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.000002,
|
||||
"output_cost_per_token": 0.000002,
|
||||
"litellm_provider": "text-completion-openai",
|
||||
"mode": "completion"
|
||||
},
|
||||
"ft:babbage-002": {
|
||||
"max_tokens": 16384,
|
||||
"max_input_tokens": 16384,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.0000004,
|
||||
"output_cost_per_token": 0.0000004,
|
||||
"litellm_provider": "text-completion-openai",
|
||||
"mode": "completion"
|
||||
},
|
||||
"text-embedding-3-large": {
|
||||
"max_tokens": 8191,
|
||||
"max_input_tokens": 8191,
|
||||
|
@ -1385,6 +1403,24 @@
|
|||
"mode": "completion",
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||
},
|
||||
"gemini/gemini-1.5-flash-latest": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 1000000,
|
||||
"max_output_tokens": 8192,
|
||||
"max_images_per_prompt": 3000,
|
||||
"max_videos_per_prompt": 10,
|
||||
"max_video_length": 1,
|
||||
"max_audio_length_hours": 8.4,
|
||||
"max_audio_per_prompt": 1,
|
||||
"max_pdf_size_mb": 30,
|
||||
"input_cost_per_token": 0,
|
||||
"output_cost_per_token": 0,
|
||||
"litellm_provider": "vertex_ai-language-models",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||
},
|
||||
"gemini/gemini-pro": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 32760,
|
||||
|
@ -1744,6 +1780,30 @@
|
|||
"litellm_provider": "openrouter",
|
||||
"mode": "chat"
|
||||
},
|
||||
"openrouter/openai/gpt-4o": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.000005,
|
||||
"output_cost_per_token": 0.000015,
|
||||
"litellm_provider": "openrouter",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"openrouter/openai/gpt-4o-2024-05-13": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.000005,
|
||||
"output_cost_per_token": 0.000015,
|
||||
"litellm_provider": "openrouter",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"openrouter/openai/gpt-4-vision-preview": {
|
||||
"max_tokens": 130000,
|
||||
"input_cost_per_token": 0.00001,
|
||||
|
@ -2943,6 +3003,24 @@
|
|||
"litellm_provider": "ollama",
|
||||
"mode": "completion"
|
||||
},
|
||||
"ollama/llama3": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 8192,
|
||||
"max_output_tokens": 8192,
|
||||
"input_cost_per_token": 0.0,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "ollama",
|
||||
"mode": "chat"
|
||||
},
|
||||
"ollama/llama3:70b": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 8192,
|
||||
"max_output_tokens": 8192,
|
||||
"input_cost_per_token": 0.0,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "ollama",
|
||||
"mode": "chat"
|
||||
},
|
||||
"ollama/mistral": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 8192,
|
||||
|
@ -2952,6 +3030,42 @@
|
|||
"litellm_provider": "ollama",
|
||||
"mode": "completion"
|
||||
},
|
||||
"ollama/mistral-7B-Instruct-v0.1": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 8192,
|
||||
"max_output_tokens": 8192,
|
||||
"input_cost_per_token": 0.0,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "ollama",
|
||||
"mode": "chat"
|
||||
},
|
||||
"ollama/mistral-7B-Instruct-v0.2": {
|
||||
"max_tokens": 32768,
|
||||
"max_input_tokens": 32768,
|
||||
"max_output_tokens": 32768,
|
||||
"input_cost_per_token": 0.0,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "ollama",
|
||||
"mode": "chat"
|
||||
},
|
||||
"ollama/mixtral-8x7B-Instruct-v0.1": {
|
||||
"max_tokens": 32768,
|
||||
"max_input_tokens": 32768,
|
||||
"max_output_tokens": 32768,
|
||||
"input_cost_per_token": 0.0,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "ollama",
|
||||
"mode": "chat"
|
||||
},
|
||||
"ollama/mixtral-8x22B-Instruct-v0.1": {
|
||||
"max_tokens": 65536,
|
||||
"max_input_tokens": 65536,
|
||||
"max_output_tokens": 65536,
|
||||
"input_cost_per_token": 0.0,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "ollama",
|
||||
"mode": "chat"
|
||||
},
|
||||
"ollama/codellama": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 4096,
|
||||
|
|
|
@ -89,6 +89,7 @@ model_list:
|
|||
litellm_params:
|
||||
model: text-completion-openai/gpt-3.5-turbo-instruct
|
||||
litellm_settings:
|
||||
# set_verbose: True # Uncomment this if you want to see verbose logs; not recommended in production
|
||||
drop_params: True
|
||||
# max_budget: 100
|
||||
# budget_duration: 30d
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "litellm"
|
||||
version = "1.37.10"
|
||||
version = "1.37.19"
|
||||
description = "Library to easily interface with LLM API providers"
|
||||
authors = ["BerriAI"]
|
||||
license = "MIT"
|
||||
|
@ -79,7 +79,7 @@ requires = ["poetry-core", "wheel"]
|
|||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.commitizen]
|
||||
version = "1.37.10"
|
||||
version = "1.37.19"
|
||||
version_files = [
|
||||
"pyproject.toml:^version"
|
||||
]
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# LITELLM PROXY DEPENDENCIES #
|
||||
anyio==4.2.0 # openai + http req.
|
||||
openai==1.14.3 # openai req.
|
||||
openai==1.27.0 # openai req.
|
||||
fastapi==0.111.0 # server dep
|
||||
backoff==2.2.1 # server dep
|
||||
pyyaml==6.0.0 # server dep
|
||||
|
|
|
@ -129,7 +129,7 @@ async def test_check_num_callbacks():
|
|||
set(all_litellm_callbacks_1) - set(all_litellm_callbacks_2),
|
||||
)
|
||||
|
||||
assert num_callbacks_1 == num_callbacks_2
|
||||
assert abs(num_callbacks_1 - num_callbacks_2) <= 4
|
||||
|
||||
await asyncio.sleep(30)
|
||||
|
||||
|
@ -142,7 +142,7 @@ async def test_check_num_callbacks():
|
|||
set(all_litellm_callbacks_3) - set(all_litellm_callbacks_2),
|
||||
)
|
||||
|
||||
assert num_callbacks_1 == num_callbacks_2 == num_callbacks_3
|
||||
assert abs(num_callbacks_3 - num_callbacks_2) <= 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -183,7 +183,7 @@ async def test_check_num_callbacks_on_lowest_latency():
|
|||
set(all_litellm_callbacks_2) - set(all_litellm_callbacks_1),
|
||||
)
|
||||
|
||||
assert num_callbacks_1 == num_callbacks_2
|
||||
assert abs(num_callbacks_1 - num_callbacks_2) <= 4
|
||||
|
||||
await asyncio.sleep(30)
|
||||
|
||||
|
@ -196,7 +196,7 @@ async def test_check_num_callbacks_on_lowest_latency():
|
|||
set(all_litellm_callbacks_3) - set(all_litellm_callbacks_2),
|
||||
)
|
||||
|
||||
assert num_callbacks_1 == num_callbacks_2 == num_callbacks_3
|
||||
assert abs(num_callbacks_2 - num_callbacks_3) <= 4
|
||||
|
||||
assert num_alerts_1 == num_alerts_2 == num_alerts_3
|
||||
|
||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -1 +1 @@
|
|||
<!DOCTYPE html><html id="__next_error__"><head><meta charSet="utf-8"/><meta name="viewport" content="width=device-width, initial-scale=1"/><link rel="preload" as="script" fetchPriority="low" href="/ui/_next/static/chunks/webpack-de9c0fadf6a94b3b.js" crossorigin=""/><script src="/ui/_next/static/chunks/fd9d1056-f960ab1e6d32b002.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/69-04708d7d4a17c1ee.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/main-app-9b4fb13a7db53edf.js" async="" crossorigin=""></script><title>LiteLLM Dashboard</title><meta name="description" content="LiteLLM Proxy Admin UI"/><link rel="icon" href="/ui/favicon.ico" type="image/x-icon" sizes="16x16"/><meta name="next-size-adjust"/><script src="/ui/_next/static/chunks/polyfills-c67a75d1b6f99dc8.js" crossorigin="" noModule=""></script></head><body><script src="/ui/_next/static/chunks/webpack-de9c0fadf6a94b3b.js" crossorigin="" async=""></script><script>(self.__next_f=self.__next_f||[]).push([0]);self.__next_f.push([2,null])</script><script>self.__next_f.push([1,"1:HL[\"/ui/_next/static/media/c9a5bc6a7c948fb0-s.p.woff2\",\"font\",{\"crossOrigin\":\"\",\"type\":\"font/woff2\"}]\n2:HL[\"/ui/_next/static/css/f04e46b02318b660.css\",\"style\",{\"crossOrigin\":\"\"}]\n0:\"$L3\"\n"])</script><script>self.__next_f.push([1,"4:I[47690,[],\"\"]\n6:I[77831,[],\"\"]\n7:I[7926,[\"936\",\"static/chunks/2f6dbc85-052c4579f80d66ae.js\",\"884\",\"static/chunks/884-7576ee407a2ecbe6.js\",\"931\",\"static/chunks/app/page-6a39771cacf75ea6.js\"],\"\"]\n8:I[5613,[],\"\"]\n9:I[31778,[],\"\"]\nb:I[48955,[],\"\"]\nc:[]\n"])</script><script>self.__next_f.push([1,"3:[[[\"$\",\"link\",\"0\",{\"rel\":\"stylesheet\",\"href\":\"/ui/_next/static/css/f04e46b02318b660.css\",\"precedence\":\"next\",\"crossOrigin\":\"\"}]],[\"$\",\"$L4\",null,{\"buildId\":\"obp5wqVSVDMiDTC414cR8\",\"assetPrefix\":\"/ui\",\"initialCanonicalUrl\":\"/\",\"initialTree\":[\"\",{\"children\":[\"__PAGE__\",{}]},\"$undefined\",\"$undefined\",true],\"initialSeedData\":[\"\",{\"children\":[\"__PAGE__\",{},[\"$L5\",[\"$\",\"$L6\",null,{\"propsForComponent\":{\"params\":{}},\"Component\":\"$7\",\"isStaticGeneration\":true}],null]]},[null,[\"$\",\"html\",null,{\"lang\":\"en\",\"children\":[\"$\",\"body\",null,{\"className\":\"__className_c23dc8\",\"children\":[\"$\",\"$L8\",null,{\"parallelRouterKey\":\"children\",\"segmentPath\":[\"children\"],\"loading\":\"$undefined\",\"loadingStyles\":\"$undefined\",\"loadingScripts\":\"$undefined\",\"hasLoading\":false,\"error\":\"$undefined\",\"errorStyles\":\"$undefined\",\"errorScripts\":\"$undefined\",\"template\":[\"$\",\"$L9\",null,{}],\"templateStyles\":\"$undefined\",\"templateScripts\":\"$undefined\",\"notFound\":[[\"$\",\"title\",null,{\"children\":\"404: This page could not be found.\"}],[\"$\",\"div\",null,{\"style\":{\"fontFamily\":\"system-ui,\\\"Segoe UI\\\",Roboto,Helvetica,Arial,sans-serif,\\\"Apple Color Emoji\\\",\\\"Segoe UI Emoji\\\"\",\"height\":\"100vh\",\"textAlign\":\"center\",\"display\":\"flex\",\"flexDirection\":\"column\",\"alignItems\":\"center\",\"justifyContent\":\"center\"},\"children\":[\"$\",\"div\",null,{\"children\":[[\"$\",\"style\",null,{\"dangerouslySetInnerHTML\":{\"__html\":\"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}\"}}],[\"$\",\"h1\",null,{\"className\":\"next-error-h1\",\"style\":{\"display\":\"inline-block\",\"margin\":\"0 20px 0 0\",\"padding\":\"0 23px 0 0\",\"fontSize\":24,\"fontWeight\":500,\"verticalAlign\":\"top\",\"lineHeight\":\"49px\"},\"children\":\"404\"}],[\"$\",\"div\",null,{\"style\":{\"display\":\"inline-block\"},\"children\":[\"$\",\"h2\",null,{\"style\":{\"fontSize\":14,\"fontWeight\":400,\"lineHeight\":\"49px\",\"margin\":0},\"children\":\"This page could not be found.\"}]}]]}]}]],\"notFoundStyles\":[],\"styles\":null}]}]}],null]],\"initialHead\":[false,\"$La\"],\"globalErrorComponent\":\"$b\",\"missingSlots\":\"$Wc\"}]]\n"])</script><script>self.__next_f.push([1,"a:[[\"$\",\"meta\",\"0\",{\"name\":\"viewport\",\"content\":\"width=device-width, initial-scale=1\"}],[\"$\",\"meta\",\"1\",{\"charSet\":\"utf-8\"}],[\"$\",\"title\",\"2\",{\"children\":\"LiteLLM Dashboard\"}],[\"$\",\"meta\",\"3\",{\"name\":\"description\",\"content\":\"LiteLLM Proxy Admin UI\"}],[\"$\",\"link\",\"4\",{\"rel\":\"icon\",\"href\":\"/ui/favicon.ico\",\"type\":\"image/x-icon\",\"sizes\":\"16x16\"}],[\"$\",\"meta\",\"5\",{\"name\":\"next-size-adjust\"}]]\n5:null\n"])</script><script>self.__next_f.push([1,""])</script></body></html>
|
||||
<!DOCTYPE html><html id="__next_error__"><head><meta charSet="utf-8"/><meta name="viewport" content="width=device-width, initial-scale=1"/><link rel="preload" as="script" fetchPriority="low" href="/ui/_next/static/chunks/webpack-de9c0fadf6a94b3b.js" crossorigin=""/><script src="/ui/_next/static/chunks/fd9d1056-f960ab1e6d32b002.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/69-04708d7d4a17c1ee.js" async="" crossorigin=""></script><script src="/ui/_next/static/chunks/main-app-9b4fb13a7db53edf.js" async="" crossorigin=""></script><title>LiteLLM Dashboard</title><meta name="description" content="LiteLLM Proxy Admin UI"/><link rel="icon" href="/ui/favicon.ico" type="image/x-icon" sizes="16x16"/><meta name="next-size-adjust"/><script src="/ui/_next/static/chunks/polyfills-c67a75d1b6f99dc8.js" crossorigin="" noModule=""></script></head><body><script src="/ui/_next/static/chunks/webpack-de9c0fadf6a94b3b.js" crossorigin="" async=""></script><script>(self.__next_f=self.__next_f||[]).push([0]);self.__next_f.push([2,null])</script><script>self.__next_f.push([1,"1:HL[\"/ui/_next/static/media/c9a5bc6a7c948fb0-s.p.woff2\",\"font\",{\"crossOrigin\":\"\",\"type\":\"font/woff2\"}]\n2:HL[\"/ui/_next/static/css/f04e46b02318b660.css\",\"style\",{\"crossOrigin\":\"\"}]\n0:\"$L3\"\n"])</script><script>self.__next_f.push([1,"4:I[47690,[],\"\"]\n6:I[77831,[],\"\"]\n7:I[4858,[\"936\",\"static/chunks/2f6dbc85-052c4579f80d66ae.js\",\"884\",\"static/chunks/884-7576ee407a2ecbe6.js\",\"931\",\"static/chunks/app/page-f20fdea77aed85ba.js\"],\"\"]\n8:I[5613,[],\"\"]\n9:I[31778,[],\"\"]\nb:I[48955,[],\"\"]\nc:[]\n"])</script><script>self.__next_f.push([1,"3:[[[\"$\",\"link\",\"0\",{\"rel\":\"stylesheet\",\"href\":\"/ui/_next/static/css/f04e46b02318b660.css\",\"precedence\":\"next\",\"crossOrigin\":\"\"}]],[\"$\",\"$L4\",null,{\"buildId\":\"l-0LDfSCdaUCAbcLIx_QC\",\"assetPrefix\":\"/ui\",\"initialCanonicalUrl\":\"/\",\"initialTree\":[\"\",{\"children\":[\"__PAGE__\",{}]},\"$undefined\",\"$undefined\",true],\"initialSeedData\":[\"\",{\"children\":[\"__PAGE__\",{},[\"$L5\",[\"$\",\"$L6\",null,{\"propsForComponent\":{\"params\":{}},\"Component\":\"$7\",\"isStaticGeneration\":true}],null]]},[null,[\"$\",\"html\",null,{\"lang\":\"en\",\"children\":[\"$\",\"body\",null,{\"className\":\"__className_c23dc8\",\"children\":[\"$\",\"$L8\",null,{\"parallelRouterKey\":\"children\",\"segmentPath\":[\"children\"],\"loading\":\"$undefined\",\"loadingStyles\":\"$undefined\",\"loadingScripts\":\"$undefined\",\"hasLoading\":false,\"error\":\"$undefined\",\"errorStyles\":\"$undefined\",\"errorScripts\":\"$undefined\",\"template\":[\"$\",\"$L9\",null,{}],\"templateStyles\":\"$undefined\",\"templateScripts\":\"$undefined\",\"notFound\":[[\"$\",\"title\",null,{\"children\":\"404: This page could not be found.\"}],[\"$\",\"div\",null,{\"style\":{\"fontFamily\":\"system-ui,\\\"Segoe UI\\\",Roboto,Helvetica,Arial,sans-serif,\\\"Apple Color Emoji\\\",\\\"Segoe UI Emoji\\\"\",\"height\":\"100vh\",\"textAlign\":\"center\",\"display\":\"flex\",\"flexDirection\":\"column\",\"alignItems\":\"center\",\"justifyContent\":\"center\"},\"children\":[\"$\",\"div\",null,{\"children\":[[\"$\",\"style\",null,{\"dangerouslySetInnerHTML\":{\"__html\":\"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}\"}}],[\"$\",\"h1\",null,{\"className\":\"next-error-h1\",\"style\":{\"display\":\"inline-block\",\"margin\":\"0 20px 0 0\",\"padding\":\"0 23px 0 0\",\"fontSize\":24,\"fontWeight\":500,\"verticalAlign\":\"top\",\"lineHeight\":\"49px\"},\"children\":\"404\"}],[\"$\",\"div\",null,{\"style\":{\"display\":\"inline-block\"},\"children\":[\"$\",\"h2\",null,{\"style\":{\"fontSize\":14,\"fontWeight\":400,\"lineHeight\":\"49px\",\"margin\":0},\"children\":\"This page could not be found.\"}]}]]}]}]],\"notFoundStyles\":[],\"styles\":null}]}]}],null]],\"initialHead\":[false,\"$La\"],\"globalErrorComponent\":\"$b\",\"missingSlots\":\"$Wc\"}]]\n"])</script><script>self.__next_f.push([1,"a:[[\"$\",\"meta\",\"0\",{\"name\":\"viewport\",\"content\":\"width=device-width, initial-scale=1\"}],[\"$\",\"meta\",\"1\",{\"charSet\":\"utf-8\"}],[\"$\",\"title\",\"2\",{\"children\":\"LiteLLM Dashboard\"}],[\"$\",\"meta\",\"3\",{\"name\":\"description\",\"content\":\"LiteLLM Proxy Admin UI\"}],[\"$\",\"link\",\"4\",{\"rel\":\"icon\",\"href\":\"/ui/favicon.ico\",\"type\":\"image/x-icon\",\"sizes\":\"16x16\"}],[\"$\",\"meta\",\"5\",{\"name\":\"next-size-adjust\"}]]\n5:null\n"])</script><script>self.__next_f.push([1,""])</script></body></html>
|
|
@ -1,7 +1,7 @@
|
|||
2:I[77831,[],""]
|
||||
3:I[7926,["936","static/chunks/2f6dbc85-052c4579f80d66ae.js","884","static/chunks/884-7576ee407a2ecbe6.js","931","static/chunks/app/page-6a39771cacf75ea6.js"],""]
|
||||
3:I[4858,["936","static/chunks/2f6dbc85-052c4579f80d66ae.js","884","static/chunks/884-7576ee407a2ecbe6.js","931","static/chunks/app/page-f20fdea77aed85ba.js"],""]
|
||||
4:I[5613,[],""]
|
||||
5:I[31778,[],""]
|
||||
0:["obp5wqVSVDMiDTC414cR8",[[["",{"children":["__PAGE__",{}]},"$undefined","$undefined",true],["",{"children":["__PAGE__",{},["$L1",["$","$L2",null,{"propsForComponent":{"params":{}},"Component":"$3","isStaticGeneration":true}],null]]},[null,["$","html",null,{"lang":"en","children":["$","body",null,{"className":"__className_c23dc8","children":["$","$L4",null,{"parallelRouterKey":"children","segmentPath":["children"],"loading":"$undefined","loadingStyles":"$undefined","loadingScripts":"$undefined","hasLoading":false,"error":"$undefined","errorStyles":"$undefined","errorScripts":"$undefined","template":["$","$L5",null,{}],"templateStyles":"$undefined","templateScripts":"$undefined","notFound":[["$","title",null,{"children":"404: This page could not be found."}],["$","div",null,{"style":{"fontFamily":"system-ui,\"Segoe UI\",Roboto,Helvetica,Arial,sans-serif,\"Apple Color Emoji\",\"Segoe UI Emoji\"","height":"100vh","textAlign":"center","display":"flex","flexDirection":"column","alignItems":"center","justifyContent":"center"},"children":["$","div",null,{"children":[["$","style",null,{"dangerouslySetInnerHTML":{"__html":"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}"}}],["$","h1",null,{"className":"next-error-h1","style":{"display":"inline-block","margin":"0 20px 0 0","padding":"0 23px 0 0","fontSize":24,"fontWeight":500,"verticalAlign":"top","lineHeight":"49px"},"children":"404"}],["$","div",null,{"style":{"display":"inline-block"},"children":["$","h2",null,{"style":{"fontSize":14,"fontWeight":400,"lineHeight":"49px","margin":0},"children":"This page could not be found."}]}]]}]}]],"notFoundStyles":[],"styles":null}]}]}],null]],[[["$","link","0",{"rel":"stylesheet","href":"/ui/_next/static/css/f04e46b02318b660.css","precedence":"next","crossOrigin":""}]],"$L6"]]]]
|
||||
0:["l-0LDfSCdaUCAbcLIx_QC",[[["",{"children":["__PAGE__",{}]},"$undefined","$undefined",true],["",{"children":["__PAGE__",{},["$L1",["$","$L2",null,{"propsForComponent":{"params":{}},"Component":"$3","isStaticGeneration":true}],null]]},[null,["$","html",null,{"lang":"en","children":["$","body",null,{"className":"__className_c23dc8","children":["$","$L4",null,{"parallelRouterKey":"children","segmentPath":["children"],"loading":"$undefined","loadingStyles":"$undefined","loadingScripts":"$undefined","hasLoading":false,"error":"$undefined","errorStyles":"$undefined","errorScripts":"$undefined","template":["$","$L5",null,{}],"templateStyles":"$undefined","templateScripts":"$undefined","notFound":[["$","title",null,{"children":"404: This page could not be found."}],["$","div",null,{"style":{"fontFamily":"system-ui,\"Segoe UI\",Roboto,Helvetica,Arial,sans-serif,\"Apple Color Emoji\",\"Segoe UI Emoji\"","height":"100vh","textAlign":"center","display":"flex","flexDirection":"column","alignItems":"center","justifyContent":"center"},"children":["$","div",null,{"children":[["$","style",null,{"dangerouslySetInnerHTML":{"__html":"body{color:#000;background:#fff;margin:0}.next-error-h1{border-right:1px solid rgba(0,0,0,.3)}@media (prefers-color-scheme:dark){body{color:#fff;background:#000}.next-error-h1{border-right:1px solid rgba(255,255,255,.3)}}"}}],["$","h1",null,{"className":"next-error-h1","style":{"display":"inline-block","margin":"0 20px 0 0","padding":"0 23px 0 0","fontSize":24,"fontWeight":500,"verticalAlign":"top","lineHeight":"49px"},"children":"404"}],["$","div",null,{"style":{"display":"inline-block"},"children":["$","h2",null,{"style":{"fontSize":14,"fontWeight":400,"lineHeight":"49px","margin":0},"children":"This page could not be found."}]}]]}]}]],"notFoundStyles":[],"styles":null}]}]}],null]],[[["$","link","0",{"rel":"stylesheet","href":"/ui/_next/static/css/f04e46b02318b660.css","precedence":"next","crossOrigin":""}]],"$L6"]]]]
|
||||
6:[["$","meta","0",{"name":"viewport","content":"width=device-width, initial-scale=1"}],["$","meta","1",{"charSet":"utf-8"}],["$","title","2",{"children":"LiteLLM Dashboard"}],["$","meta","3",{"name":"description","content":"LiteLLM Proxy Admin UI"}],["$","link","4",{"rel":"icon","href":"/ui/favicon.ico","type":"image/x-icon","sizes":"16x16"}],["$","meta","5",{"name":"next-size-adjust"}]]
|
||||
1:null
|
||||
|
|
|
@ -23,12 +23,44 @@ import {
|
|||
AccordionHeader,
|
||||
AccordionList,
|
||||
} from "@tremor/react";
|
||||
import { TabPanel, TabPanels, TabGroup, TabList, Tab, Icon } from "@tremor/react";
|
||||
import { getCallbacksCall, setCallbacksCall, serviceHealthCheck } from "./networking";
|
||||
import { Modal, Form, Input, Select, Button as Button2, message } from "antd";
|
||||
import { InformationCircleIcon, PencilAltIcon, PencilIcon, StatusOnlineIcon, TrashIcon, RefreshIcon } from "@heroicons/react/outline";
|
||||
import {
|
||||
TabPanel,
|
||||
TabPanels,
|
||||
TabGroup,
|
||||
TabList,
|
||||
Tab,
|
||||
Icon,
|
||||
} from "@tremor/react";
|
||||
import {
|
||||
getCallbacksCall,
|
||||
setCallbacksCall,
|
||||
getGeneralSettingsCall,
|
||||
serviceHealthCheck,
|
||||
updateConfigFieldSetting,
|
||||
deleteConfigFieldSetting,
|
||||
} from "./networking";
|
||||
import {
|
||||
Modal,
|
||||
Form,
|
||||
Input,
|
||||
Select,
|
||||
Button as Button2,
|
||||
message,
|
||||
InputNumber,
|
||||
} from "antd";
|
||||
import {
|
||||
InformationCircleIcon,
|
||||
PencilAltIcon,
|
||||
PencilIcon,
|
||||
StatusOnlineIcon,
|
||||
TrashIcon,
|
||||
RefreshIcon,
|
||||
CheckCircleIcon,
|
||||
XCircleIcon,
|
||||
QuestionMarkCircleIcon,
|
||||
} from "@heroicons/react/outline";
|
||||
import StaticGenerationSearchParamsBailoutProvider from "next/dist/client/components/static-generation-searchparams-bailout-provider";
|
||||
import AddFallbacks from "./add_fallbacks"
|
||||
import AddFallbacks from "./add_fallbacks";
|
||||
import openai from "openai";
|
||||
import Paragraph from "antd/es/skeleton/Paragraph";
|
||||
|
||||
|
@ -36,7 +68,7 @@ interface GeneralSettingsPageProps {
|
|||
accessToken: string | null;
|
||||
userRole: string | null;
|
||||
userID: string | null;
|
||||
modelData: any
|
||||
modelData: any;
|
||||
}
|
||||
|
||||
async function testFallbackModelResponse(
|
||||
|
@ -65,43 +97,71 @@ async function testFallbackModelResponse(
|
|||
},
|
||||
],
|
||||
// @ts-ignore
|
||||
mock_testing_fallbacks: true
|
||||
mock_testing_fallbacks: true,
|
||||
});
|
||||
|
||||
message.success(
|
||||
<span>
|
||||
Test model=<strong>{selectedModel}</strong>, received model=<strong>{response.model}</strong>.
|
||||
See <a href="#" onClick={() => window.open('https://docs.litellm.ai/docs/proxy/reliability', '_blank')} style={{ textDecoration: 'underline', color: 'blue' }}>curl</a>
|
||||
Test model=<strong>{selectedModel}</strong>, received model=
|
||||
<strong>{response.model}</strong>. See{" "}
|
||||
<a
|
||||
href="#"
|
||||
onClick={() =>
|
||||
window.open(
|
||||
"https://docs.litellm.ai/docs/proxy/reliability",
|
||||
"_blank"
|
||||
)
|
||||
}
|
||||
style={{ textDecoration: "underline", color: "blue" }}
|
||||
>
|
||||
curl
|
||||
</a>
|
||||
</span>
|
||||
);
|
||||
} catch (error) {
|
||||
message.error(`Error occurred while generating model response. Please try again. Error: ${error}`, 20);
|
||||
message.error(
|
||||
`Error occurred while generating model response. Please try again. Error: ${error}`,
|
||||
20
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
interface AccordionHeroProps {
|
||||
selectedStrategy: string | null;
|
||||
strategyArgs: routingStrategyArgs;
|
||||
paramExplanation: { [key: string]: string }
|
||||
paramExplanation: { [key: string]: string };
|
||||
}
|
||||
|
||||
interface routingStrategyArgs {
|
||||
ttl?: number;
|
||||
lowest_latency_buffer?: number;
|
||||
ttl?: number;
|
||||
lowest_latency_buffer?: number;
|
||||
}
|
||||
|
||||
interface generalSettingsItem {
|
||||
field_name: string;
|
||||
field_type: string;
|
||||
field_value: any;
|
||||
field_description: string;
|
||||
stored_in_db: boolean | null;
|
||||
}
|
||||
|
||||
const defaultLowestLatencyArgs: routingStrategyArgs = {
|
||||
"ttl": 3600,
|
||||
"lowest_latency_buffer": 0
|
||||
}
|
||||
ttl: 3600,
|
||||
lowest_latency_buffer: 0,
|
||||
};
|
||||
|
||||
export const AccordionHero: React.FC<AccordionHeroProps> = ({ selectedStrategy, strategyArgs, paramExplanation }) => (
|
||||
export const AccordionHero: React.FC<AccordionHeroProps> = ({
|
||||
selectedStrategy,
|
||||
strategyArgs,
|
||||
paramExplanation,
|
||||
}) => (
|
||||
<Accordion>
|
||||
<AccordionHeader className="text-sm font-medium text-tremor-content-strong dark:text-dark-tremor-content-strong">Routing Strategy Specific Args</AccordionHeader>
|
||||
<AccordionBody>
|
||||
{
|
||||
selectedStrategy == "latency-based-routing" ?
|
||||
<Card>
|
||||
<AccordionHeader className="text-sm font-medium text-tremor-content-strong dark:text-dark-tremor-content-strong">
|
||||
Routing Strategy Specific Args
|
||||
</AccordionHeader>
|
||||
<AccordionBody>
|
||||
{selectedStrategy == "latency-based-routing" ? (
|
||||
<Card>
|
||||
<Table>
|
||||
<TableHead>
|
||||
<TableRow>
|
||||
|
@ -114,51 +174,75 @@ export const AccordionHero: React.FC<AccordionHeroProps> = ({ selectedStrategy,
|
|||
<TableRow key={param}>
|
||||
<TableCell>
|
||||
<Text>{param}</Text>
|
||||
<p style={{fontSize: '0.65rem', color: '#808080', fontStyle: 'italic'}} className="mt-1">{paramExplanation[param]}</p>
|
||||
<p
|
||||
style={{
|
||||
fontSize: "0.65rem",
|
||||
color: "#808080",
|
||||
fontStyle: "italic",
|
||||
}}
|
||||
className="mt-1"
|
||||
>
|
||||
{paramExplanation[param]}
|
||||
</p>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<TextInput
|
||||
name={param}
|
||||
defaultValue={
|
||||
typeof value === 'object' ? JSON.stringify(value, null, 2) : value.toString()
|
||||
}
|
||||
/>
|
||||
name={param}
|
||||
defaultValue={
|
||||
typeof value === "object"
|
||||
? JSON.stringify(value, null, 2)
|
||||
: value.toString()
|
||||
}
|
||||
/>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</Card>
|
||||
: <Text>No specific settings</Text>
|
||||
}
|
||||
</AccordionBody>
|
||||
</Accordion>
|
||||
</Card>
|
||||
) : (
|
||||
<Text>No specific settings</Text>
|
||||
)}
|
||||
</AccordionBody>
|
||||
</Accordion>
|
||||
);
|
||||
|
||||
const GeneralSettings: React.FC<GeneralSettingsPageProps> = ({
|
||||
accessToken,
|
||||
userRole,
|
||||
userID,
|
||||
modelData
|
||||
modelData,
|
||||
}) => {
|
||||
const [routerSettings, setRouterSettings] = useState<{ [key: string]: any }>({});
|
||||
const [routerSettings, setRouterSettings] = useState<{ [key: string]: any }>(
|
||||
{}
|
||||
);
|
||||
const [generalSettingsDict, setGeneralSettingsDict] = useState<{
|
||||
[key: string]: any;
|
||||
}>({});
|
||||
const [generalSettings, setGeneralSettings] = useState<generalSettingsItem[]>(
|
||||
[]
|
||||
);
|
||||
const [isModalVisible, setIsModalVisible] = useState(false);
|
||||
const [form] = Form.useForm();
|
||||
const [selectedCallback, setSelectedCallback] = useState<string | null>(null);
|
||||
const [selectedStrategy, setSelectedStrategy] = useState<string | null>(null)
|
||||
const [strategySettings, setStrategySettings] = useState<routingStrategyArgs | null>(null);
|
||||
const [selectedStrategy, setSelectedStrategy] = useState<string | null>(null);
|
||||
const [strategySettings, setStrategySettings] =
|
||||
useState<routingStrategyArgs | null>(null);
|
||||
|
||||
let paramExplanation: { [key: string]: string } = {
|
||||
"routing_strategy_args": "(dict) Arguments to pass to the routing strategy",
|
||||
"routing_strategy": "(string) Routing strategy to use",
|
||||
"allowed_fails": "(int) Number of times a deployment can fail before being added to cooldown",
|
||||
"cooldown_time": "(int) time in seconds to cooldown a deployment after failure",
|
||||
"num_retries": "(int) Number of retries for failed requests. Defaults to 0.",
|
||||
"timeout": "(float) Timeout for requests. Defaults to None.",
|
||||
"retry_after": "(int) Minimum time to wait before retrying a failed request",
|
||||
"ttl": "(int) Sliding window to look back over when calculating the average latency of a deployment. Default - 1 hour (in seconds).",
|
||||
"lowest_latency_buffer": "(float) Shuffle between deployments within this % of the lowest latency. Default - 0 (i.e. always pick lowest latency)."
|
||||
}
|
||||
routing_strategy_args: "(dict) Arguments to pass to the routing strategy",
|
||||
routing_strategy: "(string) Routing strategy to use",
|
||||
allowed_fails:
|
||||
"(int) Number of times a deployment can fail before being added to cooldown",
|
||||
cooldown_time:
|
||||
"(int) time in seconds to cooldown a deployment after failure",
|
||||
num_retries: "(int) Number of retries for failed requests. Defaults to 0.",
|
||||
timeout: "(float) Timeout for requests. Defaults to None.",
|
||||
retry_after: "(int) Minimum time to wait before retrying a failed request",
|
||||
ttl: "(int) Sliding window to look back over when calculating the average latency of a deployment. Default - 1 hour (in seconds).",
|
||||
lowest_latency_buffer:
|
||||
"(float) Shuffle between deployments within this % of the lowest latency. Default - 0 (i.e. always pick lowest latency).",
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (!accessToken || !userRole || !userID) {
|
||||
|
@ -169,6 +253,10 @@ const GeneralSettings: React.FC<GeneralSettingsPageProps> = ({
|
|||
let router_settings = data.router_settings;
|
||||
setRouterSettings(router_settings);
|
||||
});
|
||||
getGeneralSettingsCall(accessToken).then((data) => {
|
||||
let general_settings = data;
|
||||
setGeneralSettings(general_settings);
|
||||
});
|
||||
}, [accessToken, userRole, userID]);
|
||||
|
||||
const handleAddCallback = () => {
|
||||
|
@ -190,8 +278,8 @@ const GeneralSettings: React.FC<GeneralSettingsPageProps> = ({
|
|||
return;
|
||||
}
|
||||
|
||||
console.log(`received key: ${key}`)
|
||||
console.log(`routerSettings['fallbacks']: ${routerSettings['fallbacks']}`)
|
||||
console.log(`received key: ${key}`);
|
||||
console.log(`routerSettings['fallbacks']: ${routerSettings["fallbacks"]}`);
|
||||
|
||||
routerSettings["fallbacks"].map((dict: { [key: string]: any }) => {
|
||||
// Check if the dictionary has the specified key and delete it if present
|
||||
|
@ -202,18 +290,73 @@ const GeneralSettings: React.FC<GeneralSettingsPageProps> = ({
|
|||
});
|
||||
|
||||
const payload = {
|
||||
router_settings: routerSettings
|
||||
router_settings: routerSettings,
|
||||
};
|
||||
|
||||
try {
|
||||
await setCallbacksCall(accessToken, payload);
|
||||
setRouterSettings({ ...routerSettings });
|
||||
setSelectedStrategy(routerSettings["routing_strategy"])
|
||||
setSelectedStrategy(routerSettings["routing_strategy"]);
|
||||
message.success("Router settings updated successfully");
|
||||
} catch (error) {
|
||||
message.error("Failed to update router settings: " + error, 20);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const handleInputChange = (fieldName: string, newValue: any) => {
|
||||
// Update the value in the state
|
||||
const updatedSettings = generalSettings.map((setting) =>
|
||||
setting.field_name === fieldName
|
||||
? { ...setting, field_value: newValue }
|
||||
: setting
|
||||
);
|
||||
setGeneralSettings(updatedSettings);
|
||||
};
|
||||
|
||||
const handleUpdateField = (fieldName: string, idx: number) => {
|
||||
if (!accessToken) {
|
||||
return;
|
||||
}
|
||||
|
||||
let fieldValue = generalSettings[idx].field_value;
|
||||
|
||||
if (fieldValue == null || fieldValue == undefined) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
updateConfigFieldSetting(accessToken, fieldName, fieldValue);
|
||||
// update value in state
|
||||
|
||||
const updatedSettings = generalSettings.map((setting) =>
|
||||
setting.field_name === fieldName
|
||||
? { ...setting, stored_in_db: true }
|
||||
: setting
|
||||
);
|
||||
setGeneralSettings(updatedSettings);
|
||||
} catch (error) {
|
||||
// do something
|
||||
}
|
||||
};
|
||||
|
||||
const handleResetField = (fieldName: string, idx: number) => {
|
||||
if (!accessToken) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
deleteConfigFieldSetting(accessToken, fieldName);
|
||||
// update value in state
|
||||
|
||||
const updatedSettings = generalSettings.map((setting) =>
|
||||
setting.field_name === fieldName
|
||||
? { ...setting, stored_in_db: null, field_value: null }
|
||||
: setting
|
||||
);
|
||||
setGeneralSettings(updatedSettings);
|
||||
} catch (error) {
|
||||
// do something
|
||||
}
|
||||
};
|
||||
|
||||
const handleSaveChanges = (router_settings: any) => {
|
||||
if (!accessToken) {
|
||||
|
@ -223,39 +366,55 @@ const GeneralSettings: React.FC<GeneralSettingsPageProps> = ({
|
|||
console.log("router_settings", router_settings);
|
||||
|
||||
const updatedVariables = Object.fromEntries(
|
||||
Object.entries(router_settings).map(([key, value]) => {
|
||||
if (key !== 'routing_strategy_args' && key !== "routing_strategy") {
|
||||
return [key, (document.querySelector(`input[name="${key}"]`) as HTMLInputElement)?.value || value];
|
||||
}
|
||||
else if (key == "routing_strategy") {
|
||||
return [key, selectedStrategy]
|
||||
}
|
||||
else if (key == "routing_strategy_args" && selectedStrategy == "latency-based-routing") {
|
||||
let setRoutingStrategyArgs: routingStrategyArgs = {}
|
||||
Object.entries(router_settings)
|
||||
.map(([key, value]) => {
|
||||
if (key !== "routing_strategy_args" && key !== "routing_strategy") {
|
||||
return [
|
||||
key,
|
||||
(
|
||||
document.querySelector(
|
||||
`input[name="${key}"]`
|
||||
) as HTMLInputElement
|
||||
)?.value || value,
|
||||
];
|
||||
} else if (key == "routing_strategy") {
|
||||
return [key, selectedStrategy];
|
||||
} else if (
|
||||
key == "routing_strategy_args" &&
|
||||
selectedStrategy == "latency-based-routing"
|
||||
) {
|
||||
let setRoutingStrategyArgs: routingStrategyArgs = {};
|
||||
|
||||
const lowestLatencyBufferElement = document.querySelector(`input[name="lowest_latency_buffer"]`) as HTMLInputElement;
|
||||
const ttlElement = document.querySelector(`input[name="ttl"]`) as HTMLInputElement;
|
||||
const lowestLatencyBufferElement = document.querySelector(
|
||||
`input[name="lowest_latency_buffer"]`
|
||||
) as HTMLInputElement;
|
||||
const ttlElement = document.querySelector(
|
||||
`input[name="ttl"]`
|
||||
) as HTMLInputElement;
|
||||
|
||||
if (lowestLatencyBufferElement?.value) {
|
||||
setRoutingStrategyArgs["lowest_latency_buffer"] = Number(lowestLatencyBufferElement.value)
|
||||
if (lowestLatencyBufferElement?.value) {
|
||||
setRoutingStrategyArgs["lowest_latency_buffer"] = Number(
|
||||
lowestLatencyBufferElement.value
|
||||
);
|
||||
}
|
||||
|
||||
if (ttlElement?.value) {
|
||||
setRoutingStrategyArgs["ttl"] = Number(ttlElement.value);
|
||||
}
|
||||
|
||||
console.log(`setRoutingStrategyArgs: ${setRoutingStrategyArgs}`);
|
||||
return ["routing_strategy_args", setRoutingStrategyArgs];
|
||||
}
|
||||
|
||||
if (ttlElement?.value) {
|
||||
setRoutingStrategyArgs["ttl"] = Number(ttlElement.value)
|
||||
}
|
||||
|
||||
console.log(`setRoutingStrategyArgs: ${setRoutingStrategyArgs}`)
|
||||
return [
|
||||
"routing_strategy_args", setRoutingStrategyArgs
|
||||
]
|
||||
}
|
||||
return null;
|
||||
}).filter(entry => entry !== null && entry !== undefined) as Iterable<[string, unknown]>
|
||||
return null;
|
||||
})
|
||||
.filter((entry) => entry !== null && entry !== undefined) as Iterable<
|
||||
[string, unknown]
|
||||
>
|
||||
);
|
||||
console.log("updatedVariables", updatedVariables);
|
||||
|
||||
const payload = {
|
||||
router_settings: updatedVariables
|
||||
router_settings: updatedVariables,
|
||||
};
|
||||
|
||||
try {
|
||||
|
@ -267,115 +426,238 @@ const GeneralSettings: React.FC<GeneralSettingsPageProps> = ({
|
|||
message.success("router settings updated successfully");
|
||||
};
|
||||
|
||||
|
||||
|
||||
if (!accessToken) {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
return (
|
||||
<div className="w-full mx-4">
|
||||
<TabGroup className="gap-2 p-8 h-[75vh] w-full mt-2">
|
||||
<TabList variant="line" defaultValue="1">
|
||||
<Tab value="1">General Settings</Tab>
|
||||
<Tab value="1">Loadbalancing</Tab>
|
||||
<Tab value="2">Fallbacks</Tab>
|
||||
<Tab value="3">General</Tab>
|
||||
</TabList>
|
||||
<TabPanels>
|
||||
<TabPanel>
|
||||
<Grid numItems={1} className="gap-2 p-8 w-full mt-2">
|
||||
<Title>Router Settings</Title>
|
||||
<Card >
|
||||
<Table>
|
||||
<TableHead>
|
||||
<TableRow>
|
||||
<TableHeaderCell>Setting</TableHeaderCell>
|
||||
<TableHeaderCell>Value</TableHeaderCell>
|
||||
</TableRow>
|
||||
</TableHead>
|
||||
<TableBody>
|
||||
{Object.entries(routerSettings).filter(([param, value]) => param != "fallbacks" && param != "context_window_fallbacks" && param != "routing_strategy_args").map(([param, value]) => (
|
||||
<TableRow key={param}>
|
||||
<TableCell>
|
||||
<Text>{param}</Text>
|
||||
<p style={{fontSize: '0.65rem', color: '#808080', fontStyle: 'italic'}} className="mt-1">{paramExplanation[param]}</p>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{
|
||||
param == "routing_strategy" ?
|
||||
<Select2 defaultValue={value} className="w-full max-w-md" onValueChange={setSelectedStrategy}>
|
||||
<SelectItem value="usage-based-routing">usage-based-routing</SelectItem>
|
||||
<SelectItem value="latency-based-routing">latency-based-routing</SelectItem>
|
||||
<SelectItem value="simple-shuffle">simple-shuffle</SelectItem>
|
||||
</Select2> :
|
||||
<TextInput
|
||||
name={param}
|
||||
defaultValue={
|
||||
typeof value === 'object' ? JSON.stringify(value, null, 2) : value.toString()
|
||||
}
|
||||
/>
|
||||
}
|
||||
</TableCell>
|
||||
<Grid numItems={1} className="gap-2 p-8 w-full mt-2">
|
||||
<Title>Router Settings</Title>
|
||||
<Card>
|
||||
<Table>
|
||||
<TableHead>
|
||||
<TableRow>
|
||||
<TableHeaderCell>Setting</TableHeaderCell>
|
||||
<TableHeaderCell>Value</TableHeaderCell>
|
||||
</TableRow>
|
||||
</TableHead>
|
||||
<TableBody>
|
||||
{Object.entries(routerSettings)
|
||||
.filter(
|
||||
([param, value]) =>
|
||||
param != "fallbacks" &&
|
||||
param != "context_window_fallbacks" &&
|
||||
param != "routing_strategy_args"
|
||||
)
|
||||
.map(([param, value]) => (
|
||||
<TableRow key={param}>
|
||||
<TableCell>
|
||||
<Text>{param}</Text>
|
||||
<p
|
||||
style={{
|
||||
fontSize: "0.65rem",
|
||||
color: "#808080",
|
||||
fontStyle: "italic",
|
||||
}}
|
||||
className="mt-1"
|
||||
>
|
||||
{paramExplanation[param]}
|
||||
</p>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{param == "routing_strategy" ? (
|
||||
<Select2
|
||||
defaultValue={value}
|
||||
className="w-full max-w-md"
|
||||
onValueChange={setSelectedStrategy}
|
||||
>
|
||||
<SelectItem value="usage-based-routing">
|
||||
usage-based-routing
|
||||
</SelectItem>
|
||||
<SelectItem value="latency-based-routing">
|
||||
latency-based-routing
|
||||
</SelectItem>
|
||||
<SelectItem value="simple-shuffle">
|
||||
simple-shuffle
|
||||
</SelectItem>
|
||||
</Select2>
|
||||
) : (
|
||||
<TextInput
|
||||
name={param}
|
||||
defaultValue={
|
||||
typeof value === "object"
|
||||
? JSON.stringify(value, null, 2)
|
||||
: value.toString()
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
<AccordionHero
|
||||
selectedStrategy={selectedStrategy}
|
||||
strategyArgs={
|
||||
routerSettings &&
|
||||
routerSettings["routing_strategy_args"] &&
|
||||
Object.keys(routerSettings["routing_strategy_args"])
|
||||
.length > 0
|
||||
? routerSettings["routing_strategy_args"]
|
||||
: defaultLowestLatencyArgs // default value when keys length is 0
|
||||
}
|
||||
paramExplanation={paramExplanation}
|
||||
/>
|
||||
</Card>
|
||||
<Col>
|
||||
<Button
|
||||
className="mt-2"
|
||||
onClick={() => handleSaveChanges(routerSettings)}
|
||||
>
|
||||
Save Changes
|
||||
</Button>
|
||||
</Col>
|
||||
</Grid>
|
||||
</TabPanel>
|
||||
<TabPanel>
|
||||
<Table>
|
||||
<TableHead>
|
||||
<TableRow>
|
||||
<TableHeaderCell>Model Name</TableHeaderCell>
|
||||
<TableHeaderCell>Fallbacks</TableHeaderCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
<AccordionHero
|
||||
selectedStrategy={selectedStrategy}
|
||||
strategyArgs={
|
||||
routerSettings && routerSettings['routing_strategy_args'] && Object.keys(routerSettings['routing_strategy_args']).length > 0
|
||||
? routerSettings['routing_strategy_args']
|
||||
: defaultLowestLatencyArgs // default value when keys length is 0
|
||||
}
|
||||
paramExplanation={paramExplanation}
|
||||
/>
|
||||
</Card>
|
||||
<Col>
|
||||
<Button className="mt-2" onClick={() => handleSaveChanges(routerSettings)}>
|
||||
Save Changes
|
||||
</Button>
|
||||
</Col>
|
||||
</Grid>
|
||||
</TabPanel>
|
||||
<TabPanel>
|
||||
<Table>
|
||||
<TableHead>
|
||||
<TableRow>
|
||||
<TableHeaderCell>Model Name</TableHeaderCell>
|
||||
<TableHeaderCell>Fallbacks</TableHeaderCell>
|
||||
</TableRow>
|
||||
</TableHead>
|
||||
</TableHead>
|
||||
|
||||
<TableBody>
|
||||
{
|
||||
routerSettings["fallbacks"] &&
|
||||
routerSettings["fallbacks"].map((item: Object, index: number) =>
|
||||
Object.entries(item).map(([key, value]) => (
|
||||
<TableRow key={index.toString() + key}>
|
||||
<TableCell>{key}</TableCell>
|
||||
<TableCell>{Array.isArray(value) ? value.join(', ') : value}</TableCell>
|
||||
<TableCell>
|
||||
<Button onClick={() => testFallbackModelResponse(key, accessToken)}>
|
||||
Test Fallback
|
||||
</Button>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<Icon
|
||||
icon={TrashIcon}
|
||||
size="sm"
|
||||
onClick={() => deleteFallbacks(key)}
|
||||
/>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))
|
||||
)
|
||||
}
|
||||
</TableBody>
|
||||
</Table>
|
||||
<AddFallbacks models={modelData?.data ? modelData.data.map((data: any) => data.model_name) : []} accessToken={accessToken} routerSettings={routerSettings} setRouterSettings={setRouterSettings}/>
|
||||
</TabPanel>
|
||||
</TabPanels>
|
||||
</TabGroup>
|
||||
<TableBody>
|
||||
{routerSettings["fallbacks"] &&
|
||||
routerSettings["fallbacks"].map(
|
||||
(item: Object, index: number) =>
|
||||
Object.entries(item).map(([key, value]) => (
|
||||
<TableRow key={index.toString() + key}>
|
||||
<TableCell>{key}</TableCell>
|
||||
<TableCell>
|
||||
{Array.isArray(value) ? value.join(", ") : value}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<Button
|
||||
onClick={() =>
|
||||
testFallbackModelResponse(key, accessToken)
|
||||
}
|
||||
>
|
||||
Test Fallback
|
||||
</Button>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<Icon
|
||||
icon={TrashIcon}
|
||||
size="sm"
|
||||
onClick={() => deleteFallbacks(key)}
|
||||
/>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))
|
||||
)}
|
||||
</TableBody>
|
||||
</Table>
|
||||
<AddFallbacks
|
||||
models={
|
||||
modelData?.data
|
||||
? modelData.data.map((data: any) => data.model_name)
|
||||
: []
|
||||
}
|
||||
accessToken={accessToken}
|
||||
routerSettings={routerSettings}
|
||||
setRouterSettings={setRouterSettings}
|
||||
/>
|
||||
</TabPanel>
|
||||
<TabPanel>
|
||||
<Card>
|
||||
<Table>
|
||||
<TableHead>
|
||||
<TableRow>
|
||||
<TableHeaderCell>Setting</TableHeaderCell>
|
||||
<TableHeaderCell>Value</TableHeaderCell>
|
||||
<TableHeaderCell>Status</TableHeaderCell>
|
||||
<TableHeaderCell>Action</TableHeaderCell>
|
||||
</TableRow>
|
||||
</TableHead>
|
||||
<TableBody>
|
||||
{generalSettings.map((value, index) => (
|
||||
<TableRow key={index}>
|
||||
<TableCell>
|
||||
<Text>{value.field_name}</Text>
|
||||
<p
|
||||
style={{
|
||||
fontSize: "0.65rem",
|
||||
color: "#808080",
|
||||
fontStyle: "italic",
|
||||
}}
|
||||
className="mt-1"
|
||||
>
|
||||
{value.field_description}
|
||||
</p>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{value.field_type == "Integer" ? (
|
||||
<InputNumber
|
||||
step={1}
|
||||
value={value.field_value}
|
||||
onChange={(newValue) =>
|
||||
handleInputChange(value.field_name, newValue)
|
||||
} // Handle value change
|
||||
/>
|
||||
) : null}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{value.stored_in_db == true ? (
|
||||
<Badge icon={CheckCircleIcon} className="text-white">
|
||||
In DB
|
||||
</Badge>
|
||||
) : value.stored_in_db == false ? (
|
||||
<Badge className="text-gray bg-white outline">
|
||||
In Config
|
||||
</Badge>
|
||||
) : (
|
||||
<Badge className="text-gray bg-white outline">
|
||||
Not Set
|
||||
</Badge>
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<Button
|
||||
onClick={() =>
|
||||
handleUpdateField(value.field_name, index)
|
||||
}
|
||||
>
|
||||
Update
|
||||
</Button>
|
||||
<Icon
|
||||
icon={TrashIcon}
|
||||
color="red"
|
||||
onClick={() =>
|
||||
handleResetField(value.field_name, index)
|
||||
}
|
||||
>
|
||||
Reset
|
||||
</Icon>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</Card>
|
||||
</TabPanel>
|
||||
</TabPanels>
|
||||
</TabGroup>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
|
|
@ -121,6 +121,7 @@ const handleSubmit = async (formValues: Record<string, any>, accessToken: string
|
|||
// Iterate through the key-value pairs in formValues
|
||||
litellmParamsObj["model"] = litellm_model
|
||||
let modelName: string = "";
|
||||
console.log("formValues add deployment:", formValues);
|
||||
for (const [key, value] of Object.entries(formValues)) {
|
||||
if (value === '') {
|
||||
continue;
|
||||
|
@ -628,6 +629,7 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
|
|||
let input_cost = "Undefined";
|
||||
let output_cost = "Undefined";
|
||||
let max_tokens = "Undefined";
|
||||
let max_input_tokens = "Undefined";
|
||||
let cleanedLitellmParams = {};
|
||||
|
||||
const getProviderFromModel = (model: string) => {
|
||||
|
@ -664,6 +666,7 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
|
|||
input_cost = model_info?.input_cost_per_token;
|
||||
output_cost = model_info?.output_cost_per_token;
|
||||
max_tokens = model_info?.max_tokens;
|
||||
max_input_tokens = model_info?.max_input_tokens;
|
||||
}
|
||||
|
||||
if (curr_model?.litellm_params) {
|
||||
|
@ -677,7 +680,19 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
|
|||
modelData.data[i].provider = provider;
|
||||
modelData.data[i].input_cost = input_cost;
|
||||
modelData.data[i].output_cost = output_cost;
|
||||
|
||||
|
||||
// Convert Cost in terms of Cost per 1M tokens
|
||||
if (modelData.data[i].input_cost) {
|
||||
modelData.data[i].input_cost = (Number(modelData.data[i].input_cost) * 1000000).toFixed(2);
|
||||
}
|
||||
|
||||
if (modelData.data[i].output_cost) {
|
||||
modelData.data[i].output_cost = (Number(modelData.data[i].output_cost) * 1000000).toFixed(2);
|
||||
}
|
||||
|
||||
modelData.data[i].max_tokens = max_tokens;
|
||||
modelData.data[i].max_input_tokens = max_input_tokens;
|
||||
modelData.data[i].api_base = curr_model?.litellm_params?.api_base;
|
||||
modelData.data[i].cleanedLitellmParams = cleanedLitellmParams;
|
||||
|
||||
|
@ -893,8 +908,9 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
|
|||
<Text>Filter by Public Model Name</Text>
|
||||
<Select
|
||||
className="mb-4 mt-2 ml-2 w-50"
|
||||
defaultValue="all"
|
||||
defaultValue={selectedModelGroup? selectedModelGroup : availableModelGroups[0]}
|
||||
onValueChange={(value) => setSelectedModelGroup(value === "all" ? "all" : value)}
|
||||
value={selectedModelGroup ? selectedModelGroup : availableModelGroups[0]}
|
||||
>
|
||||
<SelectItem
|
||||
value={"all"}
|
||||
|
@ -913,85 +929,76 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
|
|||
</Select>
|
||||
</div>
|
||||
<Card>
|
||||
<Table className="mt-5">
|
||||
<TableHead>
|
||||
<TableRow>
|
||||
|
||||
<TableHeaderCell>Public Model Name </TableHeaderCell>
|
||||
|
||||
<TableHeaderCell>
|
||||
Provider
|
||||
</TableHeaderCell>
|
||||
{
|
||||
userRole === "Admin" && (
|
||||
<TableHeaderCell>
|
||||
API Base
|
||||
</TableHeaderCell>
|
||||
)
|
||||
}
|
||||
<TableHeaderCell>
|
||||
Extra litellm Params
|
||||
</TableHeaderCell>
|
||||
<TableHeaderCell>Input Price per token ($)</TableHeaderCell>
|
||||
<TableHeaderCell>Output Price per token ($)</TableHeaderCell>
|
||||
<TableHeaderCell>Max Tokens</TableHeaderCell>
|
||||
<TableHeaderCell>Status</TableHeaderCell>
|
||||
</TableRow>
|
||||
</TableHead>
|
||||
<TableBody>
|
||||
{ modelData.data
|
||||
.filter((model: any) =>
|
||||
selectedModelGroup === "all" || model.model_name === selectedModelGroup || selectedModelGroup === null || selectedModelGroup === undefined || selectedModelGroup === ""
|
||||
)
|
||||
.map((model: any, index: number) => (
|
||||
|
||||
<TableRow key={index}>
|
||||
<TableCell>
|
||||
<Text>{model.model_name}</Text>
|
||||
</TableCell>
|
||||
<TableCell>{model.provider}</TableCell>
|
||||
{
|
||||
userRole === "Admin" && (
|
||||
<TableCell>{model.api_base}</TableCell>
|
||||
)
|
||||
}
|
||||
|
||||
<TableCell>
|
||||
|
||||
<Accordion>
|
||||
<AccordionHeader>
|
||||
<Text>Litellm params</Text>
|
||||
</AccordionHeader>
|
||||
<AccordionBody>
|
||||
<pre>
|
||||
{JSON.stringify(model.cleanedLitellmParams, null, 2)}
|
||||
</pre>
|
||||
</AccordionBody>
|
||||
</Accordion>
|
||||
|
||||
</TableCell>
|
||||
<TableCell>{model.input_cost || model.litellm_params.input_cost_per_token || null}</TableCell>
|
||||
<TableCell>{model.output_cost || model.litellm_params.output_cost_per_token || null}</TableCell>
|
||||
<TableCell>{model.max_tokens}</TableCell>
|
||||
<TableCell>
|
||||
{
|
||||
model.model_info.db_model ? <Badge icon={CheckCircleIcon} className="text-white">DB Model</Badge> : <Badge icon={XCircleIcon} className="text-black">Config Model</Badge>
|
||||
}
|
||||
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<Icon
|
||||
icon={PencilAltIcon}
|
||||
size="sm"
|
||||
onClick={() => handleEditClick(model)}
|
||||
/>
|
||||
<DeleteModelButton modelID={model.model_info.id} accessToken={accessToken} />
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
|
||||
<Table className="mt-5" style={{ maxWidth: '1500px', width: '100%' }}>
|
||||
<TableHead>
|
||||
<TableRow>
|
||||
<TableHeaderCell style={{ maxWidth: '150px', whiteSpace: 'normal', wordBreak: 'break-word' }}>Public Model Name</TableHeaderCell>
|
||||
<TableHeaderCell style={{ maxWidth: '100px', whiteSpace: 'normal', wordBreak: 'break-word' }}>Provider</TableHeaderCell>
|
||||
{userRole === "Admin" && (
|
||||
<TableHeaderCell style={{ maxWidth: '150px', whiteSpace: 'normal', wordBreak: 'break-word' }}>API Base</TableHeaderCell>
|
||||
)}
|
||||
<TableHeaderCell style={{ maxWidth: '200px', whiteSpace: 'normal', wordBreak: 'break-word' }}>Extra litellm Params</TableHeaderCell>
|
||||
<TableHeaderCell style={{ maxWidth: '85px', whiteSpace: 'normal', wordBreak: 'break-word' }}>Input Price <p style={{ fontSize: '10px', color: 'gray' }}>/1M Tokens ($)</p></TableHeaderCell>
|
||||
<TableHeaderCell style={{ maxWidth: '85px', whiteSpace: 'normal', wordBreak: 'break-word' }}>Output Price <p style={{ fontSize: '10px', color: 'gray' }}>/1M Tokens ($)</p></TableHeaderCell>
|
||||
<TableHeaderCell style={{ maxWidth: '120px', whiteSpace: 'normal', wordBreak: 'break-word' }}>Max Tokens</TableHeaderCell>
|
||||
<TableHeaderCell style={{ maxWidth: '50px', whiteSpace: 'normal', wordBreak: 'break-word' }}>Status</TableHeaderCell>
|
||||
</TableRow>
|
||||
</TableHead>
|
||||
<TableBody>
|
||||
{modelData.data
|
||||
.filter((model: any) =>
|
||||
selectedModelGroup === "all" ||
|
||||
model.model_name === selectedModelGroup ||
|
||||
selectedModelGroup === null ||
|
||||
selectedModelGroup === undefined ||
|
||||
selectedModelGroup === ""
|
||||
)
|
||||
.map((model: any, index: number) => (
|
||||
<TableRow key={index}>
|
||||
<TableCell style={{ maxWidth: '150px', whiteSpace: 'normal', wordBreak: 'break-word' }}>
|
||||
<Text>{model.model_name}</Text>
|
||||
</TableCell>
|
||||
<TableCell style={{ maxWidth: '100px', whiteSpace: 'normal', wordBreak: 'break-word' }}>{model.provider}</TableCell>
|
||||
{userRole === "Admin" && (
|
||||
<TableCell style={{ maxWidth: '150px', whiteSpace: 'normal', wordBreak: 'break-word' }}>{model.api_base}</TableCell>
|
||||
)}
|
||||
<TableCell style={{ maxWidth: '200px', whiteSpace: 'normal', wordBreak: 'break-word' }}>
|
||||
<Accordion>
|
||||
<AccordionHeader>
|
||||
<Text>Litellm params</Text>
|
||||
</AccordionHeader>
|
||||
<AccordionBody>
|
||||
<pre>{JSON.stringify(model.cleanedLitellmParams, null, 2)}</pre>
|
||||
</AccordionBody>
|
||||
</Accordion>
|
||||
</TableCell>
|
||||
<TableCell style={{ maxWidth: '80px', whiteSpace: 'normal', wordBreak: 'break-word' }}>{model.input_cost || model.litellm_params.input_cost_per_token || null}</TableCell>
|
||||
<TableCell style={{ maxWidth: '80px', whiteSpace: 'normal', wordBreak: 'break-word' }}>{model.output_cost || model.litellm_params.output_cost_per_token || null}</TableCell>
|
||||
<TableCell style={{ maxWidth: '120px', whiteSpace: 'normal', wordBreak: 'break-word' }}>
|
||||
<p style={{ fontSize: '10px' }}>
|
||||
Max Tokens: {model.max_tokens} <br></br>
|
||||
Max Input Tokens: {model.max_input_tokens}
|
||||
</p>
|
||||
</TableCell>
|
||||
<TableCell style={{ maxWidth: '100px', whiteSpace: 'normal', wordBreak: 'break-word' }}>
|
||||
{model.model_info.db_model ? (
|
||||
<Badge icon={CheckCircleIcon} size="xs" className="text-white">
|
||||
<p style={{ fontSize: '10px' }}>DB Model</p>
|
||||
</Badge>
|
||||
) : (
|
||||
<Badge icon={XCircleIcon} size="xs" className="text-black">
|
||||
<p style={{ fontSize: '10px' }}>Config Model</p>
|
||||
</Badge>
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell style={{ maxWidth: '100px', whiteSpace: 'normal', wordBreak: 'break-word' }}>
|
||||
<Icon icon={PencilAltIcon} size="sm" onClick={() => handleEditClick(model)} />
|
||||
<DeleteModelButton modelID={model.model_info.id} accessToken={accessToken} />
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</Card>
|
||||
|
||||
</Grid>
|
||||
|
@ -1116,13 +1123,22 @@ const handleEditSubmit = async (formValues: Record<string, any>) => {
|
|||
</Form.Item>
|
||||
}
|
||||
{
|
||||
selectedProvider == Providers.Azure && <Form.Item
|
||||
label="Base Model"
|
||||
name="base_model"
|
||||
>
|
||||
<TextInput placeholder="azure/gpt-3.5-turbo"/>
|
||||
<Text>The actual model your azure deployment uses. Used for accurate cost tracking. Select name from <Link href="https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json" target="_blank">here</Link></Text>
|
||||
</Form.Item>
|
||||
selectedProvider == Providers.Azure &&
|
||||
|
||||
<div>
|
||||
<Form.Item
|
||||
label="Base Model"
|
||||
name="base_model"
|
||||
className="mb-0"
|
||||
>
|
||||
<TextInput placeholder="azure/gpt-3.5-turbo"/>
|
||||
</Form.Item>
|
||||
<Row>
|
||||
<Col span={10}></Col>
|
||||
<Col span={10}><Text className="mb-2">The actual model your azure deployment uses. Used for accurate cost tracking. Select name from <Link href="https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json" target="_blank">here</Link></Text></Col>
|
||||
</Row>
|
||||
|
||||
</div>
|
||||
}
|
||||
{
|
||||
selectedProvider == Providers.Bedrock && <Form.Item
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue