Merge remote-tracking branch 'origin/main' into feat--parse-user-from-headers
|
@ -610,6 +610,8 @@ jobs:
|
||||||
name: Install Dependencies
|
name: Install Dependencies
|
||||||
command: |
|
command: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
|
pip install wheel
|
||||||
|
pip install --upgrade pip wheel setuptools
|
||||||
python -m pip install -r requirements.txt
|
python -m pip install -r requirements.txt
|
||||||
pip install "pytest==7.3.1"
|
pip install "pytest==7.3.1"
|
||||||
pip install "respx==0.21.1"
|
pip install "respx==0.21.1"
|
||||||
|
@ -1125,6 +1127,7 @@ jobs:
|
||||||
name: Install Dependencies
|
name: Install Dependencies
|
||||||
command: |
|
command: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
|
python -m pip install wheel setuptools
|
||||||
python -m pip install -r requirements.txt
|
python -m pip install -r requirements.txt
|
||||||
pip install "pytest==7.3.1"
|
pip install "pytest==7.3.1"
|
||||||
pip install "pytest-retry==1.6.3"
|
pip install "pytest-retry==1.6.3"
|
||||||
|
|
|
@ -12,8 +12,7 @@ WORKDIR /app
|
||||||
USER root
|
USER root
|
||||||
|
|
||||||
# Install build dependencies
|
# Install build dependencies
|
||||||
RUN apk update && \
|
RUN apk add --no-cache gcc python3-dev openssl openssl-dev
|
||||||
apk add --no-cache gcc python3-dev openssl openssl-dev
|
|
||||||
|
|
||||||
|
|
||||||
RUN pip install --upgrade pip && \
|
RUN pip install --upgrade pip && \
|
||||||
|
@ -52,8 +51,7 @@ FROM $LITELLM_RUNTIME_IMAGE AS runtime
|
||||||
USER root
|
USER root
|
||||||
|
|
||||||
# Install runtime dependencies
|
# Install runtime dependencies
|
||||||
RUN apk update && \
|
RUN apk add --no-cache openssl
|
||||||
apk add --no-cache openssl
|
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
# Copy the current directory contents into the container at /app
|
# Copy the current directory contents into the container at /app
|
||||||
|
|
|
@ -2,6 +2,10 @@ apiVersion: v1
|
||||||
kind: Service
|
kind: Service
|
||||||
metadata:
|
metadata:
|
||||||
name: {{ include "litellm.fullname" . }}
|
name: {{ include "litellm.fullname" . }}
|
||||||
|
{{- with .Values.service.annotations }}
|
||||||
|
annotations:
|
||||||
|
{{- toYaml . | nindent 4 }}
|
||||||
|
{{- end }}
|
||||||
labels:
|
labels:
|
||||||
{{- include "litellm.labels" . | nindent 4 }}
|
{{- include "litellm.labels" . | nindent 4 }}
|
||||||
spec:
|
spec:
|
||||||
|
|
|
@ -35,7 +35,7 @@ RUN pip wheel --no-cache-dir --wheel-dir=/wheels/ -r requirements.txt
|
||||||
FROM $LITELLM_RUNTIME_IMAGE AS runtime
|
FROM $LITELLM_RUNTIME_IMAGE AS runtime
|
||||||
|
|
||||||
# Update dependencies and clean up
|
# Update dependencies and clean up
|
||||||
RUN apk update && apk upgrade && rm -rf /var/cache/apk/*
|
RUN apk upgrade --no-cache
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
|
|
|
@ -12,8 +12,7 @@ WORKDIR /app
|
||||||
USER root
|
USER root
|
||||||
|
|
||||||
# Install build dependencies
|
# Install build dependencies
|
||||||
RUN apk update && \
|
RUN apk add --no-cache gcc python3-dev openssl openssl-dev
|
||||||
apk add --no-cache gcc python3-dev openssl openssl-dev
|
|
||||||
|
|
||||||
|
|
||||||
RUN pip install --upgrade pip && \
|
RUN pip install --upgrade pip && \
|
||||||
|
@ -44,8 +43,7 @@ FROM $LITELLM_RUNTIME_IMAGE AS runtime
|
||||||
USER root
|
USER root
|
||||||
|
|
||||||
# Install runtime dependencies
|
# Install runtime dependencies
|
||||||
RUN apk update && \
|
RUN apk add --no-cache openssl
|
||||||
apk add --no-cache openssl
|
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
# Copy the current directory contents into the container at /app
|
# Copy the current directory contents into the container at /app
|
||||||
|
|
|
@ -438,6 +438,179 @@ assert isinstance(
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Google Search Tool
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["GEMINI_API_KEY"] = ".."
|
||||||
|
|
||||||
|
tools = [{"googleSearch": {}}] # 👈 ADD GOOGLE SEARCH
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="gemini/gemini-2.0-flash",
|
||||||
|
messages=[{"role": "user", "content": "What is the weather in San Francisco?"}],
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Setup config.yaml
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: gemini-2.0-flash
|
||||||
|
litellm_params:
|
||||||
|
model: gemini/gemini-2.0-flash
|
||||||
|
api_key: os.environ/GEMINI_API_KEY
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start Proxy
|
||||||
|
```bash
|
||||||
|
$ litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Make Request!
|
||||||
|
```bash
|
||||||
|
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
|
-d '{
|
||||||
|
"model": "gemini-2.0-flash",
|
||||||
|
"messages": [{"role": "user", "content": "What is the weather in San Francisco?"}],
|
||||||
|
"tools": [{"googleSearch": {}}]
|
||||||
|
}
|
||||||
|
'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
### Google Search Retrieval
|
||||||
|
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["GEMINI_API_KEY"] = ".."
|
||||||
|
|
||||||
|
tools = [{"googleSearchRetrieval": {}}] # 👈 ADD GOOGLE SEARCH
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="gemini/gemini-2.0-flash",
|
||||||
|
messages=[{"role": "user", "content": "What is the weather in San Francisco?"}],
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Setup config.yaml
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: gemini-2.0-flash
|
||||||
|
litellm_params:
|
||||||
|
model: gemini/gemini-2.0-flash
|
||||||
|
api_key: os.environ/GEMINI_API_KEY
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start Proxy
|
||||||
|
```bash
|
||||||
|
$ litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Make Request!
|
||||||
|
```bash
|
||||||
|
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
|
-d '{
|
||||||
|
"model": "gemini-2.0-flash",
|
||||||
|
"messages": [{"role": "user", "content": "What is the weather in San Francisco?"}],
|
||||||
|
"tools": [{"googleSearchRetrieval": {}}]
|
||||||
|
}
|
||||||
|
'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
### Code Execution Tool
|
||||||
|
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["GEMINI_API_KEY"] = ".."
|
||||||
|
|
||||||
|
tools = [{"codeExecution": {}}] # 👈 ADD GOOGLE SEARCH
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="gemini/gemini-2.0-flash",
|
||||||
|
messages=[{"role": "user", "content": "What is the weather in San Francisco?"}],
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Setup config.yaml
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: gemini-2.0-flash
|
||||||
|
litellm_params:
|
||||||
|
model: gemini/gemini-2.0-flash
|
||||||
|
api_key: os.environ/GEMINI_API_KEY
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start Proxy
|
||||||
|
```bash
|
||||||
|
$ litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Make Request!
|
||||||
|
```bash
|
||||||
|
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
|
-d '{
|
||||||
|
"model": "gemini-2.0-flash",
|
||||||
|
"messages": [{"role": "user", "content": "What is the weather in San Francisco?"}],
|
||||||
|
"tools": [{"codeExecution": {}}]
|
||||||
|
}
|
||||||
|
'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## JSON Mode
|
## JSON Mode
|
||||||
|
|
||||||
<Tabs>
|
<Tabs>
|
||||||
|
|
|
@ -398,6 +398,8 @@ curl http://localhost:4000/v1/chat/completions \
|
||||||
</TabItem>
|
</TabItem>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
|
You can also use the `enterpriseWebSearch` tool for an [enterprise compliant search](https://cloud.google.com/vertex-ai/generative-ai/docs/grounding/web-grounding-enterprise).
|
||||||
|
|
||||||
#### **Moving from Vertex AI SDK to LiteLLM (GROUNDING)**
|
#### **Moving from Vertex AI SDK to LiteLLM (GROUNDING)**
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -449,6 +449,7 @@ router_settings:
|
||||||
| MICROSOFT_CLIENT_ID | Client ID for Microsoft services
|
| MICROSOFT_CLIENT_ID | Client ID for Microsoft services
|
||||||
| MICROSOFT_CLIENT_SECRET | Client secret for Microsoft services
|
| MICROSOFT_CLIENT_SECRET | Client secret for Microsoft services
|
||||||
| MICROSOFT_TENANT | Tenant ID for Microsoft Azure
|
| MICROSOFT_TENANT | Tenant ID for Microsoft Azure
|
||||||
|
| MICROSOFT_SERVICE_PRINCIPAL_ID | Service Principal ID for Microsoft Enterprise Application. (This is an advanced feature if you want litellm to auto-assign members to Litellm Teams based on their Microsoft Entra ID Groups)
|
||||||
| NO_DOCS | Flag to disable documentation generation
|
| NO_DOCS | Flag to disable documentation generation
|
||||||
| NO_PROXY | List of addresses to bypass proxy
|
| NO_PROXY | List of addresses to bypass proxy
|
||||||
| OAUTH_TOKEN_INFO_ENDPOINT | Endpoint for OAuth token info retrieval
|
| OAUTH_TOKEN_INFO_ENDPOINT | Endpoint for OAuth token info retrieval
|
||||||
|
|
|
@ -26,10 +26,12 @@ model_list:
|
||||||
- model_name: sagemaker-completion-model
|
- model_name: sagemaker-completion-model
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4
|
model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4
|
||||||
|
model_info:
|
||||||
input_cost_per_second: 0.000420
|
input_cost_per_second: 0.000420
|
||||||
- model_name: sagemaker-embedding-model
|
- model_name: sagemaker-embedding-model
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: sagemaker/berri-benchmarking-gpt-j-6b-fp16
|
model: sagemaker/berri-benchmarking-gpt-j-6b-fp16
|
||||||
|
model_info:
|
||||||
input_cost_per_second: 0.000420
|
input_cost_per_second: 0.000420
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -55,11 +57,33 @@ model_list:
|
||||||
api_key: os.environ/AZURE_API_KEY
|
api_key: os.environ/AZURE_API_KEY
|
||||||
api_base: os.environ/AZURE_API_BASE
|
api_base: os.environ/AZURE_API_BASE
|
||||||
api_version: os.envrion/AZURE_API_VERSION
|
api_version: os.envrion/AZURE_API_VERSION
|
||||||
|
model_info:
|
||||||
input_cost_per_token: 0.000421 # 👈 ONLY to track cost per token
|
input_cost_per_token: 0.000421 # 👈 ONLY to track cost per token
|
||||||
output_cost_per_token: 0.000520 # 👈 ONLY to track cost per token
|
output_cost_per_token: 0.000520 # 👈 ONLY to track cost per token
|
||||||
```
|
```
|
||||||
|
|
||||||
### Debugging
|
## Override Model Cost Map
|
||||||
|
|
||||||
|
You can override [our model cost map](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json) with your own custom pricing for a mapped model.
|
||||||
|
|
||||||
|
Just add a `model_info` key to your model in the config, and override the desired keys.
|
||||||
|
|
||||||
|
Example: Override Anthropic's model cost map for the `prod/claude-3-5-sonnet-20241022` model.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: "prod/claude-3-5-sonnet-20241022"
|
||||||
|
litellm_params:
|
||||||
|
model: "anthropic/claude-3-5-sonnet-20241022"
|
||||||
|
api_key: os.environ/ANTHROPIC_PROD_API_KEY
|
||||||
|
model_info:
|
||||||
|
input_cost_per_token: 0.000006
|
||||||
|
output_cost_per_token: 0.00003
|
||||||
|
cache_creation_input_token_cost: 0.0000075
|
||||||
|
cache_read_input_token_cost: 0.0000006
|
||||||
|
```
|
||||||
|
|
||||||
|
## Debugging
|
||||||
|
|
||||||
If you're custom pricing is not being used or you're seeing errors, please check the following:
|
If you're custom pricing is not being used or you're seeing errors, please check the following:
|
||||||
|
|
||||||
|
|
|
@ -161,6 +161,83 @@ Here's the available UI roles for a LiteLLM Internal User:
|
||||||
- `internal_user`: can login, view/create/delete their own keys, view their spend. **Cannot** add new users.
|
- `internal_user`: can login, view/create/delete their own keys, view their spend. **Cannot** add new users.
|
||||||
- `internal_user_viewer`: can login, view their own keys, view their own spend. **Cannot** create/delete keys, add new users.
|
- `internal_user_viewer`: can login, view their own keys, view their own spend. **Cannot** create/delete keys, add new users.
|
||||||
|
|
||||||
|
## Auto-add SSO users to teams
|
||||||
|
|
||||||
|
This walks through setting up sso auto-add for **Okta, Google SSO**
|
||||||
|
|
||||||
|
### Okta, Google SSO
|
||||||
|
|
||||||
|
1. Specify the JWT field that contains the team ids, that the user belongs to.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
general_settings:
|
||||||
|
master_key: sk-1234
|
||||||
|
litellm_jwtauth:
|
||||||
|
team_ids_jwt_field: "groups" # 👈 CAN BE ANY FIELD
|
||||||
|
```
|
||||||
|
|
||||||
|
This is assuming your SSO token looks like this. **If you need to inspect the JWT fields received from your SSO provider by LiteLLM, follow these instructions [here](#debugging-sso-jwt-fields)**
|
||||||
|
|
||||||
|
```
|
||||||
|
{
|
||||||
|
...,
|
||||||
|
"groups": ["team_id_1", "team_id_2"]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Create the teams on LiteLLM
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST '<PROXY_BASE_URL>/team/new' \
|
||||||
|
-H 'Authorization: Bearer <PROXY_MASTER_KEY>' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-D '{
|
||||||
|
"team_alias": "team_1",
|
||||||
|
"team_id": "team_id_1" # 👈 MUST BE THE SAME AS THE SSO GROUP ID
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test the SSO flow
|
||||||
|
|
||||||
|
Here's a walkthrough of [how it works](https://www.loom.com/share/8959be458edf41fd85937452c29a33f3?sid=7ebd6d37-569a-4023-866e-e0cde67cb23e)
|
||||||
|
|
||||||
|
### Microsoft Entra ID SSO group assignment
|
||||||
|
|
||||||
|
Follow this [tutorial for auto-adding sso users to teams with Microsoft Entra ID](https://docs.litellm.ai/docs/tutorials/msft_sso)
|
||||||
|
|
||||||
|
### Debugging SSO JWT fields
|
||||||
|
|
||||||
|
If you need to inspect the JWT fields received from your SSO provider by LiteLLM, follow these instructions. This guide walks you through setting up a debug callback to view the JWT data during the SSO process.
|
||||||
|
|
||||||
|
|
||||||
|
<Image img={require('../../img/debug_sso.png')} style={{ width: '500px', height: 'auto' }} />
|
||||||
|
<br />
|
||||||
|
|
||||||
|
1. Add `/sso/debug/callback` as a redirect URL in your SSO provider
|
||||||
|
|
||||||
|
In your SSO provider's settings, add the following URL as a new redirect (callback) URL:
|
||||||
|
|
||||||
|
```bash showLineNumbers title="Redirect URL"
|
||||||
|
http://<proxy_base_url>/sso/debug/callback
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
2. Navigate to the debug login page on your browser
|
||||||
|
|
||||||
|
Navigate to the following URL on your browser:
|
||||||
|
|
||||||
|
```bash showLineNumbers title="URL to navigate to"
|
||||||
|
https://<proxy_base_url>/sso/debug/login
|
||||||
|
```
|
||||||
|
|
||||||
|
This will initiate the standard SSO flow. You will be redirected to your SSO provider's login screen, and after successful authentication, you will be redirected back to LiteLLM's debug callback route.
|
||||||
|
|
||||||
|
|
||||||
|
3. View the JWT fields
|
||||||
|
|
||||||
|
Once redirected, you should see a page called "SSO Debug Information". This page displays the JWT fields received from your SSO provider (as shown in the image above)
|
||||||
|
|
||||||
|
|
||||||
## Advanced
|
## Advanced
|
||||||
### Setting custom logout URLs
|
### Setting custom logout URLs
|
||||||
|
|
||||||
|
@ -196,40 +273,26 @@ This budget does not apply to keys created under non-default teams.
|
||||||
|
|
||||||
[**Go Here**](./team_budgets.md)
|
[**Go Here**](./team_budgets.md)
|
||||||
|
|
||||||
### Auto-add SSO users to teams
|
### Set default params for new teams
|
||||||
|
|
||||||
1. Specify the JWT field that contains the team ids, that the user belongs to.
|
When you connect litellm to your SSO provider, litellm can auto-create teams. Use this to set the default `models`, `max_budget`, `budget_duration` for these auto-created teams.
|
||||||
|
|
||||||
```yaml
|
**How it works**
|
||||||
general_settings:
|
|
||||||
master_key: sk-1234
|
1. When litellm fetches `groups` from your SSO provider, it will check if the corresponding group_id exists as a `team_id` in litellm.
|
||||||
litellm_jwtauth:
|
2. If the team_id does not exist, litellm will auto-create a team with the default params you've set.
|
||||||
team_ids_jwt_field: "groups" # 👈 CAN BE ANY FIELD
|
3. If the team_id already exist, litellm will not apply any settings on the team.
|
||||||
|
|
||||||
|
**Usage**
|
||||||
|
|
||||||
|
```yaml showLineNumbers title="Default Params for new teams"
|
||||||
|
litellm_settings:
|
||||||
|
default_team_params: # Default Params to apply when litellm auto creates a team from SSO IDP provider
|
||||||
|
max_budget: 100 # Optional[float], optional): $100 budget for the team
|
||||||
|
budget_duration: 30d # Optional[str], optional): 30 days budget_duration for the team
|
||||||
|
models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by the team
|
||||||
```
|
```
|
||||||
|
|
||||||
This is assuming your SSO token looks like this:
|
|
||||||
```
|
|
||||||
{
|
|
||||||
...,
|
|
||||||
"groups": ["team_id_1", "team_id_2"]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Create the teams on LiteLLM
|
|
||||||
|
|
||||||
```bash
|
|
||||||
curl -X POST '<PROXY_BASE_URL>/team/new' \
|
|
||||||
-H 'Authorization: Bearer <PROXY_MASTER_KEY>' \
|
|
||||||
-H 'Content-Type: application/json' \
|
|
||||||
-D '{
|
|
||||||
"team_alias": "team_1",
|
|
||||||
"team_id": "team_id_1" # 👈 MUST BE THE SAME AS THE SSO GROUP ID
|
|
||||||
}'
|
|
||||||
```
|
|
||||||
|
|
||||||
3. Test the SSO flow
|
|
||||||
|
|
||||||
Here's a walkthrough of [how it works](https://www.loom.com/share/8959be458edf41fd85937452c29a33f3?sid=7ebd6d37-569a-4023-866e-e0cde67cb23e)
|
|
||||||
|
|
||||||
### Restrict Users from creating personal keys
|
### Restrict Users from creating personal keys
|
||||||
|
|
||||||
|
@ -241,7 +304,7 @@ This will also prevent users from using their session tokens on the test keys ch
|
||||||
|
|
||||||
## **All Settings for Self Serve / SSO Flow**
|
## **All Settings for Self Serve / SSO Flow**
|
||||||
|
|
||||||
```yaml
|
```yaml showLineNumbers title="All Settings for Self Serve / SSO Flow"
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
max_internal_user_budget: 10 # max budget for internal users
|
max_internal_user_budget: 10 # max budget for internal users
|
||||||
internal_user_budget_duration: "1mo" # reset every month
|
internal_user_budget_duration: "1mo" # reset every month
|
||||||
|
@ -251,6 +314,11 @@ litellm_settings:
|
||||||
max_budget: 100 # Optional[float], optional): $100 budget for a new SSO sign in user
|
max_budget: 100 # Optional[float], optional): $100 budget for a new SSO sign in user
|
||||||
budget_duration: 30d # Optional[str], optional): 30 days budget_duration for a new SSO sign in user
|
budget_duration: 30d # Optional[str], optional): 30 days budget_duration for a new SSO sign in user
|
||||||
models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by a new SSO sign in user
|
models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by a new SSO sign in user
|
||||||
|
|
||||||
|
default_team_params: # Default Params to apply when litellm auto creates a team from SSO IDP provider
|
||||||
|
max_budget: 100 # Optional[float], optional): $100 budget for the team
|
||||||
|
budget_duration: 30d # Optional[str], optional): 30 days budget_duration for the team
|
||||||
|
models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by the team
|
||||||
|
|
||||||
|
|
||||||
upperbound_key_generate_params: # Upperbound for /key/generate requests when self-serve flow is on
|
upperbound_key_generate_params: # Upperbound for /key/generate requests when self-serve flow is on
|
||||||
|
|
162
docs/my-website/docs/tutorials/msft_sso.md
Normal file
|
@ -0,0 +1,162 @@
|
||||||
|
import Image from '@theme/IdealImage';
|
||||||
|
|
||||||
|
# Microsoft SSO: Sync Groups, Members with LiteLLM
|
||||||
|
|
||||||
|
Sync Microsoft SSO Groups, Members with LiteLLM Teams.
|
||||||
|
|
||||||
|
<Image img={require('../../img/litellm_entra_id.png')} style={{ width: '800px', height: 'auto' }} />
|
||||||
|
|
||||||
|
<br />
|
||||||
|
<br />
|
||||||
|
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
- An Azure Entra ID account with administrative access
|
||||||
|
- A LiteLLM Enterprise App set up in your Azure Portal
|
||||||
|
- Access to Microsoft Entra ID (Azure AD)
|
||||||
|
|
||||||
|
|
||||||
|
## Overview of this tutorial
|
||||||
|
|
||||||
|
1. Auto-Create Entra ID Groups on LiteLLM Teams
|
||||||
|
2. Sync Entra ID Team Memberships
|
||||||
|
3. Set default params for new teams and users auto-created on LiteLLM
|
||||||
|
|
||||||
|
## 1. Auto-Create Entra ID Groups on LiteLLM Teams
|
||||||
|
|
||||||
|
In this step, our goal is to have LiteLLM automatically create a new team on the LiteLLM DB when there is a new Group Added to the LiteLLM Enterprise App on Azure Entra ID.
|
||||||
|
|
||||||
|
### 1.1 Create a new group in Entra ID
|
||||||
|
|
||||||
|
|
||||||
|
Navigate to [your Azure Portal](https://portal.azure.com/) > Groups > New Group. Create a new group.
|
||||||
|
|
||||||
|
<Image img={require('../../img/entra_create_team.png')} style={{ width: '800px', height: 'auto' }} />
|
||||||
|
|
||||||
|
### 1.2 Assign the group to your LiteLLM Enterprise App
|
||||||
|
|
||||||
|
On your Azure Portal, navigate to `Enterprise Applications` > Select your litellm app
|
||||||
|
|
||||||
|
<Image img={require('../../img/msft_enterprise_app.png')} style={{ width: '800px', height: 'auto' }} />
|
||||||
|
|
||||||
|
<br />
|
||||||
|
<br />
|
||||||
|
|
||||||
|
Once you've selected your litellm app, click on `Users and Groups` > `Add user/group`
|
||||||
|
|
||||||
|
<Image img={require('../../img/msft_enterprise_assign_group.png')} style={{ width: '800px', height: 'auto' }} />
|
||||||
|
|
||||||
|
<br />
|
||||||
|
|
||||||
|
Now select the group you created in step 1.1. And add it to the LiteLLM Enterprise App. At this point we have added `Production LLM Evals Group` to the LiteLLM Enterprise App. The next steps is having LiteLLM automatically create the `Production LLM Evals Group` on the LiteLLM DB when a new user signs in.
|
||||||
|
|
||||||
|
<Image img={require('../../img/msft_enterprise_select_group.png')} style={{ width: '800px', height: 'auto' }} />
|
||||||
|
|
||||||
|
|
||||||
|
### 1.3 Sign in to LiteLLM UI via SSO
|
||||||
|
|
||||||
|
Sign into the LiteLLM UI via SSO. You should be redirected to the Entra ID SSO page. This SSO sign in flow will trigger LiteLLM to fetch the latest Groups and Members from Azure Entra ID.
|
||||||
|
|
||||||
|
<Image img={require('../../img/msft_sso_sign_in.png')} style={{ width: '800px', height: 'auto' }} />
|
||||||
|
|
||||||
|
### 1.4 Check the new team on LiteLLM UI
|
||||||
|
|
||||||
|
On the LiteLLM UI, Navigate to `Teams`, You should see the new team `Production LLM Evals Group` auto-created on LiteLLM.
|
||||||
|
|
||||||
|
<Image img={require('../../img/msft_auto_team.png')} style={{ width: '900px', height: 'auto' }} />
|
||||||
|
|
||||||
|
#### How this works
|
||||||
|
|
||||||
|
When a SSO user signs in to LiteLLM:
|
||||||
|
- LiteLLM automatically fetches the Groups under the LiteLLM Enterprise App
|
||||||
|
- It finds the Production LLM Evals Group assigned to the LiteLLM Enterprise App
|
||||||
|
- LiteLLM checks if this group's ID exists in the LiteLLM Teams Table
|
||||||
|
- Since the ID doesn't exist, LiteLLM automatically creates a new team with:
|
||||||
|
- Name: Production LLM Evals Group
|
||||||
|
- ID: Same as the Entra ID group's ID
|
||||||
|
|
||||||
|
## 2. Sync Entra ID Team Memberships
|
||||||
|
|
||||||
|
In this step, we will have LiteLLM automatically add a user to the `Production LLM Evals` Team on the LiteLLM DB when a new user is added to the `Production LLM Evals` Group in Entra ID.
|
||||||
|
|
||||||
|
### 2.1 Navigate to the `Production LLM Evals` Group in Entra ID
|
||||||
|
|
||||||
|
Navigate to the `Production LLM Evals` Group in Entra ID.
|
||||||
|
|
||||||
|
<Image img={require('../../img/msft_member_1.png')} style={{ width: '800px', height: 'auto' }} />
|
||||||
|
|
||||||
|
|
||||||
|
### 2.2 Add a member to the group in Entra ID
|
||||||
|
|
||||||
|
Select `Members` > `Add members`
|
||||||
|
|
||||||
|
In this stage you should add the user you want to add to the `Production LLM Evals` Team.
|
||||||
|
|
||||||
|
<Image img={require('../../img/msft_member_2.png')} style={{ width: '800px', height: 'auto' }} />
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### 2.3 Sign in as the new user on LiteLLM UI
|
||||||
|
|
||||||
|
Sign in as the new user on LiteLLM UI. You should be redirected to the Entra ID SSO page. This SSO sign in flow will trigger LiteLLM to fetch the latest Groups and Members from Azure Entra ID. During this step LiteLLM sync it's teams, team members with what is available from Entra ID
|
||||||
|
|
||||||
|
<Image img={require('../../img/msft_sso_sign_in.png')} style={{ width: '800px', height: 'auto' }} />
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### 2.4 Check the team membership on LiteLLM UI
|
||||||
|
|
||||||
|
On the LiteLLM UI, Navigate to `Teams`, You should see the new team `Production LLM Evals Group`. Since your are now a member of the `Production LLM Evals Group` in Entra ID, you should see the new team `Production LLM Evals Group` on the LiteLLM UI.
|
||||||
|
|
||||||
|
<Image img={require('../../img/msft_member_3.png')} style={{ width: '900px', height: 'auto' }} />
|
||||||
|
|
||||||
|
## 3. Set default params for new teams auto-created on LiteLLM
|
||||||
|
|
||||||
|
Since litellm auto creates a new team on the LiteLLM DB when there is a new Group Added to the LiteLLM Enterprise App on Azure Entra ID, we can set default params for new teams created.
|
||||||
|
|
||||||
|
This allows you to set a default budget, models, etc for new teams created.
|
||||||
|
|
||||||
|
### 3.1 Set `default_team_params` on litellm
|
||||||
|
|
||||||
|
Navigate to your litellm config file and set the following params
|
||||||
|
|
||||||
|
```yaml showLineNumbers title="litellm config with default_team_params"
|
||||||
|
litellm_settings:
|
||||||
|
default_team_params: # Default Params to apply when litellm auto creates a team from SSO IDP provider
|
||||||
|
max_budget: 100 # Optional[float], optional): $100 budget for the team
|
||||||
|
budget_duration: 30d # Optional[str], optional): 30 days budget_duration for the team
|
||||||
|
models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by the team
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3.2 Auto-create a new team on LiteLLM
|
||||||
|
|
||||||
|
- In this step you should add a new group to the LiteLLM Enterprise App on Azure Entra ID (like we did in step 1.1). We will call this group `Default LiteLLM Prod Team` on Azure Entra ID.
|
||||||
|
- Start litellm proxy server with your config
|
||||||
|
- Sign into LiteLLM UI via SSO
|
||||||
|
- Navigate to `Teams` and you should see the new team `Default LiteLLM Prod Team` auto-created on LiteLLM
|
||||||
|
- Note LiteLLM will set the default params for this new team.
|
||||||
|
|
||||||
|
<Image img={require('../../img/msft_default_settings.png')} style={{ width: '900px', height: 'auto' }} />
|
||||||
|
|
||||||
|
|
||||||
|
## Video Walkthrough
|
||||||
|
|
||||||
|
This walks through setting up sso auto-add for **Microsoft Entra ID**
|
||||||
|
|
||||||
|
Follow along this video for a walkthrough of how to set this up with Microsoft Entra ID
|
||||||
|
|
||||||
|
<iframe width="840" height="500" src="https://www.loom.com/embed/ea711323aa9a496d84a01fd7b2a12f54?sid=c53e238c-5bfd-4135-b8fb-b5b1a08632cf" frameborder="0" webkitallowfullscreen mozallowfullscreen allowfullscreen></iframe>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
BIN
docs/my-website/img/debug_sso.png
Normal file
After Width: | Height: | Size: 167 KiB |
BIN
docs/my-website/img/entra_create_team.png
Normal file
After Width: | Height: | Size: 180 KiB |
BIN
docs/my-website/img/litellm_entra_id.png
Normal file
After Width: | Height: | Size: 35 KiB |
BIN
docs/my-website/img/msft_auto_team.png
Normal file
After Width: | Height: | Size: 62 KiB |
BIN
docs/my-website/img/msft_default_settings.png
Normal file
After Width: | Height: | Size: 141 KiB |
BIN
docs/my-website/img/msft_enterprise_app.png
Normal file
After Width: | Height: | Size: 292 KiB |
BIN
docs/my-website/img/msft_enterprise_assign_group.png
Normal file
After Width: | Height: | Size: 277 KiB |
BIN
docs/my-website/img/msft_enterprise_select_group.png
Normal file
After Width: | Height: | Size: 245 KiB |
BIN
docs/my-website/img/msft_member_1.png
Normal file
After Width: | Height: | Size: 296 KiB |
BIN
docs/my-website/img/msft_member_2.png
Normal file
After Width: | Height: | Size: 274 KiB |
BIN
docs/my-website/img/msft_member_3.png
Normal file
After Width: | Height: | Size: 186 KiB |
BIN
docs/my-website/img/msft_sso_sign_in.png
Normal file
After Width: | Height: | Size: 818 KiB |
|
@ -435,6 +435,7 @@ const sidebars = {
|
||||||
label: "Tutorials",
|
label: "Tutorials",
|
||||||
items: [
|
items: [
|
||||||
"tutorials/openweb_ui",
|
"tutorials/openweb_ui",
|
||||||
|
"tutorials/msft_sso",
|
||||||
'tutorials/litellm_proxy_aporia',
|
'tutorials/litellm_proxy_aporia',
|
||||||
{
|
{
|
||||||
type: "category",
|
type: "category",
|
||||||
|
|
161
docs/my-website/src/components/TransformRequestPlayground.tsx
Normal file
|
@ -0,0 +1,161 @@
|
||||||
|
import React, { useState } from 'react';
|
||||||
|
import styles from './transform_request.module.css';
|
||||||
|
|
||||||
|
const DEFAULT_REQUEST = {
|
||||||
|
"model": "bedrock/gpt-4",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Explain quantum computing in simple terms"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_tokens": 500,
|
||||||
|
"stream": true
|
||||||
|
};
|
||||||
|
|
||||||
|
type ViewMode = 'split' | 'request' | 'transformed';
|
||||||
|
|
||||||
|
const TransformRequestPlayground: React.FC = () => {
|
||||||
|
const [request, setRequest] = useState(JSON.stringify(DEFAULT_REQUEST, null, 2));
|
||||||
|
const [transformedRequest, setTransformedRequest] = useState('');
|
||||||
|
const [viewMode, setViewMode] = useState<ViewMode>('split');
|
||||||
|
|
||||||
|
const handleTransform = async () => {
|
||||||
|
try {
|
||||||
|
// Here you would make the actual API call to transform the request
|
||||||
|
// For now, we'll just set a sample response
|
||||||
|
const sampleResponse = `curl -X POST \\
|
||||||
|
https://api.openai.com/v1/chat/completions \\
|
||||||
|
-H 'Authorization: Bearer sk-xxx' \\
|
||||||
|
-H 'Content-Type: application/json' \\
|
||||||
|
-d '{
|
||||||
|
"model": "gpt-4",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant."
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"temperature": 0.7
|
||||||
|
}'`;
|
||||||
|
setTransformedRequest(sampleResponse);
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error transforming request:', error);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleCopy = () => {
|
||||||
|
navigator.clipboard.writeText(transformedRequest);
|
||||||
|
};
|
||||||
|
|
||||||
|
const renderContent = () => {
|
||||||
|
switch (viewMode) {
|
||||||
|
case 'request':
|
||||||
|
return (
|
||||||
|
<div className={styles.panel}>
|
||||||
|
<div className={styles['panel-header']}>
|
||||||
|
<h2>Original Request</h2>
|
||||||
|
<p>The request you would send to LiteLLM /chat/completions endpoint.</p>
|
||||||
|
</div>
|
||||||
|
<textarea
|
||||||
|
className={styles['code-input']}
|
||||||
|
value={request}
|
||||||
|
onChange={(e) => setRequest(e.target.value)}
|
||||||
|
spellCheck={false}
|
||||||
|
/>
|
||||||
|
<div className={styles['panel-footer']}>
|
||||||
|
<button className={styles['transform-button']} onClick={handleTransform}>
|
||||||
|
Transform →
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
case 'transformed':
|
||||||
|
return (
|
||||||
|
<div className={styles.panel}>
|
||||||
|
<div className={styles['panel-header']}>
|
||||||
|
<h2>Transformed Request</h2>
|
||||||
|
<p>How LiteLLM transforms your request for the specified provider.</p>
|
||||||
|
<p className={styles.note}>Note: Sensitive headers are not shown.</p>
|
||||||
|
</div>
|
||||||
|
<div className={styles['code-output-container']}>
|
||||||
|
<pre className={styles['code-output']}>{transformedRequest}</pre>
|
||||||
|
<button className={styles['copy-button']} onClick={handleCopy}>
|
||||||
|
Copy
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
default:
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<div className={styles.panel}>
|
||||||
|
<div className={styles['panel-header']}>
|
||||||
|
<h2>Original Request</h2>
|
||||||
|
<p>The request you would send to LiteLLM /chat/completions endpoint.</p>
|
||||||
|
</div>
|
||||||
|
<textarea
|
||||||
|
className={styles['code-input']}
|
||||||
|
value={request}
|
||||||
|
onChange={(e) => setRequest(e.target.value)}
|
||||||
|
spellCheck={false}
|
||||||
|
/>
|
||||||
|
<div className={styles['panel-footer']}>
|
||||||
|
<button className={styles['transform-button']} onClick={handleTransform}>
|
||||||
|
Transform →
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className={styles.panel}>
|
||||||
|
<div className={styles['panel-header']}>
|
||||||
|
<h2>Transformed Request</h2>
|
||||||
|
<p>How LiteLLM transforms your request for the specified provider.</p>
|
||||||
|
<p className={styles.note}>Note: Sensitive headers are not shown.</p>
|
||||||
|
</div>
|
||||||
|
<div className={styles['code-output-container']}>
|
||||||
|
<pre className={styles['code-output']}>{transformedRequest}</pre>
|
||||||
|
<button className={styles['copy-button']} onClick={handleCopy}>
|
||||||
|
Copy
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className={styles['transform-playground']}>
|
||||||
|
<div className={styles['view-toggle']}>
|
||||||
|
<button
|
||||||
|
className={viewMode === 'split' ? styles.active : ''}
|
||||||
|
onClick={() => setViewMode('split')}
|
||||||
|
>
|
||||||
|
Split View
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
className={viewMode === 'request' ? styles.active : ''}
|
||||||
|
onClick={() => setViewMode('request')}
|
||||||
|
>
|
||||||
|
Request
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
className={viewMode === 'transformed' ? styles.active : ''}
|
||||||
|
onClick={() => setViewMode('transformed')}
|
||||||
|
>
|
||||||
|
Transformed
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
<div className={styles['playground-container']}>
|
||||||
|
{renderContent()}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default TransformRequestPlayground;
|
|
@ -65,6 +65,7 @@ from litellm.proxy._types import (
|
||||||
KeyManagementSystem,
|
KeyManagementSystem,
|
||||||
KeyManagementSettings,
|
KeyManagementSettings,
|
||||||
LiteLLM_UpperboundKeyGenerateParams,
|
LiteLLM_UpperboundKeyGenerateParams,
|
||||||
|
NewTeamRequest,
|
||||||
)
|
)
|
||||||
from litellm.types.utils import StandardKeyGenerationConfig, LlmProviders
|
from litellm.types.utils import StandardKeyGenerationConfig, LlmProviders
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
@ -126,19 +127,19 @@ prometheus_initialize_budget_metrics: Optional[bool] = False
|
||||||
require_auth_for_metrics_endpoint: Optional[bool] = False
|
require_auth_for_metrics_endpoint: Optional[bool] = False
|
||||||
argilla_batch_size: Optional[int] = None
|
argilla_batch_size: Optional[int] = None
|
||||||
datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload
|
datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload
|
||||||
gcs_pub_sub_use_v1: Optional[
|
gcs_pub_sub_use_v1: Optional[bool] = (
|
||||||
bool
|
False # if you want to use v1 gcs pubsub logged payload
|
||||||
] = False # if you want to use v1 gcs pubsub logged payload
|
)
|
||||||
argilla_transformation_object: Optional[Dict[str, Any]] = None
|
argilla_transformation_object: Optional[Dict[str, Any]] = None
|
||||||
_async_input_callback: List[
|
_async_input_callback: List[Union[str, Callable, CustomLogger]] = (
|
||||||
Union[str, Callable, CustomLogger]
|
[]
|
||||||
] = [] # internal variable - async custom callbacks are routed here.
|
) # internal variable - async custom callbacks are routed here.
|
||||||
_async_success_callback: List[
|
_async_success_callback: List[Union[str, Callable, CustomLogger]] = (
|
||||||
Union[str, Callable, CustomLogger]
|
[]
|
||||||
] = [] # internal variable - async custom callbacks are routed here.
|
) # internal variable - async custom callbacks are routed here.
|
||||||
_async_failure_callback: List[
|
_async_failure_callback: List[Union[str, Callable, CustomLogger]] = (
|
||||||
Union[str, Callable, CustomLogger]
|
[]
|
||||||
] = [] # internal variable - async custom callbacks are routed here.
|
) # internal variable - async custom callbacks are routed here.
|
||||||
pre_call_rules: List[Callable] = []
|
pre_call_rules: List[Callable] = []
|
||||||
post_call_rules: List[Callable] = []
|
post_call_rules: List[Callable] = []
|
||||||
turn_off_message_logging: Optional[bool] = False
|
turn_off_message_logging: Optional[bool] = False
|
||||||
|
@ -146,18 +147,18 @@ log_raw_request_response: bool = False
|
||||||
redact_messages_in_exceptions: Optional[bool] = False
|
redact_messages_in_exceptions: Optional[bool] = False
|
||||||
redact_user_api_key_info: Optional[bool] = False
|
redact_user_api_key_info: Optional[bool] = False
|
||||||
filter_invalid_headers: Optional[bool] = False
|
filter_invalid_headers: Optional[bool] = False
|
||||||
add_user_information_to_llm_headers: Optional[
|
add_user_information_to_llm_headers: Optional[bool] = (
|
||||||
bool
|
None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers
|
||||||
] = None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers
|
)
|
||||||
store_audit_logs = False # Enterprise feature, allow users to see audit logs
|
store_audit_logs = False # Enterprise feature, allow users to see audit logs
|
||||||
### end of callbacks #############
|
### end of callbacks #############
|
||||||
|
|
||||||
email: Optional[
|
email: Optional[str] = (
|
||||||
str
|
None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||||
] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
)
|
||||||
token: Optional[
|
token: Optional[str] = (
|
||||||
str
|
None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||||
] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
)
|
||||||
telemetry = True
|
telemetry = True
|
||||||
max_tokens: int = DEFAULT_MAX_TOKENS # OpenAI Defaults
|
max_tokens: int = DEFAULT_MAX_TOKENS # OpenAI Defaults
|
||||||
drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False))
|
drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False))
|
||||||
|
@ -233,20 +234,24 @@ enable_loadbalancing_on_batch_endpoints: Optional[bool] = None
|
||||||
enable_caching_on_provider_specific_optional_params: bool = (
|
enable_caching_on_provider_specific_optional_params: bool = (
|
||||||
False # feature-flag for caching on optional params - e.g. 'top_k'
|
False # feature-flag for caching on optional params - e.g. 'top_k'
|
||||||
)
|
)
|
||||||
caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
caching: bool = (
|
||||||
caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||||
cache: Optional[
|
)
|
||||||
Cache
|
caching_with_models: bool = (
|
||||||
] = None # cache object <- use this - https://docs.litellm.ai/docs/caching
|
False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||||
|
)
|
||||||
|
cache: Optional[Cache] = (
|
||||||
|
None # cache object <- use this - https://docs.litellm.ai/docs/caching
|
||||||
|
)
|
||||||
default_in_memory_ttl: Optional[float] = None
|
default_in_memory_ttl: Optional[float] = None
|
||||||
default_redis_ttl: Optional[float] = None
|
default_redis_ttl: Optional[float] = None
|
||||||
default_redis_batch_cache_expiry: Optional[float] = None
|
default_redis_batch_cache_expiry: Optional[float] = None
|
||||||
model_alias_map: Dict[str, str] = {}
|
model_alias_map: Dict[str, str] = {}
|
||||||
model_group_alias_map: Dict[str, str] = {}
|
model_group_alias_map: Dict[str, str] = {}
|
||||||
max_budget: float = 0.0 # set the max budget across all providers
|
max_budget: float = 0.0 # set the max budget across all providers
|
||||||
budget_duration: Optional[
|
budget_duration: Optional[str] = (
|
||||||
str
|
None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
|
||||||
] = None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
|
)
|
||||||
default_soft_budget: float = (
|
default_soft_budget: float = (
|
||||||
DEFAULT_SOFT_BUDGET # by default all litellm proxy keys have a soft budget of 50.0
|
DEFAULT_SOFT_BUDGET # by default all litellm proxy keys have a soft budget of 50.0
|
||||||
)
|
)
|
||||||
|
@ -255,11 +260,15 @@ forward_traceparent_to_llm_provider: bool = False
|
||||||
|
|
||||||
_current_cost = 0.0 # private variable, used if max budget is set
|
_current_cost = 0.0 # private variable, used if max budget is set
|
||||||
error_logs: Dict = {}
|
error_logs: Dict = {}
|
||||||
add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt
|
add_function_to_prompt: bool = (
|
||||||
|
False # if function calling not supported by api, append function call details to system prompt
|
||||||
|
)
|
||||||
client_session: Optional[httpx.Client] = None
|
client_session: Optional[httpx.Client] = None
|
||||||
aclient_session: Optional[httpx.AsyncClient] = None
|
aclient_session: Optional[httpx.AsyncClient] = None
|
||||||
model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks'
|
model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks'
|
||||||
model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
model_cost_map_url: str = (
|
||||||
|
"https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
||||||
|
)
|
||||||
suppress_debug_info = False
|
suppress_debug_info = False
|
||||||
dynamodb_table_name: Optional[str] = None
|
dynamodb_table_name: Optional[str] = None
|
||||||
s3_callback_params: Optional[Dict] = None
|
s3_callback_params: Optional[Dict] = None
|
||||||
|
@ -268,6 +277,7 @@ default_key_generate_params: Optional[Dict] = None
|
||||||
upperbound_key_generate_params: Optional[LiteLLM_UpperboundKeyGenerateParams] = None
|
upperbound_key_generate_params: Optional[LiteLLM_UpperboundKeyGenerateParams] = None
|
||||||
key_generation_settings: Optional[StandardKeyGenerationConfig] = None
|
key_generation_settings: Optional[StandardKeyGenerationConfig] = None
|
||||||
default_internal_user_params: Optional[Dict] = None
|
default_internal_user_params: Optional[Dict] = None
|
||||||
|
default_team_params: Optional[Union[NewTeamRequest, Dict]] = None
|
||||||
default_team_settings: Optional[List] = None
|
default_team_settings: Optional[List] = None
|
||||||
max_user_budget: Optional[float] = None
|
max_user_budget: Optional[float] = None
|
||||||
default_max_internal_user_budget: Optional[float] = None
|
default_max_internal_user_budget: Optional[float] = None
|
||||||
|
@ -281,7 +291,9 @@ disable_end_user_cost_tracking_prometheus_only: Optional[bool] = None
|
||||||
custom_prometheus_metadata_labels: List[str] = []
|
custom_prometheus_metadata_labels: List[str] = []
|
||||||
#### REQUEST PRIORITIZATION ####
|
#### REQUEST PRIORITIZATION ####
|
||||||
priority_reservation: Optional[Dict[str, float]] = None
|
priority_reservation: Optional[Dict[str, float]] = None
|
||||||
force_ipv4: bool = False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6.
|
force_ipv4: bool = (
|
||||||
|
False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6.
|
||||||
|
)
|
||||||
module_level_aclient = AsyncHTTPHandler(
|
module_level_aclient = AsyncHTTPHandler(
|
||||||
timeout=request_timeout, client_alias="module level aclient"
|
timeout=request_timeout, client_alias="module level aclient"
|
||||||
)
|
)
|
||||||
|
@ -295,13 +307,13 @@ fallbacks: Optional[List] = None
|
||||||
context_window_fallbacks: Optional[List] = None
|
context_window_fallbacks: Optional[List] = None
|
||||||
content_policy_fallbacks: Optional[List] = None
|
content_policy_fallbacks: Optional[List] = None
|
||||||
allowed_fails: int = 3
|
allowed_fails: int = 3
|
||||||
num_retries_per_request: Optional[
|
num_retries_per_request: Optional[int] = (
|
||||||
int
|
None # for the request overall (incl. fallbacks + model retries)
|
||||||
] = None # for the request overall (incl. fallbacks + model retries)
|
)
|
||||||
####### SECRET MANAGERS #####################
|
####### SECRET MANAGERS #####################
|
||||||
secret_manager_client: Optional[
|
secret_manager_client: Optional[Any] = (
|
||||||
Any
|
None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
|
||||||
] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
|
)
|
||||||
_google_kms_resource_name: Optional[str] = None
|
_google_kms_resource_name: Optional[str] = None
|
||||||
_key_management_system: Optional[KeyManagementSystem] = None
|
_key_management_system: Optional[KeyManagementSystem] = None
|
||||||
_key_management_settings: KeyManagementSettings = KeyManagementSettings()
|
_key_management_settings: KeyManagementSettings = KeyManagementSettings()
|
||||||
|
@ -1050,10 +1062,10 @@ from .types.llms.custom_llm import CustomLLMItem
|
||||||
from .types.utils import GenericStreamingChunk
|
from .types.utils import GenericStreamingChunk
|
||||||
|
|
||||||
custom_provider_map: List[CustomLLMItem] = []
|
custom_provider_map: List[CustomLLMItem] = []
|
||||||
_custom_providers: List[
|
_custom_providers: List[str] = (
|
||||||
str
|
[]
|
||||||
] = [] # internal helper util, used to track names of custom providers
|
) # internal helper util, used to track names of custom providers
|
||||||
disable_hf_tokenizer_download: Optional[
|
disable_hf_tokenizer_download: Optional[bool] = (
|
||||||
bool
|
None # disable huggingface tokenizer download. Defaults to openai clk100
|
||||||
] = None # disable huggingface tokenizer download. Defaults to openai clk100
|
)
|
||||||
global_disable_no_log_param: bool = False
|
global_disable_no_log_param: bool = False
|
||||||
|
|
|
@ -480,6 +480,7 @@ RESPONSE_FORMAT_TOOL_NAME = "json_tool_call" # default tool name used when conv
|
||||||
|
|
||||||
########################### Logging Callback Constants ###########################
|
########################### Logging Callback Constants ###########################
|
||||||
AZURE_STORAGE_MSFT_VERSION = "2019-07-07"
|
AZURE_STORAGE_MSFT_VERSION = "2019-07-07"
|
||||||
|
PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES = 5
|
||||||
MCP_TOOL_NAME_PREFIX = "mcp_tool"
|
MCP_TOOL_NAME_PREFIX = "mcp_tool"
|
||||||
|
|
||||||
########################### LiteLLM Proxy Specific Constants ###########################
|
########################### LiteLLM Proxy Specific Constants ###########################
|
||||||
|
@ -514,6 +515,7 @@ LITELLM_PROXY_ADMIN_NAME = "default_user_id"
|
||||||
|
|
||||||
########################### DB CRON JOB NAMES ###########################
|
########################### DB CRON JOB NAMES ###########################
|
||||||
DB_SPEND_UPDATE_JOB_NAME = "db_spend_update_job"
|
DB_SPEND_UPDATE_JOB_NAME = "db_spend_update_job"
|
||||||
|
PROMETHEUS_EMIT_BUDGET_METRICS_JOB_NAME = "prometheus_emit_budget_metrics_job"
|
||||||
DEFAULT_CRON_JOB_LOCK_TTL_SECONDS = 60 # 1 minute
|
DEFAULT_CRON_JOB_LOCK_TTL_SECONDS = 60 # 1 minute
|
||||||
PROXY_BUDGET_RESCHEDULER_MIN_TIME = 597
|
PROXY_BUDGET_RESCHEDULER_MIN_TIME = 597
|
||||||
PROXY_BUDGET_RESCHEDULER_MAX_TIME = 605
|
PROXY_BUDGET_RESCHEDULER_MAX_TIME = 605
|
||||||
|
|
|
@ -16,7 +16,10 @@ from litellm.constants import (
|
||||||
from litellm.litellm_core_utils.llm_cost_calc.tool_call_cost_tracking import (
|
from litellm.litellm_core_utils.llm_cost_calc.tool_call_cost_tracking import (
|
||||||
StandardBuiltInToolCostTracking,
|
StandardBuiltInToolCostTracking,
|
||||||
)
|
)
|
||||||
from litellm.litellm_core_utils.llm_cost_calc.utils import _generic_cost_per_character
|
from litellm.litellm_core_utils.llm_cost_calc.utils import (
|
||||||
|
_generic_cost_per_character,
|
||||||
|
generic_cost_per_token,
|
||||||
|
)
|
||||||
from litellm.llms.anthropic.cost_calculation import (
|
from litellm.llms.anthropic.cost_calculation import (
|
||||||
cost_per_token as anthropic_cost_per_token,
|
cost_per_token as anthropic_cost_per_token,
|
||||||
)
|
)
|
||||||
|
@ -54,12 +57,16 @@ from litellm.llms.vertex_ai.image_generation.cost_calculator import (
|
||||||
from litellm.responses.utils import ResponseAPILoggingUtils
|
from litellm.responses.utils import ResponseAPILoggingUtils
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
HttpxBinaryResponseContent,
|
HttpxBinaryResponseContent,
|
||||||
|
OpenAIRealtimeStreamList,
|
||||||
|
OpenAIRealtimeStreamResponseBaseObject,
|
||||||
|
OpenAIRealtimeStreamSessionEvents,
|
||||||
ResponseAPIUsage,
|
ResponseAPIUsage,
|
||||||
ResponsesAPIResponse,
|
ResponsesAPIResponse,
|
||||||
)
|
)
|
||||||
from litellm.types.rerank import RerankBilledUnits, RerankResponse
|
from litellm.types.rerank import RerankBilledUnits, RerankResponse
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import (
|
||||||
CallTypesLiteral,
|
CallTypesLiteral,
|
||||||
|
LiteLLMRealtimeStreamLoggingObject,
|
||||||
LlmProviders,
|
LlmProviders,
|
||||||
LlmProvidersSet,
|
LlmProvidersSet,
|
||||||
ModelInfo,
|
ModelInfo,
|
||||||
|
@ -397,6 +404,7 @@ def _select_model_name_for_cost_calc(
|
||||||
base_model: Optional[str] = None,
|
base_model: Optional[str] = None,
|
||||||
custom_pricing: Optional[bool] = None,
|
custom_pricing: Optional[bool] = None,
|
||||||
custom_llm_provider: Optional[str] = None,
|
custom_llm_provider: Optional[str] = None,
|
||||||
|
router_model_id: Optional[str] = None,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
1. If custom pricing is true, return received model name
|
1. If custom pricing is true, return received model name
|
||||||
|
@ -411,12 +419,6 @@ def _select_model_name_for_cost_calc(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
)
|
)
|
||||||
|
|
||||||
if custom_pricing is True:
|
|
||||||
return_model = model
|
|
||||||
|
|
||||||
if base_model is not None:
|
|
||||||
return_model = base_model
|
|
||||||
|
|
||||||
completion_response_model: Optional[str] = None
|
completion_response_model: Optional[str] = None
|
||||||
if completion_response is not None:
|
if completion_response is not None:
|
||||||
if isinstance(completion_response, BaseModel):
|
if isinstance(completion_response, BaseModel):
|
||||||
|
@ -424,6 +426,16 @@ def _select_model_name_for_cost_calc(
|
||||||
elif isinstance(completion_response, dict):
|
elif isinstance(completion_response, dict):
|
||||||
completion_response_model = completion_response.get("model", None)
|
completion_response_model = completion_response.get("model", None)
|
||||||
hidden_params: Optional[dict] = getattr(completion_response, "_hidden_params", None)
|
hidden_params: Optional[dict] = getattr(completion_response, "_hidden_params", None)
|
||||||
|
|
||||||
|
if custom_pricing is True:
|
||||||
|
if router_model_id is not None and router_model_id in litellm.model_cost:
|
||||||
|
return_model = router_model_id
|
||||||
|
else:
|
||||||
|
return_model = model
|
||||||
|
|
||||||
|
if base_model is not None:
|
||||||
|
return_model = base_model
|
||||||
|
|
||||||
if completion_response_model is None and hidden_params is not None:
|
if completion_response_model is None and hidden_params is not None:
|
||||||
if (
|
if (
|
||||||
hidden_params.get("model", None) is not None
|
hidden_params.get("model", None) is not None
|
||||||
|
@ -553,6 +565,7 @@ def completion_cost( # noqa: PLR0915
|
||||||
base_model: Optional[str] = None,
|
base_model: Optional[str] = None,
|
||||||
standard_built_in_tools_params: Optional[StandardBuiltInToolsParams] = None,
|
standard_built_in_tools_params: Optional[StandardBuiltInToolsParams] = None,
|
||||||
litellm_model_name: Optional[str] = None,
|
litellm_model_name: Optional[str] = None,
|
||||||
|
router_model_id: Optional[str] = None,
|
||||||
) -> float:
|
) -> float:
|
||||||
"""
|
"""
|
||||||
Calculate the cost of a given completion call fot GPT-3.5-turbo, llama2, any litellm supported llm.
|
Calculate the cost of a given completion call fot GPT-3.5-turbo, llama2, any litellm supported llm.
|
||||||
|
@ -605,18 +618,19 @@ def completion_cost( # noqa: PLR0915
|
||||||
completion_response=completion_response
|
completion_response=completion_response
|
||||||
)
|
)
|
||||||
rerank_billed_units: Optional[RerankBilledUnits] = None
|
rerank_billed_units: Optional[RerankBilledUnits] = None
|
||||||
|
|
||||||
selected_model = _select_model_name_for_cost_calc(
|
selected_model = _select_model_name_for_cost_calc(
|
||||||
model=model,
|
model=model,
|
||||||
completion_response=completion_response,
|
completion_response=completion_response,
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
custom_pricing=custom_pricing,
|
custom_pricing=custom_pricing,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
|
router_model_id=router_model_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
potential_model_names = [selected_model]
|
potential_model_names = [selected_model]
|
||||||
if model is not None:
|
if model is not None:
|
||||||
potential_model_names.append(model)
|
potential_model_names.append(model)
|
||||||
|
|
||||||
for idx, model in enumerate(potential_model_names):
|
for idx, model in enumerate(potential_model_names):
|
||||||
try:
|
try:
|
||||||
verbose_logger.info(
|
verbose_logger.info(
|
||||||
|
@ -780,6 +794,25 @@ def completion_cost( # noqa: PLR0915
|
||||||
billed_units.get("search_units") or 1
|
billed_units.get("search_units") or 1
|
||||||
) # cohere charges per request by default.
|
) # cohere charges per request by default.
|
||||||
completion_tokens = search_units
|
completion_tokens = search_units
|
||||||
|
elif call_type == CallTypes.arealtime.value and isinstance(
|
||||||
|
completion_response, LiteLLMRealtimeStreamLoggingObject
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
cost_per_token_usage_object is None
|
||||||
|
or custom_llm_provider is None
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"usage object and custom_llm_provider must be provided for realtime stream cost calculation. Got cost_per_token_usage_object={}, custom_llm_provider={}".format(
|
||||||
|
cost_per_token_usage_object,
|
||||||
|
custom_llm_provider,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return handle_realtime_stream_cost_calculation(
|
||||||
|
results=completion_response.results,
|
||||||
|
combined_usage_object=cost_per_token_usage_object,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
litellm_model_name=model,
|
||||||
|
)
|
||||||
# Calculate cost based on prompt_tokens, completion_tokens
|
# Calculate cost based on prompt_tokens, completion_tokens
|
||||||
if (
|
if (
|
||||||
"togethercomputer" in model
|
"togethercomputer" in model
|
||||||
|
@ -909,6 +942,7 @@ def response_cost_calculator(
|
||||||
HttpxBinaryResponseContent,
|
HttpxBinaryResponseContent,
|
||||||
RerankResponse,
|
RerankResponse,
|
||||||
ResponsesAPIResponse,
|
ResponsesAPIResponse,
|
||||||
|
LiteLLMRealtimeStreamLoggingObject,
|
||||||
],
|
],
|
||||||
model: str,
|
model: str,
|
||||||
custom_llm_provider: Optional[str],
|
custom_llm_provider: Optional[str],
|
||||||
|
@ -937,6 +971,7 @@ def response_cost_calculator(
|
||||||
prompt: str = "",
|
prompt: str = "",
|
||||||
standard_built_in_tools_params: Optional[StandardBuiltInToolsParams] = None,
|
standard_built_in_tools_params: Optional[StandardBuiltInToolsParams] = None,
|
||||||
litellm_model_name: Optional[str] = None,
|
litellm_model_name: Optional[str] = None,
|
||||||
|
router_model_id: Optional[str] = None,
|
||||||
) -> float:
|
) -> float:
|
||||||
"""
|
"""
|
||||||
Returns
|
Returns
|
||||||
|
@ -967,6 +1002,8 @@ def response_cost_calculator(
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
standard_built_in_tools_params=standard_built_in_tools_params,
|
standard_built_in_tools_params=standard_built_in_tools_params,
|
||||||
|
litellm_model_name=litellm_model_name,
|
||||||
|
router_model_id=router_model_id,
|
||||||
)
|
)
|
||||||
return response_cost
|
return response_cost
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -1141,3 +1178,173 @@ def batch_cost_calculator(
|
||||||
) # batch cost is usually half of the regular token cost
|
) # batch cost is usually half of the regular token cost
|
||||||
|
|
||||||
return total_prompt_cost, total_completion_cost
|
return total_prompt_cost, total_completion_cost
|
||||||
|
|
||||||
|
|
||||||
|
class RealtimeAPITokenUsageProcessor:
|
||||||
|
@staticmethod
|
||||||
|
def collect_usage_from_realtime_stream_results(
|
||||||
|
results: OpenAIRealtimeStreamList,
|
||||||
|
) -> List[Usage]:
|
||||||
|
"""
|
||||||
|
Collect usage from realtime stream results
|
||||||
|
"""
|
||||||
|
response_done_events: List[OpenAIRealtimeStreamResponseBaseObject] = cast(
|
||||||
|
List[OpenAIRealtimeStreamResponseBaseObject],
|
||||||
|
[result for result in results if result["type"] == "response.done"],
|
||||||
|
)
|
||||||
|
usage_objects: List[Usage] = []
|
||||||
|
for result in response_done_events:
|
||||||
|
usage_object = (
|
||||||
|
ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage(
|
||||||
|
result["response"].get("usage", {})
|
||||||
|
)
|
||||||
|
)
|
||||||
|
usage_objects.append(usage_object)
|
||||||
|
return usage_objects
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def combine_usage_objects(usage_objects: List[Usage]) -> Usage:
|
||||||
|
"""
|
||||||
|
Combine multiple Usage objects into a single Usage object, checking model keys for nested values.
|
||||||
|
"""
|
||||||
|
from litellm.types.utils import (
|
||||||
|
CompletionTokensDetails,
|
||||||
|
PromptTokensDetailsWrapper,
|
||||||
|
Usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
combined = Usage()
|
||||||
|
|
||||||
|
# Sum basic token counts
|
||||||
|
for usage in usage_objects:
|
||||||
|
# Handle direct attributes by checking what exists in the model
|
||||||
|
for attr in dir(usage):
|
||||||
|
if not attr.startswith("_") and not callable(getattr(usage, attr)):
|
||||||
|
current_val = getattr(combined, attr, 0)
|
||||||
|
new_val = getattr(usage, attr, 0)
|
||||||
|
if (
|
||||||
|
new_val is not None
|
||||||
|
and isinstance(new_val, (int, float))
|
||||||
|
and isinstance(current_val, (int, float))
|
||||||
|
):
|
||||||
|
setattr(combined, attr, current_val + new_val)
|
||||||
|
# Handle nested prompt_tokens_details
|
||||||
|
if hasattr(usage, "prompt_tokens_details") and usage.prompt_tokens_details:
|
||||||
|
if (
|
||||||
|
not hasattr(combined, "prompt_tokens_details")
|
||||||
|
or not combined.prompt_tokens_details
|
||||||
|
):
|
||||||
|
combined.prompt_tokens_details = PromptTokensDetailsWrapper()
|
||||||
|
|
||||||
|
# Check what keys exist in the model's prompt_tokens_details
|
||||||
|
for attr in dir(usage.prompt_tokens_details):
|
||||||
|
if not attr.startswith("_") and not callable(
|
||||||
|
getattr(usage.prompt_tokens_details, attr)
|
||||||
|
):
|
||||||
|
current_val = getattr(combined.prompt_tokens_details, attr, 0)
|
||||||
|
new_val = getattr(usage.prompt_tokens_details, attr, 0)
|
||||||
|
if new_val is not None:
|
||||||
|
setattr(
|
||||||
|
combined.prompt_tokens_details,
|
||||||
|
attr,
|
||||||
|
current_val + new_val,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle nested completion_tokens_details
|
||||||
|
if (
|
||||||
|
hasattr(usage, "completion_tokens_details")
|
||||||
|
and usage.completion_tokens_details
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
not hasattr(combined, "completion_tokens_details")
|
||||||
|
or not combined.completion_tokens_details
|
||||||
|
):
|
||||||
|
combined.completion_tokens_details = CompletionTokensDetails()
|
||||||
|
|
||||||
|
# Check what keys exist in the model's completion_tokens_details
|
||||||
|
for attr in dir(usage.completion_tokens_details):
|
||||||
|
if not attr.startswith("_") and not callable(
|
||||||
|
getattr(usage.completion_tokens_details, attr)
|
||||||
|
):
|
||||||
|
current_val = getattr(
|
||||||
|
combined.completion_tokens_details, attr, 0
|
||||||
|
)
|
||||||
|
new_val = getattr(usage.completion_tokens_details, attr, 0)
|
||||||
|
if new_val is not None:
|
||||||
|
setattr(
|
||||||
|
combined.completion_tokens_details,
|
||||||
|
attr,
|
||||||
|
current_val + new_val,
|
||||||
|
)
|
||||||
|
|
||||||
|
return combined
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def collect_and_combine_usage_from_realtime_stream_results(
|
||||||
|
results: OpenAIRealtimeStreamList,
|
||||||
|
) -> Usage:
|
||||||
|
"""
|
||||||
|
Collect and combine usage from realtime stream results
|
||||||
|
"""
|
||||||
|
collected_usage_objects = (
|
||||||
|
RealtimeAPITokenUsageProcessor.collect_usage_from_realtime_stream_results(
|
||||||
|
results
|
||||||
|
)
|
||||||
|
)
|
||||||
|
combined_usage_object = RealtimeAPITokenUsageProcessor.combine_usage_objects(
|
||||||
|
collected_usage_objects
|
||||||
|
)
|
||||||
|
return combined_usage_object
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_logging_realtime_object(
|
||||||
|
usage: Usage, results: OpenAIRealtimeStreamList
|
||||||
|
) -> LiteLLMRealtimeStreamLoggingObject:
|
||||||
|
return LiteLLMRealtimeStreamLoggingObject(
|
||||||
|
usage=usage,
|
||||||
|
results=results,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_realtime_stream_cost_calculation(
|
||||||
|
results: OpenAIRealtimeStreamList,
|
||||||
|
combined_usage_object: Usage,
|
||||||
|
custom_llm_provider: str,
|
||||||
|
litellm_model_name: str,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Handles the cost calculation for realtime stream responses.
|
||||||
|
|
||||||
|
Pick the 'response.done' events. Calculate total cost across all 'response.done' events.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: A list of OpenAIRealtimeStreamBaseObject objects
|
||||||
|
"""
|
||||||
|
received_model = None
|
||||||
|
potential_model_names = []
|
||||||
|
for result in results:
|
||||||
|
if result["type"] == "session.created":
|
||||||
|
received_model = cast(OpenAIRealtimeStreamSessionEvents, result)["session"][
|
||||||
|
"model"
|
||||||
|
]
|
||||||
|
potential_model_names.append(received_model)
|
||||||
|
|
||||||
|
potential_model_names.append(litellm_model_name)
|
||||||
|
input_cost_per_token = 0.0
|
||||||
|
output_cost_per_token = 0.0
|
||||||
|
|
||||||
|
for model_name in potential_model_names:
|
||||||
|
try:
|
||||||
|
_input_cost_per_token, _output_cost_per_token = generic_cost_per_token(
|
||||||
|
model=model_name,
|
||||||
|
usage=combined_usage_object,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
input_cost_per_token += _input_cost_per_token
|
||||||
|
output_cost_per_token += _output_cost_per_token
|
||||||
|
break # exit if we find a valid model
|
||||||
|
total_cost = input_cost_per_token + output_cost_per_token
|
||||||
|
|
||||||
|
return total_cost
|
||||||
|
|
|
@ -1,10 +1,19 @@
|
||||||
# used for /metrics endpoint on LiteLLM Proxy
|
# used for /metrics endpoint on LiteLLM Proxy
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success, log events to Prometheus
|
# On success, log events to Prometheus
|
||||||
import asyncio
|
|
||||||
import sys
|
import sys
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Any, Awaitable, Callable, List, Literal, Optional, Tuple, cast
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Awaitable,
|
||||||
|
Callable,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import print_verbose, verbose_logger
|
from litellm._logging import print_verbose, verbose_logger
|
||||||
|
@ -14,6 +23,11 @@ from litellm.types.integrations.prometheus import *
|
||||||
from litellm.types.utils import StandardLoggingPayload
|
from litellm.types.utils import StandardLoggingPayload
|
||||||
from litellm.utils import get_end_user_id_for_cost_tracking
|
from litellm.utils import get_end_user_id_for_cost_tracking
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||||
|
else:
|
||||||
|
AsyncIOScheduler = Any
|
||||||
|
|
||||||
|
|
||||||
class PrometheusLogger(CustomLogger):
|
class PrometheusLogger(CustomLogger):
|
||||||
# Class variables or attributes
|
# Class variables or attributes
|
||||||
|
@ -359,8 +373,6 @@ class PrometheusLogger(CustomLogger):
|
||||||
label_name="litellm_requests_metric"
|
label_name="litellm_requests_metric"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self._initialize_prometheus_startup_metrics()
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_verbose(f"Got exception on init prometheus client {str(e)}")
|
print_verbose(f"Got exception on init prometheus client {str(e)}")
|
||||||
raise e
|
raise e
|
||||||
|
@ -988,9 +1000,9 @@ class PrometheusLogger(CustomLogger):
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
verbose_logger.debug("setting remaining tokens requests metric")
|
verbose_logger.debug("setting remaining tokens requests metric")
|
||||||
standard_logging_payload: Optional[
|
standard_logging_payload: Optional[StandardLoggingPayload] = (
|
||||||
StandardLoggingPayload
|
request_kwargs.get("standard_logging_object")
|
||||||
] = request_kwargs.get("standard_logging_object")
|
)
|
||||||
|
|
||||||
if standard_logging_payload is None:
|
if standard_logging_payload is None:
|
||||||
return
|
return
|
||||||
|
@ -1337,24 +1349,6 @@ class PrometheusLogger(CustomLogger):
|
||||||
|
|
||||||
return max_budget - spend
|
return max_budget - spend
|
||||||
|
|
||||||
def _initialize_prometheus_startup_metrics(self):
|
|
||||||
"""
|
|
||||||
Initialize prometheus startup metrics
|
|
||||||
|
|
||||||
Helper to create tasks for initializing metrics that are required on startup - eg. remaining budget metrics
|
|
||||||
"""
|
|
||||||
if litellm.prometheus_initialize_budget_metrics is not True:
|
|
||||||
verbose_logger.debug("Prometheus: skipping budget metrics initialization")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
if asyncio.get_running_loop():
|
|
||||||
asyncio.create_task(self._initialize_remaining_budget_metrics())
|
|
||||||
except RuntimeError as e: # no running event loop
|
|
||||||
verbose_logger.exception(
|
|
||||||
f"No running event loop - skipping budget metrics initialization: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _initialize_budget_metrics(
|
async def _initialize_budget_metrics(
|
||||||
self,
|
self,
|
||||||
data_fetch_function: Callable[..., Awaitable[Tuple[List[Any], Optional[int]]]],
|
data_fetch_function: Callable[..., Awaitable[Tuple[List[Any], Optional[int]]]],
|
||||||
|
@ -1475,12 +1469,41 @@ class PrometheusLogger(CustomLogger):
|
||||||
data_type="keys",
|
data_type="keys",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _initialize_remaining_budget_metrics(self):
|
async def initialize_remaining_budget_metrics(self):
|
||||||
"""
|
"""
|
||||||
Initialize remaining budget metrics for all teams to avoid metric discrepancies.
|
Handler for initializing remaining budget metrics for all teams to avoid metric discrepancies.
|
||||||
|
|
||||||
Runs when prometheus logger starts up.
|
Runs when prometheus logger starts up.
|
||||||
|
|
||||||
|
- If redis cache is available, we use the pod lock manager to acquire a lock and initialize the metrics.
|
||||||
|
- Ensures only one pod emits the metrics at a time.
|
||||||
|
- If redis cache is not available, we initialize the metrics directly.
|
||||||
"""
|
"""
|
||||||
|
from litellm.constants import PROMETHEUS_EMIT_BUDGET_METRICS_JOB_NAME
|
||||||
|
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||||
|
|
||||||
|
pod_lock_manager = proxy_logging_obj.db_spend_update_writer.pod_lock_manager
|
||||||
|
|
||||||
|
# if using redis, ensure only one pod emits the metrics at a time
|
||||||
|
if pod_lock_manager and pod_lock_manager.redis_cache:
|
||||||
|
if await pod_lock_manager.acquire_lock(
|
||||||
|
cronjob_id=PROMETHEUS_EMIT_BUDGET_METRICS_JOB_NAME
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
await self._initialize_remaining_budget_metrics()
|
||||||
|
finally:
|
||||||
|
await pod_lock_manager.release_lock(
|
||||||
|
cronjob_id=PROMETHEUS_EMIT_BUDGET_METRICS_JOB_NAME
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# if not using redis, initialize the metrics directly
|
||||||
|
await self._initialize_remaining_budget_metrics()
|
||||||
|
|
||||||
|
async def _initialize_remaining_budget_metrics(self):
|
||||||
|
"""
|
||||||
|
Helper to initialize remaining budget metrics for all teams and API keys.
|
||||||
|
"""
|
||||||
|
verbose_logger.debug("Emitting key, team budget metrics....")
|
||||||
await self._initialize_team_budget_metrics()
|
await self._initialize_team_budget_metrics()
|
||||||
await self._initialize_api_key_budget_metrics()
|
await self._initialize_api_key_budget_metrics()
|
||||||
|
|
||||||
|
@ -1737,6 +1760,36 @@ class PrometheusLogger(CustomLogger):
|
||||||
return (end_time - start_time).total_seconds()
|
return (end_time - start_time).total_seconds()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def initialize_budget_metrics_cron_job(scheduler: AsyncIOScheduler):
|
||||||
|
"""
|
||||||
|
Initialize budget metrics as a cron job. This job runs every `PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES` minutes.
|
||||||
|
|
||||||
|
It emits the current remaining budget metrics for all Keys and Teams.
|
||||||
|
"""
|
||||||
|
from litellm.constants import PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm.integrations.prometheus import PrometheusLogger
|
||||||
|
|
||||||
|
prometheus_loggers: List[CustomLogger] = (
|
||||||
|
litellm.logging_callback_manager.get_custom_loggers_for_type(
|
||||||
|
callback_type=PrometheusLogger
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# we need to get the initialized prometheus logger instance(s) and call logger.initialize_remaining_budget_metrics() on them
|
||||||
|
verbose_logger.debug("found %s prometheus loggers", len(prometheus_loggers))
|
||||||
|
if len(prometheus_loggers) > 0:
|
||||||
|
prometheus_logger = cast(PrometheusLogger, prometheus_loggers[0])
|
||||||
|
verbose_logger.debug(
|
||||||
|
"Initializing remaining budget metrics as a cron job executing every %s minutes"
|
||||||
|
% PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES
|
||||||
|
)
|
||||||
|
scheduler.add_job(
|
||||||
|
prometheus_logger.initialize_remaining_budget_metrics,
|
||||||
|
"interval",
|
||||||
|
minutes=PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _mount_metrics_endpoint(premium_user: bool):
|
def _mount_metrics_endpoint(premium_user: bool):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -110,5 +110,8 @@ def get_litellm_params(
|
||||||
"azure_password": kwargs.get("azure_password"),
|
"azure_password": kwargs.get("azure_password"),
|
||||||
"max_retries": max_retries,
|
"max_retries": max_retries,
|
||||||
"timeout": kwargs.get("timeout"),
|
"timeout": kwargs.get("timeout"),
|
||||||
|
"bucket_name": kwargs.get("bucket_name"),
|
||||||
|
"vertex_credentials": kwargs.get("vertex_credentials"),
|
||||||
|
"vertex_project": kwargs.get("vertex_project"),
|
||||||
}
|
}
|
||||||
return litellm_params
|
return litellm_params
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import LlmProviders
|
|
||||||
from litellm.exceptions import BadRequestError
|
from litellm.exceptions import BadRequestError
|
||||||
|
from litellm.types.utils import LlmProviders, LlmProvidersSet
|
||||||
|
|
||||||
|
|
||||||
def get_supported_openai_params( # noqa: PLR0915
|
def get_supported_openai_params( # noqa: PLR0915
|
||||||
|
@ -30,6 +30,20 @@ def get_supported_openai_params( # noqa: PLR0915
|
||||||
except BadRequestError:
|
except BadRequestError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if custom_llm_provider in LlmProvidersSet:
|
||||||
|
provider_config = litellm.ProviderConfigManager.get_provider_chat_config(
|
||||||
|
model=model, provider=LlmProviders(custom_llm_provider)
|
||||||
|
)
|
||||||
|
elif custom_llm_provider.split("/")[0] in LlmProvidersSet:
|
||||||
|
provider_config = litellm.ProviderConfigManager.get_provider_chat_config(
|
||||||
|
model=model, provider=LlmProviders(custom_llm_provider.split("/")[0])
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
provider_config = None
|
||||||
|
|
||||||
|
if provider_config and request_type == "chat_completion":
|
||||||
|
return provider_config.get_supported_openai_params(model=model)
|
||||||
|
|
||||||
if custom_llm_provider == "bedrock":
|
if custom_llm_provider == "bedrock":
|
||||||
return litellm.AmazonConverseConfig().get_supported_openai_params(model=model)
|
return litellm.AmazonConverseConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "ollama":
|
elif custom_llm_provider == "ollama":
|
||||||
|
@ -226,7 +240,8 @@ def get_supported_openai_params( # noqa: PLR0915
|
||||||
provider_config = litellm.ProviderConfigManager.get_provider_chat_config(
|
provider_config = litellm.ProviderConfigManager.get_provider_chat_config(
|
||||||
model=model, provider=LlmProviders.CUSTOM
|
model=model, provider=LlmProviders.CUSTOM
|
||||||
)
|
)
|
||||||
return provider_config.get_supported_openai_params(model=model)
|
if provider_config:
|
||||||
|
return provider_config.get_supported_openai_params(model=model)
|
||||||
elif request_type == "embeddings":
|
elif request_type == "embeddings":
|
||||||
return None
|
return None
|
||||||
elif request_type == "transcription":
|
elif request_type == "transcription":
|
||||||
|
|
|
@ -32,7 +32,10 @@ from litellm.constants import (
|
||||||
DEFAULT_MOCK_RESPONSE_COMPLETION_TOKEN_COUNT,
|
DEFAULT_MOCK_RESPONSE_COMPLETION_TOKEN_COUNT,
|
||||||
DEFAULT_MOCK_RESPONSE_PROMPT_TOKEN_COUNT,
|
DEFAULT_MOCK_RESPONSE_PROMPT_TOKEN_COUNT,
|
||||||
)
|
)
|
||||||
from litellm.cost_calculator import _select_model_name_for_cost_calc
|
from litellm.cost_calculator import (
|
||||||
|
RealtimeAPITokenUsageProcessor,
|
||||||
|
_select_model_name_for_cost_calc,
|
||||||
|
)
|
||||||
from litellm.integrations.arize.arize import ArizeLogger
|
from litellm.integrations.arize.arize import ArizeLogger
|
||||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
@ -64,6 +67,7 @@ from litellm.types.utils import (
|
||||||
ImageResponse,
|
ImageResponse,
|
||||||
LiteLLMBatch,
|
LiteLLMBatch,
|
||||||
LiteLLMLoggingBaseClass,
|
LiteLLMLoggingBaseClass,
|
||||||
|
LiteLLMRealtimeStreamLoggingObject,
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
ModelResponseStream,
|
ModelResponseStream,
|
||||||
RawRequestTypedDict,
|
RawRequestTypedDict,
|
||||||
|
@ -618,7 +622,6 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
] = RawRequestTypedDict(
|
] = RawRequestTypedDict(
|
||||||
error=str(e),
|
error=str(e),
|
||||||
)
|
)
|
||||||
traceback.print_exc()
|
|
||||||
_metadata[
|
_metadata[
|
||||||
"raw_request"
|
"raw_request"
|
||||||
] = "Unable to Log \
|
] = "Unable to Log \
|
||||||
|
@ -899,9 +902,11 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
FineTuningJob,
|
FineTuningJob,
|
||||||
ResponsesAPIResponse,
|
ResponsesAPIResponse,
|
||||||
ResponseCompletedEvent,
|
ResponseCompletedEvent,
|
||||||
|
LiteLLMRealtimeStreamLoggingObject,
|
||||||
],
|
],
|
||||||
cache_hit: Optional[bool] = None,
|
cache_hit: Optional[bool] = None,
|
||||||
litellm_model_name: Optional[str] = None,
|
litellm_model_name: Optional[str] = None,
|
||||||
|
router_model_id: Optional[str] = None,
|
||||||
) -> Optional[float]:
|
) -> Optional[float]:
|
||||||
"""
|
"""
|
||||||
Calculate response cost using result + logging object variables.
|
Calculate response cost using result + logging object variables.
|
||||||
|
@ -940,6 +945,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
"custom_pricing": custom_pricing,
|
"custom_pricing": custom_pricing,
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"standard_built_in_tools_params": self.standard_built_in_tools_params,
|
"standard_built_in_tools_params": self.standard_built_in_tools_params,
|
||||||
|
"router_model_id": router_model_id,
|
||||||
}
|
}
|
||||||
except Exception as e: # error creating kwargs for cost calculation
|
except Exception as e: # error creating kwargs for cost calculation
|
||||||
debug_info = StandardLoggingModelCostFailureDebugInformation(
|
debug_info = StandardLoggingModelCostFailureDebugInformation(
|
||||||
|
@ -1049,26 +1055,50 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
result = self._handle_anthropic_messages_response_logging(result=result)
|
result = self._handle_anthropic_messages_response_logging(result=result)
|
||||||
## if model in model cost map - log the response cost
|
## if model in model cost map - log the response cost
|
||||||
## else set cost to None
|
## else set cost to None
|
||||||
|
|
||||||
|
logging_result = result
|
||||||
|
|
||||||
|
if self.call_type == CallTypes.arealtime.value and isinstance(result, list):
|
||||||
|
combined_usage_object = RealtimeAPITokenUsageProcessor.collect_and_combine_usage_from_realtime_stream_results(
|
||||||
|
results=result
|
||||||
|
)
|
||||||
|
logging_result = (
|
||||||
|
RealtimeAPITokenUsageProcessor.create_logging_realtime_object(
|
||||||
|
usage=combined_usage_object,
|
||||||
|
results=result,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# self.model_call_details[
|
||||||
|
# "response_cost"
|
||||||
|
# ] = handle_realtime_stream_cost_calculation(
|
||||||
|
# results=result,
|
||||||
|
# combined_usage_object=combined_usage_object,
|
||||||
|
# custom_llm_provider=self.custom_llm_provider,
|
||||||
|
# litellm_model_name=self.model,
|
||||||
|
# )
|
||||||
|
# self.model_call_details["combined_usage_object"] = combined_usage_object
|
||||||
if (
|
if (
|
||||||
standard_logging_object is None
|
standard_logging_object is None
|
||||||
and result is not None
|
and result is not None
|
||||||
and self.stream is not True
|
and self.stream is not True
|
||||||
):
|
):
|
||||||
if (
|
if (
|
||||||
isinstance(result, ModelResponse)
|
isinstance(logging_result, ModelResponse)
|
||||||
or isinstance(result, ModelResponseStream)
|
or isinstance(logging_result, ModelResponseStream)
|
||||||
or isinstance(result, EmbeddingResponse)
|
or isinstance(logging_result, EmbeddingResponse)
|
||||||
or isinstance(result, ImageResponse)
|
or isinstance(logging_result, ImageResponse)
|
||||||
or isinstance(result, TranscriptionResponse)
|
or isinstance(logging_result, TranscriptionResponse)
|
||||||
or isinstance(result, TextCompletionResponse)
|
or isinstance(logging_result, TextCompletionResponse)
|
||||||
or isinstance(result, HttpxBinaryResponseContent) # tts
|
or isinstance(logging_result, HttpxBinaryResponseContent) # tts
|
||||||
or isinstance(result, RerankResponse)
|
or isinstance(logging_result, RerankResponse)
|
||||||
or isinstance(result, FineTuningJob)
|
or isinstance(logging_result, FineTuningJob)
|
||||||
or isinstance(result, LiteLLMBatch)
|
or isinstance(logging_result, LiteLLMBatch)
|
||||||
or isinstance(result, ResponsesAPIResponse)
|
or isinstance(logging_result, ResponsesAPIResponse)
|
||||||
|
or isinstance(logging_result, LiteLLMRealtimeStreamLoggingObject)
|
||||||
):
|
):
|
||||||
## HIDDEN PARAMS ##
|
## HIDDEN PARAMS ##
|
||||||
hidden_params = getattr(result, "_hidden_params", {})
|
hidden_params = getattr(logging_result, "_hidden_params", {})
|
||||||
if hidden_params:
|
if hidden_params:
|
||||||
# add to metadata for logging
|
# add to metadata for logging
|
||||||
if self.model_call_details.get("litellm_params") is not None:
|
if self.model_call_details.get("litellm_params") is not None:
|
||||||
|
@ -1086,7 +1116,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
self.model_call_details["litellm_params"]["metadata"][ # type: ignore
|
self.model_call_details["litellm_params"]["metadata"][ # type: ignore
|
||||||
"hidden_params"
|
"hidden_params"
|
||||||
] = getattr(
|
] = getattr(
|
||||||
result, "_hidden_params", {}
|
logging_result, "_hidden_params", {}
|
||||||
)
|
)
|
||||||
## RESPONSE COST - Only calculate if not in hidden_params ##
|
## RESPONSE COST - Only calculate if not in hidden_params ##
|
||||||
if "response_cost" in hidden_params:
|
if "response_cost" in hidden_params:
|
||||||
|
@ -1096,14 +1126,14 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
else:
|
else:
|
||||||
self.model_call_details[
|
self.model_call_details[
|
||||||
"response_cost"
|
"response_cost"
|
||||||
] = self._response_cost_calculator(result=result)
|
] = self._response_cost_calculator(result=logging_result)
|
||||||
## STANDARDIZED LOGGING PAYLOAD
|
## STANDARDIZED LOGGING PAYLOAD
|
||||||
|
|
||||||
self.model_call_details[
|
self.model_call_details[
|
||||||
"standard_logging_object"
|
"standard_logging_object"
|
||||||
] = get_standard_logging_object_payload(
|
] = get_standard_logging_object_payload(
|
||||||
kwargs=self.model_call_details,
|
kwargs=self.model_call_details,
|
||||||
init_response_obj=result,
|
init_response_obj=logging_result,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
logging_obj=self,
|
logging_obj=self,
|
||||||
|
@ -3122,6 +3152,7 @@ class StandardLoggingPayloadSetup:
|
||||||
prompt_integration: Optional[str] = None,
|
prompt_integration: Optional[str] = None,
|
||||||
applied_guardrails: Optional[List[str]] = None,
|
applied_guardrails: Optional[List[str]] = None,
|
||||||
mcp_tool_call_metadata: Optional[StandardLoggingMCPToolCall] = None,
|
mcp_tool_call_metadata: Optional[StandardLoggingMCPToolCall] = None,
|
||||||
|
usage_object: Optional[dict] = None,
|
||||||
) -> StandardLoggingMetadata:
|
) -> StandardLoggingMetadata:
|
||||||
"""
|
"""
|
||||||
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
|
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
|
||||||
|
@ -3169,6 +3200,7 @@ class StandardLoggingPayloadSetup:
|
||||||
prompt_management_metadata=prompt_management_metadata,
|
prompt_management_metadata=prompt_management_metadata,
|
||||||
applied_guardrails=applied_guardrails,
|
applied_guardrails=applied_guardrails,
|
||||||
mcp_tool_call_metadata=mcp_tool_call_metadata,
|
mcp_tool_call_metadata=mcp_tool_call_metadata,
|
||||||
|
usage_object=usage_object,
|
||||||
)
|
)
|
||||||
if isinstance(metadata, dict):
|
if isinstance(metadata, dict):
|
||||||
# Filter the metadata dictionary to include only the specified keys
|
# Filter the metadata dictionary to include only the specified keys
|
||||||
|
@ -3194,8 +3226,12 @@ class StandardLoggingPayloadSetup:
|
||||||
return clean_metadata
|
return clean_metadata
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_usage_from_response_obj(response_obj: Optional[dict]) -> Usage:
|
def get_usage_from_response_obj(
|
||||||
|
response_obj: Optional[dict], combined_usage_object: Optional[Usage] = None
|
||||||
|
) -> Usage:
|
||||||
## BASE CASE ##
|
## BASE CASE ##
|
||||||
|
if combined_usage_object is not None:
|
||||||
|
return combined_usage_object
|
||||||
if response_obj is None:
|
if response_obj is None:
|
||||||
return Usage(
|
return Usage(
|
||||||
prompt_tokens=0,
|
prompt_tokens=0,
|
||||||
|
@ -3324,6 +3360,7 @@ class StandardLoggingPayloadSetup:
|
||||||
litellm_overhead_time_ms=None,
|
litellm_overhead_time_ms=None,
|
||||||
batch_models=None,
|
batch_models=None,
|
||||||
litellm_model_name=None,
|
litellm_model_name=None,
|
||||||
|
usage_object=None,
|
||||||
)
|
)
|
||||||
if hidden_params is not None:
|
if hidden_params is not None:
|
||||||
for key in StandardLoggingHiddenParams.__annotations__.keys():
|
for key in StandardLoggingHiddenParams.__annotations__.keys():
|
||||||
|
@ -3440,6 +3477,7 @@ def get_standard_logging_object_payload(
|
||||||
litellm_overhead_time_ms=None,
|
litellm_overhead_time_ms=None,
|
||||||
batch_models=None,
|
batch_models=None,
|
||||||
litellm_model_name=None,
|
litellm_model_name=None,
|
||||||
|
usage_object=None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -3456,8 +3494,12 @@ def get_standard_logging_object_payload(
|
||||||
call_type = kwargs.get("call_type")
|
call_type = kwargs.get("call_type")
|
||||||
cache_hit = kwargs.get("cache_hit", False)
|
cache_hit = kwargs.get("cache_hit", False)
|
||||||
usage = StandardLoggingPayloadSetup.get_usage_from_response_obj(
|
usage = StandardLoggingPayloadSetup.get_usage_from_response_obj(
|
||||||
response_obj=response_obj
|
response_obj=response_obj,
|
||||||
|
combined_usage_object=cast(
|
||||||
|
Optional[Usage], kwargs.get("combined_usage_object")
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
id = response_obj.get("id", kwargs.get("litellm_call_id"))
|
id = response_obj.get("id", kwargs.get("litellm_call_id"))
|
||||||
|
|
||||||
_model_id = metadata.get("model_info", {}).get("id", "")
|
_model_id = metadata.get("model_info", {}).get("id", "")
|
||||||
|
@ -3496,6 +3538,7 @@ def get_standard_logging_object_payload(
|
||||||
prompt_integration=kwargs.get("prompt_integration", None),
|
prompt_integration=kwargs.get("prompt_integration", None),
|
||||||
applied_guardrails=kwargs.get("applied_guardrails", None),
|
applied_guardrails=kwargs.get("applied_guardrails", None),
|
||||||
mcp_tool_call_metadata=kwargs.get("mcp_tool_call_metadata", None),
|
mcp_tool_call_metadata=kwargs.get("mcp_tool_call_metadata", None),
|
||||||
|
usage_object=usage.model_dump(),
|
||||||
)
|
)
|
||||||
|
|
||||||
_request_body = proxy_server_request.get("body", {})
|
_request_body = proxy_server_request.get("body", {})
|
||||||
|
@ -3636,6 +3679,7 @@ def get_standard_logging_metadata(
|
||||||
prompt_management_metadata=None,
|
prompt_management_metadata=None,
|
||||||
applied_guardrails=None,
|
applied_guardrails=None,
|
||||||
mcp_tool_call_metadata=None,
|
mcp_tool_call_metadata=None,
|
||||||
|
usage_object=None,
|
||||||
)
|
)
|
||||||
if isinstance(metadata, dict):
|
if isinstance(metadata, dict):
|
||||||
# Filter the metadata dictionary to include only the specified keys
|
# Filter the metadata dictionary to include only the specified keys
|
||||||
|
@ -3730,6 +3774,7 @@ def create_dummy_standard_logging_payload() -> StandardLoggingPayload:
|
||||||
litellm_overhead_time_ms=None,
|
litellm_overhead_time_ms=None,
|
||||||
batch_models=None,
|
batch_models=None,
|
||||||
litellm_model_name=None,
|
litellm_model_name=None,
|
||||||
|
usage_object=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert numeric values to appropriate types
|
# Convert numeric values to appropriate types
|
||||||
|
|
|
@ -90,35 +90,45 @@ def _generic_cost_per_character(
|
||||||
return prompt_cost, completion_cost
|
return prompt_cost, completion_cost
|
||||||
|
|
||||||
|
|
||||||
def _get_prompt_token_base_cost(model_info: ModelInfo, usage: Usage) -> float:
|
def _get_token_base_cost(model_info: ModelInfo, usage: Usage) -> Tuple[float, float]:
|
||||||
"""
|
"""
|
||||||
Return prompt cost for a given model and usage.
|
Return prompt cost for a given model and usage.
|
||||||
|
|
||||||
If input_tokens > 128k and `input_cost_per_token_above_128k_tokens` is set, then we use the `input_cost_per_token_above_128k_tokens` field.
|
If input_tokens > threshold and `input_cost_per_token_above_[x]k_tokens` or `input_cost_per_token_above_[x]_tokens` is set,
|
||||||
|
then we use the corresponding threshold cost.
|
||||||
"""
|
"""
|
||||||
input_cost_per_token_above_128k_tokens = model_info.get(
|
prompt_base_cost = model_info["input_cost_per_token"]
|
||||||
"input_cost_per_token_above_128k_tokens"
|
completion_base_cost = model_info["output_cost_per_token"]
|
||||||
)
|
|
||||||
if _is_above_128k(usage.prompt_tokens) and input_cost_per_token_above_128k_tokens:
|
|
||||||
return input_cost_per_token_above_128k_tokens
|
|
||||||
return model_info["input_cost_per_token"]
|
|
||||||
|
|
||||||
|
## CHECK IF ABOVE THRESHOLD
|
||||||
|
threshold: Optional[float] = None
|
||||||
|
for key, value in sorted(model_info.items(), reverse=True):
|
||||||
|
if key.startswith("input_cost_per_token_above_") and value is not None:
|
||||||
|
try:
|
||||||
|
# Handle both formats: _above_128k_tokens and _above_128_tokens
|
||||||
|
threshold_str = key.split("_above_")[1].split("_tokens")[0]
|
||||||
|
threshold = float(threshold_str.replace("k", "")) * (
|
||||||
|
1000 if "k" in threshold_str else 1
|
||||||
|
)
|
||||||
|
if usage.prompt_tokens > threshold:
|
||||||
|
prompt_base_cost = cast(
|
||||||
|
float,
|
||||||
|
model_info.get(key, prompt_base_cost),
|
||||||
|
)
|
||||||
|
completion_base_cost = cast(
|
||||||
|
float,
|
||||||
|
model_info.get(
|
||||||
|
f"output_cost_per_token_above_{threshold_str}_tokens",
|
||||||
|
completion_base_cost,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except (IndexError, ValueError):
|
||||||
|
continue
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
def _get_completion_token_base_cost(model_info: ModelInfo, usage: Usage) -> float:
|
return prompt_base_cost, completion_base_cost
|
||||||
"""
|
|
||||||
Return prompt cost for a given model and usage.
|
|
||||||
|
|
||||||
If input_tokens > 128k and `input_cost_per_token_above_128k_tokens` is set, then we use the `input_cost_per_token_above_128k_tokens` field.
|
|
||||||
"""
|
|
||||||
output_cost_per_token_above_128k_tokens = model_info.get(
|
|
||||||
"output_cost_per_token_above_128k_tokens"
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
_is_above_128k(usage.completion_tokens)
|
|
||||||
and output_cost_per_token_above_128k_tokens
|
|
||||||
):
|
|
||||||
return output_cost_per_token_above_128k_tokens
|
|
||||||
return model_info["output_cost_per_token"]
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_cost_component(
|
def calculate_cost_component(
|
||||||
|
@ -215,7 +225,9 @@ def generic_cost_per_token(
|
||||||
if text_tokens == 0:
|
if text_tokens == 0:
|
||||||
text_tokens = usage.prompt_tokens - cache_hit_tokens - audio_tokens
|
text_tokens = usage.prompt_tokens - cache_hit_tokens - audio_tokens
|
||||||
|
|
||||||
prompt_base_cost = _get_prompt_token_base_cost(model_info=model_info, usage=usage)
|
prompt_base_cost, completion_base_cost = _get_token_base_cost(
|
||||||
|
model_info=model_info, usage=usage
|
||||||
|
)
|
||||||
|
|
||||||
prompt_cost = float(text_tokens) * prompt_base_cost
|
prompt_cost = float(text_tokens) * prompt_base_cost
|
||||||
|
|
||||||
|
@ -253,9 +265,6 @@ def generic_cost_per_token(
|
||||||
)
|
)
|
||||||
|
|
||||||
## CALCULATE OUTPUT COST
|
## CALCULATE OUTPUT COST
|
||||||
completion_base_cost = _get_completion_token_base_cost(
|
|
||||||
model_info=model_info, usage=usage
|
|
||||||
)
|
|
||||||
text_tokens = usage.completion_tokens
|
text_tokens = usage.completion_tokens
|
||||||
audio_tokens = 0
|
audio_tokens = 0
|
||||||
if usage.completion_tokens_details is not None:
|
if usage.completion_tokens_details is not None:
|
||||||
|
|
|
@ -36,11 +36,16 @@ class ResponseMetadata:
|
||||||
self, logging_obj: LiteLLMLoggingObject, model: Optional[str], kwargs: dict
|
self, logging_obj: LiteLLMLoggingObject, model: Optional[str], kwargs: dict
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set hidden parameters on the response"""
|
"""Set hidden parameters on the response"""
|
||||||
|
|
||||||
|
## ADD OTHER HIDDEN PARAMS
|
||||||
|
model_id = kwargs.get("model_info", {}).get("id", None)
|
||||||
new_params = {
|
new_params = {
|
||||||
"litellm_call_id": getattr(logging_obj, "litellm_call_id", None),
|
"litellm_call_id": getattr(logging_obj, "litellm_call_id", None),
|
||||||
"model_id": kwargs.get("model_info", {}).get("id", None),
|
|
||||||
"api_base": get_api_base(model=model or "", optional_params=kwargs),
|
"api_base": get_api_base(model=model or "", optional_params=kwargs),
|
||||||
"response_cost": logging_obj._response_cost_calculator(result=self.result),
|
"model_id": model_id,
|
||||||
|
"response_cost": logging_obj._response_cost_calculator(
|
||||||
|
result=self.result, litellm_model_name=model, router_model_id=model_id
|
||||||
|
),
|
||||||
"additional_headers": process_response_headers(
|
"additional_headers": process_response_headers(
|
||||||
self._get_value_from_hidden_params("additional_headers") or {}
|
self._get_value_from_hidden_params("additional_headers") or {}
|
||||||
),
|
),
|
||||||
|
|
|
@ -2,7 +2,10 @@
|
||||||
Common utility functions used for translating messages across providers
|
Common utility functions used for translating messages across providers
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Dict, List, Literal, Optional, Union, cast
|
import io
|
||||||
|
import mimetypes
|
||||||
|
from os import PathLike
|
||||||
|
from typing import Dict, List, Literal, Mapping, Optional, Union, cast
|
||||||
|
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
AllMessageValues,
|
AllMessageValues,
|
||||||
|
@ -10,7 +13,13 @@ from litellm.types.llms.openai import (
|
||||||
ChatCompletionFileObject,
|
ChatCompletionFileObject,
|
||||||
ChatCompletionUserMessage,
|
ChatCompletionUserMessage,
|
||||||
)
|
)
|
||||||
from litellm.types.utils import Choices, ModelResponse, StreamingChoices
|
from litellm.types.utils import (
|
||||||
|
Choices,
|
||||||
|
ExtractedFileData,
|
||||||
|
FileTypes,
|
||||||
|
ModelResponse,
|
||||||
|
StreamingChoices,
|
||||||
|
)
|
||||||
|
|
||||||
DEFAULT_USER_CONTINUE_MESSAGE = ChatCompletionUserMessage(
|
DEFAULT_USER_CONTINUE_MESSAGE = ChatCompletionUserMessage(
|
||||||
content="Please continue.", role="user"
|
content="Please continue.", role="user"
|
||||||
|
@ -348,3 +357,99 @@ def update_messages_with_model_file_ids(
|
||||||
)
|
)
|
||||||
file_object_file_field["file_id"] = provider_file_id
|
file_object_file_field["file_id"] = provider_file_id
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def extract_file_data(file_data: FileTypes) -> ExtractedFileData:
|
||||||
|
"""
|
||||||
|
Extracts and processes file data from various input formats.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_data: Can be a tuple of (filename, content, [content_type], [headers]) or direct file content
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ExtractedFileData containing:
|
||||||
|
- filename: Name of the file if provided
|
||||||
|
- content: The file content in bytes
|
||||||
|
- content_type: MIME type of the file
|
||||||
|
- headers: Any additional headers
|
||||||
|
"""
|
||||||
|
# Parse the file_data based on its type
|
||||||
|
filename = None
|
||||||
|
file_content = None
|
||||||
|
content_type = None
|
||||||
|
file_headers: Mapping[str, str] = {}
|
||||||
|
|
||||||
|
if isinstance(file_data, tuple):
|
||||||
|
if len(file_data) == 2:
|
||||||
|
filename, file_content = file_data
|
||||||
|
elif len(file_data) == 3:
|
||||||
|
filename, file_content, content_type = file_data
|
||||||
|
elif len(file_data) == 4:
|
||||||
|
filename, file_content, content_type, file_headers = file_data
|
||||||
|
else:
|
||||||
|
file_content = file_data
|
||||||
|
# Convert content to bytes
|
||||||
|
if isinstance(file_content, (str, PathLike)):
|
||||||
|
# If it's a path, open and read the file
|
||||||
|
with open(file_content, "rb") as f:
|
||||||
|
content = f.read()
|
||||||
|
elif isinstance(file_content, io.IOBase):
|
||||||
|
# If it's a file-like object
|
||||||
|
content = file_content.read()
|
||||||
|
|
||||||
|
if isinstance(content, str):
|
||||||
|
content = content.encode("utf-8")
|
||||||
|
# Reset file pointer to beginning
|
||||||
|
file_content.seek(0)
|
||||||
|
elif isinstance(file_content, bytes):
|
||||||
|
content = file_content
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported file content type: {type(file_content)}")
|
||||||
|
|
||||||
|
# Use provided content type or guess based on filename
|
||||||
|
if not content_type:
|
||||||
|
content_type = (
|
||||||
|
mimetypes.guess_type(filename)[0]
|
||||||
|
if filename
|
||||||
|
else "application/octet-stream"
|
||||||
|
)
|
||||||
|
|
||||||
|
return ExtractedFileData(
|
||||||
|
filename=filename,
|
||||||
|
content=content,
|
||||||
|
content_type=content_type,
|
||||||
|
headers=file_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
def unpack_defs(schema, defs):
|
||||||
|
properties = schema.get("properties", None)
|
||||||
|
if properties is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
for name, value in properties.items():
|
||||||
|
ref_key = value.get("$ref", None)
|
||||||
|
if ref_key is not None:
|
||||||
|
ref = defs[ref_key.split("defs/")[-1]]
|
||||||
|
unpack_defs(ref, defs)
|
||||||
|
properties[name] = ref
|
||||||
|
continue
|
||||||
|
|
||||||
|
anyof = value.get("anyOf", None)
|
||||||
|
if anyof is not None:
|
||||||
|
for i, atype in enumerate(anyof):
|
||||||
|
ref_key = atype.get("$ref", None)
|
||||||
|
if ref_key is not None:
|
||||||
|
ref = defs[ref_key.split("defs/")[-1]]
|
||||||
|
unpack_defs(ref, defs)
|
||||||
|
anyof[i] = ref
|
||||||
|
continue
|
||||||
|
|
||||||
|
items = value.get("items", None)
|
||||||
|
if items is not None:
|
||||||
|
ref_key = items.get("$ref", None)
|
||||||
|
if ref_key is not None:
|
||||||
|
ref = defs[ref_key.split("defs/")[-1]]
|
||||||
|
unpack_defs(ref, defs)
|
||||||
|
value["items"] = ref
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import traceback
|
|
||||||
import uuid
|
import uuid
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
@ -748,7 +747,6 @@ def convert_to_anthropic_image_obj(
|
||||||
data=base64_data,
|
data=base64_data,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
|
||||||
if "Error: Unable to fetch image from URL" in str(e):
|
if "Error: Unable to fetch image from URL" in str(e):
|
||||||
raise e
|
raise e
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
@ -3442,6 +3440,8 @@ def _bedrock_tools_pt(tools: List) -> List[BedrockToolBlock]:
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
"""
|
"""
|
||||||
|
from litellm.litellm_core_utils.prompt_templates.common_utils import unpack_defs
|
||||||
|
|
||||||
tool_block_list: List[BedrockToolBlock] = []
|
tool_block_list: List[BedrockToolBlock] = []
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
parameters = tool.get("function", {}).get(
|
parameters = tool.get("function", {}).get(
|
||||||
|
@ -3455,6 +3455,13 @@ def _bedrock_tools_pt(tools: List) -> List[BedrockToolBlock]:
|
||||||
description = tool.get("function", {}).get(
|
description = tool.get("function", {}).get(
|
||||||
"description", name
|
"description", name
|
||||||
) # converse api requires a description
|
) # converse api requires a description
|
||||||
|
|
||||||
|
defs = parameters.pop("$defs", {})
|
||||||
|
defs_copy = copy.deepcopy(defs)
|
||||||
|
# flatten the defs
|
||||||
|
for _, value in defs_copy.items():
|
||||||
|
unpack_defs(value, defs_copy)
|
||||||
|
unpack_defs(parameters, defs_copy)
|
||||||
tool_input_schema = BedrockToolInputSchemaBlock(json=parameters)
|
tool_input_schema = BedrockToolInputSchemaBlock(json=parameters)
|
||||||
tool_spec = BedrockToolSpecBlock(
|
tool_spec = BedrockToolSpecBlock(
|
||||||
inputSchema=tool_input_schema, name=name, description=description
|
inputSchema=tool_input_schema, name=name, description=description
|
||||||
|
|
|
@ -30,6 +30,11 @@ import json
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
from litellm.types.llms.openai import (
|
||||||
|
OpenAIRealtimeStreamResponseBaseObject,
|
||||||
|
OpenAIRealtimeStreamSessionEvents,
|
||||||
|
)
|
||||||
|
|
||||||
from .litellm_logging import Logging as LiteLLMLogging
|
from .litellm_logging import Logging as LiteLLMLogging
|
||||||
|
|
||||||
|
@ -53,7 +58,12 @@ class RealTimeStreaming:
|
||||||
self.websocket = websocket
|
self.websocket = websocket
|
||||||
self.backend_ws = backend_ws
|
self.backend_ws = backend_ws
|
||||||
self.logging_obj = logging_obj
|
self.logging_obj = logging_obj
|
||||||
self.messages: List = []
|
self.messages: List[
|
||||||
|
Union[
|
||||||
|
OpenAIRealtimeStreamResponseBaseObject,
|
||||||
|
OpenAIRealtimeStreamSessionEvents,
|
||||||
|
]
|
||||||
|
] = []
|
||||||
self.input_message: Dict = {}
|
self.input_message: Dict = {}
|
||||||
|
|
||||||
_logged_real_time_event_types = litellm.logged_real_time_event_types
|
_logged_real_time_event_types = litellm.logged_real_time_event_types
|
||||||
|
@ -62,10 +72,14 @@ class RealTimeStreaming:
|
||||||
_logged_real_time_event_types = DefaultLoggedRealTimeEventTypes
|
_logged_real_time_event_types = DefaultLoggedRealTimeEventTypes
|
||||||
self.logged_real_time_event_types = _logged_real_time_event_types
|
self.logged_real_time_event_types = _logged_real_time_event_types
|
||||||
|
|
||||||
def _should_store_message(self, message: Union[str, bytes]) -> bool:
|
def _should_store_message(
|
||||||
if isinstance(message, bytes):
|
self,
|
||||||
message = message.decode("utf-8")
|
message_obj: Union[
|
||||||
message_obj = json.loads(message)
|
dict,
|
||||||
|
OpenAIRealtimeStreamSessionEvents,
|
||||||
|
OpenAIRealtimeStreamResponseBaseObject,
|
||||||
|
],
|
||||||
|
) -> bool:
|
||||||
_msg_type = message_obj["type"]
|
_msg_type = message_obj["type"]
|
||||||
if self.logged_real_time_event_types == "*":
|
if self.logged_real_time_event_types == "*":
|
||||||
return True
|
return True
|
||||||
|
@ -75,8 +89,22 @@ class RealTimeStreaming:
|
||||||
|
|
||||||
def store_message(self, message: Union[str, bytes]):
|
def store_message(self, message: Union[str, bytes]):
|
||||||
"""Store message in list"""
|
"""Store message in list"""
|
||||||
if self._should_store_message(message):
|
if isinstance(message, bytes):
|
||||||
self.messages.append(message)
|
message = message.decode("utf-8")
|
||||||
|
message_obj = json.loads(message)
|
||||||
|
try:
|
||||||
|
if (
|
||||||
|
message_obj.get("type") == "session.created"
|
||||||
|
or message_obj.get("type") == "session.updated"
|
||||||
|
):
|
||||||
|
message_obj = OpenAIRealtimeStreamSessionEvents(**message_obj) # type: ignore
|
||||||
|
else:
|
||||||
|
message_obj = OpenAIRealtimeStreamResponseBaseObject(**message_obj) # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.debug(f"Error parsing message for logging: {e}")
|
||||||
|
raise e
|
||||||
|
if self._should_store_message(message_obj):
|
||||||
|
self.messages.append(message_obj)
|
||||||
|
|
||||||
def store_input(self, message: dict):
|
def store_input(self, message: dict):
|
||||||
"""Store input message"""
|
"""Store input message"""
|
||||||
|
|
|
@ -50,6 +50,7 @@ class AiohttpOpenAIChatConfig(OpenAILikeChatConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -4,7 +4,7 @@ Calling + translation logic for anthropic's `/v1/messages` endpoint
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
|
|
||||||
|
@ -301,12 +301,17 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params={**optional_params, "is_vertex_request": is_vertex_request},
|
optional_params={**optional_params, "is_vertex_request": is_vertex_request},
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
config = ProviderConfigManager.get_provider_chat_config(
|
config = ProviderConfigManager.get_provider_chat_config(
|
||||||
model=model,
|
model=model,
|
||||||
provider=LlmProviders(custom_llm_provider),
|
provider=LlmProviders(custom_llm_provider),
|
||||||
)
|
)
|
||||||
|
if config is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Provider config not found for model: {model} and provider: {custom_llm_provider}"
|
||||||
|
)
|
||||||
|
|
||||||
data = config.transform_request(
|
data = config.transform_request(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -487,29 +492,10 @@ class ModelResponseIterator:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _handle_usage(self, anthropic_usage_chunk: Union[dict, UsageDelta]) -> Usage:
|
def _handle_usage(self, anthropic_usage_chunk: Union[dict, UsageDelta]) -> Usage:
|
||||||
usage_block = Usage(
|
return AnthropicConfig().calculate_usage(
|
||||||
prompt_tokens=anthropic_usage_chunk.get("input_tokens", 0),
|
usage_object=cast(dict, anthropic_usage_chunk), reasoning_content=None
|
||||||
completion_tokens=anthropic_usage_chunk.get("output_tokens", 0),
|
|
||||||
total_tokens=anthropic_usage_chunk.get("input_tokens", 0)
|
|
||||||
+ anthropic_usage_chunk.get("output_tokens", 0),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
cache_creation_input_tokens = anthropic_usage_chunk.get(
|
|
||||||
"cache_creation_input_tokens"
|
|
||||||
)
|
|
||||||
if cache_creation_input_tokens is not None and isinstance(
|
|
||||||
cache_creation_input_tokens, int
|
|
||||||
):
|
|
||||||
usage_block["cache_creation_input_tokens"] = cache_creation_input_tokens
|
|
||||||
|
|
||||||
cache_read_input_tokens = anthropic_usage_chunk.get("cache_read_input_tokens")
|
|
||||||
if cache_read_input_tokens is not None and isinstance(
|
|
||||||
cache_read_input_tokens, int
|
|
||||||
):
|
|
||||||
usage_block["cache_read_input_tokens"] = cache_read_input_tokens
|
|
||||||
|
|
||||||
return usage_block
|
|
||||||
|
|
||||||
def _content_block_delta_helper(
|
def _content_block_delta_helper(
|
||||||
self, chunk: dict
|
self, chunk: dict
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
|
|
|
@ -682,6 +682,45 @@ class AnthropicConfig(BaseConfig):
|
||||||
reasoning_content += block["thinking"]
|
reasoning_content += block["thinking"]
|
||||||
return text_content, citations, thinking_blocks, reasoning_content, tool_calls
|
return text_content, citations, thinking_blocks, reasoning_content, tool_calls
|
||||||
|
|
||||||
|
def calculate_usage(
|
||||||
|
self, usage_object: dict, reasoning_content: Optional[str]
|
||||||
|
) -> Usage:
|
||||||
|
prompt_tokens = usage_object.get("input_tokens", 0)
|
||||||
|
completion_tokens = usage_object.get("output_tokens", 0)
|
||||||
|
_usage = usage_object
|
||||||
|
cache_creation_input_tokens: int = 0
|
||||||
|
cache_read_input_tokens: int = 0
|
||||||
|
|
||||||
|
if "cache_creation_input_tokens" in _usage:
|
||||||
|
cache_creation_input_tokens = _usage["cache_creation_input_tokens"]
|
||||||
|
if "cache_read_input_tokens" in _usage:
|
||||||
|
cache_read_input_tokens = _usage["cache_read_input_tokens"]
|
||||||
|
prompt_tokens += cache_read_input_tokens
|
||||||
|
|
||||||
|
prompt_tokens_details = PromptTokensDetailsWrapper(
|
||||||
|
cached_tokens=cache_read_input_tokens
|
||||||
|
)
|
||||||
|
completion_token_details = (
|
||||||
|
CompletionTokensDetailsWrapper(
|
||||||
|
reasoning_tokens=token_counter(
|
||||||
|
text=reasoning_content, count_response_tokens=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if reasoning_content
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
total_tokens = prompt_tokens + completion_tokens
|
||||||
|
usage = Usage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
prompt_tokens_details=prompt_tokens_details,
|
||||||
|
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||||
|
cache_read_input_tokens=cache_read_input_tokens,
|
||||||
|
completion_tokens_details=completion_token_details,
|
||||||
|
)
|
||||||
|
return usage
|
||||||
|
|
||||||
def transform_response(
|
def transform_response(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -772,45 +811,14 @@ class AnthropicConfig(BaseConfig):
|
||||||
)
|
)
|
||||||
|
|
||||||
## CALCULATING USAGE
|
## CALCULATING USAGE
|
||||||
prompt_tokens = completion_response["usage"]["input_tokens"]
|
usage = self.calculate_usage(
|
||||||
completion_tokens = completion_response["usage"]["output_tokens"]
|
usage_object=completion_response["usage"],
|
||||||
_usage = completion_response["usage"]
|
reasoning_content=reasoning_content,
|
||||||
cache_creation_input_tokens: int = 0
|
)
|
||||||
cache_read_input_tokens: int = 0
|
setattr(model_response, "usage", usage) # type: ignore
|
||||||
|
|
||||||
model_response.created = int(time.time())
|
model_response.created = int(time.time())
|
||||||
model_response.model = completion_response["model"]
|
model_response.model = completion_response["model"]
|
||||||
if "cache_creation_input_tokens" in _usage:
|
|
||||||
cache_creation_input_tokens = _usage["cache_creation_input_tokens"]
|
|
||||||
prompt_tokens += cache_creation_input_tokens
|
|
||||||
if "cache_read_input_tokens" in _usage:
|
|
||||||
cache_read_input_tokens = _usage["cache_read_input_tokens"]
|
|
||||||
prompt_tokens += cache_read_input_tokens
|
|
||||||
|
|
||||||
prompt_tokens_details = PromptTokensDetailsWrapper(
|
|
||||||
cached_tokens=cache_read_input_tokens
|
|
||||||
)
|
|
||||||
completion_token_details = (
|
|
||||||
CompletionTokensDetailsWrapper(
|
|
||||||
reasoning_tokens=token_counter(
|
|
||||||
text=reasoning_content, count_response_tokens=True
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if reasoning_content
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
total_tokens = prompt_tokens + completion_tokens
|
|
||||||
usage = Usage(
|
|
||||||
prompt_tokens=prompt_tokens,
|
|
||||||
completion_tokens=completion_tokens,
|
|
||||||
total_tokens=total_tokens,
|
|
||||||
prompt_tokens_details=prompt_tokens_details,
|
|
||||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
|
||||||
cache_read_input_tokens=cache_read_input_tokens,
|
|
||||||
completion_tokens_details=completion_token_details,
|
|
||||||
)
|
|
||||||
|
|
||||||
setattr(model_response, "usage", usage) # type: ignore
|
|
||||||
|
|
||||||
model_response._hidden_params = _hidden_params
|
model_response._hidden_params = _hidden_params
|
||||||
return model_response
|
return model_response
|
||||||
|
@ -868,6 +876,7 @@ class AnthropicConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
|
|
|
@ -87,6 +87,7 @@ class AnthropicTextConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -293,6 +293,7 @@ class AzureOpenAIConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -39,6 +39,7 @@ class AzureAIStudioConfig(OpenAIConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -262,6 +262,7 @@ class BaseConfig(ABC):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import TYPE_CHECKING, Any, List, Optional
|
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
@ -33,23 +33,22 @@ class BaseFilesConfig(BaseConfig):
|
||||||
) -> List[OpenAICreateFileRequestOptionalParams]:
|
) -> List[OpenAICreateFileRequestOptionalParams]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_complete_url(
|
def get_complete_file_url(
|
||||||
self,
|
self,
|
||||||
api_base: Optional[str],
|
api_base: Optional[str],
|
||||||
api_key: Optional[str],
|
api_key: Optional[str],
|
||||||
model: str,
|
model: str,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
litellm_params: dict,
|
litellm_params: dict,
|
||||||
stream: Optional[bool] = None,
|
data: CreateFileRequest,
|
||||||
) -> str:
|
):
|
||||||
"""
|
return self.get_complete_url(
|
||||||
OPTIONAL
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
Get the complete url for the request
|
model=model,
|
||||||
|
optional_params=optional_params,
|
||||||
Some providers need `model` in `api_base`
|
litellm_params=litellm_params,
|
||||||
"""
|
)
|
||||||
return api_base or ""
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def transform_create_file_request(
|
def transform_create_file_request(
|
||||||
|
@ -58,7 +57,7 @@ class BaseFilesConfig(BaseConfig):
|
||||||
create_file_data: CreateFileRequest,
|
create_file_data: CreateFileRequest,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
litellm_params: dict,
|
litellm_params: dict,
|
||||||
) -> dict:
|
) -> Union[dict, str, bytes]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
|
@ -65,6 +65,7 @@ class BaseImageVariationConfig(BaseConfig, ABC):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -30,6 +30,7 @@ from litellm.types.llms.openai import (
|
||||||
ChatCompletionToolParam,
|
ChatCompletionToolParam,
|
||||||
ChatCompletionToolParamFunctionChunk,
|
ChatCompletionToolParamFunctionChunk,
|
||||||
ChatCompletionUserMessage,
|
ChatCompletionUserMessage,
|
||||||
|
OpenAIChatCompletionToolParam,
|
||||||
OpenAIMessageContentListBlock,
|
OpenAIMessageContentListBlock,
|
||||||
)
|
)
|
||||||
from litellm.types.utils import ModelResponse, PromptTokensDetailsWrapper, Usage
|
from litellm.types.utils import ModelResponse, PromptTokensDetailsWrapper, Usage
|
||||||
|
@ -211,13 +212,29 @@ class AmazonConverseConfig(BaseConfig):
|
||||||
)
|
)
|
||||||
return _tool
|
return _tool
|
||||||
|
|
||||||
|
def _apply_tool_call_transformation(
|
||||||
|
self,
|
||||||
|
tools: List[OpenAIChatCompletionToolParam],
|
||||||
|
model: str,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
):
|
||||||
|
optional_params = self._add_tools_to_optional_params(
|
||||||
|
optional_params=optional_params, tools=tools
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
"meta.llama3-3-70b-instruct-v1:0" in model
|
||||||
|
and non_default_params.get("stream", False) is True
|
||||||
|
):
|
||||||
|
optional_params["fake_stream"] = True
|
||||||
|
|
||||||
def map_openai_params(
|
def map_openai_params(
|
||||||
self,
|
self,
|
||||||
non_default_params: dict,
|
non_default_params: dict,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
model: str,
|
model: str,
|
||||||
drop_params: bool,
|
drop_params: bool,
|
||||||
messages: Optional[List[AllMessageValues]] = None,
|
|
||||||
) -> dict:
|
) -> dict:
|
||||||
is_thinking_enabled = self.is_thinking_enabled(non_default_params)
|
is_thinking_enabled = self.is_thinking_enabled(non_default_params)
|
||||||
|
|
||||||
|
@ -286,8 +303,11 @@ class AmazonConverseConfig(BaseConfig):
|
||||||
if param == "top_p":
|
if param == "top_p":
|
||||||
optional_params["topP"] = value
|
optional_params["topP"] = value
|
||||||
if param == "tools" and isinstance(value, list):
|
if param == "tools" and isinstance(value, list):
|
||||||
optional_params = self._add_tools_to_optional_params(
|
self._apply_tool_call_transformation(
|
||||||
optional_params=optional_params, tools=value
|
tools=cast(List[OpenAIChatCompletionToolParam], value),
|
||||||
|
model=model,
|
||||||
|
non_default_params=non_default_params,
|
||||||
|
optional_params=optional_params,
|
||||||
)
|
)
|
||||||
if param == "tool_choice":
|
if param == "tool_choice":
|
||||||
_tool_choice_value = self.map_tool_choice_values(
|
_tool_choice_value = self.map_tool_choice_values(
|
||||||
|
@ -633,8 +653,10 @@ class AmazonConverseConfig(BaseConfig):
|
||||||
cache_read_input_tokens = usage["cacheReadInputTokens"]
|
cache_read_input_tokens = usage["cacheReadInputTokens"]
|
||||||
input_tokens += cache_read_input_tokens
|
input_tokens += cache_read_input_tokens
|
||||||
if "cacheWriteInputTokens" in usage:
|
if "cacheWriteInputTokens" in usage:
|
||||||
|
"""
|
||||||
|
Do not increment prompt_tokens with cacheWriteInputTokens
|
||||||
|
"""
|
||||||
cache_creation_input_tokens = usage["cacheWriteInputTokens"]
|
cache_creation_input_tokens = usage["cacheWriteInputTokens"]
|
||||||
input_tokens += cache_creation_input_tokens
|
|
||||||
|
|
||||||
prompt_tokens_details = PromptTokensDetailsWrapper(
|
prompt_tokens_details = PromptTokensDetailsWrapper(
|
||||||
cached_tokens=cache_read_input_tokens
|
cached_tokens=cache_read_input_tokens
|
||||||
|
@ -811,6 +833,7 @@ class AmazonConverseConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
import types
|
import types
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
|
||||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||||
AmazonInvokeConfig,
|
AmazonInvokeConfig,
|
||||||
)
|
)
|
||||||
|
from litellm.llms.cohere.chat.transformation import CohereChatConfig
|
||||||
|
|
||||||
|
|
||||||
class AmazonCohereConfig(AmazonInvokeConfig, BaseConfig):
|
class AmazonCohereConfig(AmazonInvokeConfig, CohereChatConfig):
|
||||||
"""
|
"""
|
||||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=command
|
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=command
|
||||||
|
|
||||||
|
@ -19,7 +19,6 @@ class AmazonCohereConfig(AmazonInvokeConfig, BaseConfig):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
max_tokens: Optional[int] = None
|
max_tokens: Optional[int] = None
|
||||||
temperature: Optional[float] = None
|
|
||||||
return_likelihood: Optional[str] = None
|
return_likelihood: Optional[str] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -55,11 +54,10 @@ class AmazonCohereConfig(AmazonInvokeConfig, BaseConfig):
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||||
return [
|
supported_params = CohereChatConfig.get_supported_openai_params(
|
||||||
"max_tokens",
|
self, model=model
|
||||||
"temperature",
|
)
|
||||||
"stream",
|
return supported_params
|
||||||
]
|
|
||||||
|
|
||||||
def map_openai_params(
|
def map_openai_params(
|
||||||
self,
|
self,
|
||||||
|
@ -68,11 +66,10 @@ class AmazonCohereConfig(AmazonInvokeConfig, BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
drop_params: bool,
|
drop_params: bool,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
for k, v in non_default_params.items():
|
return CohereChatConfig.map_openai_params(
|
||||||
if k == "stream":
|
self,
|
||||||
optional_params["stream"] = v
|
non_default_params=non_default_params,
|
||||||
if k == "temperature":
|
optional_params=optional_params,
|
||||||
optional_params["temperature"] = v
|
model=model,
|
||||||
if k == "max_tokens":
|
drop_params=drop_params,
|
||||||
optional_params["max_tokens"] = v
|
)
|
||||||
return optional_params
|
|
||||||
|
|
|
@ -6,14 +6,21 @@ Inherits from `AmazonConverseConfig`
|
||||||
Nova + Invoke API Tutorial: https://docs.aws.amazon.com/nova/latest/userguide/using-invoke-api.html
|
Nova + Invoke API Tutorial: https://docs.aws.amazon.com/nova/latest/userguide/using-invoke-api.html
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||||
from litellm.types.llms.bedrock import BedrockInvokeNovaRequest
|
from litellm.types.llms.bedrock import BedrockInvokeNovaRequest
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
from litellm.types.utils import ModelResponse
|
||||||
|
|
||||||
|
from ..converse_transformation import AmazonConverseConfig
|
||||||
|
from .base_invoke_transformation import AmazonInvokeConfig
|
||||||
|
|
||||||
|
|
||||||
class AmazonInvokeNovaConfig(litellm.AmazonConverseConfig):
|
class AmazonInvokeNovaConfig(AmazonInvokeConfig, AmazonConverseConfig):
|
||||||
"""
|
"""
|
||||||
Config for sending `nova` requests to `/bedrock/invoke/`
|
Config for sending `nova` requests to `/bedrock/invoke/`
|
||||||
"""
|
"""
|
||||||
|
@ -21,6 +28,20 @@ class AmazonInvokeNovaConfig(litellm.AmazonConverseConfig):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def get_supported_openai_params(self, model: str) -> list:
|
||||||
|
return AmazonConverseConfig.get_supported_openai_params(self, model)
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
) -> dict:
|
||||||
|
return AmazonConverseConfig.map_openai_params(
|
||||||
|
self, non_default_params, optional_params, model, drop_params
|
||||||
|
)
|
||||||
|
|
||||||
def transform_request(
|
def transform_request(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -29,7 +50,8 @@ class AmazonInvokeNovaConfig(litellm.AmazonConverseConfig):
|
||||||
litellm_params: dict,
|
litellm_params: dict,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
_transformed_nova_request = super().transform_request(
|
_transformed_nova_request = AmazonConverseConfig.transform_request(
|
||||||
|
self,
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
@ -45,6 +67,35 @@ class AmazonInvokeNovaConfig(litellm.AmazonConverseConfig):
|
||||||
)
|
)
|
||||||
return bedrock_invoke_nova_request
|
return bedrock_invoke_nova_request
|
||||||
|
|
||||||
|
def transform_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
raw_response: httpx.Response,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
logging_obj: Logging,
|
||||||
|
request_data: dict,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
encoding: Any,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
json_mode: Optional[bool] = None,
|
||||||
|
) -> litellm.ModelResponse:
|
||||||
|
return AmazonConverseConfig.transform_response(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
raw_response,
|
||||||
|
model_response,
|
||||||
|
logging_obj,
|
||||||
|
request_data,
|
||||||
|
messages,
|
||||||
|
optional_params,
|
||||||
|
litellm_params,
|
||||||
|
encoding,
|
||||||
|
api_key,
|
||||||
|
json_mode,
|
||||||
|
)
|
||||||
|
|
||||||
def _filter_allowed_fields(
|
def _filter_allowed_fields(
|
||||||
self, bedrock_invoke_nova_request: BedrockInvokeNovaRequest
|
self, bedrock_invoke_nova_request: BedrockInvokeNovaRequest
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -442,6 +442,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -118,6 +118,7 @@ class ClarifaiConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -60,6 +60,7 @@ class CloudflareChatConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -118,6 +118,7 @@ class CohereChatConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -101,6 +101,7 @@ class CohereTextConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -218,6 +218,10 @@ class BaseLLMAIOHTTPHandler:
|
||||||
provider_config = ProviderConfigManager.get_provider_chat_config(
|
provider_config = ProviderConfigManager.get_provider_chat_config(
|
||||||
model=model, provider=litellm.LlmProviders(custom_llm_provider)
|
model=model, provider=litellm.LlmProviders(custom_llm_provider)
|
||||||
)
|
)
|
||||||
|
if provider_config is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Provider config not found for model: {model} and provider: {custom_llm_provider}"
|
||||||
|
)
|
||||||
# get config from model, custom llm provider
|
# get config from model, custom llm provider
|
||||||
headers = provider_config.validate_environment(
|
headers = provider_config.validate_environment(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -225,6 +229,7 @@ class BaseLLMAIOHTTPHandler:
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -494,6 +499,7 @@ class BaseLLMAIOHTTPHandler:
|
||||||
model=model,
|
model=model,
|
||||||
messages=[{"role": "user", "content": "test"}],
|
messages=[{"role": "user", "content": "test"}],
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -192,7 +192,7 @@ class AsyncHTTPHandler:
|
||||||
async def post(
|
async def post(
|
||||||
self,
|
self,
|
||||||
url: str,
|
url: str,
|
||||||
data: Optional[Union[dict, str]] = None, # type: ignore
|
data: Optional[Union[dict, str, bytes]] = None, # type: ignore
|
||||||
json: Optional[dict] = None,
|
json: Optional[dict] = None,
|
||||||
params: Optional[dict] = None,
|
params: Optional[dict] = None,
|
||||||
headers: Optional[dict] = None,
|
headers: Optional[dict] = None,
|
||||||
|
@ -427,7 +427,7 @@ class AsyncHTTPHandler:
|
||||||
self,
|
self,
|
||||||
url: str,
|
url: str,
|
||||||
client: httpx.AsyncClient,
|
client: httpx.AsyncClient,
|
||||||
data: Optional[Union[dict, str]] = None, # type: ignore
|
data: Optional[Union[dict, str, bytes]] = None, # type: ignore
|
||||||
json: Optional[dict] = None,
|
json: Optional[dict] = None,
|
||||||
params: Optional[dict] = None,
|
params: Optional[dict] = None,
|
||||||
headers: Optional[dict] = None,
|
headers: Optional[dict] = None,
|
||||||
|
@ -527,7 +527,7 @@ class HTTPHandler:
|
||||||
def post(
|
def post(
|
||||||
self,
|
self,
|
||||||
url: str,
|
url: str,
|
||||||
data: Optional[Union[dict, str]] = None,
|
data: Optional[Union[dict, str, bytes]] = None,
|
||||||
json: Optional[Union[dict, str, List]] = None,
|
json: Optional[Union[dict, str, List]] = None,
|
||||||
params: Optional[dict] = None,
|
params: Optional[dict] = None,
|
||||||
headers: Optional[dict] = None,
|
headers: Optional[dict] = None,
|
||||||
|
@ -573,7 +573,6 @@ class HTTPHandler:
|
||||||
setattr(e, "text", error_text)
|
setattr(e, "text", error_text)
|
||||||
|
|
||||||
setattr(e, "status_code", e.response.status_code)
|
setattr(e, "status_code", e.response.status_code)
|
||||||
|
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
|
@ -234,6 +234,10 @@ class BaseLLMHTTPHandler:
|
||||||
provider_config = ProviderConfigManager.get_provider_chat_config(
|
provider_config = ProviderConfigManager.get_provider_chat_config(
|
||||||
model=model, provider=litellm.LlmProviders(custom_llm_provider)
|
model=model, provider=litellm.LlmProviders(custom_llm_provider)
|
||||||
)
|
)
|
||||||
|
if provider_config is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Provider config not found for model: {model} and provider: {custom_llm_provider}"
|
||||||
|
)
|
||||||
|
|
||||||
# get config from model, custom llm provider
|
# get config from model, custom llm provider
|
||||||
headers = provider_config.validate_environment(
|
headers = provider_config.validate_environment(
|
||||||
|
@ -243,6 +247,7 @@ class BaseLLMHTTPHandler:
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
api_base = provider_config.get_complete_url(
|
api_base = provider_config.get_complete_url(
|
||||||
|
@ -621,6 +626,7 @@ class BaseLLMHTTPHandler:
|
||||||
model=model,
|
model=model,
|
||||||
messages=[],
|
messages=[],
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
api_base = provider_config.get_complete_url(
|
api_base = provider_config.get_complete_url(
|
||||||
|
@ -892,6 +898,7 @@ class BaseLLMHTTPHandler:
|
||||||
model=model,
|
model=model,
|
||||||
messages=[],
|
messages=[],
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
if client is None or not isinstance(client, HTTPHandler):
|
if client is None or not isinstance(client, HTTPHandler):
|
||||||
|
@ -1224,15 +1231,19 @@ class BaseLLMHTTPHandler:
|
||||||
model="",
|
model="",
|
||||||
messages=[],
|
messages=[],
|
||||||
optional_params={},
|
optional_params={},
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
api_base = provider_config.get_complete_url(
|
api_base = provider_config.get_complete_file_url(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
model="",
|
model="",
|
||||||
optional_params={},
|
optional_params={},
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
|
data=create_file_data,
|
||||||
)
|
)
|
||||||
|
if api_base is None:
|
||||||
|
raise ValueError("api_base is required for create_file")
|
||||||
|
|
||||||
# Get the transformed request data for both steps
|
# Get the transformed request data for both steps
|
||||||
transformed_request = provider_config.transform_create_file_request(
|
transformed_request = provider_config.transform_create_file_request(
|
||||||
|
@ -1259,48 +1270,57 @@ class BaseLLMHTTPHandler:
|
||||||
else:
|
else:
|
||||||
sync_httpx_client = client
|
sync_httpx_client = client
|
||||||
|
|
||||||
try:
|
if isinstance(transformed_request, str) or isinstance(
|
||||||
# Step 1: Initial request to get upload URL
|
transformed_request, bytes
|
||||||
initial_response = sync_httpx_client.post(
|
):
|
||||||
url=api_base,
|
|
||||||
headers={
|
|
||||||
**headers,
|
|
||||||
**transformed_request["initial_request"]["headers"],
|
|
||||||
},
|
|
||||||
data=json.dumps(transformed_request["initial_request"]["data"]),
|
|
||||||
timeout=timeout,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract upload URL from response headers
|
|
||||||
upload_url = initial_response.headers.get("X-Goog-Upload-URL")
|
|
||||||
|
|
||||||
if not upload_url:
|
|
||||||
raise ValueError("Failed to get upload URL from initial request")
|
|
||||||
|
|
||||||
# Step 2: Upload the actual file
|
|
||||||
upload_response = sync_httpx_client.post(
|
upload_response = sync_httpx_client.post(
|
||||||
url=upload_url,
|
url=api_base,
|
||||||
headers=transformed_request["upload_request"]["headers"],
|
headers=headers,
|
||||||
data=transformed_request["upload_request"]["data"],
|
data=transformed_request,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
# Step 1: Initial request to get upload URL
|
||||||
|
initial_response = sync_httpx_client.post(
|
||||||
|
url=api_base,
|
||||||
|
headers={
|
||||||
|
**headers,
|
||||||
|
**transformed_request["initial_request"]["headers"],
|
||||||
|
},
|
||||||
|
data=json.dumps(transformed_request["initial_request"]["data"]),
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
return provider_config.transform_create_file_response(
|
# Extract upload URL from response headers
|
||||||
model=None,
|
upload_url = initial_response.headers.get("X-Goog-Upload-URL")
|
||||||
raw_response=upload_response,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
litellm_params=litellm_params,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
if not upload_url:
|
||||||
raise self._handle_error(
|
raise ValueError("Failed to get upload URL from initial request")
|
||||||
e=e,
|
|
||||||
provider_config=provider_config,
|
# Step 2: Upload the actual file
|
||||||
)
|
upload_response = sync_httpx_client.post(
|
||||||
|
url=upload_url,
|
||||||
|
headers=transformed_request["upload_request"]["headers"],
|
||||||
|
data=transformed_request["upload_request"]["data"],
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise self._handle_error(
|
||||||
|
e=e,
|
||||||
|
provider_config=provider_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
return provider_config.transform_create_file_response(
|
||||||
|
model=None,
|
||||||
|
raw_response=upload_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
)
|
||||||
|
|
||||||
async def async_create_file(
|
async def async_create_file(
|
||||||
self,
|
self,
|
||||||
transformed_request: dict,
|
transformed_request: Union[bytes, str, dict],
|
||||||
litellm_params: dict,
|
litellm_params: dict,
|
||||||
provider_config: BaseFilesConfig,
|
provider_config: BaseFilesConfig,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
|
@ -1319,45 +1339,54 @@ class BaseLLMHTTPHandler:
|
||||||
else:
|
else:
|
||||||
async_httpx_client = client
|
async_httpx_client = client
|
||||||
|
|
||||||
try:
|
if isinstance(transformed_request, str) or isinstance(
|
||||||
# Step 1: Initial request to get upload URL
|
transformed_request, bytes
|
||||||
initial_response = await async_httpx_client.post(
|
):
|
||||||
url=api_base,
|
|
||||||
headers={
|
|
||||||
**headers,
|
|
||||||
**transformed_request["initial_request"]["headers"],
|
|
||||||
},
|
|
||||||
data=json.dumps(transformed_request["initial_request"]["data"]),
|
|
||||||
timeout=timeout,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract upload URL from response headers
|
|
||||||
upload_url = initial_response.headers.get("X-Goog-Upload-URL")
|
|
||||||
|
|
||||||
if not upload_url:
|
|
||||||
raise ValueError("Failed to get upload URL from initial request")
|
|
||||||
|
|
||||||
# Step 2: Upload the actual file
|
|
||||||
upload_response = await async_httpx_client.post(
|
upload_response = await async_httpx_client.post(
|
||||||
url=upload_url,
|
url=api_base,
|
||||||
headers=transformed_request["upload_request"]["headers"],
|
headers=headers,
|
||||||
data=transformed_request["upload_request"]["data"],
|
data=transformed_request,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
# Step 1: Initial request to get upload URL
|
||||||
|
initial_response = await async_httpx_client.post(
|
||||||
|
url=api_base,
|
||||||
|
headers={
|
||||||
|
**headers,
|
||||||
|
**transformed_request["initial_request"]["headers"],
|
||||||
|
},
|
||||||
|
data=json.dumps(transformed_request["initial_request"]["data"]),
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
return provider_config.transform_create_file_response(
|
# Extract upload URL from response headers
|
||||||
model=None,
|
upload_url = initial_response.headers.get("X-Goog-Upload-URL")
|
||||||
raw_response=upload_response,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
litellm_params=litellm_params,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
if not upload_url:
|
||||||
verbose_logger.exception(f"Error creating file: {e}")
|
raise ValueError("Failed to get upload URL from initial request")
|
||||||
raise self._handle_error(
|
|
||||||
e=e,
|
# Step 2: Upload the actual file
|
||||||
provider_config=provider_config,
|
upload_response = await async_httpx_client.post(
|
||||||
)
|
url=upload_url,
|
||||||
|
headers=transformed_request["upload_request"]["headers"],
|
||||||
|
data=transformed_request["upload_request"]["data"],
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.exception(f"Error creating file: {e}")
|
||||||
|
raise self._handle_error(
|
||||||
|
e=e,
|
||||||
|
provider_config=provider_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
return provider_config.transform_create_file_response(
|
||||||
|
model=None,
|
||||||
|
raw_response=upload_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
)
|
||||||
|
|
||||||
def list_files(self):
|
def list_files(self):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -27,7 +27,7 @@ from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||||
strip_name_from_messages,
|
strip_name_from_messages,
|
||||||
)
|
)
|
||||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||||
from litellm.types.llms.anthropic import AnthropicMessagesTool
|
from litellm.types.llms.anthropic import AllAnthropicToolsValues
|
||||||
from litellm.types.llms.databricks import (
|
from litellm.types.llms.databricks import (
|
||||||
AllDatabricksContentValues,
|
AllDatabricksContentValues,
|
||||||
DatabricksChoice,
|
DatabricksChoice,
|
||||||
|
@ -116,6 +116,7 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
@ -160,7 +161,7 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
|
||||||
]
|
]
|
||||||
|
|
||||||
def convert_anthropic_tool_to_databricks_tool(
|
def convert_anthropic_tool_to_databricks_tool(
|
||||||
self, tool: Optional[AnthropicMessagesTool]
|
self, tool: Optional[AllAnthropicToolsValues]
|
||||||
) -> Optional[DatabricksTool]:
|
) -> Optional[DatabricksTool]:
|
||||||
if tool is None:
|
if tool is None:
|
||||||
return None
|
return None
|
||||||
|
@ -173,6 +174,19 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _map_openai_to_dbrx_tool(self, model: str, tools: List) -> List[DatabricksTool]:
|
||||||
|
# if not claude, send as is
|
||||||
|
if "claude" not in model:
|
||||||
|
return tools
|
||||||
|
|
||||||
|
# if claude, convert to anthropic tool and then to databricks tool
|
||||||
|
anthropic_tools = self._map_tools(tools=tools)
|
||||||
|
databricks_tools = [
|
||||||
|
cast(DatabricksTool, self.convert_anthropic_tool_to_databricks_tool(tool))
|
||||||
|
for tool in anthropic_tools
|
||||||
|
]
|
||||||
|
return databricks_tools
|
||||||
|
|
||||||
def map_response_format_to_databricks_tool(
|
def map_response_format_to_databricks_tool(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -202,6 +216,10 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
|
||||||
mapped_params = super().map_openai_params(
|
mapped_params = super().map_openai_params(
|
||||||
non_default_params, optional_params, model, drop_params
|
non_default_params, optional_params, model, drop_params
|
||||||
)
|
)
|
||||||
|
if "tools" in mapped_params:
|
||||||
|
mapped_params["tools"] = self._map_openai_to_dbrx_tool(
|
||||||
|
model=model, tools=mapped_params["tools"]
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
"max_completion_tokens" in non_default_params
|
"max_completion_tokens" in non_default_params
|
||||||
and replace_max_completion_tokens_with_max_tokens
|
and replace_max_completion_tokens_with_max_tokens
|
||||||
|
@ -240,6 +258,7 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
|
||||||
optional_params["thinking"] = AnthropicConfig._map_reasoning_effort(
|
optional_params["thinking"] = AnthropicConfig._map_reasoning_effort(
|
||||||
non_default_params.get("reasoning_effort")
|
non_default_params.get("reasoning_effort")
|
||||||
)
|
)
|
||||||
|
optional_params.pop("reasoning_effort", None)
|
||||||
## handle thinking tokens
|
## handle thinking tokens
|
||||||
self.update_optional_params_with_thinking_tokens(
|
self.update_optional_params_with_thinking_tokens(
|
||||||
non_default_params=non_default_params, optional_params=mapped_params
|
non_default_params=non_default_params, optional_params=mapped_params
|
||||||
|
@ -498,7 +517,10 @@ class DatabricksChatResponseIterator(BaseModelResponseIterator):
|
||||||
message.content = ""
|
message.content = ""
|
||||||
choice["delta"]["content"] = message.content
|
choice["delta"]["content"] = message.content
|
||||||
choice["delta"]["tool_calls"] = None
|
choice["delta"]["tool_calls"] = None
|
||||||
|
elif tool_calls:
|
||||||
|
for _tc in tool_calls:
|
||||||
|
if _tc.get("function", {}).get("arguments") == "{}":
|
||||||
|
_tc["function"]["arguments"] = "" # avoid invalid json
|
||||||
# extract the content str
|
# extract the content str
|
||||||
content_str = DatabricksConfig.extract_content_str(
|
content_str = DatabricksConfig.extract_content_str(
|
||||||
choice["delta"].get("content")
|
choice["delta"].get("content")
|
||||||
|
|
|
@ -171,6 +171,7 @@ class DeepgramAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -2,7 +2,11 @@ from typing import List, Literal, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionImageObject
|
from litellm.types.llms.openai import (
|
||||||
|
AllMessageValues,
|
||||||
|
ChatCompletionImageObject,
|
||||||
|
OpenAIChatCompletionToolParam,
|
||||||
|
)
|
||||||
from litellm.types.utils import ProviderSpecificModelInfo
|
from litellm.types.utils import ProviderSpecificModelInfo
|
||||||
|
|
||||||
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
|
@ -150,6 +154,14 @@ class FireworksAIConfig(OpenAIGPTConfig):
|
||||||
] = f"{content['image_url']['url']}#transform=inline"
|
] = f"{content['image_url']['url']}#transform=inline"
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
def _transform_tools(
|
||||||
|
self, tools: List[OpenAIChatCompletionToolParam]
|
||||||
|
) -> List[OpenAIChatCompletionToolParam]:
|
||||||
|
for tool in tools:
|
||||||
|
if tool.get("type") == "function":
|
||||||
|
tool["function"].pop("strict", None)
|
||||||
|
return tools
|
||||||
|
|
||||||
def _transform_messages_helper(
|
def _transform_messages_helper(
|
||||||
self, messages: List[AllMessageValues], model: str, litellm_params: dict
|
self, messages: List[AllMessageValues], model: str, litellm_params: dict
|
||||||
) -> List[AllMessageValues]:
|
) -> List[AllMessageValues]:
|
||||||
|
@ -196,6 +208,9 @@ class FireworksAIConfig(OpenAIGPTConfig):
|
||||||
messages = self._transform_messages_helper(
|
messages = self._transform_messages_helper(
|
||||||
messages=messages, model=model, litellm_params=litellm_params
|
messages=messages, model=model, litellm_params=litellm_params
|
||||||
)
|
)
|
||||||
|
if "tools" in optional_params and optional_params["tools"] is not None:
|
||||||
|
tools = self._transform_tools(tools=optional_params["tools"])
|
||||||
|
optional_params["tools"] = tools
|
||||||
return super().transform_request(
|
return super().transform_request(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|
|
@ -41,6 +41,7 @@ class FireworksAIMixin:
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -20,6 +20,7 @@ class GeminiModelInfo(BaseLLMModelInfo):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -4,11 +4,12 @@ Supports writing files to Google AI Studio Files API.
|
||||||
For vertex ai, check out the vertex_ai/files/handler.py file.
|
For vertex ai, check out the vertex_ai/files/handler.py file.
|
||||||
"""
|
"""
|
||||||
import time
|
import time
|
||||||
from typing import List, Mapping, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
|
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
|
||||||
from litellm.llms.base_llm.files.transformation import (
|
from litellm.llms.base_llm.files.transformation import (
|
||||||
BaseFilesConfig,
|
BaseFilesConfig,
|
||||||
LiteLLMLoggingObj,
|
LiteLLMLoggingObj,
|
||||||
|
@ -91,66 +92,28 @@ class GoogleAIStudioFilesHandler(GeminiModelInfo, BaseFilesConfig):
|
||||||
if file_data is None:
|
if file_data is None:
|
||||||
raise ValueError("File data is required")
|
raise ValueError("File data is required")
|
||||||
|
|
||||||
# Parse the file_data based on its type
|
# Use the common utility function to extract file data
|
||||||
filename = None
|
extracted_data = extract_file_data(file_data)
|
||||||
file_content = None
|
|
||||||
content_type = None
|
|
||||||
file_headers: Mapping[str, str] = {}
|
|
||||||
|
|
||||||
if isinstance(file_data, tuple):
|
|
||||||
if len(file_data) == 2:
|
|
||||||
filename, file_content = file_data
|
|
||||||
elif len(file_data) == 3:
|
|
||||||
filename, file_content, content_type = file_data
|
|
||||||
elif len(file_data) == 4:
|
|
||||||
filename, file_content, content_type, file_headers = file_data
|
|
||||||
else:
|
|
||||||
file_content = file_data
|
|
||||||
|
|
||||||
# Handle the file content based on its type
|
|
||||||
import io
|
|
||||||
from os import PathLike
|
|
||||||
|
|
||||||
# Convert content to bytes
|
|
||||||
if isinstance(file_content, (str, PathLike)):
|
|
||||||
# If it's a path, open and read the file
|
|
||||||
with open(file_content, "rb") as f:
|
|
||||||
content = f.read()
|
|
||||||
elif isinstance(file_content, io.IOBase):
|
|
||||||
# If it's a file-like object
|
|
||||||
content = file_content.read()
|
|
||||||
if isinstance(content, str):
|
|
||||||
content = content.encode("utf-8")
|
|
||||||
elif isinstance(file_content, bytes):
|
|
||||||
content = file_content
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported file content type: {type(file_content)}")
|
|
||||||
|
|
||||||
# Get file size
|
# Get file size
|
||||||
file_size = len(content)
|
file_size = len(extracted_data["content"])
|
||||||
|
|
||||||
# Use provided content type or guess based on filename
|
|
||||||
if not content_type:
|
|
||||||
import mimetypes
|
|
||||||
|
|
||||||
content_type = (
|
|
||||||
mimetypes.guess_type(filename)[0]
|
|
||||||
if filename
|
|
||||||
else "application/octet-stream"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Step 1: Initial resumable upload request
|
# Step 1: Initial resumable upload request
|
||||||
headers = {
|
headers = {
|
||||||
"X-Goog-Upload-Protocol": "resumable",
|
"X-Goog-Upload-Protocol": "resumable",
|
||||||
"X-Goog-Upload-Command": "start",
|
"X-Goog-Upload-Command": "start",
|
||||||
"X-Goog-Upload-Header-Content-Length": str(file_size),
|
"X-Goog-Upload-Header-Content-Length": str(file_size),
|
||||||
"X-Goog-Upload-Header-Content-Type": content_type,
|
"X-Goog-Upload-Header-Content-Type": extracted_data["content_type"],
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
headers.update(file_headers) # Add any custom headers
|
headers.update(extracted_data["headers"]) # Add any custom headers
|
||||||
|
|
||||||
# Initial metadata request body
|
# Initial metadata request body
|
||||||
initial_data = {"file": {"display_name": filename or str(int(time.time()))}}
|
initial_data = {
|
||||||
|
"file": {
|
||||||
|
"display_name": extracted_data["filename"] or str(int(time.time()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
# Step 2: Actual file upload data
|
# Step 2: Actual file upload data
|
||||||
upload_headers = {
|
upload_headers = {
|
||||||
|
@ -161,7 +124,10 @@ class GoogleAIStudioFilesHandler(GeminiModelInfo, BaseFilesConfig):
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"initial_request": {"headers": headers, "data": initial_data},
|
"initial_request": {"headers": headers, "data": initial_data},
|
||||||
"upload_request": {"headers": upload_headers, "data": content},
|
"upload_request": {
|
||||||
|
"headers": upload_headers,
|
||||||
|
"data": extracted_data["content"],
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def transform_create_file_response(
|
def transform_create_file_response(
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
@ -18,7 +18,6 @@ from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
from ..common_utils import HuggingFaceError, _fetch_inference_provider_mapping
|
from ..common_utils import HuggingFaceError, _fetch_inference_provider_mapping
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
BASE_URL = "https://router.huggingface.co"
|
BASE_URL = "https://router.huggingface.co"
|
||||||
|
@ -34,7 +33,8 @@ class HuggingFaceChatConfig(OpenAIGPTConfig):
|
||||||
headers: dict,
|
headers: dict,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: Dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
@ -51,7 +51,9 @@ class HuggingFaceChatConfig(OpenAIGPTConfig):
|
||||||
def get_error_class(
|
def get_error_class(
|
||||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||||
) -> BaseLLMException:
|
) -> BaseLLMException:
|
||||||
return HuggingFaceError(status_code=status_code, message=error_message, headers=headers)
|
return HuggingFaceError(
|
||||||
|
status_code=status_code, message=error_message, headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
def get_base_url(self, model: str, base_url: Optional[str]) -> Optional[str]:
|
def get_base_url(self, model: str, base_url: Optional[str]) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
|
@ -82,7 +84,9 @@ class HuggingFaceChatConfig(OpenAIGPTConfig):
|
||||||
if api_base is not None:
|
if api_base is not None:
|
||||||
complete_url = api_base
|
complete_url = api_base
|
||||||
elif os.getenv("HF_API_BASE") or os.getenv("HUGGINGFACE_API_BASE"):
|
elif os.getenv("HF_API_BASE") or os.getenv("HUGGINGFACE_API_BASE"):
|
||||||
complete_url = str(os.getenv("HF_API_BASE")) or str(os.getenv("HUGGINGFACE_API_BASE"))
|
complete_url = str(os.getenv("HF_API_BASE")) or str(
|
||||||
|
os.getenv("HUGGINGFACE_API_BASE")
|
||||||
|
)
|
||||||
elif model.startswith(("http://", "https://")):
|
elif model.startswith(("http://", "https://")):
|
||||||
complete_url = model
|
complete_url = model
|
||||||
# 4. Default construction with provider
|
# 4. Default construction with provider
|
||||||
|
@ -138,4 +142,8 @@ class HuggingFaceChatConfig(OpenAIGPTConfig):
|
||||||
)
|
)
|
||||||
mapped_model = provider_mapping["providerId"]
|
mapped_model = provider_mapping["providerId"]
|
||||||
messages = self._transform_messages(messages=messages, model=mapped_model)
|
messages = self._transform_messages(messages=messages, model=mapped_model)
|
||||||
return dict(ChatCompletionRequest(model=mapped_model, messages=messages, **optional_params))
|
return dict(
|
||||||
|
ChatCompletionRequest(
|
||||||
|
model=mapped_model, messages=messages, **optional_params
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
|
@ -1,15 +1,6 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import (
|
from typing import Any, Callable, Dict, List, Literal, Optional, Union, get_args
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
Union,
|
|
||||||
get_args,
|
|
||||||
)
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
@ -35,8 +26,9 @@ hf_tasks_embeddings = Literal[ # pipeline tags + hf tei endpoints - https://hug
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_hf_task_embedding_for_model(
|
||||||
def get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_base: str) -> Optional[str]:
|
model: str, task_type: Optional[str], api_base: str
|
||||||
|
) -> Optional[str]:
|
||||||
if task_type is not None:
|
if task_type is not None:
|
||||||
if task_type in get_args(hf_tasks_embeddings):
|
if task_type in get_args(hf_tasks_embeddings):
|
||||||
return task_type
|
return task_type
|
||||||
|
@ -57,7 +49,9 @@ def get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_ba
|
||||||
return pipeline_tag
|
return pipeline_tag
|
||||||
|
|
||||||
|
|
||||||
async def async_get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_base: str) -> Optional[str]:
|
async def async_get_hf_task_embedding_for_model(
|
||||||
|
model: str, task_type: Optional[str], api_base: str
|
||||||
|
) -> Optional[str]:
|
||||||
if task_type is not None:
|
if task_type is not None:
|
||||||
if task_type in get_args(hf_tasks_embeddings):
|
if task_type in get_args(hf_tasks_embeddings):
|
||||||
return task_type
|
return task_type
|
||||||
|
@ -116,7 +110,9 @@ class HuggingFaceEmbedding(BaseLLM):
|
||||||
input: List,
|
input: List,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
hf_task = await async_get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
|
hf_task = await async_get_hf_task_embedding_for_model(
|
||||||
|
model=model, task_type=task_type, api_base=HF_HUB_URL
|
||||||
|
)
|
||||||
|
|
||||||
data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task)
|
data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task)
|
||||||
|
|
||||||
|
@ -173,7 +169,9 @@ class HuggingFaceEmbedding(BaseLLM):
|
||||||
task_type = optional_params.pop("input_type", None)
|
task_type = optional_params.pop("input_type", None)
|
||||||
|
|
||||||
if call_type == "sync":
|
if call_type == "sync":
|
||||||
hf_task = get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
|
hf_task = get_hf_task_embedding_for_model(
|
||||||
|
model=model, task_type=task_type, api_base=HF_HUB_URL
|
||||||
|
)
|
||||||
elif call_type == "async":
|
elif call_type == "async":
|
||||||
return self._async_transform_input(
|
return self._async_transform_input(
|
||||||
model=model, task_type=task_type, embed_url=embed_url, input=input
|
model=model, task_type=task_type, embed_url=embed_url, input=input
|
||||||
|
@ -325,6 +323,7 @@ class HuggingFaceEmbedding(BaseLLM):
|
||||||
input: list,
|
input: list,
|
||||||
model_response: EmbeddingResponse,
|
model_response: EmbeddingResponse,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
logging_obj: LiteLLMLoggingObj,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
encoding: Callable,
|
encoding: Callable,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
@ -341,9 +340,12 @@ class HuggingFaceEmbedding(BaseLLM):
|
||||||
model=model,
|
model=model,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
messages=[],
|
messages=[],
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
task_type = optional_params.pop("input_type", None)
|
task_type = optional_params.pop("input_type", None)
|
||||||
task = get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
|
task = get_hf_task_embedding_for_model(
|
||||||
|
model=model, task_type=task_type, api_base=HF_HUB_URL
|
||||||
|
)
|
||||||
# print_verbose(f"{model}, {task}")
|
# print_verbose(f"{model}, {task}")
|
||||||
embed_url = ""
|
embed_url = ""
|
||||||
if "https" in model:
|
if "https" in model:
|
||||||
|
@ -355,7 +357,9 @@ class HuggingFaceEmbedding(BaseLLM):
|
||||||
elif "HUGGINGFACE_API_BASE" in os.environ:
|
elif "HUGGINGFACE_API_BASE" in os.environ:
|
||||||
embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
|
embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
|
||||||
else:
|
else:
|
||||||
embed_url = f"https://router.huggingface.co/hf-inference/pipeline/{task}/{model}"
|
embed_url = (
|
||||||
|
f"https://router.huggingface.co/hf-inference/pipeline/{task}/{model}"
|
||||||
|
)
|
||||||
|
|
||||||
## ROUTING ##
|
## ROUTING ##
|
||||||
if aembedding is True:
|
if aembedding is True:
|
||||||
|
|
|
@ -355,6 +355,7 @@ class HuggingFaceEmbeddingConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: Dict,
|
optional_params: Dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
|
|
|
@ -10,6 +10,11 @@ from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMProxyChatConfig(OpenAIGPTConfig):
|
class LiteLLMProxyChatConfig(OpenAIGPTConfig):
|
||||||
|
def get_supported_openai_params(self, model: str) -> List:
|
||||||
|
list = super().get_supported_openai_params(model)
|
||||||
|
list.append("thinking")
|
||||||
|
return list
|
||||||
|
|
||||||
def _get_openai_compatible_provider_info(
|
def _get_openai_compatible_provider_info(
|
||||||
self, api_base: Optional[str], api_key: Optional[str]
|
self, api_base: Optional[str], api_key: Optional[str]
|
||||||
) -> Tuple[Optional[str], Optional[str]]:
|
) -> Tuple[Optional[str], Optional[str]]:
|
||||||
|
|
|
@ -36,6 +36,7 @@ def completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
## Load Config
|
## Load Config
|
||||||
|
|
|
@ -93,6 +93,7 @@ class NLPCloudConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -353,6 +353,7 @@ class OllamaConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -32,6 +32,7 @@ def completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
if "https" in model:
|
if "https" in model:
|
||||||
completion_url = model
|
completion_url = model
|
||||||
|
@ -123,6 +124,7 @@ def embedding(
|
||||||
model=model,
|
model=model,
|
||||||
messages=[],
|
messages=[],
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
litellm_params={},
|
||||||
)
|
)
|
||||||
response = litellm.module_level_client.post(
|
response = litellm.module_level_client.post(
|
||||||
embeddings_url, headers=headers, json=data
|
embeddings_url, headers=headers, json=data
|
||||||
|
|
|
@ -88,6 +88,7 @@ class OobaboogaConfig(OpenAIGPTConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -321,6 +321,7 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -286,6 +286,7 @@ class OpenAIConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -53,6 +53,7 @@ class OpenAIWhisperAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -131,6 +131,7 @@ class PetalsConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -228,10 +228,10 @@ class PredibaseChatCompletion:
|
||||||
api_key: str,
|
api_key: str,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
acompletion=None,
|
acompletion=None,
|
||||||
litellm_params=None,
|
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
headers: dict = {},
|
headers: dict = {},
|
||||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
|
@ -241,6 +241,7 @@ class PredibaseChatCompletion:
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
model=model,
|
model=model,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
completion_url = ""
|
completion_url = ""
|
||||||
input_text = ""
|
input_text = ""
|
||||||
|
|
|
@ -164,6 +164,7 @@ class PredibaseConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -141,6 +141,7 @@ def completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
# Start a prediction and get the prediction URL
|
# Start a prediction and get the prediction URL
|
||||||
version_id = replicate_config.model_to_version_id(model)
|
version_id = replicate_config.model_to_version_id(model)
|
||||||
|
|
|
@ -312,6 +312,7 @@ class ReplicateConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -96,6 +96,7 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
model: str,
|
model: str,
|
||||||
data: dict,
|
data: dict,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
|
litellm_params: dict,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
aws_region_name: str,
|
aws_region_name: str,
|
||||||
extra_headers: Optional[dict] = None,
|
extra_headers: Optional[dict] = None,
|
||||||
|
@ -122,6 +123,7 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
request = AWSRequest(
|
request = AWSRequest(
|
||||||
method="POST", url=api_base, data=encoded_data, headers=headers
|
method="POST", url=api_base, data=encoded_data, headers=headers
|
||||||
|
@ -198,6 +200,7 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
data=data,
|
data=data,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
aws_region_name=aws_region_name,
|
aws_region_name=aws_region_name,
|
||||||
)
|
)
|
||||||
|
@ -274,6 +277,7 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
"model": model,
|
"model": model,
|
||||||
"data": _data,
|
"data": _data,
|
||||||
"optional_params": optional_params,
|
"optional_params": optional_params,
|
||||||
|
"litellm_params": litellm_params,
|
||||||
"credentials": credentials,
|
"credentials": credentials,
|
||||||
"aws_region_name": aws_region_name,
|
"aws_region_name": aws_region_name,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
|
@ -426,6 +430,7 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
"model": model,
|
"model": model,
|
||||||
"data": data,
|
"data": data,
|
||||||
"optional_params": optional_params,
|
"optional_params": optional_params,
|
||||||
|
"litellm_params": litellm_params,
|
||||||
"credentials": credentials,
|
"credentials": credentials,
|
||||||
"aws_region_name": aws_region_name,
|
"aws_region_name": aws_region_name,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
|
@ -496,6 +501,7 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
"model": model,
|
"model": model,
|
||||||
"data": data,
|
"data": data,
|
||||||
"optional_params": optional_params,
|
"optional_params": optional_params,
|
||||||
|
"litellm_params": litellm_params,
|
||||||
"credentials": credentials,
|
"credentials": credentials,
|
||||||
"aws_region_name": aws_region_name,
|
"aws_region_name": aws_region_name,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
|
|
|
@ -263,6 +263,7 @@ class SagemakerConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -92,6 +92,7 @@ class SnowflakeConfig(OpenAIGPTConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -37,6 +37,7 @@ class TopazImageVariationConfig(BaseImageVariationConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -48,6 +48,7 @@ class TritonConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: Dict,
|
optional_params: Dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
|
|
|
@ -42,6 +42,7 @@ class TritonEmbeddingConfig(BaseEmbeddingConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -1,13 +1,14 @@
|
||||||
|
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union, get_type_hints
|
||||||
import re
|
import re
|
||||||
from typing import Dict, List, Literal, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import supports_response_schema, supports_system_messages, verbose_logger
|
from litellm import supports_response_schema, supports_system_messages, verbose_logger
|
||||||
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
||||||
|
from litellm.litellm_core_utils.prompt_templates.common_utils import unpack_defs
|
||||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
from litellm.types.llms.vertex_ai import PartType
|
from litellm.types.llms.vertex_ai import PartType, Schema
|
||||||
|
|
||||||
|
|
||||||
class VertexAIError(BaseLLMException):
|
class VertexAIError(BaseLLMException):
|
||||||
|
@ -168,6 +169,9 @@ def _build_vertex_schema(parameters: dict):
|
||||||
"""
|
"""
|
||||||
This is a modified version of https://github.com/google-gemini/generative-ai-python/blob/8f77cc6ac99937cd3a81299ecf79608b91b06bbb/google/generativeai/types/content_types.py#L419
|
This is a modified version of https://github.com/google-gemini/generative-ai-python/blob/8f77cc6ac99937cd3a81299ecf79608b91b06bbb/google/generativeai/types/content_types.py#L419
|
||||||
"""
|
"""
|
||||||
|
# Get valid fields from Schema TypedDict
|
||||||
|
valid_schema_fields = set(get_type_hints(Schema).keys())
|
||||||
|
|
||||||
defs = parameters.pop("$defs", {})
|
defs = parameters.pop("$defs", {})
|
||||||
# flatten the defs
|
# flatten the defs
|
||||||
for name, value in defs.items():
|
for name, value in defs.items():
|
||||||
|
@ -181,52 +185,49 @@ def _build_vertex_schema(parameters: dict):
|
||||||
convert_anyof_null_to_nullable(parameters)
|
convert_anyof_null_to_nullable(parameters)
|
||||||
add_object_type(parameters)
|
add_object_type(parameters)
|
||||||
# Postprocessing
|
# Postprocessing
|
||||||
# 4. Suppress unnecessary title generation:
|
# Filter out fields that don't exist in Schema
|
||||||
# * https://github.com/pydantic/pydantic/issues/1051
|
filtered_parameters = filter_schema_fields(parameters, valid_schema_fields)
|
||||||
# * http://cl/586221780
|
return filtered_parameters
|
||||||
strip_field(parameters, field_name="title")
|
|
||||||
|
|
||||||
strip_field(
|
|
||||||
parameters, field_name="$schema"
|
|
||||||
) # 5. Remove $schema - json schema value, not supported by OpenAPI - causes vertex errors.
|
|
||||||
strip_field(
|
|
||||||
parameters, field_name="$id"
|
|
||||||
) # 6. Remove id - json schema value, not supported by OpenAPI - causes vertex errors.
|
|
||||||
|
|
||||||
return parameters
|
|
||||||
|
|
||||||
|
|
||||||
def unpack_defs(schema, defs):
|
def filter_schema_fields(
|
||||||
properties = schema.get("properties", None)
|
schema_dict: Dict[str, Any], valid_fields: Set[str], processed=None
|
||||||
if properties is None:
|
) -> Dict[str, Any]:
|
||||||
return
|
"""
|
||||||
|
Recursively filter a schema dictionary to keep only valid fields.
|
||||||
|
"""
|
||||||
|
if processed is None:
|
||||||
|
processed = set()
|
||||||
|
|
||||||
for name, value in properties.items():
|
# Handle circular references
|
||||||
ref_key = value.get("$ref", None)
|
schema_id = id(schema_dict)
|
||||||
if ref_key is not None:
|
if schema_id in processed:
|
||||||
ref = defs[ref_key.split("defs/")[-1]]
|
return schema_dict
|
||||||
unpack_defs(ref, defs)
|
processed.add(schema_id)
|
||||||
properties[name] = ref
|
|
||||||
|
if not isinstance(schema_dict, dict):
|
||||||
|
return schema_dict
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
for key, value in schema_dict.items():
|
||||||
|
if key not in valid_fields:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
anyof = value.get("anyOf", None)
|
if key == "properties" and isinstance(value, dict):
|
||||||
if anyof is not None:
|
result[key] = {
|
||||||
for i, atype in enumerate(anyof):
|
k: filter_schema_fields(v, valid_fields, processed)
|
||||||
ref_key = atype.get("$ref", None)
|
for k, v in value.items()
|
||||||
if ref_key is not None:
|
}
|
||||||
ref = defs[ref_key.split("defs/")[-1]]
|
elif key == "items" and isinstance(value, dict):
|
||||||
unpack_defs(ref, defs)
|
result[key] = filter_schema_fields(value, valid_fields, processed)
|
||||||
anyof[i] = ref
|
elif key == "anyOf" and isinstance(value, list):
|
||||||
continue
|
result[key] = [
|
||||||
|
filter_schema_fields(item, valid_fields, processed) for item in value # type: ignore
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
result[key] = value
|
||||||
|
|
||||||
items = value.get("items", None)
|
return result
|
||||||
if items is not None:
|
|
||||||
ref_key = items.get("$ref", None)
|
|
||||||
if ref_key is not None:
|
|
||||||
ref = defs[ref_key.split("defs/")[-1]]
|
|
||||||
unpack_defs(ref, defs)
|
|
||||||
value["items"] = ref
|
|
||||||
continue
|
|
||||||
|
|
||||||
|
|
||||||
def convert_anyof_null_to_nullable(schema, depth=0):
|
def convert_anyof_null_to_nullable(schema, depth=0):
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import asyncio
|
||||||
from typing import Any, Coroutine, Optional, Union
|
from typing import Any, Coroutine, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -11,9 +12,9 @@ from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||||
from litellm.types.llms.openai import CreateFileRequest, OpenAIFileObject
|
from litellm.types.llms.openai import CreateFileRequest, OpenAIFileObject
|
||||||
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
||||||
|
|
||||||
from .transformation import VertexAIFilesTransformation
|
from .transformation import VertexAIJsonlFilesTransformation
|
||||||
|
|
||||||
vertex_ai_files_transformation = VertexAIFilesTransformation()
|
vertex_ai_files_transformation = VertexAIJsonlFilesTransformation()
|
||||||
|
|
||||||
|
|
||||||
class VertexAIFilesHandler(GCSBucketBase):
|
class VertexAIFilesHandler(GCSBucketBase):
|
||||||
|
@ -92,5 +93,15 @@ class VertexAIFilesHandler(GCSBucketBase):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
return None # type: ignore
|
return asyncio.run(
|
||||||
|
self.async_create_file(
|
||||||
|
create_file_data=create_file_data,
|
||||||
|
api_base=api_base,
|
||||||
|
vertex_credentials=vertex_credentials,
|
||||||
|
vertex_project=vertex_project,
|
||||||
|
vertex_location=vertex_location,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
|
@ -1,7 +1,17 @@
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from httpx import Headers, Response
|
||||||
|
|
||||||
|
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
|
||||||
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
|
from litellm.llms.base_llm.files.transformation import (
|
||||||
|
BaseFilesConfig,
|
||||||
|
LiteLLMLoggingObj,
|
||||||
|
)
|
||||||
from litellm.llms.vertex_ai.common_utils import (
|
from litellm.llms.vertex_ai.common_utils import (
|
||||||
_convert_vertex_datetime_to_openai_datetime,
|
_convert_vertex_datetime_to_openai_datetime,
|
||||||
)
|
)
|
||||||
|
@ -10,14 +20,317 @@ from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
||||||
VertexGeminiConfig,
|
VertexGeminiConfig,
|
||||||
)
|
)
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
|
AllMessageValues,
|
||||||
CreateFileRequest,
|
CreateFileRequest,
|
||||||
FileTypes,
|
FileTypes,
|
||||||
|
OpenAICreateFileRequestOptionalParams,
|
||||||
OpenAIFileObject,
|
OpenAIFileObject,
|
||||||
PathLike,
|
PathLike,
|
||||||
)
|
)
|
||||||
|
from litellm.types.llms.vertex_ai import GcsBucketResponse
|
||||||
|
from litellm.types.utils import ExtractedFileData, LlmProviders
|
||||||
|
|
||||||
|
from ..common_utils import VertexAIError
|
||||||
|
from ..vertex_llm_base import VertexBase
|
||||||
|
|
||||||
|
|
||||||
class VertexAIFilesTransformation(VertexGeminiConfig):
|
class VertexAIFilesConfig(VertexBase, BaseFilesConfig):
|
||||||
|
"""
|
||||||
|
Config for VertexAI Files
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.jsonl_transformation = VertexAIJsonlFilesTransformation()
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def custom_llm_provider(self) -> LlmProviders:
|
||||||
|
return LlmProviders.VERTEX_AI
|
||||||
|
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
headers: dict,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
if not api_key:
|
||||||
|
api_key, _ = self.get_access_token(
|
||||||
|
credentials=litellm_params.get("vertex_credentials"),
|
||||||
|
project_id=litellm_params.get("vertex_project"),
|
||||||
|
)
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError("api_key is required")
|
||||||
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def _get_content_from_openai_file(self, openai_file_content: FileTypes) -> str:
|
||||||
|
"""
|
||||||
|
Helper to extract content from various OpenAI file types and return as string.
|
||||||
|
|
||||||
|
Handles:
|
||||||
|
- Direct content (str, bytes, IO[bytes])
|
||||||
|
- Tuple formats: (filename, content, [content_type], [headers])
|
||||||
|
- PathLike objects
|
||||||
|
"""
|
||||||
|
content: Union[str, bytes] = b""
|
||||||
|
# Extract file content from tuple if necessary
|
||||||
|
if isinstance(openai_file_content, tuple):
|
||||||
|
# Take the second element which is always the file content
|
||||||
|
file_content = openai_file_content[1]
|
||||||
|
else:
|
||||||
|
file_content = openai_file_content
|
||||||
|
|
||||||
|
# Handle different file content types
|
||||||
|
if isinstance(file_content, str):
|
||||||
|
# String content can be used directly
|
||||||
|
content = file_content
|
||||||
|
elif isinstance(file_content, bytes):
|
||||||
|
# Bytes content can be decoded
|
||||||
|
content = file_content
|
||||||
|
elif isinstance(file_content, PathLike): # PathLike
|
||||||
|
with open(str(file_content), "rb") as f:
|
||||||
|
content = f.read()
|
||||||
|
elif hasattr(file_content, "read"): # IO[bytes]
|
||||||
|
# File-like objects need to be read
|
||||||
|
content = file_content.read()
|
||||||
|
|
||||||
|
# Ensure content is string
|
||||||
|
if isinstance(content, bytes):
|
||||||
|
content = content.decode("utf-8")
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
|
def _get_gcs_object_name_from_batch_jsonl(
|
||||||
|
self,
|
||||||
|
openai_jsonl_content: List[Dict[str, Any]],
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Gets a unique GCS object name for the VertexAI batch prediction job
|
||||||
|
|
||||||
|
named as: litellm-vertex-{model}-{uuid}
|
||||||
|
"""
|
||||||
|
_model = openai_jsonl_content[0].get("body", {}).get("model", "")
|
||||||
|
if "publishers/google/models" not in _model:
|
||||||
|
_model = f"publishers/google/models/{_model}"
|
||||||
|
object_name = f"litellm-vertex-files/{_model}/{uuid.uuid4()}"
|
||||||
|
return object_name
|
||||||
|
|
||||||
|
def get_object_name(
|
||||||
|
self, extracted_file_data: ExtractedFileData, purpose: str
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Get the object name for the request
|
||||||
|
"""
|
||||||
|
extracted_file_data_content = extracted_file_data.get("content")
|
||||||
|
|
||||||
|
if extracted_file_data_content is None:
|
||||||
|
raise ValueError("file content is required")
|
||||||
|
|
||||||
|
if purpose == "batch":
|
||||||
|
## 1. If jsonl, check if there's a model name
|
||||||
|
file_content = self._get_content_from_openai_file(
|
||||||
|
extracted_file_data_content
|
||||||
|
)
|
||||||
|
|
||||||
|
# Split into lines and parse each line as JSON
|
||||||
|
openai_jsonl_content = [
|
||||||
|
json.loads(line) for line in file_content.splitlines() if line.strip()
|
||||||
|
]
|
||||||
|
if len(openai_jsonl_content) > 0:
|
||||||
|
return self._get_gcs_object_name_from_batch_jsonl(openai_jsonl_content)
|
||||||
|
|
||||||
|
## 2. If not jsonl, return the filename
|
||||||
|
filename = extracted_file_data.get("filename")
|
||||||
|
if filename:
|
||||||
|
return filename
|
||||||
|
## 3. If no file name, return timestamp
|
||||||
|
return str(int(time.time()))
|
||||||
|
|
||||||
|
def get_complete_file_url(
|
||||||
|
self,
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_key: Optional[str],
|
||||||
|
model: str,
|
||||||
|
optional_params: Dict,
|
||||||
|
litellm_params: Dict,
|
||||||
|
data: CreateFileRequest,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Get the complete url for the request
|
||||||
|
"""
|
||||||
|
bucket_name = litellm_params.get("bucket_name") or os.getenv("GCS_BUCKET_NAME")
|
||||||
|
if not bucket_name:
|
||||||
|
raise ValueError("GCS bucket_name is required")
|
||||||
|
file_data = data.get("file")
|
||||||
|
purpose = data.get("purpose")
|
||||||
|
if file_data is None:
|
||||||
|
raise ValueError("file is required")
|
||||||
|
if purpose is None:
|
||||||
|
raise ValueError("purpose is required")
|
||||||
|
extracted_file_data = extract_file_data(file_data)
|
||||||
|
object_name = self.get_object_name(extracted_file_data, purpose)
|
||||||
|
endpoint = (
|
||||||
|
f"upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}"
|
||||||
|
)
|
||||||
|
api_base = api_base or "https://storage.googleapis.com"
|
||||||
|
if not api_base:
|
||||||
|
raise ValueError("api_base is required")
|
||||||
|
|
||||||
|
return f"{api_base}/{endpoint}"
|
||||||
|
|
||||||
|
def get_supported_openai_params(
|
||||||
|
self, model: str
|
||||||
|
) -> List[OpenAICreateFileRequestOptionalParams]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
) -> dict:
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def _map_openai_to_vertex_params(
|
||||||
|
self,
|
||||||
|
openai_request_body: Dict[str, Any],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
wrapper to call VertexGeminiConfig.map_openai_params
|
||||||
|
"""
|
||||||
|
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
||||||
|
VertexGeminiConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
config = VertexGeminiConfig()
|
||||||
|
_model = openai_request_body.get("model", "")
|
||||||
|
vertex_params = config.map_openai_params(
|
||||||
|
model=_model,
|
||||||
|
non_default_params=openai_request_body,
|
||||||
|
optional_params={},
|
||||||
|
drop_params=False,
|
||||||
|
)
|
||||||
|
return vertex_params
|
||||||
|
|
||||||
|
def _transform_openai_jsonl_content_to_vertex_ai_jsonl_content(
|
||||||
|
self, openai_jsonl_content: List[Dict[str, Any]]
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Transforms OpenAI JSONL content to VertexAI JSONL content
|
||||||
|
|
||||||
|
jsonl body for vertex is {"request": <request_body>}
|
||||||
|
Example Vertex jsonl
|
||||||
|
{"request":{"contents": [{"role": "user", "parts": [{"text": "What is the relation between the following video and image samples?"}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/animals.mp4", "mimeType": "video/mp4"}}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/image/cricket.jpeg", "mimeType": "image/jpeg"}}]}]}}
|
||||||
|
{"request":{"contents": [{"role": "user", "parts": [{"text": "Describe what is happening in this video."}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/another_video.mov", "mimeType": "video/mov"}}]}]}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
vertex_jsonl_content = []
|
||||||
|
for _openai_jsonl_content in openai_jsonl_content:
|
||||||
|
openai_request_body = _openai_jsonl_content.get("body") or {}
|
||||||
|
vertex_request_body = _transform_request_body(
|
||||||
|
messages=openai_request_body.get("messages", []),
|
||||||
|
model=openai_request_body.get("model", ""),
|
||||||
|
optional_params=self._map_openai_to_vertex_params(openai_request_body),
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
|
litellm_params={},
|
||||||
|
cached_content=None,
|
||||||
|
)
|
||||||
|
vertex_jsonl_content.append({"request": vertex_request_body})
|
||||||
|
return vertex_jsonl_content
|
||||||
|
|
||||||
|
def transform_create_file_request(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
create_file_data: CreateFileRequest,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
) -> Union[bytes, str, dict]:
|
||||||
|
"""
|
||||||
|
2 Cases:
|
||||||
|
1. Handle basic file upload
|
||||||
|
2. Handle batch file upload (.jsonl)
|
||||||
|
"""
|
||||||
|
file_data = create_file_data.get("file")
|
||||||
|
if file_data is None:
|
||||||
|
raise ValueError("file is required")
|
||||||
|
extracted_file_data = extract_file_data(file_data)
|
||||||
|
extracted_file_data_content = extracted_file_data.get("content")
|
||||||
|
if (
|
||||||
|
create_file_data.get("purpose") == "batch"
|
||||||
|
and extracted_file_data.get("content_type") == "application/jsonl"
|
||||||
|
and extracted_file_data_content is not None
|
||||||
|
):
|
||||||
|
## 1. If jsonl, check if there's a model name
|
||||||
|
file_content = self._get_content_from_openai_file(
|
||||||
|
extracted_file_data_content
|
||||||
|
)
|
||||||
|
|
||||||
|
# Split into lines and parse each line as JSON
|
||||||
|
openai_jsonl_content = [
|
||||||
|
json.loads(line) for line in file_content.splitlines() if line.strip()
|
||||||
|
]
|
||||||
|
vertex_jsonl_content = (
|
||||||
|
self._transform_openai_jsonl_content_to_vertex_ai_jsonl_content(
|
||||||
|
openai_jsonl_content
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return json.dumps(vertex_jsonl_content)
|
||||||
|
elif isinstance(extracted_file_data_content, bytes):
|
||||||
|
return extracted_file_data_content
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported file content type")
|
||||||
|
|
||||||
|
def transform_create_file_response(
|
||||||
|
self,
|
||||||
|
model: Optional[str],
|
||||||
|
raw_response: Response,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
litellm_params: dict,
|
||||||
|
) -> OpenAIFileObject:
|
||||||
|
"""
|
||||||
|
Transform VertexAI File upload response into OpenAI-style FileObject
|
||||||
|
"""
|
||||||
|
response_json = raw_response.json()
|
||||||
|
|
||||||
|
try:
|
||||||
|
response_object = GcsBucketResponse(**response_json) # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
raise VertexAIError(
|
||||||
|
status_code=raw_response.status_code,
|
||||||
|
message=f"Error reading GCS bucket response: {e}",
|
||||||
|
headers=raw_response.headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
gcs_id = response_object.get("id", "")
|
||||||
|
# Remove the last numeric ID from the path
|
||||||
|
gcs_id = "/".join(gcs_id.split("/")[:-1]) if gcs_id else ""
|
||||||
|
|
||||||
|
return OpenAIFileObject(
|
||||||
|
purpose=response_object.get("purpose", "batch"),
|
||||||
|
id=f"gs://{gcs_id}",
|
||||||
|
filename=response_object.get("name", ""),
|
||||||
|
created_at=_convert_vertex_datetime_to_openai_datetime(
|
||||||
|
vertex_datetime=response_object.get("timeCreated", "")
|
||||||
|
),
|
||||||
|
status="uploaded",
|
||||||
|
bytes=int(response_object.get("size", 0)),
|
||||||
|
object="file",
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_error_class(
|
||||||
|
self, error_message: str, status_code: int, headers: Union[Dict, Headers]
|
||||||
|
) -> BaseLLMException:
|
||||||
|
return VertexAIError(
|
||||||
|
status_code=status_code, message=error_message, headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class VertexAIJsonlFilesTransformation(VertexGeminiConfig):
|
||||||
"""
|
"""
|
||||||
Transforms OpenAI /v1/files/* requests to VertexAI /v1/files/* requests
|
Transforms OpenAI /v1/files/* requests to VertexAI /v1/files/* requests
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -208,25 +208,24 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915
|
||||||
elif element["type"] == "input_audio":
|
elif element["type"] == "input_audio":
|
||||||
audio_element = cast(ChatCompletionAudioObject, element)
|
audio_element = cast(ChatCompletionAudioObject, element)
|
||||||
if audio_element["input_audio"].get("data") is not None:
|
if audio_element["input_audio"].get("data") is not None:
|
||||||
_part = PartType(
|
_part = _process_gemini_image(
|
||||||
inline_data=BlobType(
|
image_url=audio_element["input_audio"]["data"],
|
||||||
data=audio_element["input_audio"]["data"],
|
format=audio_element["input_audio"].get("format"),
|
||||||
mime_type="audio/{}".format(
|
|
||||||
audio_element["input_audio"]["format"]
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
_parts.append(_part)
|
_parts.append(_part)
|
||||||
elif element["type"] == "file":
|
elif element["type"] == "file":
|
||||||
file_element = cast(ChatCompletionFileObject, element)
|
file_element = cast(ChatCompletionFileObject, element)
|
||||||
file_id = file_element["file"].get("file_id")
|
file_id = file_element["file"].get("file_id")
|
||||||
format = file_element["file"].get("format")
|
format = file_element["file"].get("format")
|
||||||
|
file_data = file_element["file"].get("file_data")
|
||||||
if not file_id:
|
passed_file = file_id or file_data
|
||||||
continue
|
if passed_file is None:
|
||||||
|
raise Exception(
|
||||||
|
"Unknown file type. Please pass in a file_id or file_data"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
_part = _process_gemini_image(
|
_part = _process_gemini_image(
|
||||||
image_url=file_id, format=format
|
image_url=passed_file, format=format
|
||||||
)
|
)
|
||||||
_parts.append(_part)
|
_parts.append(_part)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
|
@ -240,6 +240,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
||||||
gtool_func_declarations = []
|
gtool_func_declarations = []
|
||||||
googleSearch: Optional[dict] = None
|
googleSearch: Optional[dict] = None
|
||||||
googleSearchRetrieval: Optional[dict] = None
|
googleSearchRetrieval: Optional[dict] = None
|
||||||
|
enterpriseWebSearch: Optional[dict] = None
|
||||||
code_execution: Optional[dict] = None
|
code_execution: Optional[dict] = None
|
||||||
# remove 'additionalProperties' from tools
|
# remove 'additionalProperties' from tools
|
||||||
value = _remove_additional_properties(value)
|
value = _remove_additional_properties(value)
|
||||||
|
@ -273,6 +274,8 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
||||||
googleSearch = tool["googleSearch"]
|
googleSearch = tool["googleSearch"]
|
||||||
elif tool.get("googleSearchRetrieval", None) is not None:
|
elif tool.get("googleSearchRetrieval", None) is not None:
|
||||||
googleSearchRetrieval = tool["googleSearchRetrieval"]
|
googleSearchRetrieval = tool["googleSearchRetrieval"]
|
||||||
|
elif tool.get("enterpriseWebSearch", None) is not None:
|
||||||
|
enterpriseWebSearch = tool["enterpriseWebSearch"]
|
||||||
elif tool.get("code_execution", None) is not None:
|
elif tool.get("code_execution", None) is not None:
|
||||||
code_execution = tool["code_execution"]
|
code_execution = tool["code_execution"]
|
||||||
elif openai_function_object is not None:
|
elif openai_function_object is not None:
|
||||||
|
@ -299,6 +302,8 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
||||||
_tools["googleSearch"] = googleSearch
|
_tools["googleSearch"] = googleSearch
|
||||||
if googleSearchRetrieval is not None:
|
if googleSearchRetrieval is not None:
|
||||||
_tools["googleSearchRetrieval"] = googleSearchRetrieval
|
_tools["googleSearchRetrieval"] = googleSearchRetrieval
|
||||||
|
if enterpriseWebSearch is not None:
|
||||||
|
_tools["enterpriseWebSearch"] = enterpriseWebSearch
|
||||||
if code_execution is not None:
|
if code_execution is not None:
|
||||||
_tools["code_execution"] = code_execution
|
_tools["code_execution"] = code_execution
|
||||||
return [_tools]
|
return [_tools]
|
||||||
|
@ -374,7 +379,11 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
||||||
optional_params["responseLogprobs"] = value
|
optional_params["responseLogprobs"] = value
|
||||||
elif param == "top_logprobs":
|
elif param == "top_logprobs":
|
||||||
optional_params["logprobs"] = value
|
optional_params["logprobs"] = value
|
||||||
elif (param == "tools" or param == "functions") and isinstance(value, list):
|
elif (
|
||||||
|
(param == "tools" or param == "functions")
|
||||||
|
and isinstance(value, list)
|
||||||
|
and value
|
||||||
|
):
|
||||||
optional_params["tools"] = self._map_function(value=value)
|
optional_params["tools"] = self._map_function(value=value)
|
||||||
optional_params["litellm_param_is_function_call"] = (
|
optional_params["litellm_param_is_function_call"] = (
|
||||||
True if param == "functions" else False
|
True if param == "functions" else False
|
||||||
|
@ -739,9 +748,6 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
||||||
chat_completion_logprobs = self._transform_logprobs(
|
chat_completion_logprobs = self._transform_logprobs(
|
||||||
logprobs_result=candidate["logprobsResult"]
|
logprobs_result=candidate["logprobsResult"]
|
||||||
)
|
)
|
||||||
# Handle avgLogprobs for Gemini Flash 2.0
|
|
||||||
elif "avgLogprobs" in candidate:
|
|
||||||
chat_completion_logprobs = candidate["avgLogprobs"]
|
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
chat_completion_message["tool_calls"] = tools
|
chat_completion_message["tool_calls"] = tools
|
||||||
|
@ -896,6 +902,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: Dict,
|
optional_params: Dict,
|
||||||
|
litellm_params: Dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
|
@ -1013,7 +1020,7 @@ class VertexLLM(VertexBase):
|
||||||
logging_obj,
|
logging_obj,
|
||||||
stream,
|
stream,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
litellm_params=None,
|
litellm_params: dict,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
client: Optional[AsyncHTTPHandler] = None,
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
@ -1054,6 +1061,7 @@ class VertexLLM(VertexBase):
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
@ -1140,6 +1148,7 @@ class VertexLLM(VertexBase):
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
request_body = await async_transform_request_body(**data) # type: ignore
|
request_body = await async_transform_request_body(**data) # type: ignore
|
||||||
|
@ -1313,6 +1322,7 @@ class VertexLLM(VertexBase):
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
## TRANSFORMATION ##
|
## TRANSFORMATION ##
|
||||||
|
|
|
@ -94,6 +94,7 @@ class VertexMultimodalEmbedding(VertexLLM):
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
api_key=auth_header,
|
api_key=auth_header,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
|
|
@ -47,6 +47,7 @@ class VertexAIMultimodalEmbeddingConfig(BaseEmbeddingConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -2,9 +2,10 @@ import types
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
|
|
||||||
|
|
||||||
class VertexAIAi21Config:
|
class VertexAIAi21Config(OpenAIGPTConfig):
|
||||||
"""
|
"""
|
||||||
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/ai21
|
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/ai21
|
||||||
|
|
||||||
|
@ -40,9 +41,6 @@ class VertexAIAi21Config:
|
||||||
and v is not None
|
and v is not None
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_supported_openai_params(self):
|
|
||||||
return litellm.OpenAIConfig().get_supported_openai_params(model="gpt-3.5-turbo")
|
|
||||||
|
|
||||||
def map_openai_params(
|
def map_openai_params(
|
||||||
self,
|
self,
|
||||||
non_default_params: dict,
|
non_default_params: dict,
|
||||||
|
|
|
@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple
|
||||||
|
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.litellm_core_utils.asyncify import asyncify
|
from litellm.litellm_core_utils.asyncify import asyncify
|
||||||
from litellm.llms.base import BaseLLM
|
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
||||||
|
|
||||||
|
@ -22,7 +21,7 @@ else:
|
||||||
GoogleCredentialsObject = Any
|
GoogleCredentialsObject = Any
|
||||||
|
|
||||||
|
|
||||||
class VertexBase(BaseLLM):
|
class VertexBase:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.access_token: Optional[str] = None
|
self.access_token: Optional[str] = None
|
||||||
|
|
|
@ -83,6 +83,7 @@ class VoyageEmbeddingConfig(BaseEmbeddingConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -49,6 +49,7 @@ class WatsonXChatHandler(OpenAILikeChatHandler):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
## UPDATE PAYLOAD (optional params)
|
## UPDATE PAYLOAD (optional params)
|
||||||
|
|
|
@ -165,6 +165,7 @@ class IBMWatsonXMixin:
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: Dict,
|
optional_params: Dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
|
|
|
@ -3616,6 +3616,7 @@ def embedding( # noqa: PLR0915
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
client=client,
|
client=client,
|
||||||
aembedding=aembedding,
|
aembedding=aembedding,
|
||||||
|
litellm_params=litellm_params_dict,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "bedrock":
|
elif custom_llm_provider == "bedrock":
|
||||||
if isinstance(input, str):
|
if isinstance(input, str):
|
||||||
|
|
|
@ -380,6 +380,7 @@
|
||||||
"supports_tool_choice": true,
|
"supports_tool_choice": true,
|
||||||
"supports_native_streaming": false,
|
"supports_native_streaming": false,
|
||||||
"supported_modalities": ["text", "image"],
|
"supported_modalities": ["text", "image"],
|
||||||
|
"supported_output_modalities": ["text"],
|
||||||
"supported_endpoints": ["/v1/responses", "/v1/batch"]
|
"supported_endpoints": ["/v1/responses", "/v1/batch"]
|
||||||
},
|
},
|
||||||
"o1-pro-2025-03-19": {
|
"o1-pro-2025-03-19": {
|
||||||
|
@ -401,6 +402,7 @@
|
||||||
"supports_tool_choice": true,
|
"supports_tool_choice": true,
|
||||||
"supports_native_streaming": false,
|
"supports_native_streaming": false,
|
||||||
"supported_modalities": ["text", "image"],
|
"supported_modalities": ["text", "image"],
|
||||||
|
"supported_output_modalities": ["text"],
|
||||||
"supported_endpoints": ["/v1/responses", "/v1/batch"]
|
"supported_endpoints": ["/v1/responses", "/v1/batch"]
|
||||||
},
|
},
|
||||||
"o1": {
|
"o1": {
|
||||||
|
@ -1286,6 +1288,68 @@
|
||||||
"supports_system_messages": true,
|
"supports_system_messages": true,
|
||||||
"supports_tool_choice": true
|
"supports_tool_choice": true
|
||||||
},
|
},
|
||||||
|
"azure/gpt-4o-realtime-preview-2024-12-17": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"max_input_tokens": 128000,
|
||||||
|
"max_output_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.000005,
|
||||||
|
"input_cost_per_audio_token": 0.00004,
|
||||||
|
"cache_read_input_token_cost": 0.0000025,
|
||||||
|
"output_cost_per_token": 0.00002,
|
||||||
|
"output_cost_per_audio_token": 0.00008,
|
||||||
|
"litellm_provider": "azure",
|
||||||
|
"mode": "chat",
|
||||||
|
"supported_modalities": ["text", "audio"],
|
||||||
|
"supported_output_modalities": ["text", "audio"],
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"supports_parallel_function_calling": true,
|
||||||
|
"supports_audio_input": true,
|
||||||
|
"supports_audio_output": true,
|
||||||
|
"supports_system_messages": true,
|
||||||
|
"supports_tool_choice": true
|
||||||
|
},
|
||||||
|
"azure/us/gpt-4o-realtime-preview-2024-12-17": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"max_input_tokens": 128000,
|
||||||
|
"max_output_tokens": 4096,
|
||||||
|
"input_cost_per_token": 5.5e-6,
|
||||||
|
"input_cost_per_audio_token": 44e-6,
|
||||||
|
"cache_read_input_token_cost": 2.75e-6,
|
||||||
|
"cache_read_input_audio_token_cost": 2.5e-6,
|
||||||
|
"output_cost_per_token": 22e-6,
|
||||||
|
"output_cost_per_audio_token": 80e-6,
|
||||||
|
"litellm_provider": "azure",
|
||||||
|
"mode": "chat",
|
||||||
|
"supported_modalities": ["text", "audio"],
|
||||||
|
"supported_output_modalities": ["text", "audio"],
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"supports_parallel_function_calling": true,
|
||||||
|
"supports_audio_input": true,
|
||||||
|
"supports_audio_output": true,
|
||||||
|
"supports_system_messages": true,
|
||||||
|
"supports_tool_choice": true
|
||||||
|
},
|
||||||
|
"azure/eu/gpt-4o-realtime-preview-2024-12-17": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"max_input_tokens": 128000,
|
||||||
|
"max_output_tokens": 4096,
|
||||||
|
"input_cost_per_token": 5.5e-6,
|
||||||
|
"input_cost_per_audio_token": 44e-6,
|
||||||
|
"cache_read_input_token_cost": 2.75e-6,
|
||||||
|
"cache_read_input_audio_token_cost": 2.5e-6,
|
||||||
|
"output_cost_per_token": 22e-6,
|
||||||
|
"output_cost_per_audio_token": 80e-6,
|
||||||
|
"litellm_provider": "azure",
|
||||||
|
"mode": "chat",
|
||||||
|
"supported_modalities": ["text", "audio"],
|
||||||
|
"supported_output_modalities": ["text", "audio"],
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"supports_parallel_function_calling": true,
|
||||||
|
"supports_audio_input": true,
|
||||||
|
"supports_audio_output": true,
|
||||||
|
"supports_system_messages": true,
|
||||||
|
"supports_tool_choice": true
|
||||||
|
},
|
||||||
"azure/gpt-4o-realtime-preview-2024-10-01": {
|
"azure/gpt-4o-realtime-preview-2024-10-01": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 128000,
|
"max_input_tokens": 128000,
|
||||||
|
@ -2300,6 +2364,18 @@
|
||||||
"source": "https://azuremarketplace.microsoft.com/en/marketplace/apps/000-000.mistral-ai-large-2407-offer?tab=Overview",
|
"source": "https://azuremarketplace.microsoft.com/en/marketplace/apps/000-000.mistral-ai-large-2407-offer?tab=Overview",
|
||||||
"supports_tool_choice": true
|
"supports_tool_choice": true
|
||||||
},
|
},
|
||||||
|
"azure_ai/mistral-large-latest": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"max_input_tokens": 128000,
|
||||||
|
"max_output_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.000002,
|
||||||
|
"output_cost_per_token": 0.000006,
|
||||||
|
"litellm_provider": "azure_ai",
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"mode": "chat",
|
||||||
|
"source": "https://azuremarketplace.microsoft.com/en/marketplace/apps/000-000.mistral-ai-large-2407-offer?tab=Overview",
|
||||||
|
"supports_tool_choice": true
|
||||||
|
},
|
||||||
"azure_ai/ministral-3b": {
|
"azure_ai/ministral-3b": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 128000,
|
"max_input_tokens": 128000,
|
||||||
|
@ -2397,25 +2473,26 @@
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 131072,
|
"max_input_tokens": 131072,
|
||||||
"max_output_tokens": 4096,
|
"max_output_tokens": 4096,
|
||||||
"input_cost_per_token": 0,
|
"input_cost_per_token": 0.000000075,
|
||||||
"output_cost_per_token": 0,
|
"output_cost_per_token": 0.0000003,
|
||||||
"litellm_provider": "azure_ai",
|
"litellm_provider": "azure_ai",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"source": "https://learn.microsoft.com/en-us/azure/ai-foundry/concepts/models-featured#microsoft"
|
"source": "https://techcommunity.microsoft.com/blog/Azure-AI-Services-blog/announcing-new-phi-pricing-empowering-your-business-with-small-language-models/4395112"
|
||||||
},
|
},
|
||||||
"azure_ai/Phi-4-multimodal-instruct": {
|
"azure_ai/Phi-4-multimodal-instruct": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 131072,
|
"max_input_tokens": 131072,
|
||||||
"max_output_tokens": 4096,
|
"max_output_tokens": 4096,
|
||||||
"input_cost_per_token": 0,
|
"input_cost_per_token": 0.00000008,
|
||||||
"output_cost_per_token": 0,
|
"input_cost_per_audio_token": 0.000004,
|
||||||
|
"output_cost_per_token": 0.00032,
|
||||||
"litellm_provider": "azure_ai",
|
"litellm_provider": "azure_ai",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_audio_input": true,
|
"supports_audio_input": true,
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"supports_vision": true,
|
"supports_vision": true,
|
||||||
"source": "https://learn.microsoft.com/en-us/azure/ai-foundry/concepts/models-featured#microsoft"
|
"source": "https://techcommunity.microsoft.com/blog/Azure-AI-Services-blog/announcing-new-phi-pricing-empowering-your-business-with-small-language-models/4395112"
|
||||||
},
|
},
|
||||||
"azure_ai/Phi-4": {
|
"azure_ai/Phi-4": {
|
||||||
"max_tokens": 16384,
|
"max_tokens": 16384,
|
||||||
|
@ -3455,7 +3532,7 @@
|
||||||
"input_cost_per_token": 0.0000008,
|
"input_cost_per_token": 0.0000008,
|
||||||
"output_cost_per_token": 0.000004,
|
"output_cost_per_token": 0.000004,
|
||||||
"cache_creation_input_token_cost": 0.000001,
|
"cache_creation_input_token_cost": 0.000001,
|
||||||
"cache_read_input_token_cost": 0.0000008,
|
"cache_read_input_token_cost": 0.00000008,
|
||||||
"litellm_provider": "anthropic",
|
"litellm_provider": "anthropic",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
|
@ -4499,20 +4576,10 @@
|
||||||
"max_audio_length_hours": 8.4,
|
"max_audio_length_hours": 8.4,
|
||||||
"max_audio_per_prompt": 1,
|
"max_audio_per_prompt": 1,
|
||||||
"max_pdf_size_mb": 30,
|
"max_pdf_size_mb": 30,
|
||||||
"input_cost_per_image": 0,
|
"input_cost_per_token": 0.00000125,
|
||||||
"input_cost_per_video_per_second": 0,
|
"input_cost_per_token_above_200k_tokens": 0.0000025,
|
||||||
"input_cost_per_audio_per_second": 0,
|
"output_cost_per_token": 0.00001,
|
||||||
"input_cost_per_token": 0,
|
"output_cost_per_token_above_200k_tokens": 0.000015,
|
||||||
"input_cost_per_character": 0,
|
|
||||||
"input_cost_per_token_above_128k_tokens": 0,
|
|
||||||
"input_cost_per_character_above_128k_tokens": 0,
|
|
||||||
"input_cost_per_image_above_128k_tokens": 0,
|
|
||||||
"input_cost_per_video_per_second_above_128k_tokens": 0,
|
|
||||||
"input_cost_per_audio_per_second_above_128k_tokens": 0,
|
|
||||||
"output_cost_per_token": 0,
|
|
||||||
"output_cost_per_character": 0,
|
|
||||||
"output_cost_per_token_above_128k_tokens": 0,
|
|
||||||
"output_cost_per_character_above_128k_tokens": 0,
|
|
||||||
"litellm_provider": "vertex_ai-language-models",
|
"litellm_provider": "vertex_ai-language-models",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_system_messages": true,
|
"supports_system_messages": true,
|
||||||
|
@ -4523,6 +4590,9 @@
|
||||||
"supports_pdf_input": true,
|
"supports_pdf_input": true,
|
||||||
"supports_response_schema": true,
|
"supports_response_schema": true,
|
||||||
"supports_tool_choice": true,
|
"supports_tool_choice": true,
|
||||||
|
"supported_endpoints": ["/v1/chat/completions", "/v1/completions"],
|
||||||
|
"supported_modalities": ["text", "image", "audio", "video"],
|
||||||
|
"supported_output_modalities": ["text"],
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
|
||||||
},
|
},
|
||||||
"gemini-2.0-pro-exp-02-05": {
|
"gemini-2.0-pro-exp-02-05": {
|
||||||
|
@ -4535,20 +4605,10 @@
|
||||||
"max_audio_length_hours": 8.4,
|
"max_audio_length_hours": 8.4,
|
||||||
"max_audio_per_prompt": 1,
|
"max_audio_per_prompt": 1,
|
||||||
"max_pdf_size_mb": 30,
|
"max_pdf_size_mb": 30,
|
||||||
"input_cost_per_image": 0,
|
"input_cost_per_token": 0.00000125,
|
||||||
"input_cost_per_video_per_second": 0,
|
"input_cost_per_token_above_200k_tokens": 0.0000025,
|
||||||
"input_cost_per_audio_per_second": 0,
|
"output_cost_per_token": 0.00001,
|
||||||
"input_cost_per_token": 0,
|
"output_cost_per_token_above_200k_tokens": 0.000015,
|
||||||
"input_cost_per_character": 0,
|
|
||||||
"input_cost_per_token_above_128k_tokens": 0,
|
|
||||||
"input_cost_per_character_above_128k_tokens": 0,
|
|
||||||
"input_cost_per_image_above_128k_tokens": 0,
|
|
||||||
"input_cost_per_video_per_second_above_128k_tokens": 0,
|
|
||||||
"input_cost_per_audio_per_second_above_128k_tokens": 0,
|
|
||||||
"output_cost_per_token": 0,
|
|
||||||
"output_cost_per_character": 0,
|
|
||||||
"output_cost_per_token_above_128k_tokens": 0,
|
|
||||||
"output_cost_per_character_above_128k_tokens": 0,
|
|
||||||
"litellm_provider": "vertex_ai-language-models",
|
"litellm_provider": "vertex_ai-language-models",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_system_messages": true,
|
"supports_system_messages": true,
|
||||||
|
@ -4559,6 +4619,9 @@
|
||||||
"supports_pdf_input": true,
|
"supports_pdf_input": true,
|
||||||
"supports_response_schema": true,
|
"supports_response_schema": true,
|
||||||
"supports_tool_choice": true,
|
"supports_tool_choice": true,
|
||||||
|
"supported_endpoints": ["/v1/chat/completions", "/v1/completions"],
|
||||||
|
"supported_modalities": ["text", "image", "audio", "video"],
|
||||||
|
"supported_output_modalities": ["text"],
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
|
||||||
},
|
},
|
||||||
"gemini-2.0-flash-exp": {
|
"gemini-2.0-flash-exp": {
|
||||||
|
@ -4592,6 +4655,8 @@
|
||||||
"supports_vision": true,
|
"supports_vision": true,
|
||||||
"supports_response_schema": true,
|
"supports_response_schema": true,
|
||||||
"supports_audio_output": true,
|
"supports_audio_output": true,
|
||||||
|
"supported_modalities": ["text", "image", "audio", "video"],
|
||||||
|
"supported_output_modalities": ["text", "image"],
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing",
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing",
|
||||||
"supports_tool_choice": true
|
"supports_tool_choice": true
|
||||||
},
|
},
|
||||||
|
@ -4616,6 +4681,8 @@
|
||||||
"supports_response_schema": true,
|
"supports_response_schema": true,
|
||||||
"supports_audio_output": true,
|
"supports_audio_output": true,
|
||||||
"supports_tool_choice": true,
|
"supports_tool_choice": true,
|
||||||
|
"supported_modalities": ["text", "image", "audio", "video"],
|
||||||
|
"supported_output_modalities": ["text", "image"],
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
|
||||||
},
|
},
|
||||||
"gemini-2.0-flash-thinking-exp": {
|
"gemini-2.0-flash-thinking-exp": {
|
||||||
|
@ -4649,6 +4716,8 @@
|
||||||
"supports_vision": true,
|
"supports_vision": true,
|
||||||
"supports_response_schema": true,
|
"supports_response_schema": true,
|
||||||
"supports_audio_output": true,
|
"supports_audio_output": true,
|
||||||
|
"supported_modalities": ["text", "image", "audio", "video"],
|
||||||
|
"supported_output_modalities": ["text", "image"],
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash",
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash",
|
||||||
"supports_tool_choice": true
|
"supports_tool_choice": true
|
||||||
},
|
},
|
||||||
|
@ -4683,6 +4752,8 @@
|
||||||
"supports_vision": true,
|
"supports_vision": true,
|
||||||
"supports_response_schema": false,
|
"supports_response_schema": false,
|
||||||
"supports_audio_output": false,
|
"supports_audio_output": false,
|
||||||
|
"supported_modalities": ["text", "image", "audio", "video"],
|
||||||
|
"supported_output_modalities": ["text", "image"],
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash",
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash",
|
||||||
"supports_tool_choice": true
|
"supports_tool_choice": true
|
||||||
},
|
},
|
||||||
|
@ -4708,6 +4779,7 @@
|
||||||
"supports_audio_output": true,
|
"supports_audio_output": true,
|
||||||
"supports_audio_input": true,
|
"supports_audio_input": true,
|
||||||
"supported_modalities": ["text", "image", "audio", "video"],
|
"supported_modalities": ["text", "image", "audio", "video"],
|
||||||
|
"supported_output_modalities": ["text", "image"],
|
||||||
"supports_tool_choice": true,
|
"supports_tool_choice": true,
|
||||||
"source": "https://ai.google.dev/pricing#2_0flash"
|
"source": "https://ai.google.dev/pricing#2_0flash"
|
||||||
},
|
},
|
||||||
|
@ -4730,6 +4802,32 @@
|
||||||
"supports_vision": true,
|
"supports_vision": true,
|
||||||
"supports_response_schema": true,
|
"supports_response_schema": true,
|
||||||
"supports_audio_output": true,
|
"supports_audio_output": true,
|
||||||
|
"supported_modalities": ["text", "image", "audio", "video"],
|
||||||
|
"supported_output_modalities": ["text"],
|
||||||
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash",
|
||||||
|
"supports_tool_choice": true
|
||||||
|
},
|
||||||
|
"gemini-2.0-flash-lite-001": {
|
||||||
|
"max_input_tokens": 1048576,
|
||||||
|
"max_output_tokens": 8192,
|
||||||
|
"max_images_per_prompt": 3000,
|
||||||
|
"max_videos_per_prompt": 10,
|
||||||
|
"max_video_length": 1,
|
||||||
|
"max_audio_length_hours": 8.4,
|
||||||
|
"max_audio_per_prompt": 1,
|
||||||
|
"max_pdf_size_mb": 50,
|
||||||
|
"input_cost_per_audio_token": 0.000000075,
|
||||||
|
"input_cost_per_token": 0.000000075,
|
||||||
|
"output_cost_per_token": 0.0000003,
|
||||||
|
"litellm_provider": "vertex_ai-language-models",
|
||||||
|
"mode": "chat",
|
||||||
|
"supports_system_messages": true,
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"supports_vision": true,
|
||||||
|
"supports_response_schema": true,
|
||||||
|
"supports_audio_output": true,
|
||||||
|
"supported_modalities": ["text", "image", "audio", "video"],
|
||||||
|
"supported_output_modalities": ["text"],
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash",
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash",
|
||||||
"supports_tool_choice": true
|
"supports_tool_choice": true
|
||||||
},
|
},
|
||||||
|
@ -4795,6 +4893,7 @@
|
||||||
"supports_audio_output": true,
|
"supports_audio_output": true,
|
||||||
"supports_audio_input": true,
|
"supports_audio_input": true,
|
||||||
"supported_modalities": ["text", "image", "audio", "video"],
|
"supported_modalities": ["text", "image", "audio", "video"],
|
||||||
|
"supported_output_modalities": ["text", "image"],
|
||||||
"supports_tool_choice": true,
|
"supports_tool_choice": true,
|
||||||
"source": "https://ai.google.dev/pricing#2_0flash"
|
"source": "https://ai.google.dev/pricing#2_0flash"
|
||||||
},
|
},
|
||||||
|
@ -4820,6 +4919,8 @@
|
||||||
"supports_response_schema": true,
|
"supports_response_schema": true,
|
||||||
"supports_audio_output": true,
|
"supports_audio_output": true,
|
||||||
"supports_tool_choice": true,
|
"supports_tool_choice": true,
|
||||||
|
"supported_modalities": ["text", "image", "audio", "video"],
|
||||||
|
"supported_output_modalities": ["text"],
|
||||||
"source": "https://ai.google.dev/gemini-api/docs/pricing#gemini-2.0-flash-lite"
|
"source": "https://ai.google.dev/gemini-api/docs/pricing#gemini-2.0-flash-lite"
|
||||||
},
|
},
|
||||||
"gemini/gemini-2.0-flash-001": {
|
"gemini/gemini-2.0-flash-001": {
|
||||||
|
@ -4845,6 +4946,8 @@
|
||||||
"supports_response_schema": true,
|
"supports_response_schema": true,
|
||||||
"supports_audio_output": false,
|
"supports_audio_output": false,
|
||||||
"supports_tool_choice": true,
|
"supports_tool_choice": true,
|
||||||
|
"supported_modalities": ["text", "image", "audio", "video"],
|
||||||
|
"supported_output_modalities": ["text", "image"],
|
||||||
"source": "https://ai.google.dev/pricing#2_0flash"
|
"source": "https://ai.google.dev/pricing#2_0flash"
|
||||||
},
|
},
|
||||||
"gemini/gemini-2.5-pro-preview-03-25": {
|
"gemini/gemini-2.5-pro-preview-03-25": {
|
||||||
|
@ -4859,9 +4962,9 @@
|
||||||
"max_pdf_size_mb": 30,
|
"max_pdf_size_mb": 30,
|
||||||
"input_cost_per_audio_token": 0.0000007,
|
"input_cost_per_audio_token": 0.0000007,
|
||||||
"input_cost_per_token": 0.00000125,
|
"input_cost_per_token": 0.00000125,
|
||||||
"input_cost_per_token_above_128k_tokens": 0.0000025,
|
"input_cost_per_token_above_200k_tokens": 0.0000025,
|
||||||
"output_cost_per_token": 0.0000010,
|
"output_cost_per_token": 0.0000010,
|
||||||
"output_cost_per_token_above_128k_tokens": 0.000015,
|
"output_cost_per_token_above_200k_tokens": 0.000015,
|
||||||
"litellm_provider": "gemini",
|
"litellm_provider": "gemini",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"rpm": 10000,
|
"rpm": 10000,
|
||||||
|
@ -4872,6 +4975,8 @@
|
||||||
"supports_response_schema": true,
|
"supports_response_schema": true,
|
||||||
"supports_audio_output": false,
|
"supports_audio_output": false,
|
||||||
"supports_tool_choice": true,
|
"supports_tool_choice": true,
|
||||||
|
"supported_modalities": ["text", "image", "audio", "video"],
|
||||||
|
"supported_output_modalities": ["text"],
|
||||||
"source": "https://ai.google.dev/gemini-api/docs/pricing#gemini-2.5-pro-preview"
|
"source": "https://ai.google.dev/gemini-api/docs/pricing#gemini-2.5-pro-preview"
|
||||||
},
|
},
|
||||||
"gemini/gemini-2.0-flash-exp": {
|
"gemini/gemini-2.0-flash-exp": {
|
||||||
|
@ -4907,6 +5012,8 @@
|
||||||
"supports_audio_output": true,
|
"supports_audio_output": true,
|
||||||
"tpm": 4000000,
|
"tpm": 4000000,
|
||||||
"rpm": 10,
|
"rpm": 10,
|
||||||
|
"supported_modalities": ["text", "image", "audio", "video"],
|
||||||
|
"supported_output_modalities": ["text", "image"],
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash",
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash",
|
||||||
"supports_tool_choice": true
|
"supports_tool_choice": true
|
||||||
},
|
},
|
||||||
|
@ -4933,6 +5040,8 @@
|
||||||
"supports_response_schema": true,
|
"supports_response_schema": true,
|
||||||
"supports_audio_output": false,
|
"supports_audio_output": false,
|
||||||
"supports_tool_choice": true,
|
"supports_tool_choice": true,
|
||||||
|
"supported_modalities": ["text", "image", "audio", "video"],
|
||||||
|
"supported_output_modalities": ["text"],
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash-lite"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash-lite"
|
||||||
},
|
},
|
||||||
"gemini/gemini-2.0-flash-thinking-exp": {
|
"gemini/gemini-2.0-flash-thinking-exp": {
|
||||||
|
@ -4968,6 +5077,8 @@
|
||||||
"supports_audio_output": true,
|
"supports_audio_output": true,
|
||||||
"tpm": 4000000,
|
"tpm": 4000000,
|
||||||
"rpm": 10,
|
"rpm": 10,
|
||||||
|
"supported_modalities": ["text", "image", "audio", "video"],
|
||||||
|
"supported_output_modalities": ["text", "image"],
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash",
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash",
|
||||||
"supports_tool_choice": true
|
"supports_tool_choice": true
|
||||||
},
|
},
|
||||||
|
@ -5004,6 +5115,8 @@
|
||||||
"supports_audio_output": true,
|
"supports_audio_output": true,
|
||||||
"tpm": 4000000,
|
"tpm": 4000000,
|
||||||
"rpm": 10,
|
"rpm": 10,
|
||||||
|
"supported_modalities": ["text", "image", "audio", "video"],
|
||||||
|
"supported_output_modalities": ["text", "image"],
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash",
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-2.0-flash",
|
||||||
"supports_tool_choice": true
|
"supports_tool_choice": true
|
||||||
},
|
},
|
||||||
|
@ -8444,7 +8557,8 @@
|
||||||
"input_cost_per_token": 0.0000015,
|
"input_cost_per_token": 0.0000015,
|
||||||
"output_cost_per_token": 0.0000020,
|
"output_cost_per_token": 0.0000020,
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
"mode": "chat"
|
"mode": "chat",
|
||||||
|
"supports_tool_choice": true
|
||||||
},
|
},
|
||||||
"bedrock/*/1-month-commitment/cohere.command-text-v14": {
|
"bedrock/*/1-month-commitment/cohere.command-text-v14": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
|
@ -8453,7 +8567,8 @@
|
||||||
"input_cost_per_second": 0.011,
|
"input_cost_per_second": 0.011,
|
||||||
"output_cost_per_second": 0.011,
|
"output_cost_per_second": 0.011,
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
"mode": "chat"
|
"mode": "chat",
|
||||||
|
"supports_tool_choice": true
|
||||||
},
|
},
|
||||||
"bedrock/*/6-month-commitment/cohere.command-text-v14": {
|
"bedrock/*/6-month-commitment/cohere.command-text-v14": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
|
@ -8462,7 +8577,8 @@
|
||||||
"input_cost_per_second": 0.0066027,
|
"input_cost_per_second": 0.0066027,
|
||||||
"output_cost_per_second": 0.0066027,
|
"output_cost_per_second": 0.0066027,
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
"mode": "chat"
|
"mode": "chat",
|
||||||
|
"supports_tool_choice": true
|
||||||
},
|
},
|
||||||
"cohere.command-light-text-v14": {
|
"cohere.command-light-text-v14": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
|
@ -8471,7 +8587,8 @@
|
||||||
"input_cost_per_token": 0.0000003,
|
"input_cost_per_token": 0.0000003,
|
||||||
"output_cost_per_token": 0.0000006,
|
"output_cost_per_token": 0.0000006,
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
"mode": "chat"
|
"mode": "chat",
|
||||||
|
"supports_tool_choice": true
|
||||||
},
|
},
|
||||||
"bedrock/*/1-month-commitment/cohere.command-light-text-v14": {
|
"bedrock/*/1-month-commitment/cohere.command-light-text-v14": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
|
@ -8480,7 +8597,8 @@
|
||||||
"input_cost_per_second": 0.001902,
|
"input_cost_per_second": 0.001902,
|
||||||
"output_cost_per_second": 0.001902,
|
"output_cost_per_second": 0.001902,
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
"mode": "chat"
|
"mode": "chat",
|
||||||
|
"supports_tool_choice": true
|
||||||
},
|
},
|
||||||
"bedrock/*/6-month-commitment/cohere.command-light-text-v14": {
|
"bedrock/*/6-month-commitment/cohere.command-light-text-v14": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
|
@ -8489,7 +8607,8 @@
|
||||||
"input_cost_per_second": 0.0011416,
|
"input_cost_per_second": 0.0011416,
|
||||||
"output_cost_per_second": 0.0011416,
|
"output_cost_per_second": 0.0011416,
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
"mode": "chat"
|
"mode": "chat",
|
||||||
|
"supports_tool_choice": true
|
||||||
},
|
},
|
||||||
"cohere.command-r-plus-v1:0": {
|
"cohere.command-r-plus-v1:0": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
|
@ -8498,7 +8617,8 @@
|
||||||
"input_cost_per_token": 0.0000030,
|
"input_cost_per_token": 0.0000030,
|
||||||
"output_cost_per_token": 0.000015,
|
"output_cost_per_token": 0.000015,
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
"mode": "chat"
|
"mode": "chat",
|
||||||
|
"supports_tool_choice": true
|
||||||
},
|
},
|
||||||
"cohere.command-r-v1:0": {
|
"cohere.command-r-v1:0": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
|
@ -8507,7 +8627,8 @@
|
||||||
"input_cost_per_token": 0.0000005,
|
"input_cost_per_token": 0.0000005,
|
||||||
"output_cost_per_token": 0.0000015,
|
"output_cost_per_token": 0.0000015,
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
"mode": "chat"
|
"mode": "chat",
|
||||||
|
"supports_tool_choice": true
|
||||||
},
|
},
|
||||||
"cohere.embed-english-v3": {
|
"cohere.embed-english-v3": {
|
||||||
"max_tokens": 512,
|
"max_tokens": 512,
|
||||||
|
|