Merge remote-tracking branch 'origin/main' into feat--parse-user-from-headers

This commit is contained in:
Damian Gleumes 2025-04-11 07:32:46 +00:00
commit caabac22c1
199 changed files with 7731 additions and 1311 deletions

View file

@ -610,6 +610,8 @@ jobs:
name: Install Dependencies
command: |
python -m pip install --upgrade pip
pip install wheel
pip install --upgrade pip wheel setuptools
python -m pip install -r requirements.txt
pip install "pytest==7.3.1"
pip install "respx==0.21.1"
@ -1125,6 +1127,7 @@ jobs:
name: Install Dependencies
command: |
python -m pip install --upgrade pip
python -m pip install wheel setuptools
python -m pip install -r requirements.txt
pip install "pytest==7.3.1"
pip install "pytest-retry==1.6.3"

View file

@ -12,8 +12,7 @@ WORKDIR /app
USER root
# Install build dependencies
RUN apk update && \
apk add --no-cache gcc python3-dev openssl openssl-dev
RUN apk add --no-cache gcc python3-dev openssl openssl-dev
RUN pip install --upgrade pip && \
@ -52,8 +51,7 @@ FROM $LITELLM_RUNTIME_IMAGE AS runtime
USER root
# Install runtime dependencies
RUN apk update && \
apk add --no-cache openssl
RUN apk add --no-cache openssl
WORKDIR /app
# Copy the current directory contents into the container at /app

View file

@ -2,6 +2,10 @@ apiVersion: v1
kind: Service
metadata:
name: {{ include "litellm.fullname" . }}
{{- with .Values.service.annotations }}
annotations:
{{- toYaml . | nindent 4 }}
{{- end }}
labels:
{{- include "litellm.labels" . | nindent 4 }}
spec:

View file

