Merge remote-tracking branch 'origin/main' into feat--parse-user-from-headers
|
@ -610,6 +610,8 @@ jobs:
|
|||
name: Install Dependencies
|
||||
command: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install wheel
|
||||
pip install --upgrade pip wheel setuptools
|
||||
python -m pip install -r requirements.txt
|
||||
pip install "pytest==7.3.1"
|
||||
pip install "respx==0.21.1"
|
||||
|
@ -1125,6 +1127,7 @@ jobs:
|
|||
name: Install Dependencies
|
||||
command: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install wheel setuptools
|
||||
python -m pip install -r requirements.txt
|
||||
pip install "pytest==7.3.1"
|
||||
pip install "pytest-retry==1.6.3"
|
||||
|
|
|
@ -12,8 +12,7 @@ WORKDIR /app
|
|||
USER root
|
||||
|
||||
# Install build dependencies
|
||||
RUN apk update && \
|
||||
apk add --no-cache gcc python3-dev openssl openssl-dev
|
||||
RUN apk add --no-cache gcc python3-dev openssl openssl-dev
|
||||
|
||||
|
||||
RUN pip install --upgrade pip && \
|
||||
|
@ -52,8 +51,7 @@ FROM $LITELLM_RUNTIME_IMAGE AS runtime
|
|||
USER root
|
||||
|
||||
# Install runtime dependencies
|
||||
RUN apk update && \
|
||||
apk add --no-cache openssl
|
||||
RUN apk add --no-cache openssl
|
||||
|
||||
WORKDIR /app
|
||||
# Copy the current directory contents into the container at /app
|
||||
|
|
|
@ -2,6 +2,10 @@ apiVersion: v1
|
|||
kind: Service
|
||||
metadata:
|
||||
name: {{ include "litellm.fullname" . }}
|
||||
{{- with .Values.service.annotations }}
|
||||
annotations:
|
||||
{{- toYaml . | nindent 4 }}
|
||||
{{- end }}
|
||||
labels:
|
||||
{{- include "litellm.labels" . | nindent 4 }}
|
||||
spec:
|
||||
|
|
|
@ -35,7 +35,7 @@ RUN pip wheel --no-cache-dir --wheel-dir=/wheels/ -r requirements.txt
|
|||
FROM $LITELLM_RUNTIME_IMAGE AS runtime
|
||||
|
||||
# Update dependencies and clean up
|
||||
RUN apk update && apk upgrade && rm -rf /var/cache/apk/*
|
||||
RUN apk upgrade --no-cache
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
|
|
@ -12,8 +12,7 @@ WORKDIR /app
|
|||
USER root
|
||||
|
||||
# Install build dependencies
|
||||
RUN apk update && \
|
||||
apk add --no-cache gcc python3-dev openssl openssl-dev
|
||||
RUN apk add --no-cache gcc python3-dev openssl openssl-dev
|
||||
|
||||
|
||||
RUN pip install --upgrade pip && \
|
||||
|
@ -44,8 +43,7 @@ FROM $LITELLM_RUNTIME_IMAGE AS runtime
|
|||
USER root
|
||||
|
||||
# Install runtime dependencies
|
||||
RUN apk update && \
|
||||
apk add --no-cache openssl
|
||||
RUN apk add --no-cache openssl
|
||||
|
||||
WORKDIR /app
|
||||
# Copy the current directory contents into the container at /app
|
||||
|
|
|
@ -438,6 +438,179 @@ assert isinstance(
|
|||
```
|
||||
|
||||
|
||||
### Google Search Tool
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
from litellm import completion
|
||||
import os
|
||||
|
||||
os.environ["GEMINI_API_KEY"] = ".."
|
||||
|
||||
tools = [{"googleSearch": {}}] # 👈 ADD GOOGLE SEARCH
|
||||
|
||||
response = completion(
|
||||
model="gemini/gemini-2.0-flash",
|
||||
messages=[{"role": "user", "content": "What is the weather in San Francisco?"}],
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
1. Setup config.yaml
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gemini-2.0-flash
|
||||
litellm_params:
|
||||
model: gemini/gemini-2.0-flash
|
||||
api_key: os.environ/GEMINI_API_KEY
|
||||
```
|
||||
|
||||
2. Start Proxy
|
||||
```bash
|
||||
$ litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
3. Make Request!
|
||||
```bash
|
||||
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-d '{
|
||||
"model": "gemini-2.0-flash",
|
||||
"messages": [{"role": "user", "content": "What is the weather in San Francisco?"}],
|
||||
"tools": [{"googleSearch": {}}]
|
||||
}
|
||||
'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
### Google Search Retrieval
|
||||
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
from litellm import completion
|
||||
import os
|
||||
|
||||
os.environ["GEMINI_API_KEY"] = ".."
|
||||
|
||||
tools = [{"googleSearchRetrieval": {}}] # 👈 ADD GOOGLE SEARCH
|
||||
|
||||
response = completion(
|
||||
model="gemini/gemini-2.0-flash",
|
||||
messages=[{"role": "user", "content": "What is the weather in San Francisco?"}],
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
1. Setup config.yaml
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gemini-2.0-flash
|
||||
litellm_params:
|
||||
model: gemini/gemini-2.0-flash
|
||||
api_key: os.environ/GEMINI_API_KEY
|
||||
```
|
||||
|
||||
2. Start Proxy
|
||||
```bash
|
||||
$ litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
3. Make Request!
|
||||
```bash
|
||||
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-d '{
|
||||
"model": "gemini-2.0-flash",
|
||||
"messages": [{"role": "user", "content": "What is the weather in San Francisco?"}],
|
||||
"tools": [{"googleSearchRetrieval": {}}]
|
||||
}
|
||||
'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
### Code Execution Tool
|
||||
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
from litellm import completion
|
||||
import os
|
||||
|
||||
os.environ["GEMINI_API_KEY"] = ".."
|
||||
|
||||
tools = [{"codeExecution": {}}] # 👈 ADD GOOGLE SEARCH
|
||||
|
||||
response = completion(
|
||||
model="gemini/gemini-2.0-flash",
|
||||
messages=[{"role": "user", "content": "What is the weather in San Francisco?"}],
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
1. Setup config.yaml
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gemini-2.0-flash
|
||||
litellm_params:
|
||||
model: gemini/gemini-2.0-flash
|
||||
api_key: os.environ/GEMINI_API_KEY
|
||||
```
|
||||
|
||||
2. Start Proxy
|
||||
```bash
|
||||
$ litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
3. Make Request!
|
||||
```bash
|
||||
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-d '{
|
||||
"model": "gemini-2.0-flash",
|
||||
"messages": [{"role": "user", "content": "What is the weather in San Francisco?"}],
|
||||
"tools": [{"codeExecution": {}}]
|
||||
}
|
||||
'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
## JSON Mode
|
||||
|
||||
<Tabs>
|
||||
|
|
|
@ -398,6 +398,8 @@ curl http://localhost:4000/v1/chat/completions \
|
|||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
You can also use the `enterpriseWebSearch` tool for an [enterprise compliant search](https://cloud.google.com/vertex-ai/generative-ai/docs/grounding/web-grounding-enterprise).
|
||||
|
||||
#### **Moving from Vertex AI SDK to LiteLLM (GROUNDING)**
|
||||
|
||||
|
||||
|
|
|
@ -449,6 +449,7 @@ router_settings:
|
|||
| MICROSOFT_CLIENT_ID | Client ID for Microsoft services
|
||||
| MICROSOFT_CLIENT_SECRET | Client secret for Microsoft services
|
||||
| MICROSOFT_TENANT | Tenant ID for Microsoft Azure
|
||||
| MICROSOFT_SERVICE_PRINCIPAL_ID | Service Principal ID for Microsoft Enterprise Application. (This is an advanced feature if you want litellm to auto-assign members to Litellm Teams based on their Microsoft Entra ID Groups)
|
||||
| NO_DOCS | Flag to disable documentation generation
|
||||
| NO_PROXY | List of addresses to bypass proxy
|
||||
| OAUTH_TOKEN_INFO_ENDPOINT | Endpoint for OAuth token info retrieval
|
||||
|
|
|
@ -26,10 +26,12 @@ model_list:
|
|||
- model_name: sagemaker-completion-model
|
||||
litellm_params:
|
||||
model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4
|
||||
model_info:
|
||||
input_cost_per_second: 0.000420
|
||||
- model_name: sagemaker-embedding-model
|
||||
litellm_params:
|
||||
model: sagemaker/berri-benchmarking-gpt-j-6b-fp16
|
||||
model_info:
|
||||
input_cost_per_second: 0.000420
|
||||
```
|
||||
|
||||
|
@ -55,11 +57,33 @@ model_list:
|
|||
api_key: os.environ/AZURE_API_KEY
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
api_version: os.envrion/AZURE_API_VERSION
|
||||
model_info:
|
||||
input_cost_per_token: 0.000421 # 👈 ONLY to track cost per token
|
||||
output_cost_per_token: 0.000520 # 👈 ONLY to track cost per token
|
||||
```
|
||||
|
||||
### Debugging
|
||||
## Override Model Cost Map
|
||||
|
||||
You can override [our model cost map](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json) with your own custom pricing for a mapped model.
|
||||
|
||||
Just add a `model_info` key to your model in the config, and override the desired keys.
|
||||
|
||||
Example: Override Anthropic's model cost map for the `prod/claude-3-5-sonnet-20241022` model.
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: "prod/claude-3-5-sonnet-20241022"
|
||||
litellm_params:
|
||||
model: "anthropic/claude-3-5-sonnet-20241022"
|
||||
api_key: os.environ/ANTHROPIC_PROD_API_KEY
|
||||
model_info:
|
||||
input_cost_per_token: 0.000006
|
||||
output_cost_per_token: 0.00003
|
||||
cache_creation_input_token_cost: 0.0000075
|
||||
cache_read_input_token_cost: 0.0000006
|
||||
```
|
||||
|
||||
## Debugging
|
||||
|
||||
If you're custom pricing is not being used or you're seeing errors, please check the following:
|
||||
|
||||
|
|
|
@ -161,6 +161,83 @@ Here's the available UI roles for a LiteLLM Internal User:
|
|||
- `internal_user`: can login, view/create/delete their own keys, view their spend. **Cannot** add new users.
|
||||
- `internal_user_viewer`: can login, view their own keys, view their own spend. **Cannot** create/delete keys, add new users.
|
||||
|
||||
## Auto-add SSO users to teams
|
||||
|
||||
This walks through setting up sso auto-add for **Okta, Google SSO**
|
||||
|
||||
### Okta, Google SSO
|
||||
|
||||
1. Specify the JWT field that contains the team ids, that the user belongs to.
|
||||
|
||||
```yaml
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
litellm_jwtauth:
|
||||
team_ids_jwt_field: "groups" # 👈 CAN BE ANY FIELD
|
||||
```
|
||||
|
||||
This is assuming your SSO token looks like this. **If you need to inspect the JWT fields received from your SSO provider by LiteLLM, follow these instructions [here](#debugging-sso-jwt-fields)**
|
||||
|
||||
```
|
||||
{
|
||||
...,
|
||||
"groups": ["team_id_1", "team_id_2"]
|
||||
}
|
||||
```
|
||||
|
||||
2. Create the teams on LiteLLM
|
||||
|
||||
```bash
|
||||
curl -X POST '<PROXY_BASE_URL>/team/new' \
|
||||
-H 'Authorization: Bearer <PROXY_MASTER_KEY>' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-D '{
|
||||
"team_alias": "team_1",
|
||||
"team_id": "team_id_1" # 👈 MUST BE THE SAME AS THE SSO GROUP ID
|
||||
}'
|
||||
```
|
||||
|
||||
3. Test the SSO flow
|
||||
|
||||
Here's a walkthrough of [how it works](https://www.loom.com/share/8959be458edf41fd85937452c29a33f3?sid=7ebd6d37-569a-4023-866e-e0cde67cb23e)
|
||||
|
||||
### Microsoft Entra ID SSO group assignment
|
||||
|
||||
Follow this [tutorial for auto-adding sso users to teams with Microsoft Entra ID](https://docs.litellm.ai/docs/tutorials/msft_sso)
|
||||
|
||||
### Debugging SSO JWT fields
|
||||
|
||||
If you need to inspect the JWT fields received from your SSO provider by LiteLLM, follow these instructions. This guide walks you through setting up a debug callback to view the JWT data during the SSO process.
|
||||
|
||||
|
||||
<Image img={require('../../img/debug_sso.png')} style={{ width: '500px', height: 'auto' }} />
|
||||
<br />
|
||||
|
||||
1. Add `/sso/debug/callback` as a redirect URL in your SSO provider
|
||||
|
||||
In your SSO provider's settings, add the following URL as a new redirect (callback) URL:
|
||||
|
||||
```bash showLineNumbers title="Redirect URL"
|
||||
http://<proxy_base_url>/sso/debug/callback
|
||||
```
|
||||
|
||||
|
||||
2. Navigate to the debug login page on your browser
|
||||
|
||||
Navigate to the following URL on your browser:
|
||||
|
||||
```bash showLineNumbers title="URL to navigate to"
|
||||
https://<proxy_base_url>/sso/debug/login
|
||||
```
|
||||
|
||||
This will initiate the standard SSO flow. You will be redirected to your SSO provider's login screen, and after successful authentication, you will be redirected back to LiteLLM's debug callback route.
|
||||
|
||||
|
||||
3. View the JWT fields
|
||||
|
||||
Once redirected, you should see a page called "SSO Debug Information". This page displays the JWT fields received from your SSO provider (as shown in the image above)
|
||||
|
||||
|
||||
## Advanced
|
||||
### Setting custom logout URLs
|
||||
|
||||
|
@ -196,40 +273,26 @@ This budget does not apply to keys created under non-default teams.
|
|||
|
||||
[**Go Here**](./team_budgets.md)
|
||||
|
||||
### Auto-add SSO users to teams
|
||||
### Set default params for new teams
|
||||
|
||||
1. Specify the JWT field that contains the team ids, that the user belongs to.
|
||||
When you connect litellm to your SSO provider, litellm can auto-create teams. Use this to set the default `models`, `max_budget`, `budget_duration` for these auto-created teams.
|
||||
|
||||
```yaml
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
litellm_jwtauth:
|
||||
team_ids_jwt_field: "groups" # 👈 CAN BE ANY FIELD
|
||||
**How it works**
|
||||
|
||||
1. When litellm fetches `groups` from your SSO provider, it will check if the corresponding group_id exists as a `team_id` in litellm.
|
||||
2. If the team_id does not exist, litellm will auto-create a team with the default params you've set.
|
||||
3. If the team_id already exist, litellm will not apply any settings on the team.
|
||||
|
||||
**Usage**
|
||||
|
||||
```yaml showLineNumbers title="Default Params for new teams"
|
||||
litellm_settings:
|
||||
default_team_params: # Default Params to apply when litellm auto creates a team from SSO IDP provider
|
||||
max_budget: 100 # Optional[float], optional): $100 budget for the team
|
||||
budget_duration: 30d # Optional[str], optional): 30 days budget_duration for the team
|
||||
models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by the team
|
||||
```
|
||||
|
||||
This is assuming your SSO token looks like this:
|
||||
```
|
||||
{
|
||||
...,
|
||||
"groups": ["team_id_1", "team_id_2"]
|
||||
}
|
||||
```
|
||||
|
||||
2. Create the teams on LiteLLM
|
||||
|
||||
```bash
|
||||
curl -X POST '<PROXY_BASE_URL>/team/new' \
|
||||
-H 'Authorization: Bearer <PROXY_MASTER_KEY>' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-D '{
|
||||
"team_alias": "team_1",
|
||||
"team_id": "team_id_1" # 👈 MUST BE THE SAME AS THE SSO GROUP ID
|
||||
}'
|
||||
```
|
||||
|
||||
3. Test the SSO flow
|
||||
|
||||
Here's a walkthrough of [how it works](https://www.loom.com/share/8959be458edf41fd85937452c29a33f3?sid=7ebd6d37-569a-4023-866e-e0cde67cb23e)
|
||||
|
||||
### Restrict Users from creating personal keys
|
||||
|
||||
|
@ -241,7 +304,7 @@ This will also prevent users from using their session tokens on the test keys ch
|
|||
|
||||
## **All Settings for Self Serve / SSO Flow**
|
||||
|
||||
```yaml
|
||||
```yaml showLineNumbers title="All Settings for Self Serve / SSO Flow"
|
||||
litellm_settings:
|
||||
max_internal_user_budget: 10 # max budget for internal users
|
||||
internal_user_budget_duration: "1mo" # reset every month
|
||||
|
@ -251,6 +314,11 @@ litellm_settings:
|
|||
max_budget: 100 # Optional[float], optional): $100 budget for a new SSO sign in user
|
||||
budget_duration: 30d # Optional[str], optional): 30 days budget_duration for a new SSO sign in user
|
||||
models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by a new SSO sign in user
|
||||
|
||||
default_team_params: # Default Params to apply when litellm auto creates a team from SSO IDP provider
|
||||
max_budget: 100 # Optional[float], optional): $100 budget for the team
|
||||
budget_duration: 30d # Optional[str], optional): 30 days budget_duration for the team
|
||||
models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by the team
|
||||
|
||||
|
||||
upperbound_key_generate_params: # Upperbound for /key/generate requests when self-serve flow is on
|
||||
|
|
162
docs/my-website/docs/tutorials/msft_sso.md
Normal file
|
@ -0,0 +1,162 @@
|
|||
import Image from '@theme/IdealImage';
|
||||
|
||||
# Microsoft SSO: Sync Groups, Members with LiteLLM
|
||||
|
||||
Sync Microsoft SSO Groups, Members with LiteLLM Teams.
|
||||
|
||||
<Image img={require('../../img/litellm_entra_id.png')} style={{ width: '800px', height: 'auto' }} />
|
||||
|
||||
<br />
|
||||
<br />
|
||||
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- An Azure Entra ID account with administrative access
|
||||
- A LiteLLM Enterprise App set up in your Azure Portal
|
||||
- Access to Microsoft Entra ID (Azure AD)
|
||||
|
||||
|
||||
## Overview of this tutorial
|
||||
|
||||
1. Auto-Create Entra ID Groups on LiteLLM Teams
|
||||
2. Sync Entra ID Team Memberships
|
||||
3. Set default params for new teams and users auto-created on LiteLLM
|
||||
|
||||
## 1. Auto-Create Entra ID Groups on LiteLLM Teams
|
||||
|
||||
In this step, our goal is to have LiteLLM automatically create a new team on the LiteLLM DB when there is a new Group Added to the LiteLLM Enterprise App on Azure Entra ID.
|
||||
|
||||
### 1.1 Create a new group in Entra ID
|
||||
|
||||
|
||||
Navigate to [your Azure Portal](https://portal.azure.com/) > Groups > New Group. Create a new group.
|
||||
|
||||
<Image img={require('../../img/entra_create_team.png')} style={{ width: '800px', height: 'auto' }} />
|
||||
|
||||
### 1.2 Assign the group to your LiteLLM Enterprise App
|
||||
|
||||
On your Azure Portal, navigate to `Enterprise Applications` > Select your litellm app
|
||||
|
||||
<Image img={require('../../img/msft_enterprise_app.png')} style={{ width: '800px', height: 'auto' }} />
|
||||
|
||||
<br />
|
||||
<br />
|
||||
|
||||
Once you've selected your litellm app, click on `Users and Groups` > `Add user/group`
|
||||
|
||||
<Image img={require('../../img/msft_enterprise_assign_group.png')} style={{ width: '800px', height: 'auto' }} />
|
||||
|
||||
<br />
|
||||
|
||||
Now select the group you created in step 1.1. And add it to the LiteLLM Enterprise App. At this point we have added `Production LLM Evals Group` to the LiteLLM Enterprise App. The next steps is having LiteLLM automatically create the `Production LLM Evals Group` on the LiteLLM DB when a new user signs in.
|
||||
|
||||
<Image img={require('../../img/msft_enterprise_select_group.png')} style={{ width: '800px', height: 'auto' }} />
|
||||
|
||||
|
||||
### 1.3 Sign in to LiteLLM UI via SSO
|
||||
|
||||
Sign into the LiteLLM UI via SSO. You should be redirected to the Entra ID SSO page. This SSO sign in flow will trigger LiteLLM to fetch the latest Groups and Members from Azure Entra ID.
|
||||
|
||||
<Image img={require('../../img/msft_sso_sign_in.png')} style={{ width: '800px', height: 'auto' }} />
|
||||
|
||||
### 1.4 Check the new team on LiteLLM UI
|
||||
|
||||
On the LiteLLM UI, Navigate to `Teams`, You should see the new team `Production LLM Evals Group` auto-created on LiteLLM.
|
||||
|
||||
<Image img={require('../../img/msft_auto_team.png')} style={{ width: '900px', height: 'auto' }} />
|
||||
|
||||
#### How this works
|
||||
|
||||
When a SSO user signs in to LiteLLM:
|
||||
- LiteLLM automatically fetches the Groups under the LiteLLM Enterprise App
|
||||
- It finds the Production LLM Evals Group assigned to the LiteLLM Enterprise App
|
||||
- LiteLLM checks if this group's ID exists in the LiteLLM Teams Table
|
||||
- Since the ID doesn't exist, LiteLLM automatically creates a new team with:
|
||||
- Name: Production LLM Evals Group
|
||||
- ID: Same as the Entra ID group's ID
|
||||
|
||||
## 2. Sync Entra ID Team Memberships
|
||||
|
||||
In this step, we will have LiteLLM automatically add a user to the `Production LLM Evals` Team on the LiteLLM DB when a new user is added to the `Production LLM Evals` Group in Entra ID.
|
||||
|
||||
### 2.1 Navigate to the `Production LLM Evals` Group in Entra ID
|
||||
|
||||
Navigate to the `Production LLM Evals` Group in Entra ID.
|
||||
|
||||
<Image img={require('../../img/msft_member_1.png')} style={{ width: '800px', height: 'auto' }} />
|
||||
|
||||
|
||||
### 2.2 Add a member to the group in Entra ID
|
||||
|
||||
Select `Members` > `Add members`
|
||||
|
||||
In this stage you should add the user you want to add to the `Production LLM Evals` Team.
|
||||
|
||||
<Image img={require('../../img/msft_member_2.png')} style={{ width: '800px', height: 'auto' }} />
|
||||
|
||||
|
||||
|
||||
### 2.3 Sign in as the new user on LiteLLM UI
|
||||
|
||||
Sign in as the new user on LiteLLM UI. You should be redirected to the Entra ID SSO page. This SSO sign in flow will trigger LiteLLM to fetch the latest Groups and Members from Azure Entra ID. During this step LiteLLM sync it's teams, team members with what is available from Entra ID
|
||||
|
||||
<Image img={require('../../img/msft_sso_sign_in.png')} style={{ width: '800px', height: 'auto' }} />
|
||||
|
||||
|
||||
|
||||
### 2.4 Check the team membership on LiteLLM UI
|
||||
|
||||
On the LiteLLM UI, Navigate to `Teams`, You should see the new team `Production LLM Evals Group`. Since your are now a member of the `Production LLM Evals Group` in Entra ID, you should see the new team `Production LLM Evals Group` on the LiteLLM UI.
|
||||
|
||||
<Image img={require('../../img/msft_member_3.png')} style={{ width: '900px', height: 'auto' }} />
|
||||
|
||||
## 3. Set default params for new teams auto-created on LiteLLM
|
||||
|
||||
Since litellm auto creates a new team on the LiteLLM DB when there is a new Group Added to the LiteLLM Enterprise App on Azure Entra ID, we can set default params for new teams created.
|
||||
|
||||
This allows you to set a default budget, models, etc for new teams created.
|
||||
|
||||
### 3.1 Set `default_team_params` on litellm
|
||||
|
||||
Navigate to your litellm config file and set the following params
|
||||
|
||||
```yaml showLineNumbers title="litellm config with default_team_params"
|
||||
litellm_settings:
|
||||
default_team_params: # Default Params to apply when litellm auto creates a team from SSO IDP provider
|
||||
max_budget: 100 # Optional[float], optional): $100 budget for the team
|
||||
budget_duration: 30d # Optional[str], optional): 30 days budget_duration for the team
|
||||
models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by the team
|
||||
```
|
||||
|
||||
### 3.2 Auto-create a new team on LiteLLM
|
||||
|
||||
- In this step you should add a new group to the LiteLLM Enterprise App on Azure Entra ID (like we did in step 1.1). We will call this group `Default LiteLLM Prod Team` on Azure Entra ID.
|
||||
- Start litellm proxy server with your config
|
||||
- Sign into LiteLLM UI via SSO
|
||||
- Navigate to `Teams` and you should see the new team `Default LiteLLM Prod Team` auto-created on LiteLLM
|
||||
- Note LiteLLM will set the default params for this new team.
|
||||
|
||||
<Image img={require('../../img/msft_default_settings.png')} style={{ width: '900px', height: 'auto' }} />
|
||||
|
||||
|
||||
## Video Walkthrough
|
||||
|
||||
This walks through setting up sso auto-add for **Microsoft Entra ID**
|
||||
|
||||
Follow along this video for a walkthrough of how to set this up with Microsoft Entra ID
|
||||
|
||||
<iframe width="840" height="500" src="https://www.loom.com/embed/ea711323aa9a496d84a01fd7b2a12f54?sid=c53e238c-5bfd-4135-b8fb-b5b1a08632cf" frameborder="0" webkitallowfullscreen mozallowfullscreen allowfullscreen></iframe>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
BIN
docs/my-website/img/debug_sso.png
Normal file
After Width: | Height: | Size: 167 KiB |
BIN
docs/my-website/img/entra_create_team.png
Normal file
After Width: | Height: | Size: 180 KiB |
BIN
docs/my-website/img/litellm_entra_id.png
Normal file
After Width: | Height: | Size: 35 KiB |
BIN
docs/my-website/img/msft_auto_team.png
Normal file
After Width: | Height: | Size: 62 KiB |
BIN
docs/my-website/img/msft_default_settings.png
Normal file
After Width: | Height: | Size: 141 KiB |
BIN
docs/my-website/img/msft_enterprise_app.png
Normal file
After Width: | Height: | Size: 292 KiB |
BIN
docs/my-website/img/msft_enterprise_assign_group.png
Normal file
After Width: | Height: | Size: 277 KiB |
BIN
docs/my-website/img/msft_enterprise_select_group.png
Normal file
After Width: | Height: | Size: 245 KiB |
BIN
docs/my-website/img/msft_member_1.png
Normal file
After Width: | Height: | Size: 296 KiB |
BIN
docs/my-website/img/msft_member_2.png
Normal file
After Width: | Height: | Size: 274 KiB |
BIN
docs/my-website/img/msft_member_3.png
Normal file
After Width: | Height: | Size: 186 KiB |
BIN
docs/my-website/img/msft_sso_sign_in.png
Normal file
After Width: | Height: | Size: 818 KiB |
|
@ -435,6 +435,7 @@ const sidebars = {
|
|||
label: "Tutorials",
|
||||
items: [
|
||||
"tutorials/openweb_ui",
|
||||
"tutorials/msft_sso",
|
||||
'tutorials/litellm_proxy_aporia',
|
||||
{
|
||||
type: "category",
|
||||
|
|
161
docs/my-website/src/components/TransformRequestPlayground.tsx
Normal file
|
@ -0,0 +1,161 @@
|
|||
import React, { useState } from 'react';
|
||||
import styles from './transform_request.module.css';
|
||||
|
||||
const DEFAULT_REQUEST = {
|
||||
"model": "bedrock/gpt-4",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Explain quantum computing in simple terms"
|
||||
}
|
||||
],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 500,
|
||||
"stream": true
|
||||
};
|
||||
|
||||
type ViewMode = 'split' | 'request' | 'transformed';
|
||||
|
||||
const TransformRequestPlayground: React.FC = () => {
|
||||
const [request, setRequest] = useState(JSON.stringify(DEFAULT_REQUEST, null, 2));
|
||||
const [transformedRequest, setTransformedRequest] = useState('');
|
||||
const [viewMode, setViewMode] = useState<ViewMode>('split');
|
||||
|
||||
const handleTransform = async () => {
|
||||
try {
|
||||
// Here you would make the actual API call to transform the request
|
||||
// For now, we'll just set a sample response
|
||||
const sampleResponse = `curl -X POST \\
|
||||
https://api.openai.com/v1/chat/completions \\
|
||||
-H 'Authorization: Bearer sk-xxx' \\
|
||||
-H 'Content-Type: application/json' \\
|
||||
-d '{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}
|
||||
],
|
||||
"temperature": 0.7
|
||||
}'`;
|
||||
setTransformedRequest(sampleResponse);
|
||||
} catch (error) {
|
||||
console.error('Error transforming request:', error);
|
||||
}
|
||||
};
|
||||
|
||||
const handleCopy = () => {
|
||||
navigator.clipboard.writeText(transformedRequest);
|
||||
};
|
||||
|
||||
const renderContent = () => {
|
||||
switch (viewMode) {
|
||||
case 'request':
|
||||
return (
|
||||
<div className={styles.panel}>
|
||||
<div className={styles['panel-header']}>
|
||||
<h2>Original Request</h2>
|
||||
<p>The request you would send to LiteLLM /chat/completions endpoint.</p>
|
||||
</div>
|
||||
<textarea
|
||||
className={styles['code-input']}
|
||||
value={request}
|
||||
onChange={(e) => setRequest(e.target.value)}
|
||||
spellCheck={false}
|
||||
/>
|
||||
<div className={styles['panel-footer']}>
|
||||
<button className={styles['transform-button']} onClick={handleTransform}>
|
||||
Transform →
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
case 'transformed':
|
||||
return (
|
||||
<div className={styles.panel}>
|
||||
<div className={styles['panel-header']}>
|
||||
<h2>Transformed Request</h2>
|
||||
<p>How LiteLLM transforms your request for the specified provider.</p>
|
||||
<p className={styles.note}>Note: Sensitive headers are not shown.</p>
|
||||
</div>
|
||||
<div className={styles['code-output-container']}>
|
||||
<pre className={styles['code-output']}>{transformedRequest}</pre>
|
||||
<button className={styles['copy-button']} onClick={handleCopy}>
|
||||
Copy
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
default:
|
||||
return (
|
||||
<>
|
||||
<div className={styles.panel}>
|
||||
<div className={styles['panel-header']}>
|
||||
<h2>Original Request</h2>
|
||||
<p>The request you would send to LiteLLM /chat/completions endpoint.</p>
|
||||
</div>
|
||||
<textarea
|
||||
className={styles['code-input']}
|
||||
value={request}
|
||||
onChange={(e) => setRequest(e.target.value)}
|
||||
spellCheck={false}
|
||||
/>
|
||||
<div className={styles['panel-footer']}>
|
||||
<button className={styles['transform-button']} onClick={handleTransform}>
|
||||
Transform →
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div className={styles.panel}>
|
||||
<div className={styles['panel-header']}>
|
||||
<h2>Transformed Request</h2>
|
||||
<p>How LiteLLM transforms your request for the specified provider.</p>
|
||||
<p className={styles.note}>Note: Sensitive headers are not shown.</p>
|
||||
</div>
|
||||
<div className={styles['code-output-container']}>
|
||||
<pre className={styles['code-output']}>{transformedRequest}</pre>
|
||||
<button className={styles['copy-button']} onClick={handleCopy}>
|
||||
Copy
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className={styles['transform-playground']}>
|
||||
<div className={styles['view-toggle']}>
|
||||
<button
|
||||
className={viewMode === 'split' ? styles.active : ''}
|
||||
onClick={() => setViewMode('split')}
|
||||
>
|
||||
Split View
|
||||
</button>
|
||||
<button
|
||||
className={viewMode === 'request' ? styles.active : ''}
|
||||
onClick={() => setViewMode('request')}
|
||||
>
|
||||
Request
|
||||
</button>
|
||||
<button
|
||||
className={viewMode === 'transformed' ? styles.active : ''}
|
||||
onClick={() => setViewMode('transformed')}
|
||||
>
|
||||
Transformed
|
||||
</button>
|
||||
</div>
|
||||
<div className={styles['playground-container']}>
|
||||
{renderContent()}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default TransformRequestPlayground;
|
|
@ -65,6 +65,7 @@ from litellm.proxy._types import (
|
|||
KeyManagementSystem,
|
||||
KeyManagementSettings,
|
||||
LiteLLM_UpperboundKeyGenerateParams,
|
||||
NewTeamRequest,
|
||||
)
|
||||
from litellm.types.utils import StandardKeyGenerationConfig, LlmProviders
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
@ -126,19 +127,19 @@ prometheus_initialize_budget_metrics: Optional[bool] = False
|
|||
require_auth_for_metrics_endpoint: Optional[bool] = False
|
||||
argilla_batch_size: Optional[int] = None
|
||||
datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload
|
||||
gcs_pub_sub_use_v1: Optional[
|
||||
bool
|
||||
] = False # if you want to use v1 gcs pubsub logged payload
|
||||
gcs_pub_sub_use_v1: Optional[bool] = (
|
||||
False # if you want to use v1 gcs pubsub logged payload
|
||||
)
|
||||
argilla_transformation_object: Optional[Dict[str, Any]] = None
|
||||
_async_input_callback: List[
|
||||
Union[str, Callable, CustomLogger]
|
||||
] = [] # internal variable - async custom callbacks are routed here.
|
||||
_async_success_callback: List[
|
||||
Union[str, Callable, CustomLogger]
|
||||
] = [] # internal variable - async custom callbacks are routed here.
|
||||
_async_failure_callback: List[
|
||||
Union[str, Callable, CustomLogger]
|
||||
] = [] # internal variable - async custom callbacks are routed here.
|
||||
_async_input_callback: List[Union[str, Callable, CustomLogger]] = (
|
||||
[]
|
||||
) # internal variable - async custom callbacks are routed here.
|
||||
_async_success_callback: List[Union[str, Callable, CustomLogger]] = (
|
||||
[]
|
||||
) # internal variable - async custom callbacks are routed here.
|
||||
_async_failure_callback: List[Union[str, Callable, CustomLogger]] = (
|
||||
[]
|
||||
) # internal variable - async custom callbacks are routed here.
|
||||
pre_call_rules: List[Callable] = []
|
||||
post_call_rules: List[Callable] = []
|
||||
turn_off_message_logging: Optional[bool] = False
|
||||
|
@ -146,18 +147,18 @@ log_raw_request_response: bool = False
|
|||
redact_messages_in_exceptions: Optional[bool] = False
|
||||
redact_user_api_key_info: Optional[bool] = False
|
||||
filter_invalid_headers: Optional[bool] = False
|
||||
add_user_information_to_llm_headers: Optional[
|
||||
bool
|
||||
] = None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers
|
||||
add_user_information_to_llm_headers: Optional[bool] = (
|
||||
None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers
|
||||
)
|
||||
store_audit_logs = False # Enterprise feature, allow users to see audit logs
|
||||
### end of callbacks #############
|
||||
|
||||
email: Optional[
|
||||
str
|
||||
] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
token: Optional[
|
||||
str
|
||||
] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
email: Optional[str] = (
|
||||
None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
)
|
||||
token: Optional[str] = (
|
||||
None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
)
|
||||
telemetry = True
|
||||
max_tokens: int = DEFAULT_MAX_TOKENS # OpenAI Defaults
|
||||
drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False))
|
||||
|
@ -233,20 +234,24 @@ enable_loadbalancing_on_batch_endpoints: Optional[bool] = None
|
|||
enable_caching_on_provider_specific_optional_params: bool = (
|
||||
False # feature-flag for caching on optional params - e.g. 'top_k'
|
||||
)
|
||||
caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
cache: Optional[
|
||||
Cache
|
||||
] = None # cache object <- use this - https://docs.litellm.ai/docs/caching
|
||||
caching: bool = (
|
||||
False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
)
|
||||
caching_with_models: bool = (
|
||||
False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
)
|
||||
cache: Optional[Cache] = (
|
||||
None # cache object <- use this - https://docs.litellm.ai/docs/caching
|
||||
)
|
||||
default_in_memory_ttl: Optional[float] = None
|
||||
default_redis_ttl: Optional[float] = None
|
||||
default_redis_batch_cache_expiry: Optional[float] = None
|
||||
model_alias_map: Dict[str, str] = {}
|
||||
model_group_alias_map: Dict[str, str] = {}
|
||||
max_budget: float = 0.0 # set the max budget across all providers
|
||||
budget_duration: Optional[
|
||||
str
|
||||
] = None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
|
||||
budget_duration: Optional[str] = (
|
||||
None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
|
||||
)
|
||||
default_soft_budget: float = (
|
||||
DEFAULT_SOFT_BUDGET # by default all litellm proxy keys have a soft budget of 50.0
|
||||
)
|
||||
|
@ -255,11 +260,15 @@ forward_traceparent_to_llm_provider: bool = False
|
|||
|
||||
_current_cost = 0.0 # private variable, used if max budget is set
|
||||
error_logs: Dict = {}
|
||||
add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt
|
||||
add_function_to_prompt: bool = (
|
||||
False # if function calling not supported by api, append function call details to system prompt
|
||||
)
|
||||
client_session: Optional[httpx.Client] = None
|
||||
aclient_session: Optional[httpx.AsyncClient] = None
|
||||
model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks'
|
||||
model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
||||
model_cost_map_url: str = (
|
||||
"https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
||||
)
|
||||
suppress_debug_info = False
|
||||
dynamodb_table_name: Optional[str] = None
|
||||
s3_callback_params: Optional[Dict] = None
|
||||
|
@ -268,6 +277,7 @@ default_key_generate_params: Optional[Dict] = None
|
|||
upperbound_key_generate_params: Optional[LiteLLM_UpperboundKeyGenerateParams] = None
|
||||
key_generation_settings: Optional[StandardKeyGenerationConfig] = None
|
||||
default_internal_user_params: Optional[Dict] = None
|
||||
default_team_params: Optional[Union[NewTeamRequest, Dict]] = None
|
||||
default_team_settings: Optional[List] = None
|
||||
max_user_budget: Optional[float] = None
|
||||
default_max_internal_user_budget: Optional[float] = None
|
||||
|
@ -281,7 +291,9 @@ disable_end_user_cost_tracking_prometheus_only: Optional[bool] = None
|
|||
custom_prometheus_metadata_labels: List[str] = []
|
||||
#### REQUEST PRIORITIZATION ####
|
||||
priority_reservation: Optional[Dict[str, float]] = None
|
||||
force_ipv4: bool = False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6.
|
||||
force_ipv4: bool = (
|
||||
False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6.
|
||||
)
|
||||
module_level_aclient = AsyncHTTPHandler(
|
||||
timeout=request_timeout, client_alias="module level aclient"
|
||||
)
|
||||
|
@ -295,13 +307,13 @@ fallbacks: Optional[List] = None
|
|||
context_window_fallbacks: Optional[List] = None
|
||||
content_policy_fallbacks: Optional[List] = None
|
||||
allowed_fails: int = 3
|
||||
num_retries_per_request: Optional[
|
||||
int
|
||||
] = None # for the request overall (incl. fallbacks + model retries)
|
||||
num_retries_per_request: Optional[int] = (
|
||||
None # for the request overall (incl. fallbacks + model retries)
|
||||
)
|
||||
####### SECRET MANAGERS #####################
|
||||
secret_manager_client: Optional[
|
||||
Any
|
||||
] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
|
||||
secret_manager_client: Optional[Any] = (
|
||||
None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
|
||||
)
|
||||
_google_kms_resource_name: Optional[str] = None
|
||||
_key_management_system: Optional[KeyManagementSystem] = None
|
||||
_key_management_settings: KeyManagementSettings = KeyManagementSettings()
|
||||
|
@ -1050,10 +1062,10 @@ from .types.llms.custom_llm import CustomLLMItem
|
|||
from .types.utils import GenericStreamingChunk
|
||||
|
||||
custom_provider_map: List[CustomLLMItem] = []
|
||||
_custom_providers: List[
|
||||
str
|
||||
] = [] # internal helper util, used to track names of custom providers
|
||||
disable_hf_tokenizer_download: Optional[
|
||||
bool
|
||||
] = None # disable huggingface tokenizer download. Defaults to openai clk100
|
||||
_custom_providers: List[str] = (
|
||||
[]
|
||||
) # internal helper util, used to track names of custom providers
|
||||
disable_hf_tokenizer_download: Optional[bool] = (
|
||||
None # disable huggingface tokenizer download. Defaults to openai clk100
|
||||
)
|
||||
global_disable_no_log_param: bool = False
|
||||
|
|
|
@ -480,6 +480,7 @@ RESPONSE_FORMAT_TOOL_NAME = "json_tool_call" # default tool name used when conv
|
|||
|
||||
########################### Logging Callback Constants ###########################
|
||||
AZURE_STORAGE_MSFT_VERSION = "2019-07-07"
|
||||
PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES = 5
|
||||
MCP_TOOL_NAME_PREFIX = "mcp_tool"
|
||||
|
||||
########################### LiteLLM Proxy Specific Constants ###########################
|
||||
|
@ -514,6 +515,7 @@ LITELLM_PROXY_ADMIN_NAME = "default_user_id"
|
|||
|
||||
########################### DB CRON JOB NAMES ###########################
|
||||
DB_SPEND_UPDATE_JOB_NAME = "db_spend_update_job"
|
||||
PROMETHEUS_EMIT_BUDGET_METRICS_JOB_NAME = "prometheus_emit_budget_metrics_job"
|
||||
DEFAULT_CRON_JOB_LOCK_TTL_SECONDS = 60 # 1 minute
|
||||
PROXY_BUDGET_RESCHEDULER_MIN_TIME = 597
|
||||
PROXY_BUDGET_RESCHEDULER_MAX_TIME = 605
|
||||
|
|
|
@ -16,7 +16,10 @@ from litellm.constants import (
|
|||
from litellm.litellm_core_utils.llm_cost_calc.tool_call_cost_tracking import (
|
||||
StandardBuiltInToolCostTracking,
|
||||
)
|
||||
from litellm.litellm_core_utils.llm_cost_calc.utils import _generic_cost_per_character
|
||||
from litellm.litellm_core_utils.llm_cost_calc.utils import (
|
||||
_generic_cost_per_character,
|
||||
generic_cost_per_token,
|
||||
)
|
||||
from litellm.llms.anthropic.cost_calculation import (
|
||||
cost_per_token as anthropic_cost_per_token,
|
||||
)
|
||||
|
@ -54,12 +57,16 @@ from litellm.llms.vertex_ai.image_generation.cost_calculator import (
|
|||
from litellm.responses.utils import ResponseAPILoggingUtils
|
||||
from litellm.types.llms.openai import (
|
||||
HttpxBinaryResponseContent,
|
||||
OpenAIRealtimeStreamList,
|
||||
OpenAIRealtimeStreamResponseBaseObject,
|
||||
OpenAIRealtimeStreamSessionEvents,
|
||||
ResponseAPIUsage,
|
||||
ResponsesAPIResponse,
|
||||
)
|
||||
from litellm.types.rerank import RerankBilledUnits, RerankResponse
|
||||
from litellm.types.utils import (
|
||||
CallTypesLiteral,
|
||||
LiteLLMRealtimeStreamLoggingObject,
|
||||
LlmProviders,
|
||||
LlmProvidersSet,
|
||||
ModelInfo,
|
||||
|
@ -397,6 +404,7 @@ def _select_model_name_for_cost_calc(
|
|||
base_model: Optional[str] = None,
|
||||
custom_pricing: Optional[bool] = None,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
router_model_id: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
1. If custom pricing is true, return received model name
|
||||
|
@ -411,12 +419,6 @@ def _select_model_name_for_cost_calc(
|
|||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
||||
if custom_pricing is True:
|
||||
return_model = model
|
||||
|
||||
if base_model is not None:
|
||||
return_model = base_model
|
||||
|
||||
completion_response_model: Optional[str] = None
|
||||
if completion_response is not None:
|
||||
if isinstance(completion_response, BaseModel):
|
||||
|
@ -424,6 +426,16 @@ def _select_model_name_for_cost_calc(
|
|||
elif isinstance(completion_response, dict):
|
||||
completion_response_model = completion_response.get("model", None)
|
||||
hidden_params: Optional[dict] = getattr(completion_response, "_hidden_params", None)
|
||||
|
||||
if custom_pricing is True:
|
||||
if router_model_id is not None and router_model_id in litellm.model_cost:
|
||||
return_model = router_model_id
|
||||
else:
|
||||
return_model = model
|
||||
|
||||
if base_model is not None:
|
||||
return_model = base_model
|
||||
|
||||
if completion_response_model is None and hidden_params is not None:
|
||||
if (
|
||||
hidden_params.get("model", None) is not None
|
||||
|
@ -553,6 +565,7 @@ def completion_cost( # noqa: PLR0915
|
|||
base_model: Optional[str] = None,
|
||||
standard_built_in_tools_params: Optional[StandardBuiltInToolsParams] = None,
|
||||
litellm_model_name: Optional[str] = None,
|
||||
router_model_id: Optional[str] = None,
|
||||
) -> float:
|
||||
"""
|
||||
Calculate the cost of a given completion call fot GPT-3.5-turbo, llama2, any litellm supported llm.
|
||||
|
@ -605,18 +618,19 @@ def completion_cost( # noqa: PLR0915
|
|||
completion_response=completion_response
|
||||
)
|
||||
rerank_billed_units: Optional[RerankBilledUnits] = None
|
||||
|
||||
selected_model = _select_model_name_for_cost_calc(
|
||||
model=model,
|
||||
completion_response=completion_response,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
custom_pricing=custom_pricing,
|
||||
base_model=base_model,
|
||||
router_model_id=router_model_id,
|
||||
)
|
||||
|
||||
potential_model_names = [selected_model]
|
||||
if model is not None:
|
||||
potential_model_names.append(model)
|
||||
|
||||
for idx, model in enumerate(potential_model_names):
|
||||
try:
|
||||
verbose_logger.info(
|
||||
|
@ -780,6 +794,25 @@ def completion_cost( # noqa: PLR0915
|
|||
billed_units.get("search_units") or 1
|
||||
) # cohere charges per request by default.
|
||||
completion_tokens = search_units
|
||||
elif call_type == CallTypes.arealtime.value and isinstance(
|
||||
completion_response, LiteLLMRealtimeStreamLoggingObject
|
||||
):
|
||||
if (
|
||||
cost_per_token_usage_object is None
|
||||
or custom_llm_provider is None
|
||||
):
|
||||
raise ValueError(
|
||||
"usage object and custom_llm_provider must be provided for realtime stream cost calculation. Got cost_per_token_usage_object={}, custom_llm_provider={}".format(
|
||||
cost_per_token_usage_object,
|
||||
custom_llm_provider,
|
||||
)
|
||||
)
|
||||
return handle_realtime_stream_cost_calculation(
|
||||
results=completion_response.results,
|
||||
combined_usage_object=cost_per_token_usage_object,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
litellm_model_name=model,
|
||||
)
|
||||
# Calculate cost based on prompt_tokens, completion_tokens
|
||||
if (
|
||||
"togethercomputer" in model
|
||||
|
@ -909,6 +942,7 @@ def response_cost_calculator(
|
|||
HttpxBinaryResponseContent,
|
||||
RerankResponse,
|
||||
ResponsesAPIResponse,
|
||||
LiteLLMRealtimeStreamLoggingObject,
|
||||
],
|
||||
model: str,
|
||||
custom_llm_provider: Optional[str],
|
||||
|
@ -937,6 +971,7 @@ def response_cost_calculator(
|
|||
prompt: str = "",
|
||||
standard_built_in_tools_params: Optional[StandardBuiltInToolsParams] = None,
|
||||
litellm_model_name: Optional[str] = None,
|
||||
router_model_id: Optional[str] = None,
|
||||
) -> float:
|
||||
"""
|
||||
Returns
|
||||
|
@ -967,6 +1002,8 @@ def response_cost_calculator(
|
|||
base_model=base_model,
|
||||
prompt=prompt,
|
||||
standard_built_in_tools_params=standard_built_in_tools_params,
|
||||
litellm_model_name=litellm_model_name,
|
||||
router_model_id=router_model_id,
|
||||
)
|
||||
return response_cost
|
||||
except Exception as e:
|
||||
|
@ -1141,3 +1178,173 @@ def batch_cost_calculator(
|
|||
) # batch cost is usually half of the regular token cost
|
||||
|
||||
return total_prompt_cost, total_completion_cost
|
||||
|
||||
|
||||
class RealtimeAPITokenUsageProcessor:
|
||||
@staticmethod
|
||||
def collect_usage_from_realtime_stream_results(
|
||||
results: OpenAIRealtimeStreamList,
|
||||
) -> List[Usage]:
|
||||
"""
|
||||
Collect usage from realtime stream results
|
||||
"""
|
||||
response_done_events: List[OpenAIRealtimeStreamResponseBaseObject] = cast(
|
||||
List[OpenAIRealtimeStreamResponseBaseObject],
|
||||
[result for result in results if result["type"] == "response.done"],
|
||||
)
|
||||
usage_objects: List[Usage] = []
|
||||
for result in response_done_events:
|
||||
usage_object = (
|
||||
ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage(
|
||||
result["response"].get("usage", {})
|
||||
)
|
||||
)
|
||||
usage_objects.append(usage_object)
|
||||
return usage_objects
|
||||
|
||||
@staticmethod
|
||||
def combine_usage_objects(usage_objects: List[Usage]) -> Usage:
|
||||
"""
|
||||
Combine multiple Usage objects into a single Usage object, checking model keys for nested values.
|
||||
"""
|
||||
from litellm.types.utils import (
|
||||
CompletionTokensDetails,
|
||||
PromptTokensDetailsWrapper,
|
||||
Usage,
|
||||
)
|
||||
|
||||
combined = Usage()
|
||||
|
||||
# Sum basic token counts
|
||||
for usage in usage_objects:
|
||||
# Handle direct attributes by checking what exists in the model
|
||||
for attr in dir(usage):
|
||||
if not attr.startswith("_") and not callable(getattr(usage, attr)):
|
||||
current_val = getattr(combined, attr, 0)
|
||||
new_val = getattr(usage, attr, 0)
|
||||
if (
|
||||
new_val is not None
|
||||
and isinstance(new_val, (int, float))
|
||||
and isinstance(current_val, (int, float))
|
||||
):
|
||||
setattr(combined, attr, current_val + new_val)
|
||||
# Handle nested prompt_tokens_details
|
||||
if hasattr(usage, "prompt_tokens_details") and usage.prompt_tokens_details:
|
||||
if (
|
||||
not hasattr(combined, "prompt_tokens_details")
|
||||
or not combined.prompt_tokens_details
|
||||
):
|
||||
combined.prompt_tokens_details = PromptTokensDetailsWrapper()
|
||||
|
||||
# Check what keys exist in the model's prompt_tokens_details
|
||||
for attr in dir(usage.prompt_tokens_details):
|
||||
if not attr.startswith("_") and not callable(
|
||||
getattr(usage.prompt_tokens_details, attr)
|
||||
):
|
||||
current_val = getattr(combined.prompt_tokens_details, attr, 0)
|
||||
new_val = getattr(usage.prompt_tokens_details, attr, 0)
|
||||
if new_val is not None:
|
||||
setattr(
|
||||
combined.prompt_tokens_details,
|
||||
attr,
|
||||
current_val + new_val,
|
||||
)
|
||||
|
||||
# Handle nested completion_tokens_details
|
||||
if (
|
||||
hasattr(usage, "completion_tokens_details")
|
||||
and usage.completion_tokens_details
|
||||
):
|
||||
if (
|
||||
not hasattr(combined, "completion_tokens_details")
|
||||
or not combined.completion_tokens_details
|
||||
):
|
||||
combined.completion_tokens_details = CompletionTokensDetails()
|
||||
|
||||
# Check what keys exist in the model's completion_tokens_details
|
||||
for attr in dir(usage.completion_tokens_details):
|
||||
if not attr.startswith("_") and not callable(
|
||||
getattr(usage.completion_tokens_details, attr)
|
||||
):
|
||||
current_val = getattr(
|
||||
combined.completion_tokens_details, attr, 0
|
||||
)
|
||||
new_val = getattr(usage.completion_tokens_details, attr, 0)
|
||||
if new_val is not None:
|
||||
setattr(
|
||||
combined.completion_tokens_details,
|
||||
attr,
|
||||
current_val + new_val,
|
||||
)
|
||||
|
||||
return combined
|
||||
|
||||
@staticmethod
|
||||
def collect_and_combine_usage_from_realtime_stream_results(
|
||||
results: OpenAIRealtimeStreamList,
|
||||
) -> Usage:
|
||||
"""
|
||||
Collect and combine usage from realtime stream results
|
||||
"""
|
||||
collected_usage_objects = (
|
||||
RealtimeAPITokenUsageProcessor.collect_usage_from_realtime_stream_results(
|
||||
results
|
||||
)
|
||||
)
|
||||
combined_usage_object = RealtimeAPITokenUsageProcessor.combine_usage_objects(
|
||||
collected_usage_objects
|
||||
)
|
||||
return combined_usage_object
|
||||
|
||||
@staticmethod
|
||||
def create_logging_realtime_object(
|
||||
usage: Usage, results: OpenAIRealtimeStreamList
|
||||
) -> LiteLLMRealtimeStreamLoggingObject:
|
||||
return LiteLLMRealtimeStreamLoggingObject(
|
||||
usage=usage,
|
||||
results=results,
|
||||
)
|
||||
|
||||
|
||||
def handle_realtime_stream_cost_calculation(
|
||||
results: OpenAIRealtimeStreamList,
|
||||
combined_usage_object: Usage,
|
||||
custom_llm_provider: str,
|
||||
litellm_model_name: str,
|
||||
) -> float:
|
||||
"""
|
||||
Handles the cost calculation for realtime stream responses.
|
||||
|
||||
Pick the 'response.done' events. Calculate total cost across all 'response.done' events.
|
||||
|
||||
Args:
|
||||
results: A list of OpenAIRealtimeStreamBaseObject objects
|
||||
"""
|
||||
received_model = None
|
||||
potential_model_names = []
|
||||
for result in results:
|
||||
if result["type"] == "session.created":
|
||||
received_model = cast(OpenAIRealtimeStreamSessionEvents, result)["session"][
|
||||
"model"
|
||||
]
|
||||
potential_model_names.append(received_model)
|
||||
|
||||
potential_model_names.append(litellm_model_name)
|
||||
input_cost_per_token = 0.0
|
||||
output_cost_per_token = 0.0
|
||||
|
||||
for model_name in potential_model_names:
|
||||
try:
|
||||
_input_cost_per_token, _output_cost_per_token = generic_cost_per_token(
|
||||
model=model_name,
|
||||
usage=combined_usage_object,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
input_cost_per_token += _input_cost_per_token
|
||||
output_cost_per_token += _output_cost_per_token
|
||||
break # exit if we find a valid model
|
||||
total_cost = input_cost_per_token + output_cost_per_token
|
||||
|
||||
return total_cost
|
||||
|
|
|
@ -1,10 +1,19 @@
|
|||
# used for /metrics endpoint on LiteLLM Proxy
|
||||
#### What this does ####
|
||||
# On success, log events to Prometheus
|
||||
import asyncio
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Awaitable, Callable, List, Literal, Optional, Tuple, cast
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
cast,
|
||||
)
|
||||
|
||||
import litellm
|
||||
from litellm._logging import print_verbose, verbose_logger
|
||||
|
@ -14,6 +23,11 @@ from litellm.types.integrations.prometheus import *
|
|||
from litellm.types.utils import StandardLoggingPayload
|
||||
from litellm.utils import get_end_user_id_for_cost_tracking
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
else:
|
||||
AsyncIOScheduler = Any
|
||||
|
||||
|
||||
class PrometheusLogger(CustomLogger):
|
||||
# Class variables or attributes
|
||||
|
@ -359,8 +373,6 @@ class PrometheusLogger(CustomLogger):
|
|||
label_name="litellm_requests_metric"
|
||||
),
|
||||
)
|
||||
self._initialize_prometheus_startup_metrics()
|
||||
|
||||
except Exception as e:
|
||||
print_verbose(f"Got exception on init prometheus client {str(e)}")
|
||||
raise e
|
||||
|
@ -988,9 +1000,9 @@ class PrometheusLogger(CustomLogger):
|
|||
):
|
||||
try:
|
||||
verbose_logger.debug("setting remaining tokens requests metric")
|
||||
standard_logging_payload: Optional[
|
||||
StandardLoggingPayload
|
||||
] = request_kwargs.get("standard_logging_object")
|
||||
standard_logging_payload: Optional[StandardLoggingPayload] = (
|
||||
request_kwargs.get("standard_logging_object")
|
||||
)
|
||||
|
||||
if standard_logging_payload is None:
|
||||
return
|
||||
|
@ -1337,24 +1349,6 @@ class PrometheusLogger(CustomLogger):
|
|||
|
||||
return max_budget - spend
|
||||
|
||||
def _initialize_prometheus_startup_metrics(self):
|
||||
"""
|
||||
Initialize prometheus startup metrics
|
||||
|
||||
Helper to create tasks for initializing metrics that are required on startup - eg. remaining budget metrics
|
||||
"""
|
||||
if litellm.prometheus_initialize_budget_metrics is not True:
|
||||
verbose_logger.debug("Prometheus: skipping budget metrics initialization")
|
||||
return
|
||||
|
||||
try:
|
||||
if asyncio.get_running_loop():
|
||||
asyncio.create_task(self._initialize_remaining_budget_metrics())
|
||||
except RuntimeError as e: # no running event loop
|
||||
verbose_logger.exception(
|
||||
f"No running event loop - skipping budget metrics initialization: {str(e)}"
|
||||
)
|
||||
|
||||
async def _initialize_budget_metrics(
|
||||
self,
|
||||
data_fetch_function: Callable[..., Awaitable[Tuple[List[Any], Optional[int]]]],
|
||||
|
@ -1475,12 +1469,41 @@ class PrometheusLogger(CustomLogger):
|
|||
data_type="keys",
|
||||
)
|
||||
|
||||
async def _initialize_remaining_budget_metrics(self):
|
||||
async def initialize_remaining_budget_metrics(self):
|
||||
"""
|
||||
Initialize remaining budget metrics for all teams to avoid metric discrepancies.
|
||||
Handler for initializing remaining budget metrics for all teams to avoid metric discrepancies.
|
||||
|
||||
Runs when prometheus logger starts up.
|
||||
|
||||
- If redis cache is available, we use the pod lock manager to acquire a lock and initialize the metrics.
|
||||
- Ensures only one pod emits the metrics at a time.
|
||||
- If redis cache is not available, we initialize the metrics directly.
|
||||
"""
|
||||
from litellm.constants import PROMETHEUS_EMIT_BUDGET_METRICS_JOB_NAME
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||
|
||||
pod_lock_manager = proxy_logging_obj.db_spend_update_writer.pod_lock_manager
|
||||
|
||||
# if using redis, ensure only one pod emits the metrics at a time
|
||||
if pod_lock_manager and pod_lock_manager.redis_cache:
|
||||
if await pod_lock_manager.acquire_lock(
|
||||
cronjob_id=PROMETHEUS_EMIT_BUDGET_METRICS_JOB_NAME
|
||||
):
|
||||
try:
|
||||
await self._initialize_remaining_budget_metrics()
|
||||
finally:
|
||||
await pod_lock_manager.release_lock(
|
||||
cronjob_id=PROMETHEUS_EMIT_BUDGET_METRICS_JOB_NAME
|
||||
)
|
||||
else:
|
||||
# if not using redis, initialize the metrics directly
|
||||
await self._initialize_remaining_budget_metrics()
|
||||
|
||||
async def _initialize_remaining_budget_metrics(self):
|
||||
"""
|
||||
Helper to initialize remaining budget metrics for all teams and API keys.
|
||||
"""
|
||||
verbose_logger.debug("Emitting key, team budget metrics....")
|
||||
await self._initialize_team_budget_metrics()
|
||||
await self._initialize_api_key_budget_metrics()
|
||||
|
||||
|
@ -1737,6 +1760,36 @@ class PrometheusLogger(CustomLogger):
|
|||
return (end_time - start_time).total_seconds()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def initialize_budget_metrics_cron_job(scheduler: AsyncIOScheduler):
|
||||
"""
|
||||
Initialize budget metrics as a cron job. This job runs every `PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES` minutes.
|
||||
|
||||
It emits the current remaining budget metrics for all Keys and Teams.
|
||||
"""
|
||||
from litellm.constants import PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.integrations.prometheus import PrometheusLogger
|
||||
|
||||
prometheus_loggers: List[CustomLogger] = (
|
||||
litellm.logging_callback_manager.get_custom_loggers_for_type(
|
||||
callback_type=PrometheusLogger
|
||||
)
|
||||
)
|
||||
# we need to get the initialized prometheus logger instance(s) and call logger.initialize_remaining_budget_metrics() on them
|
||||
verbose_logger.debug("found %s prometheus loggers", len(prometheus_loggers))
|
||||
if len(prometheus_loggers) > 0:
|
||||
prometheus_logger = cast(PrometheusLogger, prometheus_loggers[0])
|
||||
verbose_logger.debug(
|
||||
"Initializing remaining budget metrics as a cron job executing every %s minutes"
|
||||
% PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES
|
||||
)
|
||||
scheduler.add_job(
|
||||
prometheus_logger.initialize_remaining_budget_metrics,
|
||||
"interval",
|
||||
minutes=PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _mount_metrics_endpoint(premium_user: bool):
|
||||
"""
|
||||
|
|
|
@ -110,5 +110,8 @@ def get_litellm_params(
|
|||
"azure_password": kwargs.get("azure_password"),
|
||||
"max_retries": max_retries,
|
||||
"timeout": kwargs.get("timeout"),
|
||||
"bucket_name": kwargs.get("bucket_name"),
|
||||
"vertex_credentials": kwargs.get("vertex_credentials"),
|
||||
"vertex_project": kwargs.get("vertex_project"),
|
||||
}
|
||||
return litellm_params
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from typing import Literal, Optional
|
||||
|
||||
import litellm
|
||||
from litellm import LlmProviders
|
||||
from litellm.exceptions import BadRequestError
|
||||
from litellm.types.utils import LlmProviders, LlmProvidersSet
|
||||
|
||||
|
||||
def get_supported_openai_params( # noqa: PLR0915
|
||||
|
@ -30,6 +30,20 @@ def get_supported_openai_params( # noqa: PLR0915
|
|||
except BadRequestError:
|
||||
return None
|
||||
|
||||
if custom_llm_provider in LlmProvidersSet:
|
||||
provider_config = litellm.ProviderConfigManager.get_provider_chat_config(
|
||||
model=model, provider=LlmProviders(custom_llm_provider)
|
||||
)
|
||||
elif custom_llm_provider.split("/")[0] in LlmProvidersSet:
|
||||
provider_config = litellm.ProviderConfigManager.get_provider_chat_config(
|
||||
model=model, provider=LlmProviders(custom_llm_provider.split("/")[0])
|
||||
)
|
||||
else:
|
||||
provider_config = None
|
||||
|
||||
if provider_config and request_type == "chat_completion":
|
||||
return provider_config.get_supported_openai_params(model=model)
|
||||
|
||||
if custom_llm_provider == "bedrock":
|
||||
return litellm.AmazonConverseConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "ollama":
|
||||
|
@ -226,7 +240,8 @@ def get_supported_openai_params( # noqa: PLR0915
|
|||
provider_config = litellm.ProviderConfigManager.get_provider_chat_config(
|
||||
model=model, provider=LlmProviders.CUSTOM
|
||||
)
|
||||
return provider_config.get_supported_openai_params(model=model)
|
||||
if provider_config:
|
||||
return provider_config.get_supported_openai_params(model=model)
|
||||
elif request_type == "embeddings":
|
||||
return None
|
||||
elif request_type == "transcription":
|
||||
|
|
|
@ -32,7 +32,10 @@ from litellm.constants import (
|
|||
DEFAULT_MOCK_RESPONSE_COMPLETION_TOKEN_COUNT,
|
||||
DEFAULT_MOCK_RESPONSE_PROMPT_TOKEN_COUNT,
|
||||
)
|
||||
from litellm.cost_calculator import _select_model_name_for_cost_calc
|
||||
from litellm.cost_calculator import (
|
||||
RealtimeAPITokenUsageProcessor,
|
||||
_select_model_name_for_cost_calc,
|
||||
)
|
||||
from litellm.integrations.arize.arize import ArizeLogger
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
@ -64,6 +67,7 @@ from litellm.types.utils import (
|
|||
ImageResponse,
|
||||
LiteLLMBatch,
|
||||
LiteLLMLoggingBaseClass,
|
||||
LiteLLMRealtimeStreamLoggingObject,
|
||||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
RawRequestTypedDict,
|
||||
|
@ -618,7 +622,6 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
] = RawRequestTypedDict(
|
||||
error=str(e),
|
||||
)
|
||||
traceback.print_exc()
|
||||
_metadata[
|
||||
"raw_request"
|
||||
] = "Unable to Log \
|
||||
|
@ -899,9 +902,11 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
FineTuningJob,
|
||||
ResponsesAPIResponse,
|
||||
ResponseCompletedEvent,
|
||||
LiteLLMRealtimeStreamLoggingObject,
|
||||
],
|
||||
cache_hit: Optional[bool] = None,
|
||||
litellm_model_name: Optional[str] = None,
|
||||
router_model_id: Optional[str] = None,
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
Calculate response cost using result + logging object variables.
|
||||
|
@ -940,6 +945,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
"custom_pricing": custom_pricing,
|
||||
"prompt": prompt,
|
||||
"standard_built_in_tools_params": self.standard_built_in_tools_params,
|
||||
"router_model_id": router_model_id,
|
||||
}
|
||||
except Exception as e: # error creating kwargs for cost calculation
|
||||
debug_info = StandardLoggingModelCostFailureDebugInformation(
|
||||
|
@ -1049,26 +1055,50 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
result = self._handle_anthropic_messages_response_logging(result=result)
|
||||
## if model in model cost map - log the response cost
|
||||
## else set cost to None
|
||||
|
||||
logging_result = result
|
||||
|
||||
if self.call_type == CallTypes.arealtime.value and isinstance(result, list):
|
||||
combined_usage_object = RealtimeAPITokenUsageProcessor.collect_and_combine_usage_from_realtime_stream_results(
|
||||
results=result
|
||||
)
|
||||
logging_result = (
|
||||
RealtimeAPITokenUsageProcessor.create_logging_realtime_object(
|
||||
usage=combined_usage_object,
|
||||
results=result,
|
||||
)
|
||||
)
|
||||
|
||||
# self.model_call_details[
|
||||
# "response_cost"
|
||||
# ] = handle_realtime_stream_cost_calculation(
|
||||
# results=result,
|
||||
# combined_usage_object=combined_usage_object,
|
||||
# custom_llm_provider=self.custom_llm_provider,
|
||||
# litellm_model_name=self.model,
|
||||
# )
|
||||
# self.model_call_details["combined_usage_object"] = combined_usage_object
|
||||
if (
|
||||
standard_logging_object is None
|
||||
and result is not None
|
||||
and self.stream is not True
|
||||
):
|
||||
if (
|
||||
isinstance(result, ModelResponse)
|
||||
or isinstance(result, ModelResponseStream)
|
||||
or isinstance(result, EmbeddingResponse)
|
||||
or isinstance(result, ImageResponse)
|
||||
or isinstance(result, TranscriptionResponse)
|
||||
or isinstance(result, TextCompletionResponse)
|
||||
or isinstance(result, HttpxBinaryResponseContent) # tts
|
||||
or isinstance(result, RerankResponse)
|
||||
or isinstance(result, FineTuningJob)
|
||||
or isinstance(result, LiteLLMBatch)
|
||||
or isinstance(result, ResponsesAPIResponse)
|
||||
isinstance(logging_result, ModelResponse)
|
||||
or isinstance(logging_result, ModelResponseStream)
|
||||
or isinstance(logging_result, EmbeddingResponse)
|
||||
or isinstance(logging_result, ImageResponse)
|
||||
or isinstance(logging_result, TranscriptionResponse)
|
||||
or isinstance(logging_result, TextCompletionResponse)
|
||||
or isinstance(logging_result, HttpxBinaryResponseContent) # tts
|
||||
or isinstance(logging_result, RerankResponse)
|
||||
or isinstance(logging_result, FineTuningJob)
|
||||
or isinstance(logging_result, LiteLLMBatch)
|
||||
or isinstance(logging_result, ResponsesAPIResponse)
|
||||
or isinstance(logging_result, LiteLLMRealtimeStreamLoggingObject)
|
||||
):
|
||||
## HIDDEN PARAMS ##
|
||||
hidden_params = getattr(result, "_hidden_params", {})
|
||||
hidden_params = getattr(logging_result, "_hidden_params", {})
|
||||
if hidden_params:
|
||||
# add to metadata for logging
|
||||
if self.model_call_details.get("litellm_params") is not None:
|
||||
|
@ -1086,7 +1116,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
self.model_call_details["litellm_params"]["metadata"][ # type: ignore
|
||||
"hidden_params"
|
||||
] = getattr(
|
||||
result, "_hidden_params", {}
|
||||
logging_result, "_hidden_params", {}
|
||||
)
|
||||
## RESPONSE COST - Only calculate if not in hidden_params ##
|
||||
if "response_cost" in hidden_params:
|
||||
|
@ -1096,14 +1126,14 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
else:
|
||||
self.model_call_details[
|
||||
"response_cost"
|
||||
] = self._response_cost_calculator(result=result)
|
||||
] = self._response_cost_calculator(result=logging_result)
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj=result,
|
||||
init_response_obj=logging_result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
|
@ -3122,6 +3152,7 @@ class StandardLoggingPayloadSetup:
|
|||
prompt_integration: Optional[str] = None,
|
||||
applied_guardrails: Optional[List[str]] = None,
|
||||
mcp_tool_call_metadata: Optional[StandardLoggingMCPToolCall] = None,
|
||||
usage_object: Optional[dict] = None,
|
||||
) -> StandardLoggingMetadata:
|
||||
"""
|
||||
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
|
||||
|
@ -3169,6 +3200,7 @@ class StandardLoggingPayloadSetup:
|
|||
prompt_management_metadata=prompt_management_metadata,
|
||||
applied_guardrails=applied_guardrails,
|
||||
mcp_tool_call_metadata=mcp_tool_call_metadata,
|
||||
usage_object=usage_object,
|
||||
)
|
||||
if isinstance(metadata, dict):
|
||||
# Filter the metadata dictionary to include only the specified keys
|
||||
|
@ -3194,8 +3226,12 @@ class StandardLoggingPayloadSetup:
|
|||
return clean_metadata
|
||||
|
||||
@staticmethod
|
||||
def get_usage_from_response_obj(response_obj: Optional[dict]) -> Usage:
|
||||
def get_usage_from_response_obj(
|
||||
response_obj: Optional[dict], combined_usage_object: Optional[Usage] = None
|
||||
) -> Usage:
|
||||
## BASE CASE ##
|
||||
if combined_usage_object is not None:
|
||||
return combined_usage_object
|
||||
if response_obj is None:
|
||||
return Usage(
|
||||
prompt_tokens=0,
|
||||
|
@ -3324,6 +3360,7 @@ class StandardLoggingPayloadSetup:
|
|||
litellm_overhead_time_ms=None,
|
||||
batch_models=None,
|
||||
litellm_model_name=None,
|
||||
usage_object=None,
|
||||
)
|
||||
if hidden_params is not None:
|
||||
for key in StandardLoggingHiddenParams.__annotations__.keys():
|
||||
|
@ -3440,6 +3477,7 @@ def get_standard_logging_object_payload(
|
|||
litellm_overhead_time_ms=None,
|
||||
batch_models=None,
|
||||
litellm_model_name=None,
|
||||
usage_object=None,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -3456,8 +3494,12 @@ def get_standard_logging_object_payload(
|
|||
call_type = kwargs.get("call_type")
|
||||
cache_hit = kwargs.get("cache_hit", False)
|
||||
usage = StandardLoggingPayloadSetup.get_usage_from_response_obj(
|
||||
response_obj=response_obj
|
||||
response_obj=response_obj,
|
||||
combined_usage_object=cast(
|
||||
Optional[Usage], kwargs.get("combined_usage_object")
|
||||
),
|
||||
)
|
||||
|
||||
id = response_obj.get("id", kwargs.get("litellm_call_id"))
|
||||
|
||||
_model_id = metadata.get("model_info", {}).get("id", "")
|
||||
|
@ -3496,6 +3538,7 @@ def get_standard_logging_object_payload(
|
|||
prompt_integration=kwargs.get("prompt_integration", None),
|
||||
applied_guardrails=kwargs.get("applied_guardrails", None),
|
||||
mcp_tool_call_metadata=kwargs.get("mcp_tool_call_metadata", None),
|
||||
usage_object=usage.model_dump(),
|
||||
)
|
||||
|
||||
_request_body = proxy_server_request.get("body", {})
|
||||
|
@ -3636,6 +3679,7 @@ def get_standard_logging_metadata(
|
|||
prompt_management_metadata=None,
|
||||
applied_guardrails=None,
|
||||
mcp_tool_call_metadata=None,
|
||||
usage_object=None,
|
||||
)
|
||||
if isinstance(metadata, dict):
|
||||
# Filter the metadata dictionary to include only the specified keys
|
||||
|
@ -3730,6 +3774,7 @@ def create_dummy_standard_logging_payload() -> StandardLoggingPayload:
|
|||
litellm_overhead_time_ms=None,
|
||||
batch_models=None,
|
||||
litellm_model_name=None,
|
||||
usage_object=None,
|
||||
)
|
||||
|
||||
# Convert numeric values to appropriate types
|
||||
|
|
|
@ -90,35 +90,45 @@ def _generic_cost_per_character(
|
|||
return prompt_cost, completion_cost
|
||||
|
||||
|
||||
def _get_prompt_token_base_cost(model_info: ModelInfo, usage: Usage) -> float:
|
||||
def _get_token_base_cost(model_info: ModelInfo, usage: Usage) -> Tuple[float, float]:
|
||||
"""
|
||||
Return prompt cost for a given model and usage.
|
||||
|
||||
If input_tokens > 128k and `input_cost_per_token_above_128k_tokens` is set, then we use the `input_cost_per_token_above_128k_tokens` field.
|
||||
If input_tokens > threshold and `input_cost_per_token_above_[x]k_tokens` or `input_cost_per_token_above_[x]_tokens` is set,
|
||||
then we use the corresponding threshold cost.
|
||||
"""
|
||||
input_cost_per_token_above_128k_tokens = model_info.get(
|
||||
"input_cost_per_token_above_128k_tokens"
|
||||
)
|
||||
if _is_above_128k(usage.prompt_tokens) and input_cost_per_token_above_128k_tokens:
|
||||
return input_cost_per_token_above_128k_tokens
|
||||
return model_info["input_cost_per_token"]
|
||||
prompt_base_cost = model_info["input_cost_per_token"]
|
||||
completion_base_cost = model_info["output_cost_per_token"]
|
||||
|
||||
## CHECK IF ABOVE THRESHOLD
|
||||
threshold: Optional[float] = None
|
||||
for key, value in sorted(model_info.items(), reverse=True):
|
||||
if key.startswith("input_cost_per_token_above_") and value is not None:
|
||||
try:
|
||||
# Handle both formats: _above_128k_tokens and _above_128_tokens
|
||||
threshold_str = key.split("_above_")[1].split("_tokens")[0]
|
||||
threshold = float(threshold_str.replace("k", "")) * (
|
||||
1000 if "k" in threshold_str else 1
|
||||
)
|
||||
if usage.prompt_tokens > threshold:
|
||||
prompt_base_cost = cast(
|
||||
float,
|
||||
model_info.get(key, prompt_base_cost),
|
||||
)
|
||||
completion_base_cost = cast(
|
||||
float,
|
||||
model_info.get(
|
||||
f"output_cost_per_token_above_{threshold_str}_tokens",
|
||||
completion_base_cost,
|
||||
),
|
||||
)
|
||||
break
|
||||
except (IndexError, ValueError):
|
||||
continue
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
def _get_completion_token_base_cost(model_info: ModelInfo, usage: Usage) -> float:
|
||||
"""
|
||||
Return prompt cost for a given model and usage.
|
||||
|
||||
If input_tokens > 128k and `input_cost_per_token_above_128k_tokens` is set, then we use the `input_cost_per_token_above_128k_tokens` field.
|
||||
"""
|
||||
output_cost_per_token_above_128k_tokens = model_info.get(
|
||||
"output_cost_per_token_above_128k_tokens"
|
||||
)
|
||||
if (
|
||||
_is_above_128k(usage.completion_tokens)
|
||||
and output_cost_per_token_above_128k_tokens
|
||||
):
|
||||
return output_cost_per_token_above_128k_tokens
|
||||
return model_info["output_cost_per_token"]
|
||||
return prompt_base_cost, completion_base_cost
|
||||
|
||||
|
||||
def calculate_cost_component(
|
||||
|
@ -215,7 +225,9 @@ def generic_cost_per_token(
|
|||
if text_tokens == 0:
|
||||
text_tokens = usage.prompt_tokens - cache_hit_tokens - audio_tokens
|
||||
|
||||
prompt_base_cost = _get_prompt_token_base_cost(model_info=model_info, usage=usage)
|
||||
prompt_base_cost, completion_base_cost = _get_token_base_cost(
|
||||
model_info=model_info, usage=usage
|
||||
)
|
||||
|
||||
prompt_cost = float(text_tokens) * prompt_base_cost
|
||||
|
||||
|
@ -253,9 +265,6 @@ def generic_cost_per_token(
|
|||
)
|
||||
|
||||
## CALCULATE OUTPUT COST
|
||||
completion_base_cost = _get_completion_token_base_cost(
|
||||
model_info=model_info, usage=usage
|
||||
)
|
||||
text_tokens = usage.completion_tokens
|
||||
audio_tokens = 0
|
||||
if usage.completion_tokens_details is not None:
|
||||
|
|
|
@ -36,11 +36,16 @@ class ResponseMetadata:
|
|||
self, logging_obj: LiteLLMLoggingObject, model: Optional[str], kwargs: dict
|
||||
) -> None:
|
||||
"""Set hidden parameters on the response"""
|
||||
|
||||
## ADD OTHER HIDDEN PARAMS
|
||||
model_id = kwargs.get("model_info", {}).get("id", None)
|
||||
new_params = {
|
||||
"litellm_call_id": getattr(logging_obj, "litellm_call_id", None),
|
||||
"model_id": kwargs.get("model_info", {}).get("id", None),
|
||||
"api_base": get_api_base(model=model or "", optional_params=kwargs),
|
||||
"response_cost": logging_obj._response_cost_calculator(result=self.result),
|
||||
"model_id": model_id,
|
||||
"response_cost": logging_obj._response_cost_calculator(
|
||||
result=self.result, litellm_model_name=model, router_model_id=model_id
|
||||
),
|
||||
"additional_headers": process_response_headers(
|
||||
self._get_value_from_hidden_params("additional_headers") or {}
|
||||
),
|
||||
|
|
|
@ -2,7 +2,10 @@
|
|||
Common utility functions used for translating messages across providers
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Literal, Optional, Union, cast
|
||||
import io
|
||||
import mimetypes
|
||||
from os import PathLike
|
||||
from typing import Dict, List, Literal, Mapping, Optional, Union, cast
|
||||
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
|
@ -10,7 +13,13 @@ from litellm.types.llms.openai import (
|
|||
ChatCompletionFileObject,
|
||||
ChatCompletionUserMessage,
|
||||
)
|
||||
from litellm.types.utils import Choices, ModelResponse, StreamingChoices
|
||||
from litellm.types.utils import (
|
||||
Choices,
|
||||
ExtractedFileData,
|
||||
FileTypes,
|
||||
ModelResponse,
|
||||
StreamingChoices,
|
||||
)
|
||||
|
||||
DEFAULT_USER_CONTINUE_MESSAGE = ChatCompletionUserMessage(
|
||||
content="Please continue.", role="user"
|
||||
|
@ -348,3 +357,99 @@ def update_messages_with_model_file_ids(
|
|||
)
|
||||
file_object_file_field["file_id"] = provider_file_id
|
||||
return messages
|
||||
|
||||
|
||||
def extract_file_data(file_data: FileTypes) -> ExtractedFileData:
|
||||
"""
|
||||
Extracts and processes file data from various input formats.
|
||||
|
||||
Args:
|
||||
file_data: Can be a tuple of (filename, content, [content_type], [headers]) or direct file content
|
||||
|
||||
Returns:
|
||||
ExtractedFileData containing:
|
||||
- filename: Name of the file if provided
|
||||
- content: The file content in bytes
|
||||
- content_type: MIME type of the file
|
||||
- headers: Any additional headers
|
||||
"""
|
||||
# Parse the file_data based on its type
|
||||
filename = None
|
||||
file_content = None
|
||||
content_type = None
|
||||
file_headers: Mapping[str, str] = {}
|
||||
|
||||
if isinstance(file_data, tuple):
|
||||
if len(file_data) == 2:
|
||||
filename, file_content = file_data
|
||||
elif len(file_data) == 3:
|
||||
filename, file_content, content_type = file_data
|
||||
elif len(file_data) == 4:
|
||||
filename, file_content, content_type, file_headers = file_data
|
||||
else:
|
||||
file_content = file_data
|
||||
# Convert content to bytes
|
||||
if isinstance(file_content, (str, PathLike)):
|
||||
# If it's a path, open and read the file
|
||||
with open(file_content, "rb") as f:
|
||||
content = f.read()
|
||||
elif isinstance(file_content, io.IOBase):
|
||||
# If it's a file-like object
|
||||
content = file_content.read()
|
||||
|
||||
if isinstance(content, str):
|
||||
content = content.encode("utf-8")
|
||||
# Reset file pointer to beginning
|
||||
file_content.seek(0)
|
||||
elif isinstance(file_content, bytes):
|
||||
content = file_content
|
||||
else:
|
||||
raise ValueError(f"Unsupported file content type: {type(file_content)}")
|
||||
|
||||
# Use provided content type or guess based on filename
|
||||
if not content_type:
|
||||
content_type = (
|
||||
mimetypes.guess_type(filename)[0]
|
||||
if filename
|
||||
else "application/octet-stream"
|
||||
)
|
||||
|
||||
return ExtractedFileData(
|
||||
filename=filename,
|
||||
content=content,
|
||||
content_type=content_type,
|
||||
headers=file_headers,
|
||||
)
|
||||
|
||||
def unpack_defs(schema, defs):
|
||||
properties = schema.get("properties", None)
|
||||
if properties is None:
|
||||
return
|
||||
|
||||
for name, value in properties.items():
|
||||
ref_key = value.get("$ref", None)
|
||||
if ref_key is not None:
|
||||
ref = defs[ref_key.split("defs/")[-1]]
|
||||
unpack_defs(ref, defs)
|
||||
properties[name] = ref
|
||||
continue
|
||||
|
||||
anyof = value.get("anyOf", None)
|
||||
if anyof is not None:
|
||||
for i, atype in enumerate(anyof):
|
||||
ref_key = atype.get("$ref", None)
|
||||
if ref_key is not None:
|
||||
ref = defs[ref_key.split("defs/")[-1]]
|
||||
unpack_defs(ref, defs)
|
||||
anyof[i] = ref
|
||||
continue
|
||||
|
||||
items = value.get("items", None)
|
||||
if items is not None:
|
||||
ref_key = items.get("$ref", None)
|
||||
if ref_key is not None:
|
||||
ref = defs[ref_key.split("defs/")[-1]]
|
||||
unpack_defs(ref, defs)
|
||||
value["items"] = ref
|
||||
continue
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import copy
|
||||
import json
|
||||
import re
|
||||
import traceback
|
||||
import uuid
|
||||
import xml.etree.ElementTree as ET
|
||||
from enum import Enum
|
||||
|
@ -748,7 +747,6 @@ def convert_to_anthropic_image_obj(
|
|||
data=base64_data,
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if "Error: Unable to fetch image from URL" in str(e):
|
||||
raise e
|
||||
raise Exception(
|
||||
|
@ -3442,6 +3440,8 @@ def _bedrock_tools_pt(tools: List) -> List[BedrockToolBlock]:
|
|||
}
|
||||
]
|
||||
"""
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import unpack_defs
|
||||
|
||||
tool_block_list: List[BedrockToolBlock] = []
|
||||
for tool in tools:
|
||||
parameters = tool.get("function", {}).get(
|
||||
|
@ -3455,6 +3455,13 @@ def _bedrock_tools_pt(tools: List) -> List[BedrockToolBlock]:
|
|||
description = tool.get("function", {}).get(
|
||||
"description", name
|
||||
) # converse api requires a description
|
||||
|
||||
defs = parameters.pop("$defs", {})
|
||||
defs_copy = copy.deepcopy(defs)
|
||||
# flatten the defs
|
||||
for _, value in defs_copy.items():
|
||||
unpack_defs(value, defs_copy)
|
||||
unpack_defs(parameters, defs_copy)
|
||||
tool_input_schema = BedrockToolInputSchemaBlock(json=parameters)
|
||||
tool_spec = BedrockToolSpecBlock(
|
||||
inputSchema=tool_input_schema, name=name, description=description
|
||||
|
|
|
@ -30,6 +30,11 @@ import json
|
|||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.types.llms.openai import (
|
||||
OpenAIRealtimeStreamResponseBaseObject,
|
||||
OpenAIRealtimeStreamSessionEvents,
|
||||
)
|
||||
|
||||
from .litellm_logging import Logging as LiteLLMLogging
|
||||
|
||||
|
@ -53,7 +58,12 @@ class RealTimeStreaming:
|
|||
self.websocket = websocket
|
||||
self.backend_ws = backend_ws
|
||||
self.logging_obj = logging_obj
|
||||
self.messages: List = []
|
||||
self.messages: List[
|
||||
Union[
|
||||
OpenAIRealtimeStreamResponseBaseObject,
|
||||
OpenAIRealtimeStreamSessionEvents,
|
||||
]
|
||||
] = []
|
||||
self.input_message: Dict = {}
|
||||
|
||||
_logged_real_time_event_types = litellm.logged_real_time_event_types
|
||||
|
@ -62,10 +72,14 @@ class RealTimeStreaming:
|
|||
_logged_real_time_event_types = DefaultLoggedRealTimeEventTypes
|
||||
self.logged_real_time_event_types = _logged_real_time_event_types
|
||||
|
||||
def _should_store_message(self, message: Union[str, bytes]) -> bool:
|
||||
if isinstance(message, bytes):
|
||||
message = message.decode("utf-8")
|
||||
message_obj = json.loads(message)
|
||||
def _should_store_message(
|
||||
self,
|
||||
message_obj: Union[
|
||||
dict,
|
||||
OpenAIRealtimeStreamSessionEvents,
|
||||
OpenAIRealtimeStreamResponseBaseObject,
|
||||
],
|
||||
) -> bool:
|
||||
_msg_type = message_obj["type"]
|
||||
if self.logged_real_time_event_types == "*":
|
||||
return True
|
||||
|
@ -75,8 +89,22 @@ class RealTimeStreaming:
|
|||
|
||||
def store_message(self, message: Union[str, bytes]):
|
||||
"""Store message in list"""
|
||||
if self._should_store_message(message):
|
||||
self.messages.append(message)
|
||||
if isinstance(message, bytes):
|
||||
message = message.decode("utf-8")
|
||||
message_obj = json.loads(message)
|
||||
try:
|
||||
if (
|
||||
message_obj.get("type") == "session.created"
|
||||
or message_obj.get("type") == "session.updated"
|
||||
):
|
||||
message_obj = OpenAIRealtimeStreamSessionEvents(**message_obj) # type: ignore
|
||||
else:
|
||||
message_obj = OpenAIRealtimeStreamResponseBaseObject(**message_obj) # type: ignore
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Error parsing message for logging: {e}")
|
||||
raise e
|
||||
if self._should_store_message(message_obj):
|
||||
self.messages.append(message_obj)
|
||||
|
||||
def store_input(self, message: dict):
|
||||
"""Store input message"""
|
||||
|
|
|
@ -50,6 +50,7 @@ class AiohttpOpenAIChatConfig(OpenAILikeChatConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -4,7 +4,7 @@ Calling + translation logic for anthropic's `/v1/messages` endpoint
|
|||
|
||||
import copy
|
||||
import json
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import httpx # type: ignore
|
||||
|
||||
|
@ -301,12 +301,17 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
model=model,
|
||||
messages=messages,
|
||||
optional_params={**optional_params, "is_vertex_request": is_vertex_request},
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
config = ProviderConfigManager.get_provider_chat_config(
|
||||
model=model,
|
||||
provider=LlmProviders(custom_llm_provider),
|
||||
)
|
||||
if config is None:
|
||||
raise ValueError(
|
||||
f"Provider config not found for model: {model} and provider: {custom_llm_provider}"
|
||||
)
|
||||
|
||||
data = config.transform_request(
|
||||
model=model,
|
||||
|
@ -487,29 +492,10 @@ class ModelResponseIterator:
|
|||
return False
|
||||
|
||||
def _handle_usage(self, anthropic_usage_chunk: Union[dict, UsageDelta]) -> Usage:
|
||||
usage_block = Usage(
|
||||
prompt_tokens=anthropic_usage_chunk.get("input_tokens", 0),
|
||||
completion_tokens=anthropic_usage_chunk.get("output_tokens", 0),
|
||||
total_tokens=anthropic_usage_chunk.get("input_tokens", 0)
|
||||
+ anthropic_usage_chunk.get("output_tokens", 0),
|
||||
return AnthropicConfig().calculate_usage(
|
||||
usage_object=cast(dict, anthropic_usage_chunk), reasoning_content=None
|
||||
)
|
||||
|
||||
cache_creation_input_tokens = anthropic_usage_chunk.get(
|
||||
"cache_creation_input_tokens"
|
||||
)
|
||||
if cache_creation_input_tokens is not None and isinstance(
|
||||
cache_creation_input_tokens, int
|
||||
):
|
||||
usage_block["cache_creation_input_tokens"] = cache_creation_input_tokens
|
||||
|
||||
cache_read_input_tokens = anthropic_usage_chunk.get("cache_read_input_tokens")
|
||||
if cache_read_input_tokens is not None and isinstance(
|
||||
cache_read_input_tokens, int
|
||||
):
|
||||
usage_block["cache_read_input_tokens"] = cache_read_input_tokens
|
||||
|
||||
return usage_block
|
||||
|
||||
def _content_block_delta_helper(
|
||||
self, chunk: dict
|
||||
) -> Tuple[
|
||||
|
|
|
@ -682,6 +682,45 @@ class AnthropicConfig(BaseConfig):
|
|||
reasoning_content += block["thinking"]
|
||||
return text_content, citations, thinking_blocks, reasoning_content, tool_calls
|
||||
|
||||
def calculate_usage(
|
||||
self, usage_object: dict, reasoning_content: Optional[str]
|
||||
) -> Usage:
|
||||
prompt_tokens = usage_object.get("input_tokens", 0)
|
||||
completion_tokens = usage_object.get("output_tokens", 0)
|
||||
_usage = usage_object
|
||||
cache_creation_input_tokens: int = 0
|
||||
cache_read_input_tokens: int = 0
|
||||
|
||||
if "cache_creation_input_tokens" in _usage:
|
||||
cache_creation_input_tokens = _usage["cache_creation_input_tokens"]
|
||||
if "cache_read_input_tokens" in _usage:
|
||||
cache_read_input_tokens = _usage["cache_read_input_tokens"]
|
||||
prompt_tokens += cache_read_input_tokens
|
||||
|
||||
prompt_tokens_details = PromptTokensDetailsWrapper(
|
||||
cached_tokens=cache_read_input_tokens
|
||||
)
|
||||
completion_token_details = (
|
||||
CompletionTokensDetailsWrapper(
|
||||
reasoning_tokens=token_counter(
|
||||
text=reasoning_content, count_response_tokens=True
|
||||
)
|
||||
)
|
||||
if reasoning_content
|
||||
else None
|
||||
)
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
prompt_tokens_details=prompt_tokens_details,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
completion_tokens_details=completion_token_details,
|
||||
)
|
||||
return usage
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -772,45 +811,14 @@ class AnthropicConfig(BaseConfig):
|
|||
)
|
||||
|
||||
## CALCULATING USAGE
|
||||
prompt_tokens = completion_response["usage"]["input_tokens"]
|
||||
completion_tokens = completion_response["usage"]["output_tokens"]
|
||||
_usage = completion_response["usage"]
|
||||
cache_creation_input_tokens: int = 0
|
||||
cache_read_input_tokens: int = 0
|
||||
usage = self.calculate_usage(
|
||||
usage_object=completion_response["usage"],
|
||||
reasoning_content=reasoning_content,
|
||||
)
|
||||
setattr(model_response, "usage", usage) # type: ignore
|
||||
|
||||
model_response.created = int(time.time())
|
||||
model_response.model = completion_response["model"]
|
||||
if "cache_creation_input_tokens" in _usage:
|
||||
cache_creation_input_tokens = _usage["cache_creation_input_tokens"]
|
||||
prompt_tokens += cache_creation_input_tokens
|
||||
if "cache_read_input_tokens" in _usage:
|
||||
cache_read_input_tokens = _usage["cache_read_input_tokens"]
|
||||
prompt_tokens += cache_read_input_tokens
|
||||
|
||||
prompt_tokens_details = PromptTokensDetailsWrapper(
|
||||
cached_tokens=cache_read_input_tokens
|
||||
)
|
||||
completion_token_details = (
|
||||
CompletionTokensDetailsWrapper(
|
||||
reasoning_tokens=token_counter(
|
||||
text=reasoning_content, count_response_tokens=True
|
||||
)
|
||||
)
|
||||
if reasoning_content
|
||||
else None
|
||||
)
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
prompt_tokens_details=prompt_tokens_details,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
completion_tokens_details=completion_token_details,
|
||||
)
|
||||
|
||||
setattr(model_response, "usage", usage) # type: ignore
|
||||
|
||||
model_response._hidden_params = _hidden_params
|
||||
return model_response
|
||||
|
@ -868,6 +876,7 @@ class AnthropicConfig(BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict:
|
||||
|
|
|
@ -87,6 +87,7 @@ class AnthropicTextConfig(BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -293,6 +293,7 @@ class AzureOpenAIConfig(BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -39,6 +39,7 @@ class AzureAIStudioConfig(OpenAIConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -262,6 +262,7 @@ class BaseConfig(ABC):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -33,23 +33,22 @@ class BaseFilesConfig(BaseConfig):
|
|||
) -> List[OpenAICreateFileRequestOptionalParams]:
|
||||
pass
|
||||
|
||||
def get_complete_url(
|
||||
def get_complete_file_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
Get the complete url for the request
|
||||
|
||||
Some providers need `model` in `api_base`
|
||||
"""
|
||||
return api_base or ""
|
||||
data: CreateFileRequest,
|
||||
):
|
||||
return self.get_complete_url(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def transform_create_file_request(
|
||||
|
@ -58,7 +57,7 @@ class BaseFilesConfig(BaseConfig):
|
|||
create_file_data: CreateFileRequest,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> dict:
|
||||
) -> Union[dict, str, bytes]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
@ -65,6 +65,7 @@ class BaseImageVariationConfig(BaseConfig, ABC):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -30,6 +30,7 @@ from litellm.types.llms.openai import (
|
|||
ChatCompletionToolParam,
|
||||
ChatCompletionToolParamFunctionChunk,
|
||||
ChatCompletionUserMessage,
|
||||
OpenAIChatCompletionToolParam,
|
||||
OpenAIMessageContentListBlock,
|
||||
)
|
||||
from litellm.types.utils import ModelResponse, PromptTokensDetailsWrapper, Usage
|
||||
|
@ -211,13 +212,29 @@ class AmazonConverseConfig(BaseConfig):
|
|||
)
|
||||
return _tool
|
||||
|
||||
def _apply_tool_call_transformation(
|
||||
self,
|
||||
tools: List[OpenAIChatCompletionToolParam],
|
||||
model: str,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
):
|
||||
optional_params = self._add_tools_to_optional_params(
|
||||
optional_params=optional_params, tools=tools
|
||||
)
|
||||
|
||||
if (
|
||||
"meta.llama3-3-70b-instruct-v1:0" in model
|
||||
and non_default_params.get("stream", False) is True
|
||||
):
|
||||
optional_params["fake_stream"] = True
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
messages: Optional[List[AllMessageValues]] = None,
|
||||
) -> dict:
|
||||
is_thinking_enabled = self.is_thinking_enabled(non_default_params)
|
||||
|
||||
|
@ -286,8 +303,11 @@ class AmazonConverseConfig(BaseConfig):
|
|||
if param == "top_p":
|
||||
optional_params["topP"] = value
|
||||
if param == "tools" and isinstance(value, list):
|
||||
optional_params = self._add_tools_to_optional_params(
|
||||
optional_params=optional_params, tools=value
|
||||
self._apply_tool_call_transformation(
|
||||
tools=cast(List[OpenAIChatCompletionToolParam], value),
|
||||
model=model,
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
if param == "tool_choice":
|
||||
_tool_choice_value = self.map_tool_choice_values(
|
||||
|
@ -633,8 +653,10 @@ class AmazonConverseConfig(BaseConfig):
|
|||
cache_read_input_tokens = usage["cacheReadInputTokens"]
|
||||
input_tokens += cache_read_input_tokens
|
||||
if "cacheWriteInputTokens" in usage:
|
||||
"""
|
||||
Do not increment prompt_tokens with cacheWriteInputTokens
|
||||
"""
|
||||
cache_creation_input_tokens = usage["cacheWriteInputTokens"]
|
||||
input_tokens += cache_creation_input_tokens
|
||||
|
||||
prompt_tokens_details = PromptTokensDetailsWrapper(
|
||||
cached_tokens=cache_read_input_tokens
|
||||
|
@ -811,6 +833,7 @@ class AmazonConverseConfig(BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
import types
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
)
|
||||
from litellm.llms.cohere.chat.transformation import CohereChatConfig
|
||||
|
||||
|
||||
class AmazonCohereConfig(AmazonInvokeConfig, BaseConfig):
|
||||
class AmazonCohereConfig(AmazonInvokeConfig, CohereChatConfig):
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=command
|
||||
|
||||
|
@ -19,7 +19,6 @@ class AmazonCohereConfig(AmazonInvokeConfig, BaseConfig):
|
|||
"""
|
||||
|
||||
max_tokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
return_likelihood: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
|
@ -55,11 +54,10 @@ class AmazonCohereConfig(AmazonInvokeConfig, BaseConfig):
|
|||
}
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return [
|
||||
"max_tokens",
|
||||
"temperature",
|
||||
"stream",
|
||||
]
|
||||
supported_params = CohereChatConfig.get_supported_openai_params(
|
||||
self, model=model
|
||||
)
|
||||
return supported_params
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
|
@ -68,11 +66,10 @@ class AmazonCohereConfig(AmazonInvokeConfig, BaseConfig):
|
|||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for k, v in non_default_params.items():
|
||||
if k == "stream":
|
||||
optional_params["stream"] = v
|
||||
if k == "temperature":
|
||||
optional_params["temperature"] = v
|
||||
if k == "max_tokens":
|
||||
optional_params["max_tokens"] = v
|
||||
return optional_params
|
||||
return CohereChatConfig.map_openai_params(
|
||||
self,
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=drop_params,
|
||||
)
|
||||
|
|
|
@ -6,14 +6,21 @@ Inherits from `AmazonConverseConfig`
|
|||
Nova + Invoke API Tutorial: https://docs.aws.amazon.com/nova/latest/userguide/using-invoke-api.html
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
from litellm.types.llms.bedrock import BedrockInvokeNovaRequest
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
from ..converse_transformation import AmazonConverseConfig
|
||||
from .base_invoke_transformation import AmazonInvokeConfig
|
||||
|
||||
|
||||
class AmazonInvokeNovaConfig(litellm.AmazonConverseConfig):
|
||||
class AmazonInvokeNovaConfig(AmazonInvokeConfig, AmazonConverseConfig):
|
||||
"""
|
||||
Config for sending `nova` requests to `/bedrock/invoke/`
|
||||
"""
|
||||
|
@ -21,6 +28,20 @@ class AmazonInvokeNovaConfig(litellm.AmazonConverseConfig):
|
|||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
return AmazonConverseConfig.get_supported_openai_params(self, model)
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
return AmazonConverseConfig.map_openai_params(
|
||||
self, non_default_params, optional_params, model, drop_params
|
||||
)
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -29,7 +50,8 @@ class AmazonInvokeNovaConfig(litellm.AmazonConverseConfig):
|
|||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
_transformed_nova_request = super().transform_request(
|
||||
_transformed_nova_request = AmazonConverseConfig.transform_request(
|
||||
self,
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
|
@ -45,6 +67,35 @@ class AmazonInvokeNovaConfig(litellm.AmazonConverseConfig):
|
|||
)
|
||||
return bedrock_invoke_nova_request
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: Logging,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> litellm.ModelResponse:
|
||||
return AmazonConverseConfig.transform_response(
|
||||
self,
|
||||
model,
|
||||
raw_response,
|
||||
model_response,
|
||||
logging_obj,
|
||||
request_data,
|
||||
messages,
|
||||
optional_params,
|
||||
litellm_params,
|
||||
encoding,
|
||||
api_key,
|
||||
json_mode,
|
||||
)
|
||||
|
||||
def _filter_allowed_fields(
|
||||
self, bedrock_invoke_nova_request: BedrockInvokeNovaRequest
|
||||
) -> dict:
|
||||
|
|
|
@ -442,6 +442,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -118,6 +118,7 @@ class ClarifaiConfig(BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -60,6 +60,7 @@ class CloudflareChatConfig(BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -118,6 +118,7 @@ class CohereChatConfig(BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -101,6 +101,7 @@ class CohereTextConfig(BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -218,6 +218,10 @@ class BaseLLMAIOHTTPHandler:
|
|||
provider_config = ProviderConfigManager.get_provider_chat_config(
|
||||
model=model, provider=litellm.LlmProviders(custom_llm_provider)
|
||||
)
|
||||
if provider_config is None:
|
||||
raise ValueError(
|
||||
f"Provider config not found for model: {model} and provider: {custom_llm_provider}"
|
||||
)
|
||||
# get config from model, custom llm provider
|
||||
headers = provider_config.validate_environment(
|
||||
api_key=api_key,
|
||||
|
@ -225,6 +229,7 @@ class BaseLLMAIOHTTPHandler:
|
|||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
|
@ -494,6 +499,7 @@ class BaseLLMAIOHTTPHandler:
|
|||
model=model,
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
|
|
|
@ -192,7 +192,7 @@ class AsyncHTTPHandler:
|
|||
async def post(
|
||||
self,
|
||||
url: str,
|
||||
data: Optional[Union[dict, str]] = None, # type: ignore
|
||||
data: Optional[Union[dict, str, bytes]] = None, # type: ignore
|
||||
json: Optional[dict] = None,
|
||||
params: Optional[dict] = None,
|
||||
headers: Optional[dict] = None,
|
||||
|
@ -427,7 +427,7 @@ class AsyncHTTPHandler:
|
|||
self,
|
||||
url: str,
|
||||
client: httpx.AsyncClient,
|
||||
data: Optional[Union[dict, str]] = None, # type: ignore
|
||||
data: Optional[Union[dict, str, bytes]] = None, # type: ignore
|
||||
json: Optional[dict] = None,
|
||||
params: Optional[dict] = None,
|
||||
headers: Optional[dict] = None,
|
||||
|
@ -527,7 +527,7 @@ class HTTPHandler:
|
|||
def post(
|
||||
self,
|
||||
url: str,
|
||||
data: Optional[Union[dict, str]] = None,
|
||||
data: Optional[Union[dict, str, bytes]] = None,
|
||||
json: Optional[Union[dict, str, List]] = None,
|
||||
params: Optional[dict] = None,
|
||||
headers: Optional[dict] = None,
|
||||
|
@ -573,7 +573,6 @@ class HTTPHandler:
|
|||
setattr(e, "text", error_text)
|
||||
|
||||
setattr(e, "status_code", e.response.status_code)
|
||||
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
|
|
@ -234,6 +234,10 @@ class BaseLLMHTTPHandler:
|
|||
provider_config = ProviderConfigManager.get_provider_chat_config(
|
||||
model=model, provider=litellm.LlmProviders(custom_llm_provider)
|
||||
)
|
||||
if provider_config is None:
|
||||
raise ValueError(
|
||||
f"Provider config not found for model: {model} and provider: {custom_llm_provider}"
|
||||
)
|
||||
|
||||
# get config from model, custom llm provider
|
||||
headers = provider_config.validate_environment(
|
||||
|
@ -243,6 +247,7 @@ class BaseLLMHTTPHandler:
|
|||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
api_base=api_base,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
api_base = provider_config.get_complete_url(
|
||||
|
@ -621,6 +626,7 @@ class BaseLLMHTTPHandler:
|
|||
model=model,
|
||||
messages=[],
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
api_base = provider_config.get_complete_url(
|
||||
|
@ -892,6 +898,7 @@ class BaseLLMHTTPHandler:
|
|||
model=model,
|
||||
messages=[],
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
|
@ -1224,15 +1231,19 @@ class BaseLLMHTTPHandler:
|
|||
model="",
|
||||
messages=[],
|
||||
optional_params={},
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
api_base = provider_config.get_complete_url(
|
||||
api_base = provider_config.get_complete_file_url(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model="",
|
||||
optional_params={},
|
||||
litellm_params=litellm_params,
|
||||
data=create_file_data,
|
||||
)
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required for create_file")
|
||||
|
||||
# Get the transformed request data for both steps
|
||||
transformed_request = provider_config.transform_create_file_request(
|
||||
|
@ -1259,48 +1270,57 @@ class BaseLLMHTTPHandler:
|
|||
else:
|
||||
sync_httpx_client = client
|
||||
|
||||
try:
|
||||
# Step 1: Initial request to get upload URL
|
||||
initial_response = sync_httpx_client.post(
|
||||
url=api_base,
|
||||
headers={
|
||||
**headers,
|
||||
**transformed_request["initial_request"]["headers"],
|
||||
},
|
||||
data=json.dumps(transformed_request["initial_request"]["data"]),
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Extract upload URL from response headers
|
||||
upload_url = initial_response.headers.get("X-Goog-Upload-URL")
|
||||
|
||||
if not upload_url:
|
||||
raise ValueError("Failed to get upload URL from initial request")
|
||||
|
||||
# Step 2: Upload the actual file
|
||||
if isinstance(transformed_request, str) or isinstance(
|
||||
transformed_request, bytes
|
||||
):
|
||||
upload_response = sync_httpx_client.post(
|
||||
url=upload_url,
|
||||
headers=transformed_request["upload_request"]["headers"],
|
||||
data=transformed_request["upload_request"]["data"],
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=transformed_request,
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
# Step 1: Initial request to get upload URL
|
||||
initial_response = sync_httpx_client.post(
|
||||
url=api_base,
|
||||
headers={
|
||||
**headers,
|
||||
**transformed_request["initial_request"]["headers"],
|
||||
},
|
||||
data=json.dumps(transformed_request["initial_request"]["data"]),
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
return provider_config.transform_create_file_response(
|
||||
model=None,
|
||||
raw_response=upload_response,
|
||||
logging_obj=logging_obj,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
# Extract upload URL from response headers
|
||||
upload_url = initial_response.headers.get("X-Goog-Upload-URL")
|
||||
|
||||
except Exception as e:
|
||||
raise self._handle_error(
|
||||
e=e,
|
||||
provider_config=provider_config,
|
||||
)
|
||||
if not upload_url:
|
||||
raise ValueError("Failed to get upload URL from initial request")
|
||||
|
||||
# Step 2: Upload the actual file
|
||||
upload_response = sync_httpx_client.post(
|
||||
url=upload_url,
|
||||
headers=transformed_request["upload_request"]["headers"],
|
||||
data=transformed_request["upload_request"]["data"],
|
||||
timeout=timeout,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._handle_error(
|
||||
e=e,
|
||||
provider_config=provider_config,
|
||||
)
|
||||
|
||||
return provider_config.transform_create_file_response(
|
||||
model=None,
|
||||
raw_response=upload_response,
|
||||
logging_obj=logging_obj,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
async def async_create_file(
|
||||
self,
|
||||
transformed_request: dict,
|
||||
transformed_request: Union[bytes, str, dict],
|
||||
litellm_params: dict,
|
||||
provider_config: BaseFilesConfig,
|
||||
headers: dict,
|
||||
|
@ -1319,45 +1339,54 @@ class BaseLLMHTTPHandler:
|
|||
else:
|
||||
async_httpx_client = client
|
||||
|
||||
try:
|
||||
# Step 1: Initial request to get upload URL
|
||||
initial_response = await async_httpx_client.post(
|
||||
url=api_base,
|
||||
headers={
|
||||
**headers,
|
||||
**transformed_request["initial_request"]["headers"],
|
||||
},
|
||||
data=json.dumps(transformed_request["initial_request"]["data"]),
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Extract upload URL from response headers
|
||||
upload_url = initial_response.headers.get("X-Goog-Upload-URL")
|
||||
|
||||
if not upload_url:
|
||||
raise ValueError("Failed to get upload URL from initial request")
|
||||
|
||||
# Step 2: Upload the actual file
|
||||
if isinstance(transformed_request, str) or isinstance(
|
||||
transformed_request, bytes
|
||||
):
|
||||
upload_response = await async_httpx_client.post(
|
||||
url=upload_url,
|
||||
headers=transformed_request["upload_request"]["headers"],
|
||||
data=transformed_request["upload_request"]["data"],
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=transformed_request,
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
# Step 1: Initial request to get upload URL
|
||||
initial_response = await async_httpx_client.post(
|
||||
url=api_base,
|
||||
headers={
|
||||
**headers,
|
||||
**transformed_request["initial_request"]["headers"],
|
||||
},
|
||||
data=json.dumps(transformed_request["initial_request"]["data"]),
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
return provider_config.transform_create_file_response(
|
||||
model=None,
|
||||
raw_response=upload_response,
|
||||
logging_obj=logging_obj,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
# Extract upload URL from response headers
|
||||
upload_url = initial_response.headers.get("X-Goog-Upload-URL")
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error creating file: {e}")
|
||||
raise self._handle_error(
|
||||
e=e,
|
||||
provider_config=provider_config,
|
||||
)
|
||||
if not upload_url:
|
||||
raise ValueError("Failed to get upload URL from initial request")
|
||||
|
||||
# Step 2: Upload the actual file
|
||||
upload_response = await async_httpx_client.post(
|
||||
url=upload_url,
|
||||
headers=transformed_request["upload_request"]["headers"],
|
||||
data=transformed_request["upload_request"]["data"],
|
||||
timeout=timeout,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error creating file: {e}")
|
||||
raise self._handle_error(
|
||||
e=e,
|
||||
provider_config=provider_config,
|
||||
)
|
||||
|
||||
return provider_config.transform_create_file_response(
|
||||
model=None,
|
||||
raw_response=upload_response,
|
||||
logging_obj=logging_obj,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
def list_files(self):
|
||||
"""
|
||||
|
|
|
@ -27,7 +27,7 @@ from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
|||
strip_name_from_messages,
|
||||
)
|
||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||
from litellm.types.llms.anthropic import AnthropicMessagesTool
|
||||
from litellm.types.llms.anthropic import AllAnthropicToolsValues
|
||||
from litellm.types.llms.databricks import (
|
||||
AllDatabricksContentValues,
|
||||
DatabricksChoice,
|
||||
|
@ -116,6 +116,7 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
@ -160,7 +161,7 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
|
|||
]
|
||||
|
||||
def convert_anthropic_tool_to_databricks_tool(
|
||||
self, tool: Optional[AnthropicMessagesTool]
|
||||
self, tool: Optional[AllAnthropicToolsValues]
|
||||
) -> Optional[DatabricksTool]:
|
||||
if tool is None:
|
||||
return None
|
||||
|
@ -173,6 +174,19 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
|
|||
),
|
||||
)
|
||||
|
||||
def _map_openai_to_dbrx_tool(self, model: str, tools: List) -> List[DatabricksTool]:
|
||||
# if not claude, send as is
|
||||
if "claude" not in model:
|
||||
return tools
|
||||
|
||||
# if claude, convert to anthropic tool and then to databricks tool
|
||||
anthropic_tools = self._map_tools(tools=tools)
|
||||
databricks_tools = [
|
||||
cast(DatabricksTool, self.convert_anthropic_tool_to_databricks_tool(tool))
|
||||
for tool in anthropic_tools
|
||||
]
|
||||
return databricks_tools
|
||||
|
||||
def map_response_format_to_databricks_tool(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -202,6 +216,10 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
|
|||
mapped_params = super().map_openai_params(
|
||||
non_default_params, optional_params, model, drop_params
|
||||
)
|
||||
if "tools" in mapped_params:
|
||||
mapped_params["tools"] = self._map_openai_to_dbrx_tool(
|
||||
model=model, tools=mapped_params["tools"]
|
||||
)
|
||||
if (
|
||||
"max_completion_tokens" in non_default_params
|
||||
and replace_max_completion_tokens_with_max_tokens
|
||||
|
@ -240,6 +258,7 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
|
|||
optional_params["thinking"] = AnthropicConfig._map_reasoning_effort(
|
||||
non_default_params.get("reasoning_effort")
|
||||
)
|
||||
optional_params.pop("reasoning_effort", None)
|
||||
## handle thinking tokens
|
||||
self.update_optional_params_with_thinking_tokens(
|
||||
non_default_params=non_default_params, optional_params=mapped_params
|
||||
|
@ -498,7 +517,10 @@ class DatabricksChatResponseIterator(BaseModelResponseIterator):
|
|||
message.content = ""
|
||||
choice["delta"]["content"] = message.content
|
||||
choice["delta"]["tool_calls"] = None
|
||||
|
||||
elif tool_calls:
|
||||
for _tc in tool_calls:
|
||||
if _tc.get("function", {}).get("arguments") == "{}":
|
||||
_tc["function"]["arguments"] = "" # avoid invalid json
|
||||
# extract the content str
|
||||
content_str = DatabricksConfig.extract_content_str(
|
||||
choice["delta"].get("content")
|
||||
|
|
|
@ -171,6 +171,7 @@ class DeepgramAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -2,7 +2,11 @@ from typing import List, Literal, Optional, Tuple, Union, cast
|
|||
|
||||
import litellm
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionImageObject
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionImageObject,
|
||||
OpenAIChatCompletionToolParam,
|
||||
)
|
||||
from litellm.types.utils import ProviderSpecificModelInfo
|
||||
|
||||
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||
|
@ -150,6 +154,14 @@ class FireworksAIConfig(OpenAIGPTConfig):
|
|||
] = f"{content['image_url']['url']}#transform=inline"
|
||||
return content
|
||||
|
||||
def _transform_tools(
|
||||
self, tools: List[OpenAIChatCompletionToolParam]
|
||||
) -> List[OpenAIChatCompletionToolParam]:
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function":
|
||||
tool["function"].pop("strict", None)
|
||||
return tools
|
||||
|
||||
def _transform_messages_helper(
|
||||
self, messages: List[AllMessageValues], model: str, litellm_params: dict
|
||||
) -> List[AllMessageValues]:
|
||||
|
@ -196,6 +208,9 @@ class FireworksAIConfig(OpenAIGPTConfig):
|
|||
messages = self._transform_messages_helper(
|
||||
messages=messages, model=model, litellm_params=litellm_params
|
||||
)
|
||||
if "tools" in optional_params and optional_params["tools"] is not None:
|
||||
tools = self._transform_tools(tools=optional_params["tools"])
|
||||
optional_params["tools"] = tools
|
||||
return super().transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
|
|
@ -41,6 +41,7 @@ class FireworksAIMixin:
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -20,6 +20,7 @@ class GeminiModelInfo(BaseLLMModelInfo):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -4,11 +4,12 @@ Supports writing files to Google AI Studio Files API.
|
|||
For vertex ai, check out the vertex_ai/files/handler.py file.
|
||||
"""
|
||||
import time
|
||||
from typing import List, Mapping, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
|
||||
from litellm.llms.base_llm.files.transformation import (
|
||||
BaseFilesConfig,
|
||||
LiteLLMLoggingObj,
|
||||
|
@ -91,66 +92,28 @@ class GoogleAIStudioFilesHandler(GeminiModelInfo, BaseFilesConfig):
|
|||
if file_data is None:
|
||||
raise ValueError("File data is required")
|
||||
|
||||
# Parse the file_data based on its type
|
||||
filename = None
|
||||
file_content = None
|
||||
content_type = None
|
||||
file_headers: Mapping[str, str] = {}
|
||||
|
||||
if isinstance(file_data, tuple):
|
||||
if len(file_data) == 2:
|
||||
filename, file_content = file_data
|
||||
elif len(file_data) == 3:
|
||||
filename, file_content, content_type = file_data
|
||||
elif len(file_data) == 4:
|
||||
filename, file_content, content_type, file_headers = file_data
|
||||
else:
|
||||
file_content = file_data
|
||||
|
||||
# Handle the file content based on its type
|
||||
import io
|
||||
from os import PathLike
|
||||
|
||||
# Convert content to bytes
|
||||
if isinstance(file_content, (str, PathLike)):
|
||||
# If it's a path, open and read the file
|
||||
with open(file_content, "rb") as f:
|
||||
content = f.read()
|
||||
elif isinstance(file_content, io.IOBase):
|
||||
# If it's a file-like object
|
||||
content = file_content.read()
|
||||
if isinstance(content, str):
|
||||
content = content.encode("utf-8")
|
||||
elif isinstance(file_content, bytes):
|
||||
content = file_content
|
||||
else:
|
||||
raise ValueError(f"Unsupported file content type: {type(file_content)}")
|
||||
# Use the common utility function to extract file data
|
||||
extracted_data = extract_file_data(file_data)
|
||||
|
||||
# Get file size
|
||||
file_size = len(content)
|
||||
|
||||
# Use provided content type or guess based on filename
|
||||
if not content_type:
|
||||
import mimetypes
|
||||
|
||||
content_type = (
|
||||
mimetypes.guess_type(filename)[0]
|
||||
if filename
|
||||
else "application/octet-stream"
|
||||
)
|
||||
file_size = len(extracted_data["content"])
|
||||
|
||||
# Step 1: Initial resumable upload request
|
||||
headers = {
|
||||
"X-Goog-Upload-Protocol": "resumable",
|
||||
"X-Goog-Upload-Command": "start",
|
||||
"X-Goog-Upload-Header-Content-Length": str(file_size),
|
||||
"X-Goog-Upload-Header-Content-Type": content_type,
|
||||
"X-Goog-Upload-Header-Content-Type": extracted_data["content_type"],
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
headers.update(file_headers) # Add any custom headers
|
||||
headers.update(extracted_data["headers"]) # Add any custom headers
|
||||
|
||||
# Initial metadata request body
|
||||
initial_data = {"file": {"display_name": filename or str(int(time.time()))}}
|
||||
initial_data = {
|
||||
"file": {
|
||||
"display_name": extracted_data["filename"] or str(int(time.time()))
|
||||
}
|
||||
}
|
||||
|
||||
# Step 2: Actual file upload data
|
||||
upload_headers = {
|
||||
|
@ -161,7 +124,10 @@ class GoogleAIStudioFilesHandler(GeminiModelInfo, BaseFilesConfig):
|
|||
|
||||
return {
|
||||
"initial_request": {"headers": headers, "data": initial_data},
|
||||
"upload_request": {"headers": upload_headers, "data": content},
|
||||
"upload_request": {
|
||||
"headers": upload_headers,
|
||||
"data": extracted_data["content"],
|
||||
},
|
||||
}
|
||||
|
||||
def transform_create_file_response(
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -18,7 +18,6 @@ from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
|||
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||
from ..common_utils import HuggingFaceError, _fetch_inference_provider_mapping
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BASE_URL = "https://router.huggingface.co"
|
||||
|
@ -34,7 +33,8 @@ class HuggingFaceChatConfig(OpenAIGPTConfig):
|
|||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
optional_params: Dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
@ -51,7 +51,9 @@ class HuggingFaceChatConfig(OpenAIGPTConfig):
|
|||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return HuggingFaceError(status_code=status_code, message=error_message, headers=headers)
|
||||
return HuggingFaceError(
|
||||
status_code=status_code, message=error_message, headers=headers
|
||||
)
|
||||
|
||||
def get_base_url(self, model: str, base_url: Optional[str]) -> Optional[str]:
|
||||
"""
|
||||
|
@ -82,7 +84,9 @@ class HuggingFaceChatConfig(OpenAIGPTConfig):
|
|||
if api_base is not None:
|
||||
complete_url = api_base
|
||||
elif os.getenv("HF_API_BASE") or os.getenv("HUGGINGFACE_API_BASE"):
|
||||
complete_url = str(os.getenv("HF_API_BASE")) or str(os.getenv("HUGGINGFACE_API_BASE"))
|
||||
complete_url = str(os.getenv("HF_API_BASE")) or str(
|
||||
os.getenv("HUGGINGFACE_API_BASE")
|
||||
)
|
||||
elif model.startswith(("http://", "https://")):
|
||||
complete_url = model
|
||||
# 4. Default construction with provider
|
||||
|
@ -138,4 +142,8 @@ class HuggingFaceChatConfig(OpenAIGPTConfig):
|
|||
)
|
||||
mapped_model = provider_mapping["providerId"]
|
||||
messages = self._transform_messages(messages=messages, model=mapped_model)
|
||||
return dict(ChatCompletionRequest(model=mapped_model, messages=messages, **optional_params))
|
||||
return dict(
|
||||
ChatCompletionRequest(
|
||||
model=mapped_model, messages=messages, **optional_params
|
||||
)
|
||||
)
|
||||
|
|
|
@ -1,15 +1,6 @@
|
|||
import json
|
||||
import os
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Union,
|
||||
get_args,
|
||||
)
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Union, get_args
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -35,8 +26,9 @@ hf_tasks_embeddings = Literal[ # pipeline tags + hf tei endpoints - https://hug
|
|||
]
|
||||
|
||||
|
||||
|
||||
def get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_base: str) -> Optional[str]:
|
||||
def get_hf_task_embedding_for_model(
|
||||
model: str, task_type: Optional[str], api_base: str
|
||||
) -> Optional[str]:
|
||||
if task_type is not None:
|
||||
if task_type in get_args(hf_tasks_embeddings):
|
||||
return task_type
|
||||
|
@ -57,7 +49,9 @@ def get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_ba
|
|||
return pipeline_tag
|
||||
|
||||
|
||||
async def async_get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_base: str) -> Optional[str]:
|
||||
async def async_get_hf_task_embedding_for_model(
|
||||
model: str, task_type: Optional[str], api_base: str
|
||||
) -> Optional[str]:
|
||||
if task_type is not None:
|
||||
if task_type in get_args(hf_tasks_embeddings):
|
||||
return task_type
|
||||
|
@ -116,7 +110,9 @@ class HuggingFaceEmbedding(BaseLLM):
|
|||
input: List,
|
||||
optional_params: dict,
|
||||
) -> dict:
|
||||
hf_task = await async_get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
|
||||
hf_task = await async_get_hf_task_embedding_for_model(
|
||||
model=model, task_type=task_type, api_base=HF_HUB_URL
|
||||
)
|
||||
|
||||
data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task)
|
||||
|
||||
|
@ -173,7 +169,9 @@ class HuggingFaceEmbedding(BaseLLM):
|
|||
task_type = optional_params.pop("input_type", None)
|
||||
|
||||
if call_type == "sync":
|
||||
hf_task = get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
|
||||
hf_task = get_hf_task_embedding_for_model(
|
||||
model=model, task_type=task_type, api_base=HF_HUB_URL
|
||||
)
|
||||
elif call_type == "async":
|
||||
return self._async_transform_input(
|
||||
model=model, task_type=task_type, embed_url=embed_url, input=input
|
||||
|
@ -325,6 +323,7 @@ class HuggingFaceEmbedding(BaseLLM):
|
|||
input: list,
|
||||
model_response: EmbeddingResponse,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
encoding: Callable,
|
||||
api_key: Optional[str] = None,
|
||||
|
@ -341,9 +340,12 @@ class HuggingFaceEmbedding(BaseLLM):
|
|||
model=model,
|
||||
optional_params=optional_params,
|
||||
messages=[],
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
task_type = optional_params.pop("input_type", None)
|
||||
task = get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
|
||||
task = get_hf_task_embedding_for_model(
|
||||
model=model, task_type=task_type, api_base=HF_HUB_URL
|
||||
)
|
||||
# print_verbose(f"{model}, {task}")
|
||||
embed_url = ""
|
||||
if "https" in model:
|
||||
|
@ -355,7 +357,9 @@ class HuggingFaceEmbedding(BaseLLM):
|
|||
elif "HUGGINGFACE_API_BASE" in os.environ:
|
||||
embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
|
||||
else:
|
||||
embed_url = f"https://router.huggingface.co/hf-inference/pipeline/{task}/{model}"
|
||||
embed_url = (
|
||||
f"https://router.huggingface.co/hf-inference/pipeline/{task}/{model}"
|
||||
)
|
||||
|
||||
## ROUTING ##
|
||||
if aembedding is True:
|
||||
|
|
|
@ -355,6 +355,7 @@ class HuggingFaceEmbeddingConfig(BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: Dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict:
|
||||
|
|
|
@ -10,6 +10,11 @@ from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
|||
|
||||
|
||||
class LiteLLMProxyChatConfig(OpenAIGPTConfig):
|
||||
def get_supported_openai_params(self, model: str) -> List:
|
||||
list = super().get_supported_openai_params(model)
|
||||
list.append("thinking")
|
||||
return list
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self, api_base: Optional[str], api_key: Optional[str]
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
|
|
|
@ -36,6 +36,7 @@ def completion(
|
|||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
## Load Config
|
||||
|
|
|
@ -93,6 +93,7 @@ class NLPCloudConfig(BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -353,6 +353,7 @@ class OllamaConfig(BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -32,6 +32,7 @@ def completion(
|
|||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
if "https" in model:
|
||||
completion_url = model
|
||||
|
@ -123,6 +124,7 @@ def embedding(
|
|||
model=model,
|
||||
messages=[],
|
||||
optional_params=optional_params,
|
||||
litellm_params={},
|
||||
)
|
||||
response = litellm.module_level_client.post(
|
||||
embeddings_url, headers=headers, json=data
|
||||
|
|
|
@ -88,6 +88,7 @@ class OobaboogaConfig(OpenAIGPTConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -321,6 +321,7 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -286,6 +286,7 @@ class OpenAIConfig(BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -53,6 +53,7 @@ class OpenAIWhisperAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -131,6 +131,7 @@ class PetalsConfig(BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -228,10 +228,10 @@ class PredibaseChatCompletion:
|
|||
api_key: str,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
tenant_id: str,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers: dict = {},
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
|
@ -241,6 +241,7 @@ class PredibaseChatCompletion:
|
|||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
completion_url = ""
|
||||
input_text = ""
|
||||
|
|
|
@ -164,6 +164,7 @@ class PredibaseConfig(BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -141,6 +141,7 @@ def completion(
|
|||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
# Start a prediction and get the prediction URL
|
||||
version_id = replicate_config.model_to_version_id(model)
|
||||
|
|
|
@ -312,6 +312,7 @@ class ReplicateConfig(BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -96,6 +96,7 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
model: str,
|
||||
data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
litellm_params: dict,
|
||||
optional_params: dict,
|
||||
aws_region_name: str,
|
||||
extra_headers: Optional[dict] = None,
|
||||
|
@ -122,6 +123,7 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
request = AWSRequest(
|
||||
method="POST", url=api_base, data=encoded_data, headers=headers
|
||||
|
@ -198,6 +200,7 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
data=data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
credentials=credentials,
|
||||
aws_region_name=aws_region_name,
|
||||
)
|
||||
|
@ -274,6 +277,7 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
"model": model,
|
||||
"data": _data,
|
||||
"optional_params": optional_params,
|
||||
"litellm_params": litellm_params,
|
||||
"credentials": credentials,
|
||||
"aws_region_name": aws_region_name,
|
||||
"messages": messages,
|
||||
|
@ -426,6 +430,7 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
"model": model,
|
||||
"data": data,
|
||||
"optional_params": optional_params,
|
||||
"litellm_params": litellm_params,
|
||||
"credentials": credentials,
|
||||
"aws_region_name": aws_region_name,
|
||||
"messages": messages,
|
||||
|
@ -496,6 +501,7 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
"model": model,
|
||||
"data": data,
|
||||
"optional_params": optional_params,
|
||||
"litellm_params": litellm_params,
|
||||
"credentials": credentials,
|
||||
"aws_region_name": aws_region_name,
|
||||
"messages": messages,
|
||||
|
|
|
@ -263,6 +263,7 @@ class SagemakerConfig(BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -92,6 +92,7 @@ class SnowflakeConfig(OpenAIGPTConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -37,6 +37,7 @@ class TopazImageVariationConfig(BaseImageVariationConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -48,6 +48,7 @@ class TritonConfig(BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: Dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict:
|
||||
|
|
|
@ -42,6 +42,7 @@ class TritonEmbeddingConfig(BaseEmbeddingConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union, get_type_hints
|
||||
import re
|
||||
from typing import Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm import supports_response_schema, supports_system_messages, verbose_logger
|
||||
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import unpack_defs
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.types.llms.vertex_ai import PartType
|
||||
from litellm.types.llms.vertex_ai import PartType, Schema
|
||||
|
||||
|
||||
class VertexAIError(BaseLLMException):
|
||||
|
@ -168,6 +169,9 @@ def _build_vertex_schema(parameters: dict):
|
|||
"""
|
||||
This is a modified version of https://github.com/google-gemini/generative-ai-python/blob/8f77cc6ac99937cd3a81299ecf79608b91b06bbb/google/generativeai/types/content_types.py#L419
|
||||
"""
|
||||
# Get valid fields from Schema TypedDict
|
||||
valid_schema_fields = set(get_type_hints(Schema).keys())
|
||||
|
||||
defs = parameters.pop("$defs", {})
|
||||
# flatten the defs
|
||||
for name, value in defs.items():
|
||||
|
@ -181,52 +185,49 @@ def _build_vertex_schema(parameters: dict):
|
|||
convert_anyof_null_to_nullable(parameters)
|
||||
add_object_type(parameters)
|
||||
# Postprocessing
|
||||
# 4. Suppress unnecessary title generation:
|
||||
# * https://github.com/pydantic/pydantic/issues/1051
|
||||
# * http://cl/586221780
|
||||
strip_field(parameters, field_name="title")
|
||||
|
||||
strip_field(
|
||||
parameters, field_name="$schema"
|
||||
) # 5. Remove $schema - json schema value, not supported by OpenAPI - causes vertex errors.
|
||||
strip_field(
|
||||
parameters, field_name="$id"
|
||||
) # 6. Remove id - json schema value, not supported by OpenAPI - causes vertex errors.
|
||||
|
||||
return parameters
|
||||
# Filter out fields that don't exist in Schema
|
||||
filtered_parameters = filter_schema_fields(parameters, valid_schema_fields)
|
||||
return filtered_parameters
|
||||
|
||||
|
||||
def unpack_defs(schema, defs):
|
||||
properties = schema.get("properties", None)
|
||||
if properties is None:
|
||||
return
|
||||
def filter_schema_fields(
|
||||
schema_dict: Dict[str, Any], valid_fields: Set[str], processed=None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Recursively filter a schema dictionary to keep only valid fields.
|
||||
"""
|
||||
if processed is None:
|
||||
processed = set()
|
||||
|
||||
for name, value in properties.items():
|
||||
ref_key = value.get("$ref", None)
|
||||
if ref_key is not None:
|
||||
ref = defs[ref_key.split("defs/")[-1]]
|
||||
unpack_defs(ref, defs)
|
||||
properties[name] = ref
|
||||
# Handle circular references
|
||||
schema_id = id(schema_dict)
|
||||
if schema_id in processed:
|
||||
return schema_dict
|
||||
processed.add(schema_id)
|
||||
|
||||
if not isinstance(schema_dict, dict):
|
||||
return schema_dict
|
||||
|
||||
result = {}
|
||||
for key, value in schema_dict.items():
|
||||
if key not in valid_fields:
|
||||
continue
|
||||
|
||||
anyof = value.get("anyOf", None)
|
||||
if anyof is not None:
|
||||
for i, atype in enumerate(anyof):
|
||||
ref_key = atype.get("$ref", None)
|
||||
if ref_key is not None:
|
||||
ref = defs[ref_key.split("defs/")[-1]]
|
||||
unpack_defs(ref, defs)
|
||||
anyof[i] = ref
|
||||
continue
|
||||
if key == "properties" and isinstance(value, dict):
|
||||
result[key] = {
|
||||
k: filter_schema_fields(v, valid_fields, processed)
|
||||
for k, v in value.items()
|
||||
}
|
||||
elif key == "items" and isinstance(value, dict):
|
||||
result[key] = filter_schema_fields(value, valid_fields, processed)
|
||||
elif key == "anyOf" and isinstance(value, list):
|
||||
result[key] = [
|
||||
filter_schema_fields(item, valid_fields, processed) for item in value # type: ignore
|
||||
]
|
||||
else:
|
||||
result[key] = value
|
||||
|
||||
items = value.get("items", None)
|
||||
if items is not None:
|
||||
ref_key = items.get("$ref", None)
|
||||
if ref_key is not None:
|
||||
ref = defs[ref_key.split("defs/")[-1]]
|
||||
unpack_defs(ref, defs)
|
||||
value["items"] = ref
|
||||
continue
|
||||
return result
|
||||
|
||||
|
||||
def convert_anyof_null_to_nullable(schema, depth=0):
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
from typing import Any, Coroutine, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
@ -11,9 +12,9 @@ from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
|||
from litellm.types.llms.openai import CreateFileRequest, OpenAIFileObject
|
||||
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
||||
|
||||
from .transformation import VertexAIFilesTransformation
|
||||
from .transformation import VertexAIJsonlFilesTransformation
|
||||
|
||||
vertex_ai_files_transformation = VertexAIFilesTransformation()
|
||||
vertex_ai_files_transformation = VertexAIJsonlFilesTransformation()
|
||||
|
||||
|
||||
class VertexAIFilesHandler(GCSBucketBase):
|
||||
|
@ -92,5 +93,15 @@ class VertexAIFilesHandler(GCSBucketBase):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
|
||||
return None # type: ignore
|
||||
else:
|
||||
return asyncio.run(
|
||||
self.async_create_file(
|
||||
create_file_data=create_file_data,
|
||||
api_base=api_base,
|
||||
vertex_credentials=vertex_credentials,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
)
|
||||
|
|
|
@ -1,7 +1,17 @@
|
|||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from httpx import Headers, Response
|
||||
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.files.transformation import (
|
||||
BaseFilesConfig,
|
||||
LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.llms.vertex_ai.common_utils import (
|
||||
_convert_vertex_datetime_to_openai_datetime,
|
||||
)
|
||||
|
@ -10,14 +20,317 @@ from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
|||
VertexGeminiConfig,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
CreateFileRequest,
|
||||
FileTypes,
|
||||
OpenAICreateFileRequestOptionalParams,
|
||||
OpenAIFileObject,
|
||||
PathLike,
|
||||
)
|
||||
from litellm.types.llms.vertex_ai import GcsBucketResponse
|
||||
from litellm.types.utils import ExtractedFileData, LlmProviders
|
||||
|
||||
from ..common_utils import VertexAIError
|
||||
from ..vertex_llm_base import VertexBase
|
||||
|
||||
|
||||
class VertexAIFilesTransformation(VertexGeminiConfig):
|
||||
class VertexAIFilesConfig(VertexBase, BaseFilesConfig):
|
||||
"""
|
||||
Config for VertexAI Files
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.jsonl_transformation = VertexAIJsonlFilesTransformation()
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> LlmProviders:
|
||||
return LlmProviders.VERTEX_AI
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
if not api_key:
|
||||
api_key, _ = self.get_access_token(
|
||||
credentials=litellm_params.get("vertex_credentials"),
|
||||
project_id=litellm_params.get("vertex_project"),
|
||||
)
|
||||
if not api_key:
|
||||
raise ValueError("api_key is required")
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
|
||||
def _get_content_from_openai_file(self, openai_file_content: FileTypes) -> str:
|
||||
"""
|
||||
Helper to extract content from various OpenAI file types and return as string.
|
||||
|
||||
Handles:
|
||||
- Direct content (str, bytes, IO[bytes])
|
||||
- Tuple formats: (filename, content, [content_type], [headers])
|
||||
- PathLike objects
|
||||
"""
|
||||
content: Union[str, bytes] = b""
|
||||
# Extract file content from tuple if necessary
|
||||
if isinstance(openai_file_content, tuple):
|
||||
# Take the second element which is always the file content
|
||||
file_content = openai_file_content[1]
|
||||
else:
|
||||
file_content = openai_file_content
|
||||
|
||||
# Handle different file content types
|
||||
if isinstance(file_content, str):
|
||||
# String content can be used directly
|
||||
content = file_content
|
||||
elif isinstance(file_content, bytes):
|
||||
# Bytes content can be decoded
|
||||
content = file_content
|
||||
elif isinstance(file_content, PathLike): # PathLike
|
||||
with open(str(file_content), "rb") as f:
|
||||
content = f.read()
|
||||
elif hasattr(file_content, "read"): # IO[bytes]
|
||||
# File-like objects need to be read
|
||||
content = file_content.read()
|
||||
|
||||
# Ensure content is string
|
||||
if isinstance(content, bytes):
|
||||
content = content.decode("utf-8")
|
||||
|
||||
return content
|
||||
|
||||
def _get_gcs_object_name_from_batch_jsonl(
|
||||
self,
|
||||
openai_jsonl_content: List[Dict[str, Any]],
|
||||
) -> str:
|
||||
"""
|
||||
Gets a unique GCS object name for the VertexAI batch prediction job
|
||||
|
||||
named as: litellm-vertex-{model}-{uuid}
|
||||
"""
|
||||
_model = openai_jsonl_content[0].get("body", {}).get("model", "")
|
||||
if "publishers/google/models" not in _model:
|
||||
_model = f"publishers/google/models/{_model}"
|
||||
object_name = f"litellm-vertex-files/{_model}/{uuid.uuid4()}"
|
||||
return object_name
|
||||
|
||||
def get_object_name(
|
||||
self, extracted_file_data: ExtractedFileData, purpose: str
|
||||
) -> str:
|
||||
"""
|
||||
Get the object name for the request
|
||||
"""
|
||||
extracted_file_data_content = extracted_file_data.get("content")
|
||||
|
||||
if extracted_file_data_content is None:
|
||||
raise ValueError("file content is required")
|
||||
|
||||
if purpose == "batch":
|
||||
## 1. If jsonl, check if there's a model name
|
||||
file_content = self._get_content_from_openai_file(
|
||||
extracted_file_data_content
|
||||
)
|
||||
|
||||
# Split into lines and parse each line as JSON
|
||||
openai_jsonl_content = [
|
||||
json.loads(line) for line in file_content.splitlines() if line.strip()
|
||||
]
|
||||
if len(openai_jsonl_content) > 0:
|
||||
return self._get_gcs_object_name_from_batch_jsonl(openai_jsonl_content)
|
||||
|
||||
## 2. If not jsonl, return the filename
|
||||
filename = extracted_file_data.get("filename")
|
||||
if filename:
|
||||
return filename
|
||||
## 3. If no file name, return timestamp
|
||||
return str(int(time.time()))
|
||||
|
||||
def get_complete_file_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: Dict,
|
||||
litellm_params: Dict,
|
||||
data: CreateFileRequest,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete url for the request
|
||||
"""
|
||||
bucket_name = litellm_params.get("bucket_name") or os.getenv("GCS_BUCKET_NAME")
|
||||
if not bucket_name:
|
||||
raise ValueError("GCS bucket_name is required")
|
||||
file_data = data.get("file")
|
||||
purpose = data.get("purpose")
|
||||
if file_data is None:
|
||||
raise ValueError("file is required")
|
||||
if purpose is None:
|
||||
raise ValueError("purpose is required")
|
||||
extracted_file_data = extract_file_data(file_data)
|
||||
object_name = self.get_object_name(extracted_file_data, purpose)
|
||||
endpoint = (
|
||||
f"upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}"
|
||||
)
|
||||
api_base = api_base or "https://storage.googleapis.com"
|
||||
if not api_base:
|
||||
raise ValueError("api_base is required")
|
||||
|
||||
return f"{api_base}/{endpoint}"
|
||||
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAICreateFileRequestOptionalParams]:
|
||||
return []
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
return optional_params
|
||||
|
||||
def _map_openai_to_vertex_params(
|
||||
self,
|
||||
openai_request_body: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
wrapper to call VertexGeminiConfig.map_openai_params
|
||||
"""
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
VertexGeminiConfig,
|
||||
)
|
||||
|
||||
config = VertexGeminiConfig()
|
||||
_model = openai_request_body.get("model", "")
|
||||
vertex_params = config.map_openai_params(
|
||||
model=_model,
|
||||
non_default_params=openai_request_body,
|
||||
optional_params={},
|
||||
drop_params=False,
|
||||
)
|
||||
return vertex_params
|
||||
|
||||
def _transform_openai_jsonl_content_to_vertex_ai_jsonl_content(
|
||||
self, openai_jsonl_content: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Transforms OpenAI JSONL content to VertexAI JSONL content
|
||||
|
||||
jsonl body for vertex is {"request": <request_body>}
|
||||
Example Vertex jsonl
|
||||
{"request":{"contents": [{"role": "user", "parts": [{"text": "What is the relation between the following video and image samples?"}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/animals.mp4", "mimeType": "video/mp4"}}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/image/cricket.jpeg", "mimeType": "image/jpeg"}}]}]}}
|
||||
{"request":{"contents": [{"role": "user", "parts": [{"text": "Describe what is happening in this video."}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/another_video.mov", "mimeType": "video/mov"}}]}]}}
|
||||
"""
|
||||
|
||||
vertex_jsonl_content = []
|
||||
for _openai_jsonl_content in openai_jsonl_content:
|
||||
openai_request_body = _openai_jsonl_content.get("body") or {}
|
||||
vertex_request_body = _transform_request_body(
|
||||
messages=openai_request_body.get("messages", []),
|
||||
model=openai_request_body.get("model", ""),
|
||||
optional_params=self._map_openai_to_vertex_params(openai_request_body),
|
||||
custom_llm_provider="vertex_ai",
|
||||
litellm_params={},
|
||||
cached_content=None,
|
||||
)
|
||||
vertex_jsonl_content.append({"request": vertex_request_body})
|
||||
return vertex_jsonl_content
|
||||
|
||||
def transform_create_file_request(
|
||||
self,
|
||||
model: str,
|
||||
create_file_data: CreateFileRequest,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> Union[bytes, str, dict]:
|
||||
"""
|
||||
2 Cases:
|
||||
1. Handle basic file upload
|
||||
2. Handle batch file upload (.jsonl)
|
||||
"""
|
||||
file_data = create_file_data.get("file")
|
||||
if file_data is None:
|
||||
raise ValueError("file is required")
|
||||
extracted_file_data = extract_file_data(file_data)
|
||||
extracted_file_data_content = extracted_file_data.get("content")
|
||||
if (
|
||||
create_file_data.get("purpose") == "batch"
|
||||
and extracted_file_data.get("content_type") == "application/jsonl"
|
||||
and extracted_file_data_content is not None
|
||||
):
|
||||
## 1. If jsonl, check if there's a model name
|
||||
file_content = self._get_content_from_openai_file(
|
||||
extracted_file_data_content
|
||||
)
|
||||
|
||||
# Split into lines and parse each line as JSON
|
||||
openai_jsonl_content = [
|
||||
json.loads(line) for line in file_content.splitlines() if line.strip()
|
||||
]
|
||||
vertex_jsonl_content = (
|
||||
self._transform_openai_jsonl_content_to_vertex_ai_jsonl_content(
|
||||
openai_jsonl_content
|
||||
)
|
||||
)
|
||||
return json.dumps(vertex_jsonl_content)
|
||||
elif isinstance(extracted_file_data_content, bytes):
|
||||
return extracted_file_data_content
|
||||
else:
|
||||
raise ValueError("Unsupported file content type")
|
||||
|
||||
def transform_create_file_response(
|
||||
self,
|
||||
model: Optional[str],
|
||||
raw_response: Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> OpenAIFileObject:
|
||||
"""
|
||||
Transform VertexAI File upload response into OpenAI-style FileObject
|
||||
"""
|
||||
response_json = raw_response.json()
|
||||
|
||||
try:
|
||||
response_object = GcsBucketResponse(**response_json) # type: ignore
|
||||
except Exception as e:
|
||||
raise VertexAIError(
|
||||
status_code=raw_response.status_code,
|
||||
message=f"Error reading GCS bucket response: {e}",
|
||||
headers=raw_response.headers,
|
||||
)
|
||||
|
||||
gcs_id = response_object.get("id", "")
|
||||
# Remove the last numeric ID from the path
|
||||
gcs_id = "/".join(gcs_id.split("/")[:-1]) if gcs_id else ""
|
||||
|
||||
return OpenAIFileObject(
|
||||
purpose=response_object.get("purpose", "batch"),
|
||||
id=f"gs://{gcs_id}",
|
||||
filename=response_object.get("name", ""),
|
||||
created_at=_convert_vertex_datetime_to_openai_datetime(
|
||||
vertex_datetime=response_object.get("timeCreated", "")
|
||||
),
|
||||
status="uploaded",
|
||||
bytes=int(response_object.get("size", 0)),
|
||||
object="file",
|
||||
)
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[Dict, Headers]
|
||||
) -> BaseLLMException:
|
||||
return VertexAIError(
|
||||
status_code=status_code, message=error_message, headers=headers
|
||||
)
|
||||
|
||||
|
||||
class VertexAIJsonlFilesTransformation(VertexGeminiConfig):
|
||||
"""
|
||||
Transforms OpenAI /v1/files/* requests to VertexAI /v1/files/* requests
|
||||
"""
|
||||
|
|
|
@ -208,25 +208,24 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915
|
|||
elif element["type"] == "input_audio":
|
||||
audio_element = cast(ChatCompletionAudioObject, element)
|
||||
if audio_element["input_audio"].get("data") is not None:
|
||||
_part = PartType(
|
||||
inline_data=BlobType(
|
||||
data=audio_element["input_audio"]["data"],
|
||||
mime_type="audio/{}".format(
|
||||
audio_element["input_audio"]["format"]
|
||||
),
|
||||
)
|
||||
_part = _process_gemini_image(
|
||||
image_url=audio_element["input_audio"]["data"],
|
||||
format=audio_element["input_audio"].get("format"),
|
||||
)
|
||||
_parts.append(_part)
|
||||
elif element["type"] == "file":
|
||||
file_element = cast(ChatCompletionFileObject, element)
|
||||
file_id = file_element["file"].get("file_id")
|
||||
format = file_element["file"].get("format")
|
||||
|
||||
if not file_id:
|
||||
continue
|
||||
file_data = file_element["file"].get("file_data")
|
||||
passed_file = file_id or file_data
|
||||
if passed_file is None:
|
||||
raise Exception(
|
||||
"Unknown file type. Please pass in a file_id or file_data"
|
||||
)
|
||||
try:
|
||||
_part = _process_gemini_image(
|
||||
image_url=file_id, format=format
|
||||
image_url=passed_file, format=format
|
||||
)
|
||||
_parts.append(_part)
|
||||
except Exception:
|
||||
|
|
|
@ -240,6 +240,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
|||
gtool_func_declarations = []
|
||||
googleSearch: Optional[dict] = None
|
||||
googleSearchRetrieval: Optional[dict] = None
|
||||
enterpriseWebSearch: Optional[dict] = None
|
||||
code_execution: Optional[dict] = None
|
||||
# remove 'additionalProperties' from tools
|
||||
value = _remove_additional_properties(value)
|
||||
|
@ -273,6 +274,8 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
|||
googleSearch = tool["googleSearch"]
|
||||
elif tool.get("googleSearchRetrieval", None) is not None:
|
||||
googleSearchRetrieval = tool["googleSearchRetrieval"]
|
||||
elif tool.get("enterpriseWebSearch", None) is not None:
|
||||
enterpriseWebSearch = tool["enterpriseWebSearch"]
|
||||
elif tool.get("code_execution", None) is not None:
|
||||
code_execution = tool["code_execution"]
|
||||
elif openai_function_object is not None:
|
||||
|
@ -299,6 +302,8 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
|||
_tools["googleSearch"] = googleSearch
|
||||
if googleSearchRetrieval is not None:
|
||||
_tools["googleSearchRetrieval"] = googleSearchRetrieval
|
||||
if enterpriseWebSearch is not None:
|
||||
_tools["enterpriseWebSearch"] = enterpriseWebSearch
|
||||
if code_execution is not None:
|
||||
_tools["code_execution"] = code_execution
|
||||
return [_tools]
|
||||
|
@ -374,7 +379,11 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
|||
optional_params["responseLogprobs"] = value
|
||||
elif param == "top_logprobs":
|
||||
optional_params["logprobs"] = value
|
||||
elif (param == "tools" or param == "functions") and isinstance(value, list):
|
||||
elif (
|
||||
(param == "tools" or param == "functions")
|
||||
and isinstance(value, list)
|
||||
and value
|
||||
):
|
||||
optional_params["tools"] = self._map_function(value=value)
|
||||
optional_params["litellm_param_is_function_call"] = (
|
||||
True if param == "functions" else False
|
||||
|
@ -739,9 +748,6 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
|||
chat_completion_logprobs = self._transform_logprobs(
|
||||
logprobs_result=candidate["logprobsResult"]
|
||||
)
|
||||
# Handle avgLogprobs for Gemini Flash 2.0
|
||||
elif "avgLogprobs" in candidate:
|
||||
chat_completion_logprobs = candidate["avgLogprobs"]
|
||||
|
||||
if tools:
|
||||
chat_completion_message["tool_calls"] = tools
|
||||
|
@ -896,6 +902,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: Dict,
|
||||
litellm_params: Dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict:
|
||||
|
@ -1013,7 +1020,7 @@ class VertexLLM(VertexBase):
|
|||
logging_obj,
|
||||
stream,
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
litellm_params: dict,
|
||||
logger_fn=None,
|
||||
api_base: Optional[str] = None,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
|
@ -1054,6 +1061,7 @@ class VertexLLM(VertexBase):
|
|||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
|
@ -1140,6 +1148,7 @@ class VertexLLM(VertexBase):
|
|||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
request_body = await async_transform_request_body(**data) # type: ignore
|
||||
|
@ -1313,6 +1322,7 @@ class VertexLLM(VertexBase):
|
|||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
## TRANSFORMATION ##
|
||||
|
|
|
@ -94,6 +94,7 @@ class VertexMultimodalEmbedding(VertexLLM):
|
|||
optional_params=optional_params,
|
||||
api_key=auth_header,
|
||||
api_base=api_base,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
|
|
|
@ -47,6 +47,7 @@ class VertexAIMultimodalEmbeddingConfig(BaseEmbeddingConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -2,9 +2,10 @@ import types
|
|||
from typing import Optional
|
||||
|
||||
import litellm
|
||||
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||
|
||||
|
||||
class VertexAIAi21Config:
|
||||
class VertexAIAi21Config(OpenAIGPTConfig):
|
||||
"""
|
||||
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/ai21
|
||||
|
||||
|
@ -40,9 +41,6 @@ class VertexAIAi21Config:
|
|||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return litellm.OpenAIConfig().get_supported_openai_params(model="gpt-3.5-turbo")
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
|
|
|
@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple
|
|||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.asyncify import asyncify
|
||||
from litellm.llms.base import BaseLLM
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
||||
|
||||
|
@ -22,7 +21,7 @@ else:
|
|||
GoogleCredentialsObject = Any
|
||||
|
||||
|
||||
class VertexBase(BaseLLM):
|
||||
class VertexBase:
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.access_token: Optional[str] = None
|
||||
|
|
|
@ -83,6 +83,7 @@ class VoyageEmbeddingConfig(BaseEmbeddingConfig):
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
|
|
@ -49,6 +49,7 @@ class WatsonXChatHandler(OpenAILikeChatHandler):
|
|||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
api_key=api_key,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
## UPDATE PAYLOAD (optional params)
|
||||
|
|
|
@ -165,6 +165,7 @@ class IBMWatsonXMixin:
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: Dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict:
|
||||
|
|
|
@ -3616,6 +3616,7 @@ def embedding( # noqa: PLR0915
|
|||
optional_params=optional_params,
|
||||
client=client,
|
||||
aembedding=aembedding,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
elif custom_llm_provider == "bedrock":
|
||||
if isinstance(input, str):
|
||||
|
|
|
@ -380,6 +380,7 @@
|
|||
"supports_tool_choice": true,
|
||||
"supports_native_streaming": false,
|
||||
"supported_modalities": ["text", "image"],
|
||||
"supported_output_modalities": ["text"],
|
||||
"supported_endpoints": ["/v1/responses", "/v1/batch"]
|
||||
},
|
||||
"o1-pro-2025-03-19": {
|
||||
|
@ -401,6 +402,7 @@
|
|||
"supports_tool_choice": true,
|
||||
"supports_native_streaming": false,
|
||||
"supported_modalities": ["text", "image"],
|
||||
"supported_output_modalities": ["text"],
|
||||
"supported_endpoints": ["/v1/responses", "/v1/batch"]
|
||||
},
|
||||
"o1": {
|
||||
|
@ -1286,6 +1288,68 @@
|
|||
"supports_system_messages": true,
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
"azure/gpt-4o-realtime-preview-2024-12-17": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.000005,
|
||||
"input_cost_per_audio_token": 0.00004,
|
||||
"cache_read_input_token_cost": 0.0000025,
|
||||
"output_cost_per_token": 0.00002,
|
||||
"output_cost_per_audio_token": 0.00008,
|
||||
"litellm_provider": "azure",
|
||||
"mode": "chat",
|
||||
"supported_modalities": ["text", "audio"],
|
||||
"supported_output_modalities": ["text", "audio"],
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_audio_input": true,
|
||||
"supports_audio_output": true,
|
||||
"supports_system_messages": true,
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
"azure/us/gpt-4o-realtime-preview-2024-12-17": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 5.5e-6,
|
||||
"input_cost_per_audio_token": 44e-6,
|
||||
"cache_read_input_token_cost": 2.75e-6,
|
||||
"cache_read_input_audio_token_cost": 2.5e-6,
|
||||
"output_cost_per_token": 22e-6,
|
||||
"output_cost_per_audio_token": 80e-6,
|
||||
"litellm_provider": "azure",
|
||||
"mode": "chat",
|
||||
"supported_modalities": ["text", "audio"],
|
||||
"supported_output_modalities": ["text", "audio"],
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_audio_input": true,
|
||||
"supports_audio_output": true,
|
||||
"supports_system_messages": true,
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
"azure/eu/gpt-4o-realtime-preview-2024-12-17": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 5.5e-6,
|
||||
"input_cost_per_audio_token": 44e-6,
|
||||
"cache_read_input_token_cost": 2.75e-6,
|
||||
"cache_read_input_audio_token_cost": 2.5e-6,
|
||||
"output_cost_per_token": 22e-6,
|
||||
"output_cost_per_audio_token": 80e-6,
|
||||
"litellm_provider": "azure",
|
||||
"mode": "chat",
|
||||
"supported_modalities": ["text", "audio"],
|
||||
"supported_output_modalities": ["text", "audio"],
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_audio_input": true,
|
||||
"supports_audio_output": true,
|
||||
"supports_system_messages": true,
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
"azure/gpt-4o-realtime-preview-2024-10-01": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 128000,
|
||||
|
@ -2300,6 +2364,18 @@
|
|||
"source": "https://azuremarketplace.microsoft.com/en/marketplace/apps/000-000.mistral-ai-large-2407-offer?tab=Overview",
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
"azure_ai/mistral-large-latest": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.000002,
|
||||
"output_cost_per_token": 0.000006,
|
||||
"litellm_provider": "azure_ai",
|
||||
"supports_function_calling": true,
|
||||
"mode": "chat",
|
||||
"source": "https://azuremarketplace.microsoft.com/en/marketplace/apps/000-000.mistral-ai-large-2407-offer?tab=Overview",
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
"azure_ai/ministral-3b": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 128000,
|
||||
|
@ -2397,25 +2473,26 @@
|
|||
"max_tokens": 4096,
|
||||
"max_input_tokens": 131072,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0,
|
||||
"output_cost_per_token": 0,
|
||||
"input_cost_per_token": 0.000000075,
|
||||
"output_cost_per_token": 0.0000003,
|
||||
"litellm_provider": "azure_ai",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"source": "https://learn.microsoft.com/en-us/azure/ai-foundry/concepts/models-featured#microsoft"
|
||||
"source": "https://techcommunity.microsoft.com/blog/Azure-AI-Services-blog/announcing-new-phi-pricing-empowering-your-business-with-small-language-models/4395112"
|
||||
},
|
||||
"azure_ai/Phi-4-multimodal-instruct": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 131072,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0,
|
||||
"output_cost_per_token": 0,
|
||||
"input_cost_per_token": 0.00000008,
|
||||
"input_cost_per_audio_token": 0.000004,
|
||||
"output_cost_per_token": 0.00032,
|
||||
"litellm_provider": "azure_ai",
|
||||
"mode": "chat",
|
||||
"supports_audio_input": true,
|
||||
"supports_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"source": "https://learn.microsoft.com/en-us/azure/ai-foundry/concepts/models-featured#microsoft"
|
||||
"source": "https://techcommunity.microsoft.com/blog/Azure-AI-Services-blog/announcing-new-phi-pricing-empowering-your-business-with-small-language-models/4395112"
|
||||
},
|
||||
"azure_ai/Phi-4": {
|
||||
"max_tokens": 16384,
|
||||
|
@ -3455,7 +3532,7 @@
|
|||
"input_cost_per_token": 0.0000008,
|
||||
"output_cost_per_token": 0.000004,
|
||||
"cache_creation_input_token_cost": 0.000001,
|
||||
"cache_read_input_token_cost": 0.0000008,
|
||||
"cache_read_input_token_cost": 0.00000008,
|
||||
"litellm_provider": "anthropic",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
|
@ -4499,20 +4576,10 @@
|
|||
"max_audio_length_hours": 8.4,
|
||||
"max_audio_per_prompt": 1,
|
||||
"max_pdf_size_mb": 30,
|
||||
"input_cost_per_image": 0,
|
||||
"input_cost_per_video_per_second": 0,
|
||||
"input_cost_per_audio_per_second": 0,
|
||||
"input_cost_per_token": 0,
|
||||
"input_cost_per_character": 0,
|
||||
"input_cost_per_token_above_128k_tokens": 0,
|
||||
"input_cost_per_character_above_128k_tokens": 0,
|
||||
"input_cost_per_image_above_128k_tokens": 0,
|
||||
"input_cost_per_video_per_second_above_128k_tokens": 0,
|
||||
"input_cost_per_audio_per_second_above_128k_tokens": 0,
|
||||
"output_cost_per_token": 0,
|
||||
"output_cost_per_character": 0,
|
||||
"output_cost_per_token_above_128k_tokens": 0,
|
||||
"output_cost_per_character_above_128k_tokens": 0,
|
||||
"input_cost_per_token": 0.00000125,
|
||||
"input_cost_per_token_above_200k_tokens": 0.0000025,
|
||||
"output_cost_per_token": 0.00001,
|
||||
"output_cost_per_token_above_200k_tokens": 0.000015,
|
||||
"litellm_provider": "vertex_ai-language-models",
|
||||
"mode": "chat",
|
||||
"supports_system_messages": true,
|
||||
|
@ -4523,6 +4590,9 @@
|
|||
"supports_pdf_input": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_tool_choice": true,
|
||||
"supported_endpoints": ["/v1/chat/completions", "/v1/completions"],
|
||||
"supported_modalities": ["text", "image", "audio", "video"],
|
||||
"supported_output_modalities": ["text"],
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
|
||||
},
|
||||
"gemini-2.0-pro-exp-02-05": {
|
||||
|
@ -4535,20 +4605,10 @@
|
|||
"max_audio_length_hours": 8.4,
|
||||
"max_audio_per_prompt": 1,
|
||||
"max_pdf_size_mb": 30,
|
||||
"input_cost_per_image": 0,
|
||||
"input_cost_per_video_per_second": 0,
|
||||
"input_cost_per_audio_per_second": 0,
|
||||
"input_cost_per_token": 0,
|
||||
"input_cost_per_character": 0,
|
||||
"input_cost_per_token_above_128k_tokens": 0,
|
||||
"input_cost_per_character_above_128k_tokens": 0,
|
||||
"input_cost_per_image_above_128k_tokens": 0,
|
||||
"input_cost_per_video_per_second_above_128k_tokens": 0,
|
||||
"input_cost_per_audio_per_second_above_128k_tokens": 0,
|
||||
"output_cost_per_token": 0,
|
||||
"output_cost_per_character": 0,
|
||||
"output_cost_per_token_above_128k_tokens": 0,
|
||||
"output_cost_per_character_above_128k_tokens": 0,
|
||||
"input_cost_per_token": 0.00000125,
|
||||
"input_cost_per_token_above_200k_tokens": 0.0000025,
|
||||
"output_cost_per_token": 0.00001,
|
||||
"output_cost_per_token_above_200k_tokens": 0.000015,
|
||||
"litellm_provider": "vertex_ai-language-models",
|
||||
"mode": "chat",
|
||||
"supports_system_messages": true,
|
||||
|
@ -4559,6 +4619,9 @@
|
|||
"supports_pdf_input": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_tool_choice": true,
|
||||
"supported_endpoints": ["/v1/chat/completions", "/v1/completions"],
|
||||
"supported_modalities": ["text", "image", "audio", "video"],
|
||||
"supported_output_modalities": ["text"],
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
|
||||
},
|
||||
"gemini-2.0-flash-exp": {
|
||||
|
@ -4592,6 +4655,8 @@
|
|||
"supports_vision": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_audio_output": true,
|
||||
"supported_modalities": ["text", "image", "audio", "video"],
|
||||
"supported_output_modalities": ["text", "image"],
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing",
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
|
@ -4616,6 +4681,8 @@
|
|||
"supports_response_schema": true,
|
||||
"supports_audio_output": true,
|
||||
"supports_tool_choice": true,
|
||||
"supported_modalities": ["text", "image", "audio", "video"],
|
||||
"supported_output_modalities": ["text", "image"],
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
|
||||
},
|
||||
"gemini-2.0-flash-thinking-exp": {
|
||||
|
@ -4649,6 +4716,8 @@
|
|||
"supports_vision": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_audio_output": true,
|
||||
"supported_modalities": ["text", "image", "audio", "video"],
|
||||
"supported_output_modalities": ["text", "image"],
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash",
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
|
@ -4683,6 +4752,8 @@
|
|||
"supports_vision": true,
|
||||
"supports_response_schema": false,
|
||||
"supports_audio_output": false,
|
||||
"supported_modalities": ["text", "image", "audio", "video"],
|
||||
"supported_output_modalities": ["text", "image"],
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash",
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
|
@ -4708,6 +4779,7 @@
|
|||
"supports_audio_output": true,
|
||||
"supports_audio_input": true,
|
||||
"supported_modalities": ["text", "image", "audio", "video"],
|
||||
"supported_output_modalities": ["text", "image"],
|
||||
"supports_tool_choice": true,
|
||||
"source": "https://ai.google.dev/pricing#2_0flash"
|
||||
},
|
||||
|
@ -4730,6 +4802,32 @@
|
|||
"supports_vision": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_audio_output": true,
|
||||
"supported_modalities": ["text", "image", "audio", "video"],
|
||||
"supported_output_modalities": ["text"],
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash",
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
"gemini-2.0-flash-lite-001": {
|
||||
"max_input_tokens": 1048576,
|
||||
"max_output_tokens": 8192,
|
||||
"max_images_per_prompt": 3000,
|
||||
"max_videos_per_prompt": 10,
|
||||
"max_video_length": 1,
|
||||
"max_audio_length_hours": 8.4,
|
||||
"max_audio_per_prompt": 1,
|
||||
"max_pdf_size_mb": 50,
|
||||
"input_cost_per_audio_token": 0.000000075,
|
||||
"input_cost_per_token": 0.000000075,
|
||||
"output_cost_per_token": 0.0000003,
|
||||
"litellm_provider": "vertex_ai-language-models",
|
||||
"mode": "chat",
|
||||
"supports_system_messages": true,
|
||||
"supports_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_audio_output": true,
|
||||
"supported_modalities": ["text", "image", "audio", "video"],
|
||||
"supported_output_modalities": ["text"],
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash",
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
|
@ -4795,6 +4893,7 @@
|
|||
"supports_audio_output": true,
|
||||
"supports_audio_input": true,
|
||||
"supported_modalities": ["text", "image", "audio", "video"],
|
||||
"supported_output_modalities": ["text", "image"],
|
||||
"supports_tool_choice": true,
|
||||
"source": "https://ai.google.dev/pricing#2_0flash"
|
||||
},
|
||||
|
@ -4820,6 +4919,8 @@
|
|||
"supports_response_schema": true,
|
||||
"supports_audio_output": true,
|
||||
"supports_tool_choice": true,
|
||||
"supported_modalities": ["text", "image", "audio", "video"],
|
||||
"supported_output_modalities": ["text"],
|
||||
"source": "https://ai.google.dev/gemini-api/docs/pricing#gemini-2.0-flash-lite"
|
||||
},
|
||||
"gemini/gemini-2.0-flash-001": {
|
||||
|
@ -4845,6 +4946,8 @@
|
|||
"supports_response_schema": true,
|
||||
"supports_audio_output": false,
|
||||
"supports_tool_choice": true,
|
||||
"supported_modalities": ["text", "image", "audio", "video"],
|
||||
"supported_output_modalities": ["text", "image"],
|
||||
"source": "https://ai.google.dev/pricing#2_0flash"
|
||||
},
|
||||
"gemini/gemini-2.5-pro-preview-03-25": {
|
||||
|
@ -4859,9 +4962,9 @@
|
|||
"max_pdf_size_mb": 30,
|
||||
"input_cost_per_audio_token": 0.0000007,
|
||||
"input_cost_per_token": 0.00000125,
|
||||
"input_cost_per_token_above_128k_tokens": 0.0000025,
|
||||
"input_cost_per_token_above_200k_tokens": 0.0000025,
|
||||
"output_cost_per_token": 0.0000010,
|
||||
"output_cost_per_token_above_128k_tokens": 0.000015,
|
||||
"output_cost_per_token_above_200k_tokens": 0.000015,
|
||||
"litellm_provider": "gemini",
|
||||
"mode": "chat",
|
||||
"rpm": 10000,
|
||||
|
@ -4872,6 +4975,8 @@
|
|||
"supports_response_schema": true,
|
||||
"supports_audio_output": false,
|
||||
"supports_tool_choice": true,
|
||||
"supported_modalities": ["text", "image", "audio", "video"],
|
||||
"supported_output_modalities": ["text"],
|
||||
"source": "https://ai.google.dev/gemini-api/docs/pricing#gemini-2.5-pro-preview"
|
||||
},
|
||||
"gemini/gemini-2.0-flash-exp": {
|
||||
|
@ -4907,6 +5012,8 @@
|
|||
"supports_audio_output": true,
|
||||
"tpm": 4000000,
|
||||
"rpm": 10,
|
||||
"supported_modalities": ["text", "image", "audio", "video"],
|
||||
"supported_output_modalities": ["text", "image"],
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash",
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
|
@ -4933,6 +5040,8 @@
|
|||
"supports_response_schema": true,
|
||||
"supports_audio_output": false,
|
||||
"supports_tool_choice": true,
|
||||
"supported_modalities": ["text", "image", "audio", "video"],
|
||||
"supported_output_modalities": ["text"],
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash-lite"
|
||||
},
|
||||
"gemini/gemini-2.0-flash-thinking-exp": {
|
||||
|
@ -4968,6 +5077,8 @@
|
|||
"supports_audio_output": true,
|
||||
"tpm": 4000000,
|
||||
"rpm": 10,
|
||||
"supported_modalities": ["text", "image", "audio", "video"],
|
||||
"supported_output_modalities": ["text", "image"],
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash",
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
|
@ -5004,6 +5115,8 @@
|
|||
"supports_audio_output": true,
|
||||
"tpm": 4000000,
|
||||
"rpm": 10,
|
||||
"supported_modalities": ["text", "image", "audio", "video"],
|
||||
"supported_output_modalities": ["text", "image"],
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash",
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
|
@ -8444,7 +8557,8 @@
|
|||
"input_cost_per_token": 0.0000015,
|
||||
"output_cost_per_token": 0.0000020,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "chat"
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
"bedrock/*/1-month-commitment/cohere.command-text-v14": {
|
||||
"max_tokens": 4096,
|
||||
|
@ -8453,7 +8567,8 @@
|
|||
"input_cost_per_second": 0.011,
|
||||
"output_cost_per_second": 0.011,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "chat"
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
"bedrock/*/6-month-commitment/cohere.command-text-v14": {
|
||||
"max_tokens": 4096,
|
||||
|
@ -8462,7 +8577,8 @@
|
|||
"input_cost_per_second": 0.0066027,
|
||||
"output_cost_per_second": 0.0066027,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "chat"
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
"cohere.command-light-text-v14": {
|
||||
"max_tokens": 4096,
|
||||
|
@ -8471,7 +8587,8 @@
|
|||
"input_cost_per_token": 0.0000003,
|
||||
"output_cost_per_token": 0.0000006,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "chat"
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
"bedrock/*/1-month-commitment/cohere.command-light-text-v14": {
|
||||
"max_tokens": 4096,
|
||||
|
@ -8480,7 +8597,8 @@
|
|||
"input_cost_per_second": 0.001902,
|
||||
"output_cost_per_second": 0.001902,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "chat"
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
"bedrock/*/6-month-commitment/cohere.command-light-text-v14": {
|
||||
"max_tokens": 4096,
|
||||
|
@ -8489,7 +8607,8 @@
|
|||
"input_cost_per_second": 0.0011416,
|
||||
"output_cost_per_second": 0.0011416,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "chat"
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
"cohere.command-r-plus-v1:0": {
|
||||
"max_tokens": 4096,
|
||||
|
@ -8498,7 +8617,8 @@
|
|||
"input_cost_per_token": 0.0000030,
|
||||
"output_cost_per_token": 0.000015,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "chat"
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
"cohere.command-r-v1:0": {
|
||||
"max_tokens": 4096,
|
||||
|
@ -8507,7 +8627,8 @@
|
|||
"input_cost_per_token": 0.0000005,
|
||||
"output_cost_per_token": 0.0000015,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "chat"
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
"cohere.embed-english-v3": {
|
||||
"max_tokens": 512,
|
||||
|
|