@ -35,7 +35,7 @@ RUN pip wheel --no-cache-dir --wheel-dir=/wheels/ -r requirements.txt
FROM $LITELLM_RUNTIME_IMAGE AS runtime
# Update dependencies and clean up
RUN apk update && apk upgrade && rm -rf /var/cache/apk/*
RUN apk upgrade --no-cache
WORKDIR /app

View file

@ -12,8 +12,7 @@ WORKDIR /app
USER root
# Install build dependencies
RUN apk update && \
apk add --no-cache gcc python3-dev openssl openssl-dev
RUN apk add --no-cache gcc python3-dev openssl openssl-dev
RUN pip install --upgrade pip && \
@ -44,8 +43,7 @@ FROM $LITELLM_RUNTIME_IMAGE AS runtime
USER root
# Install runtime dependencies
RUN apk update && \
apk add --no-cache openssl
RUN apk add --no-cache openssl
WORKDIR /app
# Copy the current directory contents into the container at /app

View file

@ -438,6 +438,179 @@ assert isinstance(
```
### Google Search Tool
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
import os
os.environ["GEMINI_API_KEY"] = ".."
tools = [{"googleSearch": {}}] # 👈 ADD GOOGLE SEARCH
response = completion(
model="gemini/gemini-2.0-flash",
messages=[{"role": "user", "content": "What is the weather in San Francisco?"}],
tools=tools,
)
print(response)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Setup config.yaml
```yaml
model_list:
- model_name: gemini-2.0-flash
litellm_params:
model: gemini/gemini-2.0-flash
api_key: os.environ/GEMINI_API_KEY
```
2. Start Proxy
```bash
$ litellm --config /path/to/config.yaml
```
3. Make Request!
```bash
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-d '{
"model": "gemini-2.0-flash",
"messages": [{"role": "user", "content": "What is the weather in San Francisco?"}],
"tools": [{"googleSearch": {}}]
}
'
```
</TabItem>
</Tabs>
### Google Search Retrieval
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
import os
os.environ["GEMINI_API_KEY"] = ".."
tools = [{"googleSearchRetrieval": {}}] # 👈 ADD GOOGLE SEARCH
response = completion(
model="gemini/gemini-2.0-flash",
messages=[{"role": "user", "content": "What is the weather in San Francisco?"}],
tools=tools,
)
print(response)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Setup config.yaml
```yaml
model_list:
- model_name: gemini-2.0-flash
litellm_params:
model: gemini/gemini-2.0-flash
api_key: os.environ/GEMINI_API_KEY
```
2. Start Proxy
```bash
$ litellm --config /path/to/config.yaml
```
3. Make Request!
```bash
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-d '{
"model": "gemini-2.0-flash",
"messages": [{"role": "user", "content": "What is the weather in San Francisco?"}],
"tools": [{"googleSearchRetrieval": {}}]
}
'
```
</TabItem>
</Tabs>
### Code Execution Tool
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
import os
os.environ["GEMINI_API_KEY"] = ".."
tools = [{"codeExecution": {}}] # 👈 ADD GOOGLE SEARCH
response = completion(
model="gemini/gemini-2.0-flash",
messages=[{"role": "user", "content": "What is the weather in San Francisco?"}],
tools=tools,
)
print(response)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Setup config.yaml
```yaml
model_list:
- model_name: gemini-2.0-flash
litellm_params:
model: gemini/gemini-2.0-flash
api_key: os.environ/GEMINI_API_KEY
```
2. Start Proxy
```bash
$ litellm --config /path/to/config.yaml
```
3. Make Request!
```bash
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-d '{
"model": "gemini-2.0-flash",
"messages": [{"role": "user", "content": "What is the weather in San Francisco?"}],
"tools": [{"codeExecution": {}}]
}
'
```
</TabItem>
</Tabs>
## JSON Mode
<Tabs>

View file

@ -398,6 +398,8 @@ curl http://localhost:4000/v1/chat/completions \
</TabItem>
</Tabs>
You can also use the `enterpriseWebSearch` tool for an [enterprise compliant search](https://cloud.google.com/vertex-ai/generative-ai/docs/grounding/web-grounding-enterprise).
#### **Moving from Vertex AI SDK to LiteLLM (GROUNDING)**

View file

@ -449,6 +449,7 @@ router_settings:
| MICROSOFT_CLIENT_ID | Client ID for Microsoft services
| MICROSOFT_CLIENT_SECRET | Client secret for Microsoft services
| MICROSOFT_TENANT | Tenant ID for Microsoft Azure
| MICROSOFT_SERVICE_PRINCIPAL_ID | Service Principal ID for Microsoft Enterprise Application. (This is an advanced feature if you want litellm to auto-assign members to Litellm Teams based on their Microsoft Entra ID Groups)
| NO_DOCS | Flag to disable documentation generation
| NO_PROXY | List of addresses to bypass proxy
| OAUTH_TOKEN_INFO_ENDPOINT | Endpoint for OAuth token info retrieval

View file

@ -26,10 +26,12 @@ model_list:
- model_name: sagemaker-completion-model
litellm_params:
model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4
model_info:
input_cost_per_second: 0.000420
- model_name: sagemaker-embedding-model
litellm_params:
model: sagemaker/berri-benchmarking-gpt-j-6b-fp16
model_info:
input_cost_per_second: 0.000420
```
@ -55,11 +57,33 @@ model_list:
api_key: os.environ/AZURE_API_KEY
api_base: os.environ/AZURE_API_BASE
api_version: os.envrion/AZURE_API_VERSION
model_info:
input_cost_per_token: 0.000421 # 👈 ONLY to track cost per token
output_cost_per_token: 0.000520 # 👈 ONLY to track cost per token
```
### Debugging
## Override Model Cost Map
You can override [our model cost map](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json) with your own custom pricing for a mapped model.
Just add a `model_info` key to your model in the config, and override the desired keys.
Example: Override Anthropic's model cost map for the `prod/claude-3-5-sonnet-20241022` model.
```yaml
model_list:
- model_name: "prod/claude-3-5-sonnet-20241022"
litellm_params:
model: "anthropic/claude-3-5-sonnet-20241022"
api_key: os.environ/ANTHROPIC_PROD_API_KEY
model_info:
input_cost_per_token: 0.000006
output_cost_per_token: 0.00003
cache_creation_input_token_cost: 0.0000075
cache_read_input_token_cost: 0.0000006
```
## Debugging
If you're custom pricing is not being used or you're seeing errors, please check the following:

View file

@ -161,6 +161,83 @@ Here's the available UI roles for a LiteLLM Internal User:
- `internal_user`: can login, view/create/delete their own keys, view their spend. **Cannot** add new users.
- `internal_user_viewer`: can login, view their own keys, view their own spend. **Cannot** create/delete keys, add new users.
## Auto-add SSO users to teams
This walks through setting up sso auto-add for **Okta, Google SSO**
### Okta, Google SSO
1. Specify the JWT field that contains the team ids, that the user belongs to.
```yaml
general_settings:
master_key: sk-1234
litellm_jwtauth:
team_ids_jwt_field: "groups" # 👈 CAN BE ANY FIELD
```
This is assuming your SSO token looks like this. **If you need to inspect the JWT fields received from your SSO provider by LiteLLM, follow these instructions [here](#debugging-sso-jwt-fields)**
```
{
...,
"groups": ["team_id_1", "team_id_2"]
}
```
2. Create the teams on LiteLLM
```bash
curl -X POST '<PROXY_BASE_URL>/team/new' \
-H 'Authorization: Bearer <PROXY_MASTER_KEY>' \
-H 'Content-Type: application/json' \
-D '{
"team_alias": "team_1",
"team_id": "team_id_1" # 👈 MUST BE THE SAME AS THE SSO GROUP ID
}'
```
3. Test the SSO flow
Here's a walkthrough of [how it works](https://www.loom.com/share/8959be458edf41fd85937452c29a33f3?sid=7ebd6d37-569a-4023-866e-e0cde67cb23e)
### Microsoft Entra ID SSO group assignment
Follow this [tutorial for auto-adding sso users to teams with Microsoft Entra ID](https://docs.litellm.ai/docs/tutorials/msft_sso)
### Debugging SSO JWT fields
If you need to inspect the JWT fields received from your SSO provider by LiteLLM, follow these instructions. This guide walks you through setting up a debug callback to view the JWT data during the SSO process.
<Image img={require('../../img/debug_sso.png')} style={{ width: '500px', height: 'auto' }} />
<br />
1. Add `/sso/debug/callback` as a redirect URL in your SSO provider
In your SSO provider's settings, add the following URL as a new redirect (callback) URL:
```bash showLineNumbers title="Redirect URL"
http://<proxy_base_url>/sso/debug/callback
```
2. Navigate to the debug login page on your browser
Navigate to the following URL on your browser:
```bash showLineNumbers title="URL to navigate to"
https://<proxy_base_url>/sso/debug/login
```
This will initiate the standard SSO flow. You will be redirected to your SSO provider's login screen, and after successful authentication, you will be redirected back to LiteLLM's debug callback route.
3. View the JWT fields
Once redirected, you should see a page called "SSO Debug Information". This page displays the JWT fields received from your SSO provider (as shown in the image above)
## Advanced
### Setting custom logout URLs
@ -196,40 +273,26 @@ This budget does not apply to keys created under non-default teams.
[**Go Here**](./team_budgets.md)
### Auto-add SSO users to teams
### Set default params for new teams
1. Specify the JWT field that contains the team ids, that the user belongs to.
When you connect litellm to your SSO provider, litellm can auto-create teams. Use this to set the default `models`, `max_budget`, `budget_duration` for these auto-created teams.
```yaml
general_settings:
master_key: sk-1234
litellm_jwtauth:
team_ids_jwt_field: "groups" # 👈 CAN BE ANY FIELD
**How it works**
1. When litellm fetches `groups` from your SSO provider, it will check if the corresponding group_id exists as a `team_id` in litellm.
2. If the team_id does not exist, litellm will auto-create a team with the default params you've set.
3. If the team_id already exist, litellm will not apply any settings on the team.
**Usage**
```yaml showLineNumbers title="Default Params for new teams"
litellm_settings:
default_team_params: # Default Params to apply when litellm auto creates a team from SSO IDP provider
max_budget: 100 # Optional[float], optional): $100 budget for the team
budget_duration: 30d # Optional[str], optional): 30 days budget_duration for the team
models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by the team
```
This is assuming your SSO token looks like this:
```
{
...,
"groups": ["team_id_1", "team_id_2"]
}
```
2. Create the teams on LiteLLM
```bash
curl -X POST '<PROXY_BASE_URL>/team/new' \
-H 'Authorization: Bearer <PROXY_MASTER_KEY>' \
-H 'Content-Type: application/json' \
-D '{
"team_alias": "team_1",
"team_id": "team_id_1" # 👈 MUST BE THE SAME AS THE SSO GROUP ID
}'
```
3. Test the SSO flow
Here's a walkthrough of [how it works](https://www.loom.com/share/8959be458edf41fd85937452c29a33f3?sid=7ebd6d37-569a-4023-866e-e0cde67cb23e)
### Restrict Users from creating personal keys
@ -241,7 +304,7 @@ This will also prevent users from using their session tokens on the test keys ch
## **All Settings for Self Serve / SSO Flow**
```yaml
```yaml showLineNumbers title="All Settings for Self Serve / SSO Flow"
litellm_settings:
max_internal_user_budget: 10 # max budget for internal users
internal_user_budget_duration: "1mo" # reset every month
@ -251,6 +314,11 @@ litellm_settings:
max_budget: 100 # Optional[float], optional): $100 budget for a new SSO sign in user
budget_duration: 30d # Optional[str], optional): 30 days budget_duration for a new SSO sign in user
models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by a new SSO sign in user
default_team_params: # Default Params to apply when litellm auto creates a team from SSO IDP provider
max_budget: 100 # Optional[float], optional): $100 budget for the team
budget_duration: 30d # Optional[str], optional): 30 days budget_duration for the team
models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by the team
upperbound_key_generate_params: # Upperbound for /key/generate requests when self-serve flow is on

View file

@ -0,0 +1,162 @@
import Image from '@theme/IdealImage';
# Microsoft SSO: Sync Groups, Members with LiteLLM
Sync Microsoft SSO Groups, Members with LiteLLM Teams.
<Image img={require('../../img/litellm_entra_id.png')} style={{ width: '800px', height: 'auto' }} />
<br />
<br />
## Prerequisites
- An Azure Entra ID account with administrative access
- A LiteLLM Enterprise App set up in your Azure Portal
- Access to Microsoft Entra ID (Azure AD)
## Overview of this tutorial
1. Auto-Create Entra ID Groups on LiteLLM Teams
2. Sync Entra ID Team Memberships
3. Set default params for new teams and users auto-created on LiteLLM
## 1. Auto-Create Entra ID Groups on LiteLLM Teams
In this step, our goal is to have LiteLLM automatically create a new team on the LiteLLM DB when there is a new Group Added to the LiteLLM Enterprise App on Azure Entra ID.
### 1.1 Create a new group in Entra ID
Navigate to [your Azure Portal](https://portal.azure.com/) > Groups > New Group. Create a new group.
<Image img={require('../../img/entra_create_team.png')} style={{ width: '800px', height: 'auto' }} />
### 1.2 Assign the group to your LiteLLM Enterprise App
On your Azure Portal, navigate to `Enterprise Applications` > Select your litellm app
<Image img={require('../../img/msft_enterprise_app.png')} style={{ width: '800px', height: 'auto' }} />
<br />
<br />
Once you've selected your litellm app, click on `Users and Groups` > `Add user/group`
<Image img={require('../../img/msft_enterprise_assign_group.png')} style={{ width: '800px', height: 'auto' }} />
<br />
Now select the group you created in step 1.1. And add it to the LiteLLM Enterprise App. At this point we have added `Production LLM Evals Group` to the LiteLLM Enterprise App. The next steps is having LiteLLM automatically create the `Production LLM Evals Group` on the LiteLLM DB when a new user signs in.
<Image img={require('../../img/msft_enterprise_select_group.png')} style={{ width: '800px', height: 'auto' }} />
### 1.3 Sign in to LiteLLM UI via SSO
Sign into the LiteLLM UI via SSO. You should be redirected to the Entra ID SSO page. This SSO sign in flow will trigger LiteLLM to fetch the latest Groups and Members from Azure Entra ID.
<Image img={require('../../img/msft_sso_sign_in.png')} style={{ width: '800px', height: 'auto' }} />
### 1.4 Check the new team on LiteLLM UI
On the LiteLLM UI, Navigate to `Teams`, You should see the new team `Production LLM Evals Group` auto-created on LiteLLM.
<Image img={require('../../img/msft_auto_team.png')} style={{ width: '900px', height: 'auto' }} />
#### How this works
When a SSO user signs in to LiteLLM:
- LiteLLM automatically fetches the Groups under the LiteLLM Enterprise App
- It finds the Production LLM Evals Group assigned to the LiteLLM Enterprise App
- LiteLLM checks if this group's ID exists in the LiteLLM Teams Table
- Since the ID doesn't exist, LiteLLM automatically creates a new team with:
- Name: Production LLM Evals Group
- ID: Same as the Entra ID group's ID
## 2. Sync Entra ID Team Memberships
In this step, we will have LiteLLM automatically add a user to the `Production LLM Evals` Team on the LiteLLM DB when a new user is added to the `Production LLM Evals` Group in Entra ID.
### 2.1 Navigate to the `Production LLM Evals` Group in Entra ID
Navigate to the `Production LLM Evals` Group in Entra ID.
<Image img={require('../../img/msft_member_1.png')} style={{ width: '800px', height: 'auto' }} />
### 2.2 Add a member to the group in Entra ID
Select `Members` > `Add members`
In this stage you should add the user you want to add to the `Production LLM Evals` Team.
<Image img={require('../../img/msft_member_2.png')} style={{ width: '800px', height: 'auto' }} />
### 2.3 Sign in as the new user on LiteLLM UI
Sign in as the new user on LiteLLM UI. You should be redirected to the Entra ID SSO page. This SSO sign in flow will trigger LiteLLM to fetch the latest Groups and Members from Azure Entra ID. During this step LiteLLM sync it's teams, team members with what is available from Entra ID
<Image img={require('../../img/msft_sso_sign_in.png')} style={{ width: '800px', height: 'auto' }} />
### 2.4 Check the team membership on LiteLLM UI
On the LiteLLM UI, Navigate to `Teams`, You should see the new team `Production LLM Evals Group`. Since your are now a member of the `Production LLM Evals Group` in Entra ID, you should see the new team `Production LLM Evals Group` on the LiteLLM UI.
<Image img={require('../../img/msft_member_3.png')} style={{ width: '900px', height: 'auto' }} />
## 3. Set default params for new teams auto-created on LiteLLM
Since litellm auto creates a new team on the LiteLLM DB when there is a new Group Added to the LiteLLM Enterprise App on Azure Entra ID, we can set default params for new teams created.
This allows you to set a default budget, models, etc for new teams created.
### 3.1 Set `default_team_params` on litellm
Navigate to your litellm config file and set the following params
```yaml showLineNumbers title="litellm config with default_team_params"
litellm_settings:
default_team_params: # Default Params to apply when litellm auto creates a team from SSO IDP provider
max_budget: 100 # Optional[float], optional): $100 budget for the team
budget_duration: 30d # Optional[str], optional): 30 days budget_duration for the team
models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by the team
```
### 3.2 Auto-create a new team on LiteLLM
- In this step you should add a new group to the LiteLLM Enterprise App on Azure Entra ID (like we did in step 1.1). We will call this group `Default LiteLLM Prod Team` on Azure Entra ID.
- Start litellm proxy server with your config
- Sign into LiteLLM UI via SSO
- Navigate to `Teams` and you should see the new team `Default LiteLLM Prod Team` auto-created on LiteLLM
- Note LiteLLM will set the default params for this new team.
<Image img={require('../../img/msft_default_settings.png')} style={{ width: '900px', height: 'auto' }} />
## Video Walkthrough
This walks through setting up sso auto-add for **Microsoft Entra ID**
Follow along this video for a walkthrough of how to set this up with Microsoft Entra ID
<iframe width="840" height="500" src="https://www.loom.com/embed/ea711323aa9a496d84a01fd7b2a12f54?sid=c53e238c-5bfd-4135-b8fb-b5b1a08632cf" frameborder="0" webkitallowfullscreen mozallowfullscreen allowfullscreen></iframe>

Binary file not shown.

After

Width:  |  Height:  |  Size: 167 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 180 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 35 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 62 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 141 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 292 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 277 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 245 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 296 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 274 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 186 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 818 KiB

View file

@ -435,6 +435,7 @@ const sidebars = {
label: "Tutorials",
items: [
"tutorials/openweb_ui",
"tutorials/msft_sso",
'tutorials/litellm_proxy_aporia',
{
type: "category",

View file

@ -0,0 +1,161 @@
import React, { useState } from 'react';
import styles from './transform_request.module.css';
const DEFAULT_REQUEST = {
"model": "bedrock/gpt-4",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "Explain quantum computing in simple terms"
}
],
"temperature": 0.7,
"max_tokens": 500,
"stream": true
};
type ViewMode = 'split' | 'request' | 'transformed';
const TransformRequestPlayground: React.FC = () => {
const [request, setRequest] = useState(JSON.stringify(DEFAULT_REQUEST, null, 2));
const [transformedRequest, setTransformedRequest] = useState('');
const [viewMode, setViewMode] = useState<ViewMode>('split');
const handleTransform = async () => {
try {
// Here you would make the actual API call to transform the request
// For now, we'll just set a sample response
const sampleResponse = `curl -X POST \\
https://api.openai.com/v1/chat/completions \\
-H 'Authorization: Bearer sk-xxx' \\
-H 'Content-Type: application/json' \\
-d '{
"model": "gpt-4",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
}
],
"temperature": 0.7
}'`;
setTransformedRequest(sampleResponse);
} catch (error) {
console.error('Error transforming request:', error);
}
};
const handleCopy = () => {
navigator.clipboard.writeText(transformedRequest);
};
const renderContent = () => {
switch (viewMode) {
case 'request':
return (
<div className={styles.panel}>
<div className={styles['panel-header']}>
<h2>Original Request</h2>
<p>The request you would send to LiteLLM /chat/completions endpoint.</p>
</div>
<textarea
className={styles['code-input']}
value={request}
onChange={(e) => setRequest(e.target.value)}
spellCheck={false}
/>
<div className={styles['panel-footer']}>
<button className={styles['transform-button']} onClick={handleTransform}>
Transform
</button>
</div>
</div>
);
case 'transformed':
return (
<div className={styles.panel}>
<div className={styles['panel-header']}>
<h2>Transformed Request</h2>
<p>How LiteLLM transforms your request for the specified provider.</p>
<p className={styles.note}>Note: Sensitive headers are not shown.</p>
</div>
<div className={styles['code-output-container']}>
<pre className={styles['code-output']}>{transformedRequest}</pre>
<button className={styles['copy-button']} onClick={handleCopy}>
Copy
</button>
</div>
</div>
);
default:
return (
<>
<div className={styles.panel}>
<div className={styles['panel-header']}>
<h2>Original Request</h2>
<p>The request you would send to LiteLLM /chat/completions endpoint.</p>
</div>
<textarea
className={styles['code-input']}
value={request}
onChange={(e) => setRequest(e.target.value)}
spellCheck={false}
/>
<div className={styles['panel-footer']}>
<button className={styles['transform-button']} onClick={handleTransform}>
Transform
</button>
</div>
</div>
<div className={styles.panel}>
<div className={styles['panel-header']}>
<h2>Transformed Request</h2>
<p>How LiteLLM transforms your request for the specified provider.</p>
<p className={styles.note}>Note: Sensitive headers are not shown.</p>
</div>
<div className={styles['code-output-container']}>
<pre className={styles['code-output']}>{transformedRequest}</pre>
<button className={styles['copy-button']} onClick={handleCopy}>
Copy
</button>
</div>
</div>
</>
);
}
};
return (
<div className={styles['transform-playground']}>
<div className={styles['view-toggle']}>
<button
className={viewMode === 'split' ? styles.active : ''}
onClick={() => setViewMode('split')}
>
Split View
</button>
<button
className={viewMode === 'request' ? styles.active : ''}
onClick={() => setViewMode('request')}
>
Request
</button>
<button
className={viewMode === 'transformed' ? styles.active : ''}
onClick={() => setViewMode('transformed')}
>
Transformed
</button>
</div>
<div className={styles['playground-container']}>
{renderContent()}
</div>
</div>
);
};
export default TransformRequestPlayground;

View file

@ -65,6 +65,7 @@ from litellm.proxy._types import (
KeyManagementSystem,
KeyManagementSettings,
LiteLLM_UpperboundKeyGenerateParams,
NewTeamRequest,
)
from litellm.types.utils import StandardKeyGenerationConfig, LlmProviders
from litellm.integrations.custom_logger import CustomLogger
@ -126,19 +127,19 @@ prometheus_initialize_budget_metrics: Optional[bool] = False
require_auth_for_metrics_endpoint: Optional[bool] = False
argilla_batch_size: Optional[int] = None
datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload
gcs_pub_sub_use_v1: Optional[
bool
] = False # if you want to use v1 gcs pubsub logged payload
gcs_pub_sub_use_v1: Optional[bool] = (
False # if you want to use v1 gcs pubsub logged payload
)
argilla_transformation_object: Optional[Dict[str, Any]] = None
_async_input_callback: List[
Union[str, Callable, CustomLogger]
] = [] # internal variable - async custom callbacks are routed here.
_async_success_callback: List[
Union[str, Callable, CustomLogger]
] = [] # internal variable - async custom callbacks are routed here.
_async_failure_callback: List[
Union[str, Callable, CustomLogger]
] = [] # internal variable - async custom callbacks are routed here.
_async_input_callback: List[Union[str, Callable, CustomLogger]] = (
[]
) # internal variable - async custom callbacks are routed here.
_async_success_callback: List[Union[str, Callable, CustomLogger]] = (
[]
) # internal variable - async custom callbacks are routed here.
_async_failure_callback: List[Union[str, Callable, CustomLogger]] = (
[]
) # internal variable - async custom callbacks are routed here.
pre_call_rules: List[Callable] = []
post_call_rules: List[Callable] = []
turn_off_message_logging: Optional[bool] = False
@ -146,18 +147,18 @@ log_raw_request_response: bool = False
redact_messages_in_exceptions: Optional[bool] = False
redact_user_api_key_info: Optional[bool] = False
filter_invalid_headers: Optional[bool] = False
add_user_information_to_llm_headers: Optional[
bool
] = None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers
add_user_information_to_llm_headers: Optional[bool] = (
None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers
)
store_audit_logs = False # Enterprise feature, allow users to see audit logs
### end of callbacks #############
email: Optional[
str
] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
token: Optional[
str
] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
email: Optional[str] = (
None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
)
token: Optional[str] = (
None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
)
telemetry = True
max_tokens: int = DEFAULT_MAX_TOKENS # OpenAI Defaults
drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False))
@ -233,20 +234,24 @@ enable_loadbalancing_on_batch_endpoints: Optional[bool] = None
enable_caching_on_provider_specific_optional_params: bool = (
False # feature-flag for caching on optional params - e.g. 'top_k'
)
caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
cache: Optional[
Cache
] = None # cache object <- use this - https://docs.litellm.ai/docs/caching
caching: bool = (
False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
)
caching_with_models: bool = (
False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
)
cache: Optional[Cache] = (
None # cache object <- use this - https://docs.litellm.ai/docs/caching
)
default_in_memory_ttl: Optional[float] = None
default_redis_ttl: Optional[float] = None
default_redis_batch_cache_expiry: Optional[float] = None
model_alias_map: Dict[str, str] = {}
model_group_alias_map: Dict[str, str] = {}
max_budget: float = 0.0 # set the max budget across all providers
budget_duration: Optional[
str
] = None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
budget_duration: Optional[str] = (
None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
)
default_soft_budget: float = (
DEFAULT_SOFT_BUDGET # by default all litellm proxy keys have a soft budget of 50.0
)
@ -255,11 +260,15 @@ forward_traceparent_to_llm_provider: bool = False
_current_cost = 0.0 # private variable, used if max budget is set
error_logs: Dict = {}
add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt
add_function_to_prompt: bool = (
False # if function calling not supported by api, append function call details to system prompt
)
client_session: Optional[httpx.Client] = None
aclient_session: Optional[httpx.AsyncClient] = None
model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks'
model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
model_cost_map_url: str = (
"https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
)
suppress_debug_info = False
dynamodb_table_name: Optional[str] = None
s3_callback_params: Optional[Dict] = None
@ -268,6 +277,7 @@ default_key_generate_params: Optional[Dict] = None
upperbound_key_generate_params: Optional[LiteLLM_UpperboundKeyGenerateParams] = None
key_generation_settings: Optional[StandardKeyGenerationConfig] = None
default_internal_user_params: Optional[Dict] = None
default_team_params: Optional[Union[NewTeamRequest, Dict]] = None
default_team_settings: Optional[List] = None
max_user_budget: Optional[float] = None
default_max_internal_user_budget: Optional[float] = None
@ -281,7 +291,9 @@ disable_end_user_cost_tracking_prometheus_only: Optional[bool] = None
custom_prometheus_metadata_labels: List[str] = []
#### REQUEST PRIORITIZATION ####
priority_reservation: Optional[Dict[str, float]] = None
force_ipv4: bool = False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6.
force_ipv4: bool = (
False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6.
)
module_level_aclient = AsyncHTTPHandler(
timeout=request_timeout, client_alias="module level aclient"
)
@ -295,13 +307,13 @@ fallbacks: Optional[List] = None
context_window_fallbacks: Optional[List] = None
content_policy_fallbacks: Optional[List] = None
allowed_fails: int = 3
num_retries_per_request: Optional[
int
] = None # for the request overall (incl. fallbacks + model retries)
num_retries_per_request: Optional[int] = (
None # for the request overall (incl. fallbacks + model retries)
)
####### SECRET MANAGERS #####################
secret_manager_client: Optional[
Any
] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
secret_manager_client: Optional[Any] = (
None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
)
_google_kms_resource_name: Optional[str] = None
_key_management_system: Optional[KeyManagementSystem] = None
_key_management_settings: KeyManagementSettings = KeyManagementSettings()
@ -1050,10 +1062,10 @@ from .types.llms.custom_llm import CustomLLMItem
from .types.utils import GenericStreamingChunk
custom_provider_map: List[CustomLLMItem] = []
_custom_providers: List[
str
] = [] # internal helper util, used to track names of custom providers
disable_hf_tokenizer_download: Optional[
bool
] = None # disable huggingface tokenizer download. Defaults to openai clk100
_custom_providers: List[str] = (
[]
) # internal helper util, used to track names of custom providers
disable_hf_tokenizer_download: Optional[bool] = (
None # disable huggingface tokenizer download. Defaults to openai clk100
)
global_disable_no_log_param: bool = False

View file

@ -480,6 +480,7 @@ RESPONSE_FORMAT_TOOL_NAME = "json_tool_call" # default tool name used when conv
########################### Logging Callback Constants ###########################
AZURE_STORAGE_MSFT_VERSION = "2019-07-07"
PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES = 5
MCP_TOOL_NAME_PREFIX = "mcp_tool"
########################### LiteLLM Proxy Specific Constants ###########################
@ -514,6 +515,7 @@ LITELLM_PROXY_ADMIN_NAME = "default_user_id"
########################### DB CRON JOB NAMES ###########################
DB_SPEND_UPDATE_JOB_NAME = "db_spend_update_job"
PROMETHEUS_EMIT_BUDGET_METRICS_JOB_NAME = "prometheus_emit_budget_metrics_job"
DEFAULT_CRON_JOB_LOCK_TTL_SECONDS = 60 # 1 minute
PROXY_BUDGET_RESCHEDULER_MIN_TIME = 597
PROXY_BUDGET_RESCHEDULER_MAX_TIME = 605

View file

@ -16,7 +16,10 @@ from litellm.constants import (
from litellm.litellm_core_utils.llm_cost_calc.tool_call_cost_tracking import (
StandardBuiltInToolCostTracking,
)
from litellm.litellm_core_utils.llm_cost_calc.utils import _generic_cost_per_character
from litellm.litellm_core_utils.llm_cost_calc.utils import (
_generic_cost_per_character,
generic_cost_per_token,
)
from litellm.llms.anthropic.cost_calculation import (
cost_per_token as anthropic_cost_per_token,
)
@ -54,12 +57,16 @@ from litellm.llms.vertex_ai.image_generation.cost_calculator import (
from litellm.responses.utils import ResponseAPILoggingUtils
from litellm.types.llms.openai import (
HttpxBinaryResponseContent,
OpenAIRealtimeStreamList,
OpenAIRealtimeStreamResponseBaseObject,
OpenAIRealtimeStreamSessionEvents,
ResponseAPIUsage,
ResponsesAPIResponse,
)
from litellm.types.rerank import RerankBilledUnits, RerankResponse
from litellm.types.utils import (
CallTypesLiteral,
LiteLLMRealtimeStreamLoggingObject,
LlmProviders,
LlmProvidersSet,
ModelInfo,
@ -397,6 +404,7 @@ def _select_model_name_for_cost_calc(
base_model: Optional[str] = None,
custom_pricing: Optional[bool] = None,
custom_llm_provider: Optional[str] = None,
router_model_id: Optional[str] = None,
) -> Optional[str]:
"""
1. If custom pricing is true, return received model name
@ -411,12 +419,6 @@ def _select_model_name_for_cost_calc(
model=model, custom_llm_provider=custom_llm_provider
)
if custom_pricing is True:
return_model = model
if base_model is not None:
return_model = base_model
completion_response_model: Optional[str] = None
if completion_response is not None:
if isinstance(completion_response, BaseModel):
@ -424,6 +426,16 @@ def _select_model_name_for_cost_calc(
elif isinstance(completion_response, dict):
completion_response_model = completion_response.get("model", None)
hidden_params: Optional[dict] = getattr(completion_response, "_hidden_params", None)
if custom_pricing is True:
if router_model_id is not None and router_model_id in litellm.model_cost:
return_model = router_model_id
else:
return_model = model
if base_model is not None:
return_model = base_model
if completion_response_model is None and hidden_params is not None:
if (
hidden_params.get("model", None) is not None
@ -553,6 +565,7 @@ def completion_cost( # noqa: PLR0915
base_model: Optional[str] = None,
standard_built_in_tools_params: Optional[StandardBuiltInToolsParams] = None,
litellm_model_name: Optional[str] = None,
router_model_id: Optional[str] = None,
) -> float:
"""
Calculate the cost of a given completion call fot GPT-3.5-turbo, llama2, any litellm supported llm.
@ -605,18 +618,19 @@ def completion_cost( # noqa: PLR0915
completion_response=completion_response
)
rerank_billed_units: Optional[RerankBilledUnits] = None
selected_model = _select_model_name_for_cost_calc(
model=model,
completion_response=completion_response,
custom_llm_provider=custom_llm_provider,
custom_pricing=custom_pricing,
base_model=base_model,
router_model_id=router_model_id,
)
potential_model_names = [selected_model]
if model is not None:
potential_model_names.append(model)
for idx, model in enumerate(potential_model_names):
try:
verbose_logger.info(
@ -780,6 +794,25 @@ def completion_cost( # noqa: PLR0915
billed_units.get("search_units") or 1
) # cohere charges per request by default.
completion_tokens = search_units
elif call_type == CallTypes.arealtime.value and isinstance(
completion_response, LiteLLMRealtimeStreamLoggingObject
):
if (
cost_per_token_usage_object is None
or custom_llm_provider is None
):
raise ValueError(
"usage object and custom_llm_provider must be provided for realtime stream cost calculation. Got cost_per_token_usage_object={}, custom_llm_provider={}".format(
cost_per_token_usage_object,
custom_llm_provider,
)
)
return handle_realtime_stream_cost_calculation(
results=completion_response.results,
combined_usage_object=cost_per_token_usage_object,
custom_llm_provider=custom_llm_provider,
litellm_model_name=model,
)
# Calculate cost based on prompt_tokens, completion_tokens
if (
"togethercomputer" in model
@ -909,6 +942,7 @@ def response_cost_calculator(
HttpxBinaryResponseContent,
RerankResponse,
ResponsesAPIResponse,
LiteLLMRealtimeStreamLoggingObject,
],
model: str,
custom_llm_provider: Optional[str],
@ -937,6 +971,7 @@ def response_cost_calculator(
prompt: str = "",
standard_built_in_tools_params: Optional[StandardBuiltInToolsParams] = None,
litellm_model_name: Optional[str] = None,
router_model_id: Optional[str] = None,
) -> float:
"""
Returns
@ -967,6 +1002,8 @@ def response_cost_calculator(
base_model=base_model,
prompt=prompt,
standard_built_in_tools_params=standard_built_in_tools_params,
litellm_model_name=litellm_model_name,
router_model_id=router_model_id,
)
return response_cost
except Exception as e:
@ -1141,3 +1178,173 @@ def batch_cost_calculator(
) # batch cost is usually half of the regular token cost
return total_prompt_cost, total_completion_cost
class RealtimeAPITokenUsageProcessor:
@staticmethod
def collect_usage_from_realtime_stream_results(
results: OpenAIRealtimeStreamList,
) -> List[Usage]:
"""
Collect usage from realtime stream results
"""
response_done_events: List[OpenAIRealtimeStreamResponseBaseObject] = cast(
List[OpenAIRealtimeStreamResponseBaseObject],
[result for result in results if result["type"] == "response.done"],
)
usage_objects: List[Usage] = []
for result in response_done_events:
usage_object = (
ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage(
result["response"].get("usage", {})
)
)
usage_objects.append(usage_object)
return usage_objects
@staticmethod
def combine_usage_objects(usage_objects: List[Usage]) -> Usage:
"""
Combine multiple Usage objects into a single Usage object, checking model keys for nested values.
"""
from litellm.types.utils import (
CompletionTokensDetails,
PromptTokensDetailsWrapper,
Usage,
)
combined = Usage()
# Sum basic token counts
for usage in usage_objects:
# Handle direct attributes by checking what exists in the model
for attr in dir(usage):
if not attr.startswith("_") and not callable(getattr(usage, attr)):
current_val = getattr(combined, attr, 0)
new_val = getattr(usage, attr, 0)
if (
new_val is not None
and isinstance(new_val, (int, float))
and isinstance(current_val, (int, float))
):
setattr(combined, attr, current_val + new_val)
# Handle nested prompt_tokens_details
if hasattr(usage, "prompt_tokens_details") and usage.prompt_tokens_details:
if (
not hasattr(combined, "prompt_tokens_details")
or not combined.prompt_tokens_details
):
combined.prompt_tokens_details = PromptTokensDetailsWrapper()
# Check what keys exist in the model's prompt_tokens_details
for attr in dir(usage.prompt_tokens_details):
if not attr.startswith("_") and not callable(
getattr(usage.prompt_tokens_details, attr)
):
current_val = getattr(combined.prompt_tokens_details, attr, 0)
new_val = getattr(usage.prompt_tokens_details, attr, 0)
if new_val is not None:
setattr(
combined.prompt_tokens_details,
attr,
current_val + new_val,
)
# Handle nested completion_tokens_details
if (
hasattr(usage, "completion_tokens_details")
and usage.completion_tokens_details
):
if (
not hasattr(combined, "completion_tokens_details")
or not combined.completion_tokens_details
):
combined.completion_tokens_details = CompletionTokensDetails()
# Check what keys exist in the model's completion_tokens_details
for attr in dir(usage.completion_tokens_details):
if not attr.startswith("_") and not callable(
getattr(usage.completion_tokens_details, attr)
):
current_val = getattr(
combined.completion_tokens_details, attr, 0
)
new_val = getattr(usage.completion_tokens_details, attr, 0)
if new_val is not None:
setattr(
combined.completion_tokens_details,
attr,
current_val + new_val,
)
return combined
@staticmethod
def collect_and_combine_usage_from_realtime_stream_results(
results: OpenAIRealtimeStreamList,
) -> Usage:
"""
Collect and combine usage from realtime stream results
"""
collected_usage_objects = (
RealtimeAPITokenUsageProcessor.collect_usage_from_realtime_stream_results(
results
)
)
combined_usage_object = RealtimeAPITokenUsageProcessor.combine_usage_objects(
collected_usage_objects
)
return combined_usage_object
@staticmethod
def create_logging_realtime_object(
usage: Usage, results: OpenAIRealtimeStreamList
) -> LiteLLMRealtimeStreamLoggingObject:
return LiteLLMRealtimeStreamLoggingObject(
usage=usage,
results=results,
)
def handle_realtime_stream_cost_calculation(
results: OpenAIRealtimeStreamList,
combined_usage_object: Usage,
custom_llm_provider: str,
litellm_model_name: str,
) -> float:
"""
Handles the cost calculation for realtime stream responses.
Pick the 'response.done' events. Calculate total cost across all 'response.done' events.
Args:
results: A list of OpenAIRealtimeStreamBaseObject objects
"""
received_model = None
potential_model_names = []
for result in results:
if result["type"] == "session.created":
received_model = cast(OpenAIRealtimeStreamSessionEvents, result)["session"][
"model"
]
potential_model_names.append(received_model)
potential_model_names.append(litellm_model_name)
input_cost_per_token = 0.0
output_cost_per_token = 0.0
for model_name in potential_model_names:
try:
_input_cost_per_token, _output_cost_per_token = generic_cost_per_token(
model=model_name,
usage=combined_usage_object,
custom_llm_provider=custom_llm_provider,
)
except Exception:
continue
input_cost_per_token += _input_cost_per_token
output_cost_per_token += _output_cost_per_token
break # exit if we find a valid model
total_cost = input_cost_per_token + output_cost_per_token
return total_cost

View file

@ -1,10 +1,19 @@
# used for /metrics endpoint on LiteLLM Proxy
#### What this does ####
# On success, log events to Prometheus
import asyncio
import sys
from datetime import datetime, timedelta
from typing import Any, Awaitable, Callable, List, Literal, Optional, Tuple, cast
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
List,
Literal,
Optional,
Tuple,
cast,
)
import litellm
from litellm._logging import print_verbose, verbose_logger
@ -14,6 +23,11 @@ from litellm.types.integrations.prometheus import *
from litellm.types.utils import StandardLoggingPayload
from litellm.utils import get_end_user_id_for_cost_tracking
if TYPE_CHECKING:
from apscheduler.schedulers.asyncio import AsyncIOScheduler
else:
AsyncIOScheduler = Any
class PrometheusLogger(CustomLogger):
# Class variables or attributes
@ -359,8 +373,6 @@ class PrometheusLogger(CustomLogger):
label_name="litellm_requests_metric"
),
)
self._initialize_prometheus_startup_metrics()
except Exception as e:
print_verbose(f"Got exception on init prometheus client {str(e)}")
raise e
@ -988,9 +1000,9 @@ class PrometheusLogger(CustomLogger):
):
try:
verbose_logger.debug("setting remaining tokens requests metric")
standard_logging_payload: Optional[
StandardLoggingPayload
] = request_kwargs.get("standard_logging_object")
standard_logging_payload: Optional[StandardLoggingPayload] = (
request_kwargs.get("standard_logging_object")
)
if standard_logging_payload is None:
return
@ -1337,24 +1349,6 @@ class PrometheusLogger(CustomLogger):
return max_budget - spend
def _initialize_prometheus_startup_metrics(self):
"""
Initialize prometheus startup metrics
Helper to create tasks for initializing metrics that are required on startup - eg. remaining budget metrics
"""
if litellm.prometheus_initialize_budget_metrics is not True:
verbose_logger.debug("Prometheus: skipping budget metrics initialization")
return
try:
if asyncio.get_running_loop():
asyncio.create_task(self._initialize_remaining_budget_metrics())
except RuntimeError as e: # no running event loop
verbose_logger.exception(
f"No running event loop - skipping budget metrics initialization: {str(e)}"
)
async def _initialize_budget_metrics(
self,
data_fetch_function: Callable[..., Awaitable[Tuple[List[Any], Optional[int]]]],
@ -1475,12 +1469,41 @@ class PrometheusLogger(CustomLogger):
data_type="keys",
)
async def _initialize_remaining_budget_metrics(self):
async def initialize_remaining_budget_metrics(self):
"""
Initialize remaining budget metrics for all teams to avoid metric discrepancies.
Handler for initializing remaining budget metrics for all teams to avoid metric discrepancies.
Runs when prometheus logger starts up.
- If redis cache is available, we use the pod lock manager to acquire a lock and initialize the metrics.
- Ensures only one pod emits the metrics at a time.
- If redis cache is not available, we initialize the metrics directly.
"""
from litellm.constants import PROMETHEUS_EMIT_BUDGET_METRICS_JOB_NAME
from litellm.proxy.proxy_server import proxy_logging_obj
pod_lock_manager = proxy_logging_obj.db_spend_update_writer.pod_lock_manager
# if using redis, ensure only one pod emits the metrics at a time
if pod_lock_manager and pod_lock_manager.redis_cache:
if await pod_lock_manager.acquire_lock(
cronjob_id=PROMETHEUS_EMIT_BUDGET_METRICS_JOB_NAME
):
try:
await self._initialize_remaining_budget_metrics()
finally:
await pod_lock_manager.release_lock(
cronjob_id=PROMETHEUS_EMIT_BUDGET_METRICS_JOB_NAME
)
else:
# if not using redis, initialize the metrics directly
await self._initialize_remaining_budget_metrics()
async def _initialize_remaining_budget_metrics(self):
"""
Helper to initialize remaining budget metrics for all teams and API keys.
"""
verbose_logger.debug("Emitting key, team budget metrics....")
await self._initialize_team_budget_metrics()
await self._initialize_api_key_budget_metrics()
@ -1737,6 +1760,36 @@ class PrometheusLogger(CustomLogger):
return (end_time - start_time).total_seconds()
return None
@staticmethod
def initialize_budget_metrics_cron_job(scheduler: AsyncIOScheduler):
"""
Initialize budget metrics as a cron job. This job runs every `PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES` minutes.
It emits the current remaining budget metrics for all Keys and Teams.
"""
from litellm.constants import PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES
from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.prometheus import PrometheusLogger
prometheus_loggers: List[CustomLogger] = (
litellm.logging_callback_manager.get_custom_loggers_for_type(
callback_type=PrometheusLogger
)
)
# we need to get the initialized prometheus logger instance(s) and call logger.initialize_remaining_budget_metrics() on them
verbose_logger.debug("found %s prometheus loggers", len(prometheus_loggers))
if len(prometheus_loggers) > 0:
prometheus_logger = cast(PrometheusLogger, prometheus_loggers[0])
verbose_logger.debug(
"Initializing remaining budget metrics as a cron job executing every %s minutes"
% PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES
)
scheduler.add_job(
prometheus_logger.initialize_remaining_budget_metrics,
"interval",
minutes=PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES,
)
@staticmethod
def _mount_metrics_endpoint(premium_user: bool):
"""

View file

@ -110,5 +110,8 @@ def get_litellm_params(
"azure_password": kwargs.get("azure_password"),
"max_retries": max_retries,
"timeout": kwargs.get("timeout"),
"bucket_name": kwargs.get("bucket_name"),
"vertex_credentials": kwargs.get("vertex_credentials"),
"vertex_project": kwargs.get("vertex_project"),
}
return litellm_params

View file

@ -1,8 +1,8 @@
from typing import Literal, Optional
import litellm
from litellm import LlmProviders
from litellm.exceptions import BadRequestError
from litellm.types.utils import LlmProviders, LlmProvidersSet
def get_supported_openai_params( # noqa: PLR0915
@ -30,6 +30,20 @@ def get_supported_openai_params( # noqa: PLR0915
except BadRequestError:
return None
if custom_llm_provider in LlmProvidersSet:
provider_config = litellm.ProviderConfigManager.get_provider_chat_config(
model=model, provider=LlmProviders(custom_llm_provider)
)
elif custom_llm_provider.split("/")[0] in LlmProvidersSet:
provider_config = litellm.ProviderConfigManager.get_provider_chat_config(
model=model, provider=LlmProviders(custom_llm_provider.split("/")[0])
)
else:
provider_config = None
if provider_config and request_type == "chat_completion":
return provider_config.get_supported_openai_params(model=model)
if custom_llm_provider == "bedrock":
return litellm.AmazonConverseConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "ollama":
@ -226,7 +240,8 @@ def get_supported_openai_params( # noqa: PLR0915
provider_config = litellm.ProviderConfigManager.get_provider_chat_config(
model=model, provider=LlmProviders.CUSTOM
)
return provider_config.get_supported_openai_params(model=model)
if provider_config:
return provider_config.get_supported_openai_params(model=model)
elif request_type == "embeddings":
return None
elif request_type == "transcription":

View file

@ -32,7 +32,10 @@ from litellm.constants import (
DEFAULT_MOCK_RESPONSE_COMPLETION_TOKEN_COUNT,
DEFAULT_MOCK_RESPONSE_PROMPT_TOKEN_COUNT,
)
from litellm.cost_calculator import _select_model_name_for_cost_calc
from litellm.cost_calculator import (
RealtimeAPITokenUsageProcessor,
_select_model_name_for_cost_calc,
)
from litellm.integrations.arize.arize import ArizeLogger
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.integrations.custom_logger import CustomLogger
@ -64,6 +67,7 @@ from litellm.types.utils import (
ImageResponse,
LiteLLMBatch,
LiteLLMLoggingBaseClass,
LiteLLMRealtimeStreamLoggingObject,
ModelResponse,
ModelResponseStream,
RawRequestTypedDict,
@ -618,7 +622,6 @@ class Logging(LiteLLMLoggingBaseClass):
] = RawRequestTypedDict(
error=str(e),
)
traceback.print_exc()
_metadata[
"raw_request"
] = "Unable to Log \
@ -899,9 +902,11 @@ class Logging(LiteLLMLoggingBaseClass):
FineTuningJob,
ResponsesAPIResponse,
ResponseCompletedEvent,
LiteLLMRealtimeStreamLoggingObject,
],
cache_hit: Optional[bool] = None,
litellm_model_name: Optional[str] = None,
router_model_id: Optional[str] = None,
) -> Optional[float]:
"""
Calculate response cost using result + logging object variables.
@ -940,6 +945,7 @@ class Logging(LiteLLMLoggingBaseClass):
"custom_pricing": custom_pricing,
"prompt": prompt,
"standard_built_in_tools_params": self.standard_built_in_tools_params,
"router_model_id": router_model_id,
}
except Exception as e: # error creating kwargs for cost calculation
debug_info = StandardLoggingModelCostFailureDebugInformation(
@ -1049,26 +1055,50 @@ class Logging(LiteLLMLoggingBaseClass):
result = self._handle_anthropic_messages_response_logging(result=result)
## if model in model cost map - log the response cost
## else set cost to None
logging_result = result
if self.call_type == CallTypes.arealtime.value and isinstance(result, list):
combined_usage_object = RealtimeAPITokenUsageProcessor.collect_and_combine_usage_from_realtime_stream_results(
results=result
)
logging_result = (
RealtimeAPITokenUsageProcessor.create_logging_realtime_object(
usage=combined_usage_object,
results=result,
)
)
# self.model_call_details[
# "response_cost"
# ] = handle_realtime_stream_cost_calculation(
# results=result,
# combined_usage_object=combined_usage_object,
# custom_llm_provider=self.custom_llm_provider,
# litellm_model_name=self.model,
# )
# self.model_call_details["combined_usage_object"] = combined_usage_object
if (
standard_logging_object is None
and result is not None
and self.stream is not True
):
if (
isinstance(result, ModelResponse)
or isinstance(result, ModelResponseStream)
or isinstance(result, EmbeddingResponse)
or isinstance(result, ImageResponse)
or isinstance(result, TranscriptionResponse)
or isinstance(result, TextCompletionResponse)
or isinstance(result, HttpxBinaryResponseContent) # tts
or isinstance(result, RerankResponse)
or isinstance(result, FineTuningJob)
or isinstance(result, LiteLLMBatch)
or isinstance(result, ResponsesAPIResponse)
isinstance(logging_result, ModelResponse)
or isinstance(logging_result, ModelResponseStream)
or isinstance(logging_result, EmbeddingResponse)
or isinstance(logging_result, ImageResponse)
or isinstance(logging_result, TranscriptionResponse)
or isinstance(logging_result, TextCompletionResponse)
or isinstance(logging_result, HttpxBinaryResponseContent) # tts
or isinstance(logging_result, RerankResponse)
or isinstance(logging_result, FineTuningJob)
or isinstance(logging_result, LiteLLMBatch)
or isinstance(logging_result, ResponsesAPIResponse)
or isinstance(logging_result, LiteLLMRealtimeStreamLoggingObject)
):
## HIDDEN PARAMS ##
hidden_params = getattr(result, "_hidden_params", {})
hidden_params = getattr(logging_result, "_hidden_params", {})
if hidden_params:
# add to metadata for logging
if self.model_call_details.get("litellm_params") is not None:
@ -1086,7 +1116,7 @@ class Logging(LiteLLMLoggingBaseClass):
self.model_call_details["litellm_params"]["metadata"][ # type: ignore
"hidden_params"
] = getattr(
result, "_hidden_params", {}
logging_result, "_hidden_params", {}
)
## RESPONSE COST - Only calculate if not in hidden_params ##
if "response_cost" in hidden_params:
@ -1096,14 +1126,14 @@ class Logging(LiteLLMLoggingBaseClass):
else:
self.model_call_details[
"response_cost"
] = self._response_cost_calculator(result=result)
] = self._response_cost_calculator(result=logging_result)
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details[
"standard_logging_object"
] = get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj=result,
init_response_obj=logging_result,
start_time=start_time,
end_time=end_time,
logging_obj=self,
@ -3122,6 +3152,7 @@ class StandardLoggingPayloadSetup:
prompt_integration: Optional[str] = None,
applied_guardrails: Optional[List[str]] = None,
mcp_tool_call_metadata: Optional[StandardLoggingMCPToolCall] = None,
usage_object: Optional[dict] = None,
) -> StandardLoggingMetadata:
"""
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
@ -3169,6 +3200,7 @@ class StandardLoggingPayloadSetup:
prompt_management_metadata=prompt_management_metadata,
applied_guardrails=applied_guardrails,
mcp_tool_call_metadata=mcp_tool_call_metadata,
usage_object=usage_object,
)
if isinstance(metadata, dict):
# Filter the metadata dictionary to include only the specified keys
@ -3194,8 +3226,12 @@ class StandardLoggingPayloadSetup:
return clean_metadata
@staticmethod
def get_usage_from_response_obj(response_obj: Optional[dict]) -> Usage:
def get_usage_from_response_obj(
response_obj: Optional[dict], combined_usage_object: Optional[Usage] = None
) -> Usage:
## BASE CASE ##
if combined_usage_object is not None:
return combined_usage_object
if response_obj is None:
return Usage(
prompt_tokens=0,
@ -3324,6 +3360,7 @@ class StandardLoggingPayloadSetup:
litellm_overhead_time_ms=None,
batch_models=None,
litellm_model_name=None,
usage_object=None,
)
if hidden_params is not None:
for key in StandardLoggingHiddenParams.__annotations__.keys():
@ -3440,6 +3477,7 @@ def get_standard_logging_object_payload(
litellm_overhead_time_ms=None,
batch_models=None,
litellm_model_name=None,
usage_object=None,
)
)
@ -3456,8 +3494,12 @@ def get_standard_logging_object_payload(
call_type = kwargs.get("call_type")
cache_hit = kwargs.get("cache_hit", False)
usage = StandardLoggingPayloadSetup.get_usage_from_response_obj(
response_obj=response_obj
response_obj=response_obj,
combined_usage_object=cast(
Optional[Usage], kwargs.get("combined_usage_object")
),
)
id = response_obj.get("id", kwargs.get("litellm_call_id"))
_model_id = metadata.get("model_info", {}).get("id", "")
@ -3496,6 +3538,7 @@ def get_standard_logging_object_payload(
prompt_integration=kwargs.get("prompt_integration", None),
applied_guardrails=kwargs.get("applied_guardrails", None),
mcp_tool_call_metadata=kwargs.get("mcp_tool_call_metadata", None),
usage_object=usage.model_dump(),
)
_request_body = proxy_server_request.get("body", {})
@ -3636,6 +3679,7 @@ def get_standard_logging_metadata(
prompt_management_metadata=None,
applied_guardrails=None,
mcp_tool_call_metadata=None,
usage_object=None,
)
if isinstance(metadata, dict):
# Filter the metadata dictionary to include only the specified keys
@ -3730,6 +3774,7 @@ def create_dummy_standard_logging_payload() -> StandardLoggingPayload:
litellm_overhead_time_ms=None,
batch_models=None,
litellm_model_name=None,
usage_object=None,
)
# Convert numeric values to appropriate types

View file

@ -90,35 +90,45 @@ def _generic_cost_per_character(
return prompt_cost, completion_cost
def _get_prompt_token_base_cost(model_info: ModelInfo, usage: Usage) -> float:
def _get_token_base_cost(model_info: ModelInfo, usage: Usage) -> Tuple[float, float]:
"""
Return prompt cost for a given model and usage.
If input_tokens > 128k and `input_cost_per_token_above_128k_tokens` is set, then we use the `input_cost_per_token_above_128k_tokens` field.
If input_tokens > threshold and `input_cost_per_token_above_[x]k_tokens` or `input_cost_per_token_above_[x]_tokens` is set,
then we use the corresponding threshold cost.
"""
input_cost_per_token_above_128k_tokens = model_info.get(
"input_cost_per_token_above_128k_tokens"
)
if _is_above_128k(usage.prompt_tokens) and input_cost_per_token_above_128k_tokens:
return input_cost_per_token_above_128k_tokens
return model_info["input_cost_per_token"]
prompt_base_cost = model_info["input_cost_per_token"]
completion_base_cost = model_info["output_cost_per_token"]
## CHECK IF ABOVE THRESHOLD
threshold: Optional[float] = None
for key, value in sorted(model_info.items(), reverse=True):
if key.startswith("input_cost_per_token_above_") and value is not None:
try:
# Handle both formats: _above_128k_tokens and _above_128_tokens
threshold_str = key.split("_above_")[1].split("_tokens")[0]
threshold = float(threshold_str.replace("k", "")) * (
1000 if "k" in threshold_str else 1
)
if usage.prompt_tokens > threshold:
prompt_base_cost = cast(
float,
model_info.get(key, prompt_base_cost),
)
completion_base_cost = cast(
float,
model_info.get(
f"output_cost_per_token_above_{threshold_str}_tokens",
completion_base_cost,
),
)
break
except (IndexError, ValueError):
continue
except Exception:
continue
def _get_completion_token_base_cost(model_info: ModelInfo, usage: Usage) -> float:
"""
Return prompt cost for a given model and usage.
If input_tokens > 128k and `input_cost_per_token_above_128k_tokens` is set, then we use the `input_cost_per_token_above_128k_tokens` field.
"""
output_cost_per_token_above_128k_tokens = model_info.get(
"output_cost_per_token_above_128k_tokens"
)
if (
_is_above_128k(usage.completion_tokens)
and output_cost_per_token_above_128k_tokens
):
return output_cost_per_token_above_128k_tokens
return model_info["output_cost_per_token"]
return prompt_base_cost, completion_base_cost
def calculate_cost_component(
@ -215,7 +225,9 @@ def generic_cost_per_token(
if text_tokens == 0:
text_tokens = usage.prompt_tokens - cache_hit_tokens - audio_tokens
prompt_base_cost = _get_prompt_token_base_cost(model_info=model_info, usage=usage)
prompt_base_cost, completion_base_cost = _get_token_base_cost(
model_info=model_info, usage=usage
)
prompt_cost = float(text_tokens) * prompt_base_cost
@ -253,9 +265,6 @@ def generic_cost_per_token(
)
## CALCULATE OUTPUT COST
completion_base_cost = _get_completion_token_base_cost(
model_info=model_info, usage=usage
)
text_tokens = usage.completion_tokens
audio_tokens = 0
if usage.completion_tokens_details is not None:

View file

@ -36,11 +36,16 @@ class ResponseMetadata:
self, logging_obj: LiteLLMLoggingObject, model: Optional[str], kwargs: dict
) -> None:
"""Set hidden parameters on the response"""
## ADD OTHER HIDDEN PARAMS
model_id = kwargs.get("model_info", {}).get("id", None)
new_params = {
"litellm_call_id": getattr(logging_obj, "litellm_call_id", None),
"model_id": kwargs.get("model_info", {}).get("id", None),
"api_base": get_api_base(model=model or "", optional_params=kwargs),
"response_cost": logging_obj._response_cost_calculator(result=self.result),
"model_id": model_id,
"response_cost": logging_obj._response_cost_calculator(
result=self.result, litellm_model_name=model, router_model_id=model_id
),
"additional_headers": process_response_headers(
self._get_value_from_hidden_params("additional_headers") or {}
),

View file

@ -2,7 +2,10 @@
Common utility functions used for translating messages across providers
"""
from typing import Dict, List, Literal, Optional, Union, cast
import io
import mimetypes
from os import PathLike
from typing import Dict, List, Literal, Mapping, Optional, Union, cast
from litellm.types.llms.openai import (
AllMessageValues,
@ -10,7 +13,13 @@ from litellm.types.llms.openai import (
ChatCompletionFileObject,
ChatCompletionUserMessage,
)
from litellm.types.utils import Choices, ModelResponse, StreamingChoices
from litellm.types.utils import (
Choices,
ExtractedFileData,
FileTypes,
ModelResponse,
StreamingChoices,
)
DEFAULT_USER_CONTINUE_MESSAGE = ChatCompletionUserMessage(
content="Please continue.", role="user"
@ -348,3 +357,99 @@ def update_messages_with_model_file_ids(
)
file_object_file_field["file_id"] = provider_file_id
return messages
def extract_file_data(file_data: FileTypes) -> ExtractedFileData:
"""
Extracts and processes file data from various input formats.
Args:
file_data: Can be a tuple of (filename, content, [content_type], [headers]) or direct file content
Returns:
ExtractedFileData containing:
- filename: Name of the file if provided
- content: The file content in bytes
- content_type: MIME type of the file
- headers: Any additional headers
"""
# Parse the file_data based on its type
filename = None
file_content = None
content_type = None
file_headers: Mapping[str, str] = {}
if isinstance(file_data, tuple):
if len(file_data) == 2:
filename, file_content = file_data
elif len(file_data) == 3:
filename, file_content, content_type = file_data
elif len(file_data) == 4:
filename, file_content, content_type, file_headers = file_data
else:
file_content = file_data
# Convert content to bytes
if isinstance(file_content, (str, PathLike)):
# If it's a path, open and read the file
with open(file_content, "rb") as f:
content = f.read()
elif isinstance(file_content, io.IOBase):
# If it's a file-like object
content = file_content.read()
if isinstance(content, str):
content = content.encode("utf-8")
# Reset file pointer to beginning
file_content.seek(0)
elif isinstance(file_content, bytes):
content = file_content
else:
raise ValueError(f"Unsupported file content type: {type(file_content)}")
# Use provided content type or guess based on filename
if not content_type:
content_type = (
mimetypes.guess_type(filename)[0]
if filename
else "application/octet-stream"
)
return ExtractedFileData(
filename=filename,
content=content,
content_type=content_type,
headers=file_headers,
)
def unpack_defs(schema, defs):
properties = schema.get("properties", None)
if properties is None:
return
for name, value in properties.items():
ref_key = value.get("$ref", None)
if ref_key is not None:
ref = defs[ref_key.split("defs/")[-1]]
unpack_defs(ref, defs)
properties[name] = ref
continue
anyof = value.get("anyOf", None)
if anyof is not None:
for i, atype in enumerate(anyof):
ref_key = atype.get("$ref", None)
if ref_key is not None:
ref = defs[ref_key.split("defs/")[-1]]
unpack_defs(ref, defs)
anyof[i] = ref
continue
items = value.get("items", None)
if items is not None:
ref_key = items.get("$ref", None)
if ref_key is not None:
ref = defs[ref_key.split("defs/")[-1]]
unpack_defs(ref, defs)
value["items"] = ref
continue

View file

@ -1,7 +1,6 @@
import copy
import json
import re
import traceback
import uuid
import xml.etree.ElementTree as ET
from enum import Enum
@ -748,7 +747,6 @@ def convert_to_anthropic_image_obj(
data=base64_data,
)
except Exception as e:
traceback.print_exc()
if "Error: Unable to fetch image from URL" in str(e):
raise e
raise Exception(
@ -3442,6 +3440,8 @@ def _bedrock_tools_pt(tools: List) -> List[BedrockToolBlock]:
}
]
"""
from litellm.litellm_core_utils.prompt_templates.common_utils import unpack_defs
tool_block_list: List[BedrockToolBlock] = []
for tool in tools:
parameters = tool.get("function", {}).get(
@ -3455,6 +3455,13 @@ def _bedrock_tools_pt(tools: List) -> List[BedrockToolBlock]:
description = tool.get("function", {}).get(
"description", name
) # converse api requires a description
defs = parameters.pop("$defs", {})
defs_copy = copy.deepcopy(defs)
# flatten the defs
for _, value in defs_copy.items():
unpack_defs(value, defs_copy)
unpack_defs(parameters, defs_copy)
tool_input_schema = BedrockToolInputSchemaBlock(json=parameters)
tool_spec = BedrockToolSpecBlock(
inputSchema=tool_input_schema, name=name, description=description

View file

@ -30,6 +30,11 @@ import json
from typing import Any, Dict, List, Optional, Union
import litellm
from litellm._logging import verbose_logger
from litellm.types.llms.openai import (
OpenAIRealtimeStreamResponseBaseObject,
OpenAIRealtimeStreamSessionEvents,
)
from .litellm_logging import Logging as LiteLLMLogging
@ -53,7 +58,12 @@ class RealTimeStreaming:
self.websocket = websocket
self.backend_ws = backend_ws
self.logging_obj = logging_obj
self.messages: List = []
self.messages: List[
Union[
OpenAIRealtimeStreamResponseBaseObject,
OpenAIRealtimeStreamSessionEvents,
]
] = []
self.input_message: Dict = {}
_logged_real_time_event_types = litellm.logged_real_time_event_types
@ -62,10 +72,14 @@ class RealTimeStreaming:
_logged_real_time_event_types = DefaultLoggedRealTimeEventTypes
self.logged_real_time_event_types = _logged_real_time_event_types
def _should_store_message(self, message: Union[str, bytes]) -> bool:
if isinstance(message, bytes):
message = message.decode("utf-8")
message_obj = json.loads(message)
def _should_store_message(
self,
message_obj: Union[
dict,
OpenAIRealtimeStreamSessionEvents,
OpenAIRealtimeStreamResponseBaseObject,
],
) -> bool:
_msg_type = message_obj["type"]
if self.logged_real_time_event_types == "*":
return True
@ -75,8 +89,22 @@ class RealTimeStreaming:
def store_message(self, message: Union[str, bytes]):
"""Store message in list"""
if self._should_store_message(message):
self.messages.append(message)
if isinstance(message, bytes):
message = message.decode("utf-8")
message_obj = json.loads(message)
try:
if (
message_obj.get("type") == "session.created"
or message_obj.get("type") == "session.updated"
):
message_obj = OpenAIRealtimeStreamSessionEvents(**message_obj) # type: ignore
else:
message_obj = OpenAIRealtimeStreamResponseBaseObject(**message_obj) # type: ignore
except Exception as e:
verbose_logger.debug(f"Error parsing message for logging: {e}")
raise e
if self._should_store_message(message_obj):
self.messages.append(message_obj)
def store_input(self, message: dict):
"""Store input message"""

View file

@ -50,6 +50,7 @@ class AiohttpOpenAIChatConfig(OpenAILikeChatConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -4,7 +4,7 @@ Calling + translation logic for anthropic's `/v1/messages` endpoint
import copy
import json
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
import httpx # type: ignore
@ -301,12 +301,17 @@ class AnthropicChatCompletion(BaseLLM):
model=model,
messages=messages,
optional_params={**optional_params, "is_vertex_request": is_vertex_request},
litellm_params=litellm_params,
)
config = ProviderConfigManager.get_provider_chat_config(
model=model,
provider=LlmProviders(custom_llm_provider),
)
if config is None:
raise ValueError(
f"Provider config not found for model: {model} and provider: {custom_llm_provider}"
)
data = config.transform_request(
model=model,
@ -487,29 +492,10 @@ class ModelResponseIterator:
return False
def _handle_usage(self, anthropic_usage_chunk: Union[dict, UsageDelta]) -> Usage:
usage_block = Usage(
prompt_tokens=anthropic_usage_chunk.get("input_tokens", 0),
completion_tokens=anthropic_usage_chunk.get("output_tokens", 0),
total_tokens=anthropic_usage_chunk.get("input_tokens", 0)
+ anthropic_usage_chunk.get("output_tokens", 0),
return AnthropicConfig().calculate_usage(
usage_object=cast(dict, anthropic_usage_chunk), reasoning_content=None
)
cache_creation_input_tokens = anthropic_usage_chunk.get(
"cache_creation_input_tokens"
)
if cache_creation_input_tokens is not None and isinstance(
cache_creation_input_tokens, int
):
usage_block["cache_creation_input_tokens"] = cache_creation_input_tokens
cache_read_input_tokens = anthropic_usage_chunk.get("cache_read_input_tokens")
if cache_read_input_tokens is not None and isinstance(
cache_read_input_tokens, int
):
usage_block["cache_read_input_tokens"] = cache_read_input_tokens
return usage_block
def _content_block_delta_helper(
self, chunk: dict
) -> Tuple[

View file

@ -682,6 +682,45 @@ class AnthropicConfig(BaseConfig):
reasoning_content += block["thinking"]
return text_content, citations, thinking_blocks, reasoning_content, tool_calls
def calculate_usage(
self, usage_object: dict, reasoning_content: Optional[str]
) -> Usage:
prompt_tokens = usage_object.get("input_tokens", 0)
completion_tokens = usage_object.get("output_tokens", 0)
_usage = usage_object
cache_creation_input_tokens: int = 0
cache_read_input_tokens: int = 0
if "cache_creation_input_tokens" in _usage:
cache_creation_input_tokens = _usage["cache_creation_input_tokens"]
if "cache_read_input_tokens" in _usage:
cache_read_input_tokens = _usage["cache_read_input_tokens"]
prompt_tokens += cache_read_input_tokens
prompt_tokens_details = PromptTokensDetailsWrapper(
cached_tokens=cache_read_input_tokens
)
completion_token_details = (
CompletionTokensDetailsWrapper(
reasoning_tokens=token_counter(
text=reasoning_content, count_response_tokens=True
)
)
if reasoning_content
else None
)
total_tokens = prompt_tokens + completion_tokens
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
prompt_tokens_details=prompt_tokens_details,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
completion_tokens_details=completion_token_details,
)
return usage
def transform_response(
self,
model: str,
@ -772,45 +811,14 @@ class AnthropicConfig(BaseConfig):
)
## CALCULATING USAGE
prompt_tokens = completion_response["usage"]["input_tokens"]
completion_tokens = completion_response["usage"]["output_tokens"]
_usage = completion_response["usage"]
cache_creation_input_tokens: int = 0
cache_read_input_tokens: int = 0
usage = self.calculate_usage(
usage_object=completion_response["usage"],
reasoning_content=reasoning_content,
)
setattr(model_response, "usage", usage) # type: ignore
model_response.created = int(time.time())
model_response.model = completion_response["model"]
if "cache_creation_input_tokens" in _usage:
cache_creation_input_tokens = _usage["cache_creation_input_tokens"]
prompt_tokens += cache_creation_input_tokens
if "cache_read_input_tokens" in _usage:
cache_read_input_tokens = _usage["cache_read_input_tokens"]
prompt_tokens += cache_read_input_tokens
prompt_tokens_details = PromptTokensDetailsWrapper(
cached_tokens=cache_read_input_tokens
)
completion_token_details = (
CompletionTokensDetailsWrapper(
reasoning_tokens=token_counter(
text=reasoning_content, count_response_tokens=True
)
)
if reasoning_content
else None
)
total_tokens = prompt_tokens + completion_tokens
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
prompt_tokens_details=prompt_tokens_details,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
completion_tokens_details=completion_token_details,
)
setattr(model_response, "usage", usage) # type: ignore
model_response._hidden_params = _hidden_params
return model_response
@ -868,6 +876,7 @@ class AnthropicConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict:

View file

@ -87,6 +87,7 @@ class AnthropicTextConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -293,6 +293,7 @@ class AzureOpenAIConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -39,6 +39,7 @@ class AzureAIStudioConfig(OpenAIConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -262,6 +262,7 @@ class BaseConfig(ABC):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, List, Optional
from typing import TYPE_CHECKING, Any, List, Optional, Union
import httpx
@ -33,23 +33,22 @@ class BaseFilesConfig(BaseConfig):
) -> List[OpenAICreateFileRequestOptionalParams]:
pass
def get_complete_url(
def get_complete_file_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
OPTIONAL
Get the complete url for the request
Some providers need `model` in `api_base`
"""
return api_base or ""
data: CreateFileRequest,
):
return self.get_complete_url(
api_base=api_base,
api_key=api_key,
model=model,
optional_params=optional_params,
litellm_params=litellm_params,
)
@abstractmethod
def transform_create_file_request(
@ -58,7 +57,7 @@ class BaseFilesConfig(BaseConfig):
create_file_data: CreateFileRequest,
optional_params: dict,
litellm_params: dict,
) -> dict:
) -> Union[dict, str, bytes]:
pass
@abstractmethod

View file

@ -65,6 +65,7 @@ class BaseImageVariationConfig(BaseConfig, ABC):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -30,6 +30,7 @@ from litellm.types.llms.openai import (
ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk,
ChatCompletionUserMessage,
OpenAIChatCompletionToolParam,
OpenAIMessageContentListBlock,
)
from litellm.types.utils import ModelResponse, PromptTokensDetailsWrapper, Usage
@ -211,13 +212,29 @@ class AmazonConverseConfig(BaseConfig):
)
return _tool
def _apply_tool_call_transformation(
self,
tools: List[OpenAIChatCompletionToolParam],
model: str,
non_default_params: dict,
optional_params: dict,
):
optional_params = self._add_tools_to_optional_params(
optional_params=optional_params, tools=tools
)
if (
"meta.llama3-3-70b-instruct-v1:0" in model
and non_default_params.get("stream", False) is True
):
optional_params["fake_stream"] = True
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
messages: Optional[List[AllMessageValues]] = None,
) -> dict:
is_thinking_enabled = self.is_thinking_enabled(non_default_params)
@ -286,8 +303,11 @@ class AmazonConverseConfig(BaseConfig):
if param == "top_p":
optional_params["topP"] = value
if param == "tools" and isinstance(value, list):
optional_params = self._add_tools_to_optional_params(
optional_params=optional_params, tools=value
self._apply_tool_call_transformation(
tools=cast(List[OpenAIChatCompletionToolParam], value),
model=model,
non_default_params=non_default_params,
optional_params=optional_params,
)
if param == "tool_choice":
_tool_choice_value = self.map_tool_choice_values(
@ -633,8 +653,10 @@ class AmazonConverseConfig(BaseConfig):
cache_read_input_tokens = usage["cacheReadInputTokens"]
input_tokens += cache_read_input_tokens
if "cacheWriteInputTokens" in usage:
"""
Do not increment prompt_tokens with cacheWriteInputTokens
"""
cache_creation_input_tokens = usage["cacheWriteInputTokens"]
input_tokens += cache_creation_input_tokens
prompt_tokens_details = PromptTokensDetailsWrapper(
cached_tokens=cache_read_input_tokens
@ -811,6 +833,7 @@ class AmazonConverseConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -1,13 +1,13 @@
import types
from typing import List, Optional
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
AmazonInvokeConfig,
)
from litellm.llms.cohere.chat.transformation import CohereChatConfig
class AmazonCohereConfig(AmazonInvokeConfig, BaseConfig):
class AmazonCohereConfig(AmazonInvokeConfig, CohereChatConfig):
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=command
@ -19,7 +19,6 @@ class AmazonCohereConfig(AmazonInvokeConfig, BaseConfig):
"""
max_tokens: Optional[int] = None
temperature: Optional[float] = None
return_likelihood: Optional[str] = None
def __init__(
@ -55,11 +54,10 @@ class AmazonCohereConfig(AmazonInvokeConfig, BaseConfig):
}
def get_supported_openai_params(self, model: str) -> List[str]:
return [
"max_tokens",
"temperature",
"stream",
]
supported_params = CohereChatConfig.get_supported_openai_params(
self, model=model
)
return supported_params
def map_openai_params(
self,
@ -68,11 +66,10 @@ class AmazonCohereConfig(AmazonInvokeConfig, BaseConfig):
model: str,
drop_params: bool,
) -> dict:
for k, v in non_default_params.items():
if k == "stream":
optional_params["stream"] = v
if k == "temperature":
optional_params["temperature"] = v
if k == "max_tokens":
optional_params["max_tokens"] = v
return optional_params
return CohereChatConfig.map_openai_params(
self,
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=drop_params,
)

View file

@ -6,14 +6,21 @@ Inherits from `AmazonConverseConfig`
Nova + Invoke API Tutorial: https://docs.aws.amazon.com/nova/latest/userguide/using-invoke-api.html
"""
from typing import List
from typing import Any, List, Optional
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.types.llms.bedrock import BedrockInvokeNovaRequest
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse
from ..converse_transformation import AmazonConverseConfig
from .base_invoke_transformation import AmazonInvokeConfig
class AmazonInvokeNovaConfig(litellm.AmazonConverseConfig):
class AmazonInvokeNovaConfig(AmazonInvokeConfig, AmazonConverseConfig):
"""
Config for sending `nova` requests to `/bedrock/invoke/`
"""
@ -21,6 +28,20 @@ class AmazonInvokeNovaConfig(litellm.AmazonConverseConfig):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def get_supported_openai_params(self, model: str) -> list:
return AmazonConverseConfig.get_supported_openai_params(self, model)
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
return AmazonConverseConfig.map_openai_params(
self, non_default_params, optional_params, model, drop_params
)
def transform_request(
self,
model: str,
@ -29,7 +50,8 @@ class AmazonInvokeNovaConfig(litellm.AmazonConverseConfig):
litellm_params: dict,
headers: dict,
) -> dict:
_transformed_nova_request = super().transform_request(
_transformed_nova_request = AmazonConverseConfig.transform_request(
self,
model=model,
messages=messages,
optional_params=optional_params,
@ -45,6 +67,35 @@ class AmazonInvokeNovaConfig(litellm.AmazonConverseConfig):
)
return bedrock_invoke_nova_request
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: Logging,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> litellm.ModelResponse:
return AmazonConverseConfig.transform_response(
self,
model,
raw_response,
model_response,
logging_obj,
request_data,
messages,
optional_params,
litellm_params,
encoding,
api_key,
json_mode,
)
def _filter_allowed_fields(
self, bedrock_invoke_nova_request: BedrockInvokeNovaRequest
) -> dict:

View file

@ -442,6 +442,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -118,6 +118,7 @@ class ClarifaiConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -60,6 +60,7 @@ class CloudflareChatConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -118,6 +118,7 @@ class CohereChatConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -101,6 +101,7 @@ class CohereTextConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -218,6 +218,10 @@ class BaseLLMAIOHTTPHandler:
provider_config = ProviderConfigManager.get_provider_chat_config(
model=model, provider=litellm.LlmProviders(custom_llm_provider)
)
if provider_config is None:
raise ValueError(
f"Provider config not found for model: {model} and provider: {custom_llm_provider}"
)
# get config from model, custom llm provider
headers = provider_config.validate_environment(
api_key=api_key,
@ -225,6 +229,7 @@ class BaseLLMAIOHTTPHandler:
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
api_base=api_base,
)
@ -494,6 +499,7 @@ class BaseLLMAIOHTTPHandler:
model=model,
messages=[{"role": "user", "content": "test"}],
optional_params=optional_params,
litellm_params=litellm_params,
api_base=api_base,
)

View file

@ -192,7 +192,7 @@ class AsyncHTTPHandler:
async def post(
self,
url: str,
data: Optional[Union[dict, str]] = None, # type: ignore
data: Optional[Union[dict, str, bytes]] = None, # type: ignore
json: Optional[dict] = None,
params: Optional[dict] = None,
headers: Optional[dict] = None,
@ -427,7 +427,7 @@ class AsyncHTTPHandler:
self,
url: str,
client: httpx.AsyncClient,
data: Optional[Union[dict, str]] = None, # type: ignore
data: Optional[Union[dict, str, bytes]] = None, # type: ignore
json: Optional[dict] = None,
params: Optional[dict] = None,
headers: Optional[dict] = None,
@ -527,7 +527,7 @@ class HTTPHandler:
def post(
self,
url: str,
data: Optional[Union[dict, str]] = None,
data: Optional[Union[dict, str, bytes]] = None,
json: Optional[Union[dict, str, List]] = None,
params: Optional[dict] = None,
headers: Optional[dict] = None,
@ -573,7 +573,6 @@ class HTTPHandler:
setattr(e, "text", error_text)
setattr(e, "status_code", e.response.status_code)
raise e
except Exception as e:
raise e

View file

@ -234,6 +234,10 @@ class BaseLLMHTTPHandler:
provider_config = ProviderConfigManager.get_provider_chat_config(
model=model, provider=litellm.LlmProviders(custom_llm_provider)
)
if provider_config is None:
raise ValueError(
f"Provider config not found for model: {model} and provider: {custom_llm_provider}"
)
# get config from model, custom llm provider
headers = provider_config.validate_environment(
@ -243,6 +247,7 @@ class BaseLLMHTTPHandler:
messages=messages,
optional_params=optional_params,
api_base=api_base,
litellm_params=litellm_params,
)
api_base = provider_config.get_complete_url(
@ -621,6 +626,7 @@ class BaseLLMHTTPHandler:
model=model,
messages=[],
optional_params=optional_params,
litellm_params=litellm_params,
)
api_base = provider_config.get_complete_url(
@ -892,6 +898,7 @@ class BaseLLMHTTPHandler:
model=model,
messages=[],
optional_params=optional_params,
litellm_params=litellm_params,
)
if client is None or not isinstance(client, HTTPHandler):
@ -1224,15 +1231,19 @@ class BaseLLMHTTPHandler:
model="",
messages=[],
optional_params={},
litellm_params=litellm_params,
)
api_base = provider_config.get_complete_url(
api_base = provider_config.get_complete_file_url(
api_base=api_base,
api_key=api_key,
model="",
optional_params={},
litellm_params=litellm_params,
data=create_file_data,
)
if api_base is None:
raise ValueError("api_base is required for create_file")
# Get the transformed request data for both steps
transformed_request = provider_config.transform_create_file_request(
@ -1259,48 +1270,57 @@ class BaseLLMHTTPHandler:
else:
sync_httpx_client = client
try:
# Step 1: Initial request to get upload URL
initial_response = sync_httpx_client.post(
url=api_base,
headers={
**headers,
**transformed_request["initial_request"]["headers"],
},
data=json.dumps(transformed_request["initial_request"]["data"]),
timeout=timeout,
)
# Extract upload URL from response headers
upload_url = initial_response.headers.get("X-Goog-Upload-URL")
if not upload_url:
raise ValueError("Failed to get upload URL from initial request")
# Step 2: Upload the actual file
if isinstance(transformed_request, str) or isinstance(
transformed_request, bytes
):
upload_response = sync_httpx_client.post(
url=upload_url,
headers=transformed_request["upload_request"]["headers"],
data=transformed_request["upload_request"]["data"],
url=api_base,
headers=headers,
data=transformed_request,
timeout=timeout,
)
else:
try:
# Step 1: Initial request to get upload URL
initial_response = sync_httpx_client.post(
url=api_base,
headers={
**headers,
**transformed_request["initial_request"]["headers"],
},
data=json.dumps(transformed_request["initial_request"]["data"]),
timeout=timeout,
)
return provider_config.transform_create_file_response(
model=None,
raw_response=upload_response,
logging_obj=logging_obj,
litellm_params=litellm_params,
)
# Extract upload URL from response headers
upload_url = initial_response.headers.get("X-Goog-Upload-URL")
except Exception as e:
raise self._handle_error(
e=e,
provider_config=provider_config,
)
if not upload_url:
raise ValueError("Failed to get upload URL from initial request")
# Step 2: Upload the actual file
upload_response = sync_httpx_client.post(
url=upload_url,
headers=transformed_request["upload_request"]["headers"],
data=transformed_request["upload_request"]["data"],
timeout=timeout,
)
except Exception as e:
raise self._handle_error(
e=e,
provider_config=provider_config,
)
return provider_config.transform_create_file_response(
model=None,
raw_response=upload_response,
logging_obj=logging_obj,
litellm_params=litellm_params,
)
async def async_create_file(
self,
transformed_request: dict,
transformed_request: Union[bytes, str, dict],
litellm_params: dict,
provider_config: BaseFilesConfig,
headers: dict,
@ -1319,45 +1339,54 @@ class BaseLLMHTTPHandler:
else:
async_httpx_client = client
try:
# Step 1: Initial request to get upload URL
initial_response = await async_httpx_client.post(
url=api_base,
headers={
**headers,
**transformed_request["initial_request"]["headers"],
},
data=json.dumps(transformed_request["initial_request"]["data"]),
timeout=timeout,
)
# Extract upload URL from response headers
upload_url = initial_response.headers.get("X-Goog-Upload-URL")
if not upload_url:
raise ValueError("Failed to get upload URL from initial request")
# Step 2: Upload the actual file
if isinstance(transformed_request, str) or isinstance(
transformed_request, bytes
):
upload_response = await async_httpx_client.post(
url=upload_url,
headers=transformed_request["upload_request"]["headers"],
data=transformed_request["upload_request"]["data"],
url=api_base,
headers=headers,
data=transformed_request,
timeout=timeout,
)
else:
try:
# Step 1: Initial request to get upload URL
initial_response = await async_httpx_client.post(
url=api_base,
headers={
**headers,
**transformed_request["initial_request"]["headers"],
},
data=json.dumps(transformed_request["initial_request"]["data"]),
timeout=timeout,
)
return provider_config.transform_create_file_response(
model=None,
raw_response=upload_response,
logging_obj=logging_obj,
litellm_params=litellm_params,
)
# Extract upload URL from response headers
upload_url = initial_response.headers.get("X-Goog-Upload-URL")
except Exception as e:
verbose_logger.exception(f"Error creating file: {e}")
raise self._handle_error(
e=e,
provider_config=provider_config,
)
if not upload_url:
raise ValueError("Failed to get upload URL from initial request")
# Step 2: Upload the actual file
upload_response = await async_httpx_client.post(
url=upload_url,
headers=transformed_request["upload_request"]["headers"],
data=transformed_request["upload_request"]["data"],
timeout=timeout,
)
except Exception as e:
verbose_logger.exception(f"Error creating file: {e}")
raise self._handle_error(
e=e,
provider_config=provider_config,
)
return provider_config.transform_create_file_response(
model=None,
raw_response=upload_response,
logging_obj=logging_obj,
litellm_params=litellm_params,
)
def list_files(self):
"""

View file

@ -27,7 +27,7 @@ from litellm.litellm_core_utils.prompt_templates.common_utils import (
strip_name_from_messages,
)
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.types.llms.anthropic import AnthropicMessagesTool
from litellm.types.llms.anthropic import AllAnthropicToolsValues
from litellm.types.llms.databricks import (
AllDatabricksContentValues,
DatabricksChoice,
@ -116,6 +116,7 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
@ -160,7 +161,7 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
]
def convert_anthropic_tool_to_databricks_tool(
self, tool: Optional[AnthropicMessagesTool]
self, tool: Optional[AllAnthropicToolsValues]
) -> Optional[DatabricksTool]:
if tool is None:
return None
@ -173,6 +174,19 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
),
)
def _map_openai_to_dbrx_tool(self, model: str, tools: List) -> List[DatabricksTool]:
# if not claude, send as is
if "claude" not in model:
return tools
# if claude, convert to anthropic tool and then to databricks tool
anthropic_tools = self._map_tools(tools=tools)
databricks_tools = [
cast(DatabricksTool, self.convert_anthropic_tool_to_databricks_tool(tool))
for tool in anthropic_tools
]
return databricks_tools
def map_response_format_to_databricks_tool(
self,
model: str,
@ -202,6 +216,10 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
mapped_params = super().map_openai_params(
non_default_params, optional_params, model, drop_params
)
if "tools" in mapped_params:
mapped_params["tools"] = self._map_openai_to_dbrx_tool(
model=model, tools=mapped_params["tools"]
)
if (
"max_completion_tokens" in non_default_params
and replace_max_completion_tokens_with_max_tokens
@ -240,6 +258,7 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
optional_params["thinking"] = AnthropicConfig._map_reasoning_effort(
non_default_params.get("reasoning_effort")
)
optional_params.pop("reasoning_effort", None)
## handle thinking tokens
self.update_optional_params_with_thinking_tokens(
non_default_params=non_default_params, optional_params=mapped_params
@ -498,7 +517,10 @@ class DatabricksChatResponseIterator(BaseModelResponseIterator):
message.content = ""
choice["delta"]["content"] = message.content
choice["delta"]["tool_calls"] = None
elif tool_calls:
for _tc in tool_calls:
if _tc.get("function", {}).get("arguments") == "{}":
_tc["function"]["arguments"] = "" # avoid invalid json
# extract the content str
content_str = DatabricksConfig.extract_content_str(
choice["delta"].get("content")

View file

@ -171,6 +171,7 @@ class DeepgramAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -2,7 +2,11 @@ from typing import List, Literal, Optional, Tuple, Union, cast
import litellm
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues, ChatCompletionImageObject
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionImageObject,
OpenAIChatCompletionToolParam,
)
from litellm.types.utils import ProviderSpecificModelInfo
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
@ -150,6 +154,14 @@ class FireworksAIConfig(OpenAIGPTConfig):
] = f"{content['image_url']['url']}#transform=inline"
return content
def _transform_tools(
self, tools: List[OpenAIChatCompletionToolParam]
) -> List[OpenAIChatCompletionToolParam]:
for tool in tools:
if tool.get("type") == "function":
tool["function"].pop("strict", None)
return tools
def _transform_messages_helper(
self, messages: List[AllMessageValues], model: str, litellm_params: dict
) -> List[AllMessageValues]:
@ -196,6 +208,9 @@ class FireworksAIConfig(OpenAIGPTConfig):
messages = self._transform_messages_helper(
messages=messages, model=model, litellm_params=litellm_params
)
if "tools" in optional_params and optional_params["tools"] is not None:
tools = self._transform_tools(tools=optional_params["tools"])
optional_params["tools"] = tools
return super().transform_request(
model=model,
messages=messages,

View file

@ -41,6 +41,7 @@ class FireworksAIMixin:
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -20,6 +20,7 @@ class GeminiModelInfo(BaseLLMModelInfo):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -4,11 +4,12 @@ Supports writing files to Google AI Studio Files API.
For vertex ai, check out the vertex_ai/files/handler.py file.
"""
import time
from typing import List, Mapping, Optional
from typing import List, Optional
import httpx
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
from litellm.llms.base_llm.files.transformation import (
BaseFilesConfig,
LiteLLMLoggingObj,
@ -91,66 +92,28 @@ class GoogleAIStudioFilesHandler(GeminiModelInfo, BaseFilesConfig):
if file_data is None:
raise ValueError("File data is required")
# Parse the file_data based on its type
filename = None
file_content = None
content_type = None
file_headers: Mapping[str, str] = {}
if isinstance(file_data, tuple):
if len(file_data) == 2:
filename, file_content = file_data
elif len(file_data) == 3:
filename, file_content, content_type = file_data
elif len(file_data) == 4:
filename, file_content, content_type, file_headers = file_data
else:
file_content = file_data
# Handle the file content based on its type
import io
from os import PathLike
# Convert content to bytes
if isinstance(file_content, (str, PathLike)):
# If it's a path, open and read the file
with open(file_content, "rb") as f:
content = f.read()
elif isinstance(file_content, io.IOBase):
# If it's a file-like object
content = file_content.read()
if isinstance(content, str):
content = content.encode("utf-8")
elif isinstance(file_content, bytes):
content = file_content
else:
raise ValueError(f"Unsupported file content type: {type(file_content)}")
# Use the common utility function to extract file data
extracted_data = extract_file_data(file_data)
# Get file size
file_size = len(content)
# Use provided content type or guess based on filename
if not content_type:
import mimetypes
content_type = (
mimetypes.guess_type(filename)[0]
if filename
else "application/octet-stream"
)
file_size = len(extracted_data["content"])
# Step 1: Initial resumable upload request
headers = {
"X-Goog-Upload-Protocol": "resumable",
"X-Goog-Upload-Command": "start",
"X-Goog-Upload-Header-Content-Length": str(file_size),
"X-Goog-Upload-Header-Content-Type": content_type,
"X-Goog-Upload-Header-Content-Type": extracted_data["content_type"],
"Content-Type": "application/json",
}
headers.update(file_headers) # Add any custom headers
headers.update(extracted_data["headers"]) # Add any custom headers
# Initial metadata request body
initial_data = {"file": {"display_name": filename or str(int(time.time()))}}
initial_data = {
"file": {
"display_name": extracted_data["filename"] or str(int(time.time()))
}
}
# Step 2: Actual file upload data
upload_headers = {
@ -161,7 +124,10 @@ class GoogleAIStudioFilesHandler(GeminiModelInfo, BaseFilesConfig):
return {
"initial_request": {"headers": headers, "data": initial_data},
"upload_request": {"headers": upload_headers, "data": content},
"upload_request": {
"headers": upload_headers,
"data": extracted_data["content"],
},
}
def transform_create_file_response(

View file

@ -1,6 +1,6 @@
import logging
import os
from typing import TYPE_CHECKING, Any, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import httpx
@ -18,7 +18,6 @@ from litellm.llms.base_llm.chat.transformation import BaseLLMException
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
from ..common_utils import HuggingFaceError, _fetch_inference_provider_mapping
logger = logging.getLogger(__name__)
BASE_URL = "https://router.huggingface.co"
@ -34,7 +33,8 @@ class HuggingFaceChatConfig(OpenAIGPTConfig):
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
optional_params: Dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
@ -51,7 +51,9 @@ class HuggingFaceChatConfig(OpenAIGPTConfig):
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return HuggingFaceError(status_code=status_code, message=error_message, headers=headers)
return HuggingFaceError(
status_code=status_code, message=error_message, headers=headers
)
def get_base_url(self, model: str, base_url: Optional[str]) -> Optional[str]:
"""
@ -82,7 +84,9 @@ class HuggingFaceChatConfig(OpenAIGPTConfig):
if api_base is not None:
complete_url = api_base
elif os.getenv("HF_API_BASE") or os.getenv("HUGGINGFACE_API_BASE"):
complete_url = str(os.getenv("HF_API_BASE")) or str(os.getenv("HUGGINGFACE_API_BASE"))
complete_url = str(os.getenv("HF_API_BASE")) or str(
os.getenv("HUGGINGFACE_API_BASE")
)
elif model.startswith(("http://", "https://")):
complete_url = model
# 4. Default construction with provider
@ -138,4 +142,8 @@ class HuggingFaceChatConfig(OpenAIGPTConfig):
)
mapped_model = provider_mapping["providerId"]
messages = self._transform_messages(messages=messages, model=mapped_model)
return dict(ChatCompletionRequest(model=mapped_model, messages=messages, **optional_params))
return dict(
ChatCompletionRequest(
model=mapped_model, messages=messages, **optional_params
)
)

View file

@ -1,15 +1,6 @@
import json
import os
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Union,
get_args,
)
from typing import Any, Callable, Dict, List, Literal, Optional, Union, get_args
import httpx
@ -35,8 +26,9 @@ hf_tasks_embeddings = Literal[ # pipeline tags + hf tei endpoints - https://hug
]
def get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_base: str) -> Optional[str]:
def get_hf_task_embedding_for_model(
model: str, task_type: Optional[str], api_base: str
) -> Optional[str]:
if task_type is not None:
if task_type in get_args(hf_tasks_embeddings):
return task_type
@ -57,7 +49,9 @@ def get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_ba
return pipeline_tag
async def async_get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_base: str) -> Optional[str]:
async def async_get_hf_task_embedding_for_model(
model: str, task_type: Optional[str], api_base: str
) -> Optional[str]:
if task_type is not None:
if task_type in get_args(hf_tasks_embeddings):
return task_type
@ -116,7 +110,9 @@ class HuggingFaceEmbedding(BaseLLM):
input: List,
optional_params: dict,
) -> dict:
hf_task = await async_get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
hf_task = await async_get_hf_task_embedding_for_model(
model=model, task_type=task_type, api_base=HF_HUB_URL
)
data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task)
@ -173,7 +169,9 @@ class HuggingFaceEmbedding(BaseLLM):
task_type = optional_params.pop("input_type", None)
if call_type == "sync":
hf_task = get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
hf_task = get_hf_task_embedding_for_model(
model=model, task_type=task_type, api_base=HF_HUB_URL
)
elif call_type == "async":
return self._async_transform_input(
model=model, task_type=task_type, embed_url=embed_url, input=input
@ -325,6 +323,7 @@ class HuggingFaceEmbedding(BaseLLM):
input: list,
model_response: EmbeddingResponse,
optional_params: dict,
litellm_params: dict,
logging_obj: LiteLLMLoggingObj,
encoding: Callable,
api_key: Optional[str] = None,
@ -341,9 +340,12 @@ class HuggingFaceEmbedding(BaseLLM):
model=model,
optional_params=optional_params,
messages=[],
litellm_params=litellm_params,
)
task_type = optional_params.pop("input_type", None)
task = get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
task = get_hf_task_embedding_for_model(
model=model, task_type=task_type, api_base=HF_HUB_URL
)
# print_verbose(f"{model}, {task}")
embed_url = ""
if "https" in model:
@ -355,7 +357,9 @@ class HuggingFaceEmbedding(BaseLLM):
elif "HUGGINGFACE_API_BASE" in os.environ:
embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
else:
embed_url = f"https://router.huggingface.co/hf-inference/pipeline/{task}/{model}"
embed_url = (
f"https://router.huggingface.co/hf-inference/pipeline/{task}/{model}"
)
## ROUTING ##
if aembedding is True:

View file

@ -355,6 +355,7 @@ class HuggingFaceEmbeddingConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: Dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict:

View file

@ -10,6 +10,11 @@ from ...openai.chat.gpt_transformation import OpenAIGPTConfig
class LiteLLMProxyChatConfig(OpenAIGPTConfig):
def get_supported_openai_params(self, model: str) -> List:
list = super().get_supported_openai_params(model)
list.append("thinking")
return list
def _get_openai_compatible_provider_info(
self, api_base: Optional[str], api_key: Optional[str]
) -> Tuple[Optional[str], Optional[str]]:

View file

@ -36,6 +36,7 @@ def completion(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
)
## Load Config

View file

@ -93,6 +93,7 @@ class NLPCloudConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -353,6 +353,7 @@ class OllamaConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -32,6 +32,7 @@ def completion(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
)
if "https" in model:
completion_url = model
@ -123,6 +124,7 @@ def embedding(
model=model,
messages=[],
optional_params=optional_params,
litellm_params={},
)
response = litellm.module_level_client.post(
embeddings_url, headers=headers, json=data

View file

@ -88,6 +88,7 @@ class OobaboogaConfig(OpenAIGPTConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -321,6 +321,7 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -286,6 +286,7 @@ class OpenAIConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -53,6 +53,7 @@ class OpenAIWhisperAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -131,6 +131,7 @@ class PetalsConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -228,10 +228,10 @@ class PredibaseChatCompletion:
api_key: str,
logging_obj,
optional_params: dict,
litellm_params: dict,
tenant_id: str,
timeout: Union[float, httpx.Timeout],
acompletion=None,
litellm_params=None,
logger_fn=None,
headers: dict = {},
) -> Union[ModelResponse, CustomStreamWrapper]:
@ -241,6 +241,7 @@ class PredibaseChatCompletion:
messages=messages,
optional_params=optional_params,
model=model,
litellm_params=litellm_params,
)
completion_url = ""
input_text = ""

View file

@ -164,6 +164,7 @@ class PredibaseConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -141,6 +141,7 @@ def completion(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
)
# Start a prediction and get the prediction URL
version_id = replicate_config.model_to_version_id(model)

View file

@ -312,6 +312,7 @@ class ReplicateConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -96,6 +96,7 @@ class SagemakerLLM(BaseAWSLLM):
model: str,
data: dict,
messages: List[AllMessageValues],
litellm_params: dict,
optional_params: dict,
aws_region_name: str,
extra_headers: Optional[dict] = None,
@ -122,6 +123,7 @@ class SagemakerLLM(BaseAWSLLM):
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
)
request = AWSRequest(
method="POST", url=api_base, data=encoded_data, headers=headers
@ -198,6 +200,7 @@ class SagemakerLLM(BaseAWSLLM):
data=data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
credentials=credentials,
aws_region_name=aws_region_name,
)
@ -274,6 +277,7 @@ class SagemakerLLM(BaseAWSLLM):
"model": model,
"data": _data,
"optional_params": optional_params,
"litellm_params": litellm_params,
"credentials": credentials,
"aws_region_name": aws_region_name,
"messages": messages,
@ -426,6 +430,7 @@ class SagemakerLLM(BaseAWSLLM):
"model": model,
"data": data,
"optional_params": optional_params,
"litellm_params": litellm_params,
"credentials": credentials,
"aws_region_name": aws_region_name,
"messages": messages,
@ -496,6 +501,7 @@ class SagemakerLLM(BaseAWSLLM):
"model": model,
"data": data,
"optional_params": optional_params,
"litellm_params": litellm_params,
"credentials": credentials,
"aws_region_name": aws_region_name,
"messages": messages,

View file

@ -263,6 +263,7 @@ class SagemakerConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -92,6 +92,7 @@ class SnowflakeConfig(OpenAIGPTConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -37,6 +37,7 @@ class TopazImageVariationConfig(BaseImageVariationConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -48,6 +48,7 @@ class TritonConfig(BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: Dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict:

View file

@ -42,6 +42,7 @@ class TritonEmbeddingConfig(BaseEmbeddingConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -1,13 +1,14 @@
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union, get_type_hints
import re
from typing import Dict, List, Literal, Optional, Tuple, Union
import httpx
import litellm
from litellm import supports_response_schema, supports_system_messages, verbose_logger
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
from litellm.litellm_core_utils.prompt_templates.common_utils import unpack_defs
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.llms.vertex_ai import PartType
from litellm.types.llms.vertex_ai import PartType, Schema
class VertexAIError(BaseLLMException):
@ -168,6 +169,9 @@ def _build_vertex_schema(parameters: dict):
"""
This is a modified version of https://github.com/google-gemini/generative-ai-python/blob/8f77cc6ac99937cd3a81299ecf79608b91b06bbb/google/generativeai/types/content_types.py#L419
"""
# Get valid fields from Schema TypedDict
valid_schema_fields = set(get_type_hints(Schema).keys())
defs = parameters.pop("$defs", {})
# flatten the defs
for name, value in defs.items():
@ -181,52 +185,49 @@ def _build_vertex_schema(parameters: dict):
convert_anyof_null_to_nullable(parameters)
add_object_type(parameters)
# Postprocessing
# 4. Suppress unnecessary title generation:
# * https://github.com/pydantic/pydantic/issues/1051
# * http://cl/586221780
strip_field(parameters, field_name="title")
strip_field(
parameters, field_name="$schema"
) # 5. Remove $schema - json schema value, not supported by OpenAPI - causes vertex errors.
strip_field(
parameters, field_name="$id"
) # 6. Remove id - json schema value, not supported by OpenAPI - causes vertex errors.
return parameters
# Filter out fields that don't exist in Schema
filtered_parameters = filter_schema_fields(parameters, valid_schema_fields)
return filtered_parameters
def unpack_defs(schema, defs):
properties = schema.get("properties", None)
if properties is None:
return
def filter_schema_fields(
schema_dict: Dict[str, Any], valid_fields: Set[str], processed=None
) -> Dict[str, Any]:
"""
Recursively filter a schema dictionary to keep only valid fields.
"""
if processed is None:
processed = set()
for name, value in properties.items():
ref_key = value.get("$ref", None)
if ref_key is not None:
ref = defs[ref_key.split("defs/")[-1]]
unpack_defs(ref, defs)
properties[name] = ref
# Handle circular references
schema_id = id(schema_dict)
if schema_id in processed:
return schema_dict
processed.add(schema_id)
if not isinstance(schema_dict, dict):
return schema_dict
result = {}
for key, value in schema_dict.items():
if key not in valid_fields:
continue
anyof = value.get("anyOf", None)
if anyof is not None:
for i, atype in enumerate(anyof):
ref_key = atype.get("$ref", None)
if ref_key is not None:
ref = defs[ref_key.split("defs/")[-1]]
unpack_defs(ref, defs)
anyof[i] = ref
continue
if key == "properties" and isinstance(value, dict):
result[key] = {
k: filter_schema_fields(v, valid_fields, processed)
for k, v in value.items()
}
elif key == "items" and isinstance(value, dict):
result[key] = filter_schema_fields(value, valid_fields, processed)
elif key == "anyOf" and isinstance(value, list):
result[key] = [
filter_schema_fields(item, valid_fields, processed) for item in value # type: ignore
]
else:
result[key] = value
items = value.get("items", None)
if items is not None:
ref_key = items.get("$ref", None)
if ref_key is not None:
ref = defs[ref_key.split("defs/")[-1]]
unpack_defs(ref, defs)
value["items"] = ref
continue
return result
def convert_anyof_null_to_nullable(schema, depth=0):

View file

@ -1,3 +1,4 @@
import asyncio
from typing import Any, Coroutine, Optional, Union
import httpx
@ -11,9 +12,9 @@ from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.types.llms.openai import CreateFileRequest, OpenAIFileObject
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
from .transformation import VertexAIFilesTransformation
from .transformation import VertexAIJsonlFilesTransformation
vertex_ai_files_transformation = VertexAIFilesTransformation()
vertex_ai_files_transformation = VertexAIJsonlFilesTransformation()
class VertexAIFilesHandler(GCSBucketBase):
@ -92,5 +93,15 @@ class VertexAIFilesHandler(GCSBucketBase):
timeout=timeout,
max_retries=max_retries,
)
return None # type: ignore
else:
return asyncio.run(
self.async_create_file(
create_file_data=create_file_data,
api_base=api_base,
vertex_credentials=vertex_credentials,
vertex_project=vertex_project,
vertex_location=vertex_location,
timeout=timeout,
max_retries=max_retries,
)
)

View file

@ -1,7 +1,17 @@
import json
import os
import time
import uuid
from typing import Any, Dict, List, Optional, Tuple, Union
from httpx import Headers, Response
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.base_llm.files.transformation import (
BaseFilesConfig,
LiteLLMLoggingObj,
)
from litellm.llms.vertex_ai.common_utils import (
_convert_vertex_datetime_to_openai_datetime,
)
@ -10,14 +20,317 @@ from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
VertexGeminiConfig,
)
from litellm.types.llms.openai import (
AllMessageValues,
CreateFileRequest,
FileTypes,
OpenAICreateFileRequestOptionalParams,
OpenAIFileObject,
PathLike,
)
from litellm.types.llms.vertex_ai import GcsBucketResponse
from litellm.types.utils import ExtractedFileData, LlmProviders
from ..common_utils import VertexAIError
from ..vertex_llm_base import VertexBase
class VertexAIFilesTransformation(VertexGeminiConfig):
class VertexAIFilesConfig(VertexBase, BaseFilesConfig):
"""
Config for VertexAI Files
"""
def __init__(self):
self.jsonl_transformation = VertexAIJsonlFilesTransformation()
super().__init__()
@property
def custom_llm_provider(self) -> LlmProviders:
return LlmProviders.VERTEX_AI
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
if not api_key:
api_key, _ = self.get_access_token(
credentials=litellm_params.get("vertex_credentials"),
project_id=litellm_params.get("vertex_project"),
)
if not api_key:
raise ValueError("api_key is required")
headers["Authorization"] = f"Bearer {api_key}"
return headers
def _get_content_from_openai_file(self, openai_file_content: FileTypes) -> str:
"""
Helper to extract content from various OpenAI file types and return as string.
Handles:
- Direct content (str, bytes, IO[bytes])
- Tuple formats: (filename, content, [content_type], [headers])
- PathLike objects
"""
content: Union[str, bytes] = b""
# Extract file content from tuple if necessary
if isinstance(openai_file_content, tuple):
# Take the second element which is always the file content
file_content = openai_file_content[1]
else:
file_content = openai_file_content
# Handle different file content types
if isinstance(file_content, str):
# String content can be used directly
content = file_content
elif isinstance(file_content, bytes):
# Bytes content can be decoded
content = file_content
elif isinstance(file_content, PathLike): # PathLike
with open(str(file_content), "rb") as f:
content = f.read()
elif hasattr(file_content, "read"): # IO[bytes]
# File-like objects need to be read
content = file_content.read()
# Ensure content is string
if isinstance(content, bytes):
content = content.decode("utf-8")
return content
def _get_gcs_object_name_from_batch_jsonl(
self,
openai_jsonl_content: List[Dict[str, Any]],
) -> str:
"""
Gets a unique GCS object name for the VertexAI batch prediction job
named as: litellm-vertex-{model}-{uuid}
"""
_model = openai_jsonl_content[0].get("body", {}).get("model", "")
if "publishers/google/models" not in _model:
_model = f"publishers/google/models/{_model}"
object_name = f"litellm-vertex-files/{_model}/{uuid.uuid4()}"
return object_name
def get_object_name(
self, extracted_file_data: ExtractedFileData, purpose: str
) -> str:
"""
Get the object name for the request
"""
extracted_file_data_content = extracted_file_data.get("content")
if extracted_file_data_content is None:
raise ValueError("file content is required")
if purpose == "batch":
## 1. If jsonl, check if there's a model name
file_content = self._get_content_from_openai_file(
extracted_file_data_content
)
# Split into lines and parse each line as JSON
openai_jsonl_content = [
json.loads(line) for line in file_content.splitlines() if line.strip()
]
if len(openai_jsonl_content) > 0:
return self._get_gcs_object_name_from_batch_jsonl(openai_jsonl_content)
## 2. If not jsonl, return the filename
filename = extracted_file_data.get("filename")
if filename:
return filename
## 3. If no file name, return timestamp
return str(int(time.time()))
def get_complete_file_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: Dict,
litellm_params: Dict,
data: CreateFileRequest,
) -> str:
"""
Get the complete url for the request
"""
bucket_name = litellm_params.get("bucket_name") or os.getenv("GCS_BUCKET_NAME")
if not bucket_name:
raise ValueError("GCS bucket_name is required")
file_data = data.get("file")
purpose = data.get("purpose")
if file_data is None:
raise ValueError("file is required")
if purpose is None:
raise ValueError("purpose is required")
extracted_file_data = extract_file_data(file_data)
object_name = self.get_object_name(extracted_file_data, purpose)
endpoint = (
f"upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}"
)
api_base = api_base or "https://storage.googleapis.com"
if not api_base:
raise ValueError("api_base is required")
return f"{api_base}/{endpoint}"
def get_supported_openai_params(
self, model: str
) -> List[OpenAICreateFileRequestOptionalParams]:
return []
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
return optional_params
def _map_openai_to_vertex_params(
self,
openai_request_body: Dict[str, Any],
) -> Dict[str, Any]:
"""
wrapper to call VertexGeminiConfig.map_openai_params
"""
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
VertexGeminiConfig,
)
config = VertexGeminiConfig()
_model = openai_request_body.get("model", "")
vertex_params = config.map_openai_params(
model=_model,
non_default_params=openai_request_body,
optional_params={},
drop_params=False,
)
return vertex_params
def _transform_openai_jsonl_content_to_vertex_ai_jsonl_content(
self, openai_jsonl_content: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""
Transforms OpenAI JSONL content to VertexAI JSONL content
jsonl body for vertex is {"request": <request_body>}
Example Vertex jsonl
{"request":{"contents": [{"role": "user", "parts": [{"text": "What is the relation between the following video and image samples?"}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/animals.mp4", "mimeType": "video/mp4"}}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/image/cricket.jpeg", "mimeType": "image/jpeg"}}]}]}}
{"request":{"contents": [{"role": "user", "parts": [{"text": "Describe what is happening in this video."}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/another_video.mov", "mimeType": "video/mov"}}]}]}}
"""
vertex_jsonl_content = []
for _openai_jsonl_content in openai_jsonl_content:
openai_request_body = _openai_jsonl_content.get("body") or {}
vertex_request_body = _transform_request_body(
messages=openai_request_body.get("messages", []),
model=openai_request_body.get("model", ""),
optional_params=self._map_openai_to_vertex_params(openai_request_body),
custom_llm_provider="vertex_ai",
litellm_params={},
cached_content=None,
)
vertex_jsonl_content.append({"request": vertex_request_body})
return vertex_jsonl_content
def transform_create_file_request(
self,
model: str,
create_file_data: CreateFileRequest,
optional_params: dict,
litellm_params: dict,
) -> Union[bytes, str, dict]:
"""
2 Cases:
1. Handle basic file upload
2. Handle batch file upload (.jsonl)
"""
file_data = create_file_data.get("file")
if file_data is None:
raise ValueError("file is required")
extracted_file_data = extract_file_data(file_data)
extracted_file_data_content = extracted_file_data.get("content")
if (
create_file_data.get("purpose") == "batch"
and extracted_file_data.get("content_type") == "application/jsonl"
and extracted_file_data_content is not None
):
## 1. If jsonl, check if there's a model name
file_content = self._get_content_from_openai_file(
extracted_file_data_content
)
# Split into lines and parse each line as JSON
openai_jsonl_content = [
json.loads(line) for line in file_content.splitlines() if line.strip()
]
vertex_jsonl_content = (
self._transform_openai_jsonl_content_to_vertex_ai_jsonl_content(
openai_jsonl_content
)
)
return json.dumps(vertex_jsonl_content)
elif isinstance(extracted_file_data_content, bytes):
return extracted_file_data_content
else:
raise ValueError("Unsupported file content type")
def transform_create_file_response(
self,
model: Optional[str],
raw_response: Response,
logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
) -> OpenAIFileObject:
"""
Transform VertexAI File upload response into OpenAI-style FileObject
"""
response_json = raw_response.json()
try:
response_object = GcsBucketResponse(**response_json) # type: ignore
except Exception as e:
raise VertexAIError(
status_code=raw_response.status_code,
message=f"Error reading GCS bucket response: {e}",
headers=raw_response.headers,
)
gcs_id = response_object.get("id", "")
# Remove the last numeric ID from the path
gcs_id = "/".join(gcs_id.split("/")[:-1]) if gcs_id else ""
return OpenAIFileObject(
purpose=response_object.get("purpose", "batch"),
id=f"gs://{gcs_id}",
filename=response_object.get("name", ""),
created_at=_convert_vertex_datetime_to_openai_datetime(
vertex_datetime=response_object.get("timeCreated", "")
),
status="uploaded",
bytes=int(response_object.get("size", 0)),
object="file",
)
def get_error_class(
self, error_message: str, status_code: int, headers: Union[Dict, Headers]
) -> BaseLLMException:
return VertexAIError(
status_code=status_code, message=error_message, headers=headers
)
class VertexAIJsonlFilesTransformation(VertexGeminiConfig):
"""
Transforms OpenAI /v1/files/* requests to VertexAI /v1/files/* requests
"""

View file

@ -208,25 +208,24 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915
elif element["type"] == "input_audio":
audio_element = cast(ChatCompletionAudioObject, element)
if audio_element["input_audio"].get("data") is not None:
_part = PartType(
inline_data=BlobType(
data=audio_element["input_audio"]["data"],
mime_type="audio/{}".format(
audio_element["input_audio"]["format"]
),
)
_part = _process_gemini_image(
image_url=audio_element["input_audio"]["data"],
format=audio_element["input_audio"].get("format"),
)
_parts.append(_part)
elif element["type"] == "file":
file_element = cast(ChatCompletionFileObject, element)
file_id = file_element["file"].get("file_id")
format = file_element["file"].get("format")
if not file_id:
continue
file_data = file_element["file"].get("file_data")
passed_file = file_id or file_data
if passed_file is None:
raise Exception(
"Unknown file type. Please pass in a file_id or file_data"
)
try:
_part = _process_gemini_image(
image_url=file_id, format=format
image_url=passed_file, format=format
)
_parts.append(_part)
except Exception:

View file

@ -240,6 +240,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
gtool_func_declarations = []
googleSearch: Optional[dict] = None
googleSearchRetrieval: Optional[dict] = None
enterpriseWebSearch: Optional[dict] = None
code_execution: Optional[dict] = None
# remove 'additionalProperties' from tools
value = _remove_additional_properties(value)
@ -273,6 +274,8 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
googleSearch = tool["googleSearch"]
elif tool.get("googleSearchRetrieval", None) is not None:
googleSearchRetrieval = tool["googleSearchRetrieval"]
elif tool.get("enterpriseWebSearch", None) is not None:
enterpriseWebSearch = tool["enterpriseWebSearch"]
elif tool.get("code_execution", None) is not None:
code_execution = tool["code_execution"]
elif openai_function_object is not None:
@ -299,6 +302,8 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
_tools["googleSearch"] = googleSearch
if googleSearchRetrieval is not None:
_tools["googleSearchRetrieval"] = googleSearchRetrieval
if enterpriseWebSearch is not None:
_tools["enterpriseWebSearch"] = enterpriseWebSearch
if code_execution is not None:
_tools["code_execution"] = code_execution
return [_tools]
@ -374,7 +379,11 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
optional_params["responseLogprobs"] = value
elif param == "top_logprobs":
optional_params["logprobs"] = value
elif (param == "tools" or param == "functions") and isinstance(value, list):
elif (
(param == "tools" or param == "functions")
and isinstance(value, list)
and value
):
optional_params["tools"] = self._map_function(value=value)
optional_params["litellm_param_is_function_call"] = (
True if param == "functions" else False
@ -739,9 +748,6 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
chat_completion_logprobs = self._transform_logprobs(
logprobs_result=candidate["logprobsResult"]
)
# Handle avgLogprobs for Gemini Flash 2.0
elif "avgLogprobs" in candidate:
chat_completion_logprobs = candidate["avgLogprobs"]
if tools:
chat_completion_message["tool_calls"] = tools
@ -896,6 +902,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
model: str,
messages: List[AllMessageValues],
optional_params: Dict,
litellm_params: Dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict:
@ -1013,7 +1020,7 @@ class VertexLLM(VertexBase):
logging_obj,
stream,
optional_params: dict,
litellm_params=None,
litellm_params: dict,
logger_fn=None,
api_base: Optional[str] = None,
client: Optional[AsyncHTTPHandler] = None,
@ -1054,6 +1061,7 @@ class VertexLLM(VertexBase):
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
)
## LOGGING
@ -1140,6 +1148,7 @@ class VertexLLM(VertexBase):
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
)
request_body = await async_transform_request_body(**data) # type: ignore
@ -1313,6 +1322,7 @@ class VertexLLM(VertexBase):
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
)
## TRANSFORMATION ##

View file

@ -94,6 +94,7 @@ class VertexMultimodalEmbedding(VertexLLM):
optional_params=optional_params,
api_key=auth_header,
api_base=api_base,
litellm_params=litellm_params,
)
## LOGGING

View file

@ -47,6 +47,7 @@ class VertexAIMultimodalEmbeddingConfig(BaseEmbeddingConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -2,9 +2,10 @@ import types
from typing import Optional
import litellm
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
class VertexAIAi21Config:
class VertexAIAi21Config(OpenAIGPTConfig):
"""
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/ai21
@ -40,9 +41,6 @@ class VertexAIAi21Config:
and v is not None
}
def get_supported_openai_params(self):
return litellm.OpenAIConfig().get_supported_openai_params(model="gpt-3.5-turbo")
def map_openai_params(
self,
non_default_params: dict,

View file

@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.asyncify import asyncify
from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
@ -22,7 +21,7 @@ else:
GoogleCredentialsObject = Any
class VertexBase(BaseLLM):
class VertexBase:
def __init__(self) -> None:
super().__init__()
self.access_token: Optional[str] = None

View file

@ -83,6 +83,7 @@ class VoyageEmbeddingConfig(BaseEmbeddingConfig):
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:

View file

@ -49,6 +49,7 @@ class WatsonXChatHandler(OpenAILikeChatHandler):
messages=messages,
optional_params=optional_params,
api_key=api_key,
litellm_params=litellm_params,
)
## UPDATE PAYLOAD (optional params)

View file

@ -165,6 +165,7 @@ class IBMWatsonXMixin:
model: str,
messages: List[AllMessageValues],
optional_params: Dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict:

View file

@ -3616,6 +3616,7 @@ def embedding( # noqa: PLR0915
optional_params=optional_params,
client=client,
aembedding=aembedding,
litellm_params=litellm_params_dict,
)
elif custom_llm_provider == "bedrock":
if isinstance(input, str):

View file

@ -380,6 +380,7 @@
"supports_tool_choice": true,
"supports_native_streaming": false,
"supported_modalities": ["text", "image"],
"supported_output_modalities": ["text"],
"supported_endpoints": ["/v1/responses", "/v1/batch"]
},
"o1-pro-2025-03-19": {
@ -401,6 +402,7 @@
"supports_tool_choice": true,
"supports_native_streaming": false,
"supported_modalities": ["text", "image"],
"supported_output_modalities": ["text"],
"supported_endpoints": ["/v1/responses", "/v1/batch"]
},
"o1": {
@ -1286,6 +1288,68 @@
"supports_system_messages": true,
"supports_tool_choice": true
},
"azure/gpt-4o-realtime-preview-2024-12-17": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000005,
"input_cost_per_audio_token": 0.00004,
"cache_read_input_token_cost": 0.0000025,
"output_cost_per_token": 0.00002,
"output_cost_per_audio_token": 0.00008,
"litellm_provider": "azure",
"mode": "chat",
"supported_modalities": ["text", "audio"],
"supported_output_modalities": ["text", "audio"],
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_audio_input": true,
"supports_audio_output": true,
"supports_system_messages": true,
"supports_tool_choice": true
},
"azure/us/gpt-4o-realtime-preview-2024-12-17": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 5.5e-6,
"input_cost_per_audio_token": 44e-6,
"cache_read_input_token_cost": 2.75e-6,
"cache_read_input_audio_token_cost": 2.5e-6,
"output_cost_per_token": 22e-6,
"output_cost_per_audio_token": 80e-6,
"litellm_provider": "azure",
"mode": "chat",
"supported_modalities": ["text", "audio"],
"supported_output_modalities": ["text", "audio"],
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_audio_input": true,
"supports_audio_output": true,
"supports_system_messages": true,
"supports_tool_choice": true
},
"azure/eu/gpt-4o-realtime-preview-2024-12-17": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 5.5e-6,
"input_cost_per_audio_token": 44e-6,
"cache_read_input_token_cost": 2.75e-6,
"cache_read_input_audio_token_cost": 2.5e-6,
"output_cost_per_token": 22e-6,
"output_cost_per_audio_token": 80e-6,
"litellm_provider": "azure",
"mode": "chat",
"supported_modalities": ["text", "audio"],
"supported_output_modalities": ["text", "audio"],
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_audio_input": true,
"supports_audio_output": true,
"supports_system_messages": true,
"supports_tool_choice": true
},
"azure/gpt-4o-realtime-preview-2024-10-01": {
"max_tokens": 4096,
"max_input_tokens": 128000,
@ -2300,6 +2364,18 @@
"source": "https://azuremarketplace.microsoft.com/en/marketplace/apps/000-000.mistral-ai-large-2407-offer?tab=Overview",
"supports_tool_choice": true
},
"azure_ai/mistral-large-latest": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000002,
"output_cost_per_token": 0.000006,
"litellm_provider": "azure_ai",
"supports_function_calling": true,
"mode": "chat",
"source": "https://azuremarketplace.microsoft.com/en/marketplace/apps/000-000.mistral-ai-large-2407-offer?tab=Overview",
"supports_tool_choice": true
},
"azure_ai/ministral-3b": {
"max_tokens": 4096,
"max_input_tokens": 128000,
@ -2397,25 +2473,26 @@
"max_tokens": 4096,
"max_input_tokens": 131072,
"max_output_tokens": 4096,
"input_cost_per_token": 0,
"output_cost_per_token": 0,
"input_cost_per_token": 0.000000075,
"output_cost_per_token": 0.0000003,
"litellm_provider": "azure_ai",
"mode": "chat",
"supports_function_calling": true,
"source": "https://learn.microsoft.com/en-us/azure/ai-foundry/concepts/models-featured#microsoft"
"source": "https://techcommunity.microsoft.com/blog/Azure-AI-Services-blog/announcing-new-phi-pricing-empowering-your-business-with-small-language-models/4395112"
},
"azure_ai/Phi-4-multimodal-instruct": {
"max_tokens": 4096,
"max_input_tokens": 131072,
"max_output_tokens": 4096,
"input_cost_per_token": 0,
"output_cost_per_token": 0,
"input_cost_per_token": 0.00000008,
"input_cost_per_audio_token": 0.000004,
"output_cost_per_token": 0.00032,
"litellm_provider": "azure_ai",
"mode": "chat",
"supports_audio_input": true,
"supports_function_calling": true,
"supports_vision": true,
"source": "https://learn.microsoft.com/en-us/azure/ai-foundry/concepts/models-featured#microsoft"
"source": "https://techcommunity.microsoft.com/blog/Azure-AI-Services-blog/announcing-new-phi-pricing-empowering-your-business-with-small-language-models/4395112"
},
"azure_ai/Phi-4": {
"max_tokens": 16384,
@ -3455,7 +3532,7 @@
"input_cost_per_token": 0.0000008,
"output_cost_per_token": 0.000004,
"cache_creation_input_token_cost": 0.000001,
"cache_read_input_token_cost": 0.0000008,
"cache_read_input_token_cost": 0.00000008,
"litellm_provider": "anthropic",
"mode": "chat",
"supports_function_calling": true,
@ -4499,20 +4576,10 @@
"max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1,
"max_pdf_size_mb": 30,
"input_cost_per_image": 0,
"input_cost_per_video_per_second": 0,
"input_cost_per_audio_per_second": 0,
"input_cost_per_token": 0,
"input_cost_per_character": 0,
"input_cost_per_token_above_128k_tokens": 0,
"input_cost_per_character_above_128k_tokens": 0,
"input_cost_per_image_above_128k_tokens": 0,
"input_cost_per_video_per_second_above_128k_tokens": 0,
"input_cost_per_audio_per_second_above_128k_tokens": 0,
"output_cost_per_token": 0,
"output_cost_per_character": 0,
"output_cost_per_token_above_128k_tokens": 0,
"output_cost_per_character_above_128k_tokens": 0,
"input_cost_per_token": 0.00000125,
"input_cost_per_token_above_200k_tokens": 0.0000025,
"output_cost_per_token": 0.00001,
"output_cost_per_token_above_200k_tokens": 0.000015,
"litellm_provider": "vertex_ai-language-models",
"mode": "chat",
"supports_system_messages": true,
@ -4523,6 +4590,9 @@
"supports_pdf_input": true,
"supports_response_schema": true,
"supports_tool_choice": true,
"supported_endpoints": ["/v1/chat/completions", "/v1/completions"],
"supported_modalities": ["text", "image", "audio", "video"],
"supported_output_modalities": ["text"],
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
},
"gemini-2.0-pro-exp-02-05": {
@ -4535,20 +4605,10 @@
"max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1,
"max_pdf_size_mb": 30,
"input_cost_per_image": 0,
"input_cost_per_video_per_second": 0,
"input_cost_per_audio_per_second": 0,
"input_cost_per_token": 0,
"input_cost_per_character": 0,
"input_cost_per_token_above_128k_tokens": 0,
"input_cost_per_character_above_128k_tokens": 0,
"input_cost_per_image_above_128k_tokens": 0,
"input_cost_per_video_per_second_above_128k_tokens": 0,
"input_cost_per_audio_per_second_above_128k_tokens": 0,
"output_cost_per_token": 0,
"output_cost_per_character": 0,
"output_cost_per_token_above_128k_tokens": 0,
"output_cost_per_character_above_128k_tokens": 0,
"input_cost_per_token": 0.00000125,
"input_cost_per_token_above_200k_tokens": 0.0000025,
"output_cost_per_token": 0.00001,
"output_cost_per_token_above_200k_tokens": 0.000015,
"litellm_provider": "vertex_ai-language-models",
"mode": "chat",
"supports_system_messages": true,
@ -4559,6 +4619,9 @@
"supports_pdf_input": true,
"supports_response_schema": true,
"supports_tool_choice": true,
"supported_endpoints": ["/v1/chat/completions", "/v1/completions"],
"supported_modalities": ["text", "image", "audio", "video"],
"supported_output_modalities": ["text"],
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
},
"gemini-2.0-flash-exp": {
@ -4592,6 +4655,8 @@
"supports_vision": true,
"supports_response_schema": true,
"supports_audio_output": true,
"supported_modalities": ["text", "image", "audio", "video"],
"supported_output_modalities": ["text", "image"],
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing",
"supports_tool_choice": true
},
@ -4616,6 +4681,8 @@
"supports_response_schema": true,
"supports_audio_output": true,
"supports_tool_choice": true,
"supported_modalities": ["text", "image", "audio", "video"],
"supported_output_modalities": ["text", "image"],
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
},
"gemini-2.0-flash-thinking-exp": {
@ -4649,6 +4716,8 @@
"supports_vision": true,
"supports_response_schema": true,
"supports_audio_output": true,
"supported_modalities": ["text", "image", "audio", "video"],
"supported_output_modalities": ["text", "image"],
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash",
"supports_tool_choice": true
},
@ -4683,6 +4752,8 @@
"supports_vision": true,
"supports_response_schema": false,
"supports_audio_output": false,
"supported_modalities": ["text", "image", "audio", "video"],
"supported_output_modalities": ["text", "image"],
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash",
"supports_tool_choice": true
},
@ -4708,6 +4779,7 @@
"supports_audio_output": true,
"supports_audio_input": true,
"supported_modalities": ["text", "image", "audio", "video"],
"supported_output_modalities": ["text", "image"],
"supports_tool_choice": true,
"source": "https://ai.google.dev/pricing#2_0flash"
},
@ -4730,6 +4802,32 @@
"supports_vision": true,
"supports_response_schema": true,
"supports_audio_output": true,
"supported_modalities": ["text", "image", "audio", "video"],
"supported_output_modalities": ["text"],
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash",
"supports_tool_choice": true
},
"gemini-2.0-flash-lite-001": {
"max_input_tokens": 1048576,
"max_output_tokens": 8192,
"max_images_per_prompt": 3000,
"max_videos_per_prompt": 10,
"max_video_length": 1,
"max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1,
"max_pdf_size_mb": 50,
"input_cost_per_audio_token": 0.000000075,
"input_cost_per_token": 0.000000075,
"output_cost_per_token": 0.0000003,
"litellm_provider": "vertex_ai-language-models",
"mode": "chat",
"supports_system_messages": true,
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"supports_audio_output": true,
"supported_modalities": ["text", "image", "audio", "video"],
"supported_output_modalities": ["text"],
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash",
"supports_tool_choice": true
},
@ -4795,6 +4893,7 @@
"supports_audio_output": true,
"supports_audio_input": true,
"supported_modalities": ["text", "image", "audio", "video"],
"supported_output_modalities": ["text", "image"],
"supports_tool_choice": true,
"source": "https://ai.google.dev/pricing#2_0flash"
},
@ -4820,6 +4919,8 @@
"supports_response_schema": true,
"supports_audio_output": true,
"supports_tool_choice": true,
"supported_modalities": ["text", "image", "audio", "video"],
"supported_output_modalities": ["text"],
"source": "https://ai.google.dev/gemini-api/docs/pricing#gemini-2.0-flash-lite"
},
"gemini/gemini-2.0-flash-001": {
@ -4845,6 +4946,8 @@
"supports_response_schema": true,
"supports_audio_output": false,
"supports_tool_choice": true,
"supported_modalities": ["text", "image", "audio", "video"],
"supported_output_modalities": ["text", "image"],
"source": "https://ai.google.dev/pricing#2_0flash"
},
"gemini/gemini-2.5-pro-preview-03-25": {
@ -4859,9 +4962,9 @@
"max_pdf_size_mb": 30,
"input_cost_per_audio_token": 0.0000007,
"input_cost_per_token": 0.00000125,
"input_cost_per_token_above_128k_tokens": 0.0000025,
"input_cost_per_token_above_200k_tokens": 0.0000025,
"output_cost_per_token": 0.0000010,
"output_cost_per_token_above_128k_tokens": 0.000015,
"output_cost_per_token_above_200k_tokens": 0.000015,
"litellm_provider": "gemini",
"mode": "chat",
"rpm": 10000,
@ -4872,6 +4975,8 @@
"supports_response_schema": true,
"supports_audio_output": false,
"supports_tool_choice": true,
"supported_modalities": ["text", "image", "audio", "video"],
"supported_output_modalities": ["text"],
"source": "https://ai.google.dev/gemini-api/docs/pricing#gemini-2.5-pro-preview"
},
"gemini/gemini-2.0-flash-exp": {
@ -4907,6 +5012,8 @@
"supports_audio_output": true,
"tpm": 4000000,
"rpm": 10,
"supported_modalities": ["text", "image", "audio", "video"],
"supported_output_modalities": ["text", "image"],
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash",
"supports_tool_choice": true
},
@ -4933,6 +5040,8 @@
"supports_response_schema": true,
"supports_audio_output": false,
"supports_tool_choice": true,
"supported_modalities": ["text", "image", "audio", "video"],
"supported_output_modalities": ["text"],
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash-lite"
},
"gemini/gemini-2.0-flash-thinking-exp": {
@ -4968,6 +5077,8 @@
"supports_audio_output": true,
"tpm": 4000000,
"rpm": 10,
"supported_modalities": ["text", "image", "audio", "video"],
"supported_output_modalities": ["text", "image"],
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash",
"supports_tool_choice": true
},
@ -5004,6 +5115,8 @@
"supports_audio_output": true,
"tpm": 4000000,
"rpm": 10,
"supported_modalities": ["text", "image", "audio", "video"],
"supported_output_modalities": ["text", "image"],
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash",
"supports_tool_choice": true
},
@ -8444,7 +8557,8 @@
"input_cost_per_token": 0.0000015,
"output_cost_per_token": 0.0000020,
"litellm_provider": "bedrock",
"mode": "chat"
"mode": "chat",
"supports_tool_choice": true
},
"bedrock/*/1-month-commitment/cohere.command-text-v14": {
"max_tokens": 4096,
@ -8453,7 +8567,8 @@
"input_cost_per_second": 0.011,
"output_cost_per_second": 0.011,
"litellm_provider": "bedrock",
"mode": "chat"
"mode": "chat",
"supports_tool_choice": true
},
"bedrock/*/6-month-commitment/cohere.command-text-v14": {
"max_tokens": 4096,
@ -8462,7 +8577,8 @@
"input_cost_per_second": 0.0066027,
"output_cost_per_second": 0.0066027,
"litellm_provider": "bedrock",
"mode": "chat"
"mode": "chat",
"supports_tool_choice": true
},
"cohere.command-light-text-v14": {
"max_tokens": 4096,
@ -8471,7 +8587,8 @@
"input_cost_per_token": 0.0000003,
"output_cost_per_token": 0.0000006,
"litellm_provider": "bedrock",
"mode": "chat"
"mode": "chat",
"supports_tool_choice": true
},
"bedrock/*/1-month-commitment/cohere.command-light-text-v14": {
"max_tokens": 4096,
@ -8480,7 +8597,8 @@
"input_cost_per_second": 0.001902,
"output_cost_per_second": 0.001902,
"litellm_provider": "bedrock",
"mode": "chat"
"mode": "chat",
"supports_tool_choice": true
},
"bedrock/*/6-month-commitment/cohere.command-light-text-v14": {
"max_tokens": 4096,
@ -8489,7 +8607,8 @@
"input_cost_per_second": 0.0011416,
"output_cost_per_second": 0.0011416,
"litellm_provider": "bedrock",
"mode": "chat"
"mode": "chat",
"supports_tool_choice": true
},
"cohere.command-r-plus-v1:0": {
"max_tokens": 4096,
@ -8498,7 +8617,8 @@
"input_cost_per_token": 0.0000030,
"output_cost_per_token": 0.000015,
"litellm_provider": "bedrock",
"mode": "chat"
"mode": "chat",
"supports_tool_choice": true
},
"cohere.command-r-v1:0": {
"max_tokens": 4096,
@ -8507,7 +8627,8 @@
"input_cost_per_token": 0.0000005,
"output_cost_per_token": 0.0000015,
"litellm_provider": "bedrock",
"mode": "chat"
"mode": "chat",
"supports_tool_choice": true
},
"cohere.embed-english-v3": {
"max_tokens": 512,

Some files were not shown because too many files have changed in this diff Show